author DomySh <me@domysh.com> 1656456810 +0200
committer DomySh <me@domysh.com> 1656457473 +0200

Small Fixes (Stable only with 1 worker!)
This commit is contained in:
DomySh
2022-06-29 00:53:30 +02:00
parent 80a38f0d50
commit fad6ad4e68
13 changed files with 80 additions and 127 deletions

View File

@@ -1,10 +1,7 @@
from base64 import b64decode
from datetime import datetime, timedelta
import sqlite3, uvicorn, sys, secrets, re, os, asyncio, httpx, urllib, websockets
from tabnanny import check
from typing import Union
from fastapi import FastAPI, Request, HTTPException, WebSocket, Depends
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, WebSocket, Depends
from pydantic import BaseModel, BaseSettings
from fastapi.responses import FileResponse, StreamingResponse
from utils import SQLite, KeyValueStorage, gen_internal_port, ProxyManager, from_name_get_id, STATUS
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
@@ -20,16 +17,22 @@ db = SQLite('db/firegex.db')
conf = KeyValueStorage(db)
firewall = ProxyManager(db)
JWT_ALGORITHM="HS256"
JWT_SECRET = secrets.token_hex(32)
APP_STATUS = "init"
REACT_BUILD_DIR = "../frontend/build/" if not ON_DOCKER else "frontend/"
REACT_HTML_PATH = os.path.join(REACT_BUILD_DIR,"index.html")
class Settings(BaseSettings):
JWT_ALGORITHM: str = "HS256"
REACT_BUILD_DIR: str = "../frontend/build/" if not ON_DOCKER else "frontend/"
REACT_HTML_PATH: str = os.path.join(REACT_BUILD_DIR,"index.html")
settings = Settings()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/login", auto_error=False)
crypto = CryptContext(schemes=["bcrypt"], deprecated="auto")
app = FastAPI(debug=DEBUG)
def APP_STATUS(): return "init" if conf.get("password") is None else "run"
def JWT_SECRET(): return conf.get("secret")
@app.on_event("shutdown")
async def shutdown_event():
await firewall.close()
@@ -37,50 +40,21 @@ async def shutdown_event():
@app.on_event("startup")
async def startup_event():
global APP_STATUS
db.connect()
db.create_schema({
'services': {
'status': 'VARCHAR(100) NOT NULL',
'service_id': 'VARCHAR(100) PRIMARY KEY',
'internal_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE',
'public_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE',
'name': 'VARCHAR(100) NOT NULL'
},
'regexes': {
'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL',
'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'regex_id': 'INTEGER PRIMARY KEY',
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
},
'keys_values': {
'key': 'VARCHAR(100) PRIMARY KEY',
'value': 'VARCHAR(100) NOT NULL',
},
})
db.query("CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);")
if not conf.get("password") is None:
APP_STATUS = "run"
db.init()
if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32))
await firewall.reload()
def create_access_token(data: dict):
global JWT_SECRET
to_encode = data.copy()
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
encoded_jwt = jwt.encode(to_encode, JWT_SECRET(), algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
async def check_login(token: str = Depends(oauth2_scheme)):
global JWT_SECRET
if not token:
return False
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
payload = jwt.decode(token, JWT_SECRET(), algorithms=[settings.JWT_ALGORITHM])
logged_in: bool = payload.get("logged_in")
except JWTError:
return False
@@ -97,9 +71,8 @@ async def is_loggined(auth: bool = Depends(check_login)):
@app.get("/api/status")
async def get_status(auth: bool = Depends(check_login)):
global APP_STATUS
return {
"status":APP_STATUS,
"status": APP_STATUS(),
"loggined": auth
}
@@ -112,29 +85,23 @@ class PasswordChangeForm(BaseModel):
@app.post("/api/login")
async def login_api(form: OAuth2PasswordRequestForm = Depends()):
global APP_STATUS, JWT_SECRET
if APP_STATUS != "run": raise HTTPException(status_code=400)
if APP_STATUS() != "run": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
await asyncio.sleep(0.3) # No bruteforce :)
if crypto.verify(form.password, conf.get("password")):
print("access granted, good job")
return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"}
raise HTTPException(406,"Wrong password!")
@app.post('/api/change-password')
async def change_password(form: PasswordChangeForm, auth: bool = Depends(is_loggined)):
global APP_STATUS, JWT_SECRET
if APP_STATUS != "run": raise HTTPException(status_code=400)
if APP_STATUS() != "run": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
if form.expire:
JWT_SECRET = secrets.token_hex(32)
conf.put("secret", secrets.token_hex(32))
hash_psw = crypto.hash(form.password)
conf.put("password",hash_psw)
@@ -143,14 +110,11 @@ async def change_password(form: PasswordChangeForm, auth: bool = Depends(is_logg
@app.post('/api/set-password')
async def set_password(form: PasswordForm):
global APP_STATUS, JWT_SECRET
if APP_STATUS != "init": raise HTTPException(status_code=400)
if APP_STATUS() != "init": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
hash_psw = crypto.hash(form.password)
conf.put("password",hash_psw)
APP_STATUS = "run"
return {"status":"ok", "access_token": create_access_token({"logged_in": True})}
@app.get('/api/general-stats')
@@ -165,7 +129,6 @@ async def get_general_stats(auth: bool = Depends(is_loggined)):
@app.get('/api/services')
async def get_services(auth: bool = Depends(is_loggined)):
return db.query("""
SELECT
s.service_id `id`,
@@ -199,25 +162,21 @@ async def get_service(service_id: str, auth: bool = Depends(is_loggined)):
@app.get('/api/service/{service_id}/stop')
async def get_service_stop(service_id: str, auth: bool = Depends(is_loggined)):
await firewall.get(service_id).next(STATUS.STOP)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/pause')
async def get_service_pause(service_id: str, auth: bool = Depends(is_loggined)):
await firewall.get(service_id).next(STATUS.PAUSE)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/start')
async def get_service_start(service_id: str, auth: bool = Depends(is_loggined)):
await firewall.get(service_id).next(STATUS.ACTIVE)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/delete')
async def get_service_delete(service_id: str, auth: bool = Depends(is_loggined)):
db.query('DELETE FROM services WHERE service_id = ?;', service_id)
db.query('DELETE FROM regexes WHERE service_id = ?;', service_id)
await firewall.remove(service_id)
@@ -226,7 +185,6 @@ async def get_service_delete(service_id: str, auth: bool = Depends(is_loggined))
@app.get('/api/service/{service_id}/regen-port')
async def get_regen_port(service_id: str, auth: bool = Depends(is_loggined)):
db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', gen_internal_port(db), service_id)
await firewall.get(service_id).update_port()
return {'status': 'ok'}
@@ -234,7 +192,6 @@ async def get_regen_port(service_id: str, auth: bool = Depends(is_loggined)):
@app.get('/api/service/{service_id}/regexes')
async def get_service_regexes(service_id: str, auth: bool = Depends(is_loggined)):
return db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
@@ -244,7 +201,6 @@ async def get_service_regexes(service_id: str, auth: bool = Depends(is_loggined)
@app.get('/api/regex/{regex_id}')
async def get_regex_id(regex_id: int, auth: bool = Depends(is_loggined)):
res = db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
@@ -256,9 +212,7 @@ async def get_regex_id(regex_id: int, auth: bool = Depends(is_loggined)):
@app.get('/api/regex/{regex_id}/delete')
async def get_regex_delete(regex_id: int, auth: bool = Depends(is_loggined)):
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
db.query('DELETE FROM regexes WHERE regex_id = ?;', regex_id)
await firewall.get(res[0]["service_id"]).update_filters()
@@ -274,7 +228,6 @@ class RegexAddForm(BaseModel):
@app.post('/api/regexes/add')
async def post_regexes_add(form: RegexAddForm, auth: bool = Depends(is_loggined)):
try:
re.compile(b64decode(form.regex))
except Exception:
@@ -294,7 +247,6 @@ class ServiceAddForm(BaseModel):
@app.post('/api/services/add')
async def post_services_add(form: ServiceAddForm, auth: bool = Depends(is_loggined)):
serv_id = from_name_get_id(form.name)
try:
db.query("INSERT INTO services (name, service_id, internal_port, public_port, status) VALUES (?, ?, ?, ?, ?)",
@@ -312,9 +264,9 @@ async def frontend_debug_proxy(path):
return StreamingResponse(resp.aiter_bytes(),status_code=resp.status_code)
async def react_deploy(path):
file_request = os.path.join(REACT_BUILD_DIR, path)
file_request = os.path.join(settings.REACT_BUILD_DIR, path)
if not os.path.isfile(file_request):
return FileResponse(REACT_HTML_PATH, media_type='text/html')
return FileResponse(settings.REACT_HTML_PATH, media_type='text/html')
else:
return FileResponse(file_request)
@@ -356,5 +308,5 @@ if __name__ == '__main__':
port=int(os.getenv("PORT","4444")),
reload=DEBUG,
access_log=DEBUG,
workers=2
workers=1
)