diff --git a/backend/app.py b/backend/app.py index c0530d1..753f350 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,10 +1,15 @@ from base64 import b64decode -import sqlite3, uvicorn, sys, bcrypt, secrets, re, os, asyncio, httpx, urllib, websockets -from fastapi import FastAPI, Request, HTTPException, WebSocket -from starlette.middleware.sessions import SessionMiddleware +from datetime import datetime, timedelta +import sqlite3, uvicorn, sys, secrets, re, os, asyncio, httpx, urllib, websockets +from tabnanny import check +from typing import Union +from fastapi import FastAPI, Request, HTTPException, WebSocket, Depends from pydantic import BaseModel from fastapi.responses import FileResponse, StreamingResponse from utils import SQLite, KeyValueStorage, gen_internal_port, ProxyManager, from_name_get_id, STATUS +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from jose import JWTError, jwt +from passlib.context import CryptContext ON_DOCKER = len(sys.argv) > 1 and sys.argv[1] == "DOCKER" DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG" @@ -15,6 +20,14 @@ db = SQLite('db/firegex.db') conf = KeyValueStorage(db) firewall = ProxyManager(db) +JWT_ALGORITHM="HS256" +JWT_SECRET = secrets.token_hex(32) +APP_STATUS = "init" +REACT_BUILD_DIR = "../frontend/build/" if not ON_DOCKER else "frontend/" +REACT_HTML_PATH = os.path.join(REACT_BUILD_DIR,"index.html") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/login", auto_error=False) +crypto = CryptContext(schemes=["bcrypt"], deprecated="auto") + app = FastAPI(debug=DEBUG) @app.on_event("shutdown") @@ -56,28 +69,38 @@ async def startup_event(): await firewall.reload() -app.add_middleware(SessionMiddleware, secret_key=os.urandom(32)) -SESSION_TOKEN = secrets.token_hex(8) -APP_STATUS = "init" -REACT_BUILD_DIR = "../frontend/build/" if not ON_DOCKER else "frontend/" -REACT_HTML_PATH = os.path.join(REACT_BUILD_DIR,"index.html") +def create_access_token(data: dict): + global JWT_SECRET + to_encode = data.copy() + encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) + return encoded_jwt +async def check_login(token: str = Depends(oauth2_scheme)): + global JWT_SECRET + if not token: + return False + try: + payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) + logged_in: bool = payload.get("logged_in") + except JWTError: + return False + return logged_in - -def is_loggined(request: Request): - global SESSION_TOKEN - return request.session.get("token", "") == SESSION_TOKEN - -def login_check(request: Request): - if is_loggined(request): return True - raise HTTPException(status_code=401, detail="Invalid login session!") +async def is_loggined(auth: bool = Depends(check_login)): + if not auth: + raise HTTPException( + status_code=401, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return True @app.get("/api/status") -async def get_status(request: Request): +async def get_status(auth: bool = Depends(check_login)): global APP_STATUS return { "status":APP_STATUS, - "loggined": is_loggined(request) + "loggined": auth } class PasswordForm(BaseModel): @@ -88,58 +111,51 @@ class PasswordChangeForm(BaseModel): expire: bool @app.post("/api/login") -async def login_api(request: Request, form: PasswordForm): - global APP_STATUS, SESSION_TOKEN +async def login_api(form: OAuth2PasswordRequestForm = Depends()): + global APP_STATUS, JWT_SECRET + if APP_STATUS != "run": raise HTTPException(status_code=400) if form.password == "": return {"status":"Cannot insert an empty password!"} await asyncio.sleep(0.3) # No bruteforce :) + if crypto.verify(form.password, conf.get("password")): + print("access granted, good job") + return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"} + raise HTTPException(406,"Wrong password!") - if bcrypt.checkpw(form.password.encode(), conf.get("password").encode()): - request.session["token"] = SESSION_TOKEN - return { "status":"ok" } - - return {"status":"Wrong password!"} - -@app.get("/api/logout") -async def logout(request: Request): - request.session["token"] = False - return { "status":"ok" } @app.post('/api/change-password') -async def change_password(request: Request, form: PasswordChangeForm): - login_check(request) - global APP_STATUS, SESSION_TOKEN +async def change_password(form: PasswordChangeForm, auth: bool = Depends(is_loggined)): + + global APP_STATUS, JWT_SECRET if APP_STATUS != "run": raise HTTPException(status_code=400) if form.password == "": return {"status":"Cannot insert an empty password!"} if form.expire: - SESSION_TOKEN = secrets.token_hex(8) - request.session["token"] = SESSION_TOKEN + JWT_SECRET = secrets.token_hex(32) - hash_psw = bcrypt.hashpw(form.password.encode(), bcrypt.gensalt()) - conf.put("password",hash_psw.decode()) - return {"status":"ok"} + hash_psw = crypto.hash(form.password) + conf.put("password",hash_psw) + return {"status":"ok", "access_token": create_access_token({"logged_in": True})} @app.post('/api/set-password') -async def set_password(request: Request, form: PasswordForm): - global APP_STATUS, SESSION_TOKEN +async def set_password(form: PasswordForm): + global APP_STATUS, JWT_SECRET if APP_STATUS != "init": raise HTTPException(status_code=400) if form.password == "": return {"status":"Cannot insert an empty password!"} - hash_psw = bcrypt.hashpw(form.password.encode(), bcrypt.gensalt()) - conf.put("password",hash_psw.decode()) + hash_psw = crypto.hash(form.password) + conf.put("password",hash_psw) APP_STATUS = "run" - request.session["token"] = SESSION_TOKEN - return {"status":"ok"} + return {"status":"ok", "access_token": create_access_token({"logged_in": True})} @app.get('/api/general-stats') -async def get_general_stats(request: Request): - login_check(request) +async def get_general_stats(auth: bool = Depends(is_loggined)): + return db.query(""" SELECT (SELECT COALESCE(SUM(blocked_packets),0) FROM regexes) closed, @@ -148,8 +164,8 @@ async def get_general_stats(request: Request): """)[0] @app.get('/api/services') -async def get_services(request: Request): - login_check(request) +async def get_services(auth: bool = Depends(is_loggined)): + return db.query(""" SELECT s.service_id `id`, @@ -164,8 +180,8 @@ async def get_services(request: Request): """) @app.get('/api/service/{service_id}') -async def get_service(request: Request, service_id: str): - login_check(request) +async def get_service(service_id: str, auth: bool = Depends(is_loggined)): + res = db.query(""" SELECT s.service_id `id`, @@ -182,26 +198,26 @@ async def get_service(request: Request, service_id: str): return res[0] @app.get('/api/service/{service_id}/stop') -async def get_service_stop(request: Request, service_id: str): - login_check(request) +async def get_service_stop(service_id: str, auth: bool = Depends(is_loggined)): + await firewall.get(service_id).next(STATUS.STOP) return {'status': 'ok'} @app.get('/api/service/{service_id}/pause') -async def get_service_pause(request: Request, service_id: str): - login_check(request) +async def get_service_pause(service_id: str, auth: bool = Depends(is_loggined)): + await firewall.get(service_id).next(STATUS.PAUSE) return {'status': 'ok'} @app.get('/api/service/{service_id}/start') -async def get_service_start(request: Request, service_id: str): - login_check(request) +async def get_service_start(service_id: str, auth: bool = Depends(is_loggined)): + await firewall.get(service_id).next(STATUS.ACTIVE) return {'status': 'ok'} @app.get('/api/service/{service_id}/delete') -async def get_service_delete(request: Request, service_id: str): - login_check(request) +async def get_service_delete(service_id: str, auth: bool = Depends(is_loggined)): + db.query('DELETE FROM services WHERE service_id = ?;', service_id) db.query('DELETE FROM regexes WHERE service_id = ?;', service_id) await firewall.remove(service_id) @@ -209,16 +225,16 @@ async def get_service_delete(request: Request, service_id: str): @app.get('/api/service/{service_id}/regen-port') -async def get_regen_port(request: Request, service_id: str): - login_check(request) +async def get_regen_port(service_id: str, auth: bool = Depends(is_loggined)): + db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', gen_internal_port(db), service_id) await firewall.get(service_id).update_port() return {'status': 'ok'} @app.get('/api/service/{service_id}/regexes') -async def get_service_regexes(request: Request, service_id: str): - login_check(request) +async def get_service_regexes(service_id: str, auth: bool = Depends(is_loggined)): + return db.query(""" SELECT regex, mode, regex_id `id`, service_id, is_blacklist, @@ -227,8 +243,8 @@ async def get_service_regexes(request: Request, service_id: str): """, service_id) @app.get('/api/regex/{regex_id}') -async def get_regex_id(request: Request, regex_id: int): - login_check(request) +async def get_regex_id(regex_id: int, auth: bool = Depends(is_loggined)): + res = db.query(""" SELECT regex, mode, regex_id `id`, service_id, is_blacklist, @@ -239,8 +255,8 @@ async def get_regex_id(request: Request, regex_id: int): return res[0] @app.get('/api/regex/{regex_id}/delete') -async def get_regex_delete(request: Request, regex_id: int): - login_check(request) +async def get_regex_delete(regex_id: int, auth: bool = Depends(is_loggined)): + res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id) if len(res) != 0: @@ -257,8 +273,8 @@ class RegexAddForm(BaseModel): is_case_sensitive: bool @app.post('/api/regexes/add') -async def post_regexes_add(request: Request, form: RegexAddForm): - login_check(request) +async def post_regexes_add(form: RegexAddForm, auth: bool = Depends(is_loggined)): + try: re.compile(b64decode(form.regex)) except Exception: @@ -277,8 +293,8 @@ class ServiceAddForm(BaseModel): port: int @app.post('/api/services/add') -async def post_services_add(request: Request, form: ServiceAddForm): - login_check(request) +async def post_services_add(form: ServiceAddForm, auth: bool = Depends(is_loggined)): + serv_id = from_name_get_id(form.name) try: db.query("INSERT INTO services (name, service_id, internal_port, public_port, status) VALUES (?, ?, ?, ?, ?)", @@ -323,7 +339,7 @@ if DEBUG: await asyncio.gather(fwd_task, rev_task) @app.get("/{full_path:path}") -async def catch_all(request: Request, full_path:str): +async def catch_all(full_path:str): if DEBUG: try: return await frontend_debug_proxy(full_path) @@ -339,5 +355,6 @@ if __name__ == '__main__': host="0.0.0.0", port=int(os.getenv("PORT","4444")), reload=DEBUG, - access_log=DEBUG + access_log=DEBUG, + workers=2 ) diff --git a/backend/proxy/__init__.py b/backend/proxy/__init__.py index c2e2738..bf782d2 100755 --- a/backend/proxy/__init__.py +++ b/backend/proxy/__init__.py @@ -1,6 +1,4 @@ -import subprocess, re, os, asyncio - -#c++ -o proxy proxy.cpp +import re, os, asyncio class Filter: def __init__(self, regex, is_case_sensitive=True, is_blacklist=True, c_to_s=False, s_to_c=False, blocked_packets=0, code=None): diff --git a/backend/requirements.txt b/backend/requirements.txt index 8feb4c9..649a99d 100755 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,5 +1,5 @@ fastapi[all] httpx uvicorn[standard] -bcrypt -kthread +passlib[bcrypt] +python-jose[cryptography] \ No newline at end of file diff --git a/backend/utils.py b/backend/utils.py index bc9c3bc..1cd6f0b 100755 --- a/backend/utils.py +++ b/backend/utils.py @@ -182,14 +182,17 @@ class ServiceManager: def __proxy_starter(self,to): async def func(): - while True: - if check_port_is_open(self.proxy.public_port): - self._set_status(to) - await self.proxy.start(in_pause=(to==STATUS.PAUSE)) - self._set_status(STATUS.STOP) - return - else: - await asyncio.sleep(.5) + try: + while True: + if check_port_is_open(self.proxy.public_port): + self._set_status(to) + await self.proxy.start(in_pause=(to==STATUS.PAUSE)) + self._set_status(STATUS.STOP) + return + else: + await asyncio.sleep(.5) + except Exception: + await self.proxy.stop() self.starter = asyncio.create_task(func()) class ProxyManager: diff --git a/frontend/build/asset-manifest.json b/frontend/build/asset-manifest.json deleted file mode 100644 index 6d1fa46..0000000 --- a/frontend/build/asset-manifest.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "files": { - "main.css": "/static/css/main.0efd334b.css", - "main.js": "/static/js/main.f153478b.js", - "index.html": "/index.html", - "main.0efd334b.css.map": "/static/css/main.0efd334b.css.map", - "main.f153478b.js.map": "/static/js/main.f153478b.js.map" - }, - "entrypoints": [ - "static/css/main.0efd334b.css", - "static/js/main.f153478b.js" - ] -} \ No newline at end of file diff --git a/frontend/build/index.html b/frontend/build/index.html deleted file mode 100644 index 398fadd..0000000 --- a/frontend/build/index.html +++ /dev/null @@ -1 +0,0 @@ -
a||125d?(a.sortIndex=c,f(t,a),null===h(r)&&a===h(t)&&(B?(E(L),L=-1):B=!0,K(H,c-d))):(a.sortIndex=e,f(r,a),A||z||(A=!0,I(J)));return a};\nexports.unstable_shouldYield=M;exports.unstable_wrapCallback=function(a){var b=y;return function(){var c=y;y=b;try{return a.apply(this,arguments)}finally{y=c}}};\n","'use strict';\n\nif (process.env.NODE_ENV === 'production') {\n module.exports = require('./cjs/scheduler.production.min.js');\n} else {\n module.exports = require('./cjs/scheduler.development.js');\n}\n","// The module cache\nvar __webpack_module_cache__ = {};\n\n// The require function\nfunction __webpack_require__(moduleId) {\n\t// Check if module is in cache\n\tvar cachedModule = __webpack_module_cache__[moduleId];\n\tif (cachedModule !== undefined) {\n\t\treturn cachedModule.exports;\n\t}\n\t// Create a new module (and put it into the cache)\n\tvar module = __webpack_module_cache__[moduleId] = {\n\t\t// no module.id needed\n\t\t// no module.loaded needed\n\t\texports: {}\n\t};\n\n\t// Execute the module function\n\t__webpack_modules__[moduleId](module, module.exports, __webpack_require__);\n\n\t// Return the exports of the module\n\treturn module.exports;\n}\n\n","// getDefaultExport function for compatibility with non-harmony modules\n__webpack_require__.n = function(module) {\n\tvar getter = module && module.__esModule ?\n\t\tfunction() { return module['default']; } :\n\t\tfunction() { return module; };\n\t__webpack_require__.d(getter, { a: getter });\n\treturn getter;\n};","// define getter functions for harmony exports\n__webpack_require__.d = function(exports, definition) {\n\tfor(var key in definition) {\n\t\tif(__webpack_require__.o(definition, key) && !__webpack_require__.o(exports, key)) {\n\t\t\tObject.defineProperty(exports, key, { enumerable: true, get: definition[key] });\n\t\t}\n\t}\n};","__webpack_require__.o = function(obj, prop) { return Object.prototype.hasOwnProperty.call(obj, prop); }","export default function _arrayLikeToArray(arr, len) {\n if (len == null || len > arr.length) len = arr.length;\n\n for (var i = 0, arr2 = new Array(len); i < len; i++) {\n arr2[i] = arr[i];\n }\n\n return arr2;\n}","import arrayLikeToArray from \"./arrayLikeToArray.js\";\nexport default function _unsupportedIterableToArray(o, minLen) {\n if (!o) return;\n if (typeof o === \"string\") return arrayLikeToArray(o, minLen);\n var n = Object.prototype.toString.call(o).slice(8, -1);\n if (n === \"Object\" && o.constructor) n = o.constructor.name;\n if (n === \"Map\" || n === \"Set\") return Array.from(o);\n if (n === \"Arguments\" || /^(?:Ui|I)nt(?:8|16|32)(?:Clamped)?Array$/.test(n)) return arrayLikeToArray(o, minLen);\n}","import arrayWithHoles from \"./arrayWithHoles.js\";\nimport iterableToArrayLimit from \"./iterableToArrayLimit.js\";\nimport unsupportedIterableToArray from \"./unsupportedIterableToArray.js\";\nimport nonIterableRest from \"./nonIterableRest.js\";\nexport default function _slicedToArray(arr, i) {\n return arrayWithHoles(arr) || iterableToArrayLimit(arr, i) || unsupportedIterableToArray(arr, i) || nonIterableRest();\n}","export default function _arrayWithHoles(arr) {\n if (Array.isArray(arr)) return arr;\n}","export default function _iterableToArrayLimit(arr, i) {\n var _i = arr == null ? null : typeof Symbol !== \"undefined\" && arr[Symbol.iterator] || arr[\"@@iterator\"];\n\n if (_i == null) return;\n var _arr = [];\n var _n = true;\n var _d = false;\n\n var _s, _e;\n\n try {\n for (_i = _i.call(arr); !(_n = (_s = _i.next()).done); _n = true) {\n _arr.push(_s.value);\n\n if (i && _arr.length === i) break;\n }\n } catch (err) {\n _d = true;\n _e = err;\n } finally {\n try {\n if (!_n && _i[\"return\"] != null) _i[\"return\"]();\n } finally {\n if (_d) throw _e;\n }\n }\n\n return _arr;\n}","export default function _nonIterableRest() {\n throw new TypeError(\"Invalid attempt to destructure non-iterable instance.\\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.\");\n}","export default function _extends() {\n _extends = Object.assign ? Object.assign.bind() : function (target) {\n for (var i = 1; i < arguments.length; i++) {\n var source = arguments[i];\n\n for (var key in source) {\n if (Object.prototype.hasOwnProperty.call(source, key)) {\n target[key] = source[key];\n }\n }\n }\n\n return target;\n };\n return _extends.apply(this, arguments);\n}","import * as React from \"react\";\nimport type { History, Location } from \"history\";\nimport { Action as NavigationType } from \"history\";\n\nimport type { RouteMatch } from \"./router\";\n\n/**\n * A Navigator is a \"location changer\"; it's how you get to different locations.\n *\n * Every history instance conforms to the Navigator interface, but the\n * distinction is useful primarily when it comes to the low-level