Flask -> FastAPI

This commit is contained in:
DomySh
2022-06-28 13:26:06 +02:00
parent d781801bac
commit 3f2e8bb2f8
9 changed files with 307 additions and 435 deletions

View File

@@ -1,351 +1,297 @@
from base64 import b64decode
import sqlite3, subprocess, sys, threading, bcrypt, secrets, time, re
from flask import Flask, jsonify, request, abort, session
from functools import wraps
from flask_cors import CORS
from jsonschema import validate
import sqlite3, uvicorn, sys, bcrypt, secrets, re, os, asyncio, httpx, urllib, websockets
from fastapi import FastAPI, Request, HTTPException, WebSocket
from starlette.middleware.sessions import SessionMiddleware
from pydantic import BaseModel
from fastapi.responses import FileResponse, StreamingResponse
from utils import SQLite, KeyValueStorage, gen_internal_port, ProxyManager, from_name_get_id, STATUS
ON_DOCKER = len(sys.argv) > 1 and sys.argv[1] == "DOCKER"
DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG"
# DB init
db = SQLite('firegex')
db.connect()
conf = KeyValueStorage(db)
firewall = ProxyManager(db)
try:
import uwsgi
IN_UWSGI = True
except ImportError:
IN_UWSGI = False
app = FastAPI(debug=DEBUG)
app = Flask(__name__)
app.add_middleware(SessionMiddleware, secret_key=os.urandom(32))
SESSION_TOKEN = secrets.token_hex(8)
APP_STATUS = "init"
REACT_BUILD_DIR = "../frontend/build/" if not ON_DOCKER else "../frontend/"
REACT_HTML_PATH = os.path.join(REACT_BUILD_DIR,"index.html")
DEBUG = not ((len(sys.argv) > 1 and sys.argv[1] == "UWSGI") or IN_UWSGI)
if not conf.get("password") is None:
APP_STATUS = "run"
def is_loggined():
if DEBUG: return True
return True if session.get("loggined") else False
def is_loggined(request: Request):
return request.session.get("token", "") == SESSION_TOKEN
def login_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if is_loggined() or DEBUG:
return f(*args, **kwargs)
else:
return abort(401)
return decorated_function
def login_check(request: Request):
if is_loggined(request): return True
raise HTTPException(status_code=401, detail="Invalid login session!")
@app.before_first_request
def before_first_request():
firewall.reload()
app.config['SECRET_KEY'] = secrets.token_hex(32)
if DEBUG:
app.config["STATUS"] = "run"
elif conf.get("password") is None:
app.config["STATUS"] = "init"
else:
app.config["STATUS"] = "run"
@app.route("/api/status")
def get_status():
if DEBUG:
return {
"status":app.config["STATUS"],
"loggined": is_loggined(),
"debug":True
}
else:
return {
"status":app.config["STATUS"],
"loggined": is_loggined()
}
@app.route("/api/login", methods = ['POST'])
def login():
if DEBUG: return { "status":"ok" }
if app.config["STATUS"] != "run": return abort(404)
req = request.get_json(force = True)
try:
validate(
instance=req,
schema={
"type" : "object",
"properties" : {
"password" : {"type" : "string"}
},
})
except Exception:
return abort(400)
if req["password"] == "":
return {"status":"Cannot insert an empty password!"}
time.sleep(.3) # No bruteforce :)
if bcrypt.checkpw(req["password"].encode(), conf.get("password").encode()):
session["loggined"] = True
return { "status":"ok" }
return {"status":"Wrong password!"}
@app.route("/api/logout")
def logout():
if DEBUG: return { "status":"ok" }
session["loggined"] = False
return { "status":"ok" }
@app.route('/api/change-password', methods = ['POST'])
@login_required
def change_password():
if DEBUG: return { "status":"ok" }
if app.config["STATUS"] != "run": return abort(404)
req = request.get_json(force = True)
try:
validate(
instance=req,
schema={
"type" : "object",
"properties" : {
"password" : {"type" : "string"},
"expire": {"type" : "boolean"},
},
})
except Exception:
return abort(400)
if req["password"] == "":
return {"status":"Cannot insert an empty password!"}
if req["expire"]:
app.config['SECRET_KEY'] = secrets.token_hex(32)
session["loggined"] = True
hash_psw = bcrypt.hashpw(req["password"].encode(), bcrypt.gensalt())
conf.put("password",hash_psw.decode())
return {"status":"ok"}
@app.route('/api/set-password', methods = ['POST'])
def set_password():
if DEBUG: return { "status":"ok" }
if app.config["STATUS"] != "init": return abort(404)
req = request.get_json(force = True)
try:
validate(
instance=req,
schema={
"type" : "object",
"properties" : {
"password" : {"type" : "string"}
},
})
except Exception:
return abort(400)
if not "password" in req or not isinstance(req["password"],str):
return abort(400)
if req["password"] == "":
return {"status":"Cannot insert an empty password!"}
hash_psw = bcrypt.hashpw(req["password"].encode(), bcrypt.gensalt())
conf.put("password",hash_psw.decode())
app.config["STATUS"] = "run"
session["loggined"] = True
return {"status":"ok"}
@app.route('/api/general-stats')
@login_required
def get_general_stats():
n_packets = db.query("SELECT SUM(blocked_packets) FROM regexes;")[0][0]
return {
'services': db.query("SELECT COUNT (*) FROM services;")[0][0],
'regexes': db.query("SELECT COUNT (*) FROM regexes;")[0][0],
'closed': n_packets if n_packets else 0
@app.get("/api/status")
async def get_status(request: Request):
global APP_STATUS
return {
"status":APP_STATUS,
"loggined": is_loggined(request)
}
@app.route('/api/services')
@login_required
def get_services():
res = []
for i in db.query('SELECT * FROM services;'):
n_regex = db.query('SELECT COUNT (*) FROM regexes WHERE service_id = ?;', (i[1],))[0][0]
n_packets = db.query('SELECT SUM(blocked_packets) FROM regexes WHERE service_id = ?;', (i[1],))[0][0]
class PasswordForm(BaseModel):
password: str
res.append({
'id': i[1],
'status': i[0],
'public_port': i[3],
'internal_port': i[2],
'n_regex': n_regex,
'n_packets': n_packets if n_packets else 0,
'name': i[4]
})
class PasswordChangeForm(BaseModel):
password: str
expire: bool
return jsonify(res)
@app.post("/api/login")
async def login_api(request: Request, form: PasswordForm):
global APP_STATUS
if APP_STATUS != "run": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
await asyncio.sleep(0.3) # No bruteforce :)
@app.route('/api/service/<serv>')
@login_required
def get_service(serv):
q = db.query('SELECT * FROM services WHERE service_id = ?;', (serv,))
if len(q) != 0:
n_regex = db.query('SELECT COUNT (*) FROM regexes WHERE service_id = ?;', (serv,))[0][0]
n_packets = db.query('SELECT SUM(blocked_packets) FROM regexes WHERE service_id = ?;', (serv,))[0][0]
return {
'id': q[0][1],
'status': q[0][0],
'public_port': q[0][3],
'internal_port': q[0][2],
'n_packets': n_packets if n_packets else 0,
'n_regex': n_regex,
'name': q[0][4]
}
else:
return abort(404)
@app.route('/api/service/<serv>/stop')
@login_required
def get_service_stop(serv):
firewall.change_status(serv,STATUS.STOP)
return {'status': 'ok'}
@app.route('/api/service/<serv>/pause')
@login_required
def get_service_pause(serv):
firewall.change_status(serv,STATUS.PAUSE)
return {'status': 'ok'}
@app.route('/api/service/<serv>/start')
@login_required
def get_service_start(serv):
firewall.change_status(serv,STATUS.ACTIVE)
return {'status': 'ok'}
@app.route('/api/service/<serv>/delete')
@login_required
def get_service_delete(serv):
db.query('DELETE FROM services WHERE service_id = ?;', (serv,))
db.query('DELETE FROM regexes WHERE service_id = ?;', (serv,))
firewall.fire_update(serv)
return {'status': 'ok'}
@app.route('/api/service/<serv>/regen-port')
@login_required
def get_regen_port(serv):
db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', (gen_internal_port(db), serv))
firewall.fire_update(serv)
return {'status': 'ok'}
@app.route('/api/service/<serv>/regexes')
@login_required
def get_service_regexes(serv):
return jsonify([
{
'id': row[5],
'service_id': row[2],
'regex': row[0],
'is_blacklist': True if row[3] == "1" else False,
'is_case_sensitive' : True if row[6] == "1" else False,
'mode': row[1],
'n_packets': row[4],
} for row in db.query('SELECT * FROM regexes WHERE service_id = ?;', (serv,))
])
@app.route('/api/regex/<int:regex_id>')
@login_required
def get_regex_id(regex_id):
q = db.query('SELECT * FROM regexes WHERE regex_id = ?;', (regex_id,))
if len(q) != 0:
return {
'id': regex_id,
'service_id': q[0][2],
'regex': q[0][0],
'is_blacklist': True if q[0][3] == "1" else False,
'is_case_sensitive' : True if q[0][7] == "1" else False,
'mode': q[0][1],
'n_packets': q[0][4],
}
else:
return abort(404)
@app.route('/api/regex/<int:regex_id>/delete')
@login_required
def get_regex_delete(regex_id):
q = db.query('SELECT * FROM regexes WHERE regex_id = ?;', (regex_id,))
if bcrypt.checkpw(form.password.encode(), conf.get("password").encode()):
request.session["token"] = SESSION_TOKEN
return { "status":"ok" }
if len(q) != 0:
db.query('DELETE FROM regexes WHERE regex_id = ?;', (regex_id,))
firewall.fire_update(q[0][2])
return {"status":"Wrong password!"}
@app.get("/api/logout")
async def logout(request: Request):
request.session["token"] = False
return { "status":"ok" }
@app.post('/api/change-password')
async def change_password(request: Request, form: PasswordChangeForm):
login_check(request)
global APP_STATUS
if APP_STATUS != "run": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
if form.expire:
SESSION_TOKEN = secrets.token_hex(8)
request.session["token"] = SESSION_TOKEN
hash_psw = bcrypt.hashpw(form.password.encode(), bcrypt.gensalt())
conf.put("password",hash_psw.decode())
return {"status":"ok"}
@app.post('/api/set-password')
async def set_password(request: Request, form: PasswordForm):
global APP_STATUS
if APP_STATUS != "init": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
hash_psw = bcrypt.hashpw(form.password.encode(), bcrypt.gensalt())
conf.put("password",hash_psw.decode())
APP_STATUS = "run"
request.session["token"] = SESSION_TOKEN
return {"status":"ok"}
@app.get('/api/general-stats')
async def get_general_stats(request: Request):
login_check(request)
return db.query("""
SELECT
(SELECT COALESCE(SUM(blocked_packets),0) FROM regexes) closed,
(SELECT COUNT(*) FROM regexes) regexes,
(SELECT COUNT(*) FROM services) services
""")[0]
@app.get('/api/services')
async def get_services(request: Request):
login_check(request)
return db.query("""
SELECT
s.service_id `id`,
s.status status,
s.public_port public_port,
s.internal_port internal_port,
s.name name,
COUNT(*) n_regex,
COALESCE(SUM(r.blocked_packets),0) n_packets
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id
GROUP BY s.service_id;
""")
@app.get('/api/service/{service_id}')
async def get_service(request: Request, service_id: str):
login_check(request)
res = db.query("""
SELECT
s.service_id `id`,
s.status status,
s.public_port public_port,
s.internal_port internal_port,
s.name name,
COUNT(*) n_regex,
COALESCE(SUM(r.blocked_packets),0) n_packets
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id WHERE s.service_id = ?
GROUP BY s.service_id;
""", service_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
return res[0]
@app.get('/api/service/{service_id}/stop')
async def get_service_stop(request: Request, service_id: str):
login_check(request)
firewall.change_status(service_id,STATUS.STOP)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/pause')
async def get_service_pause(request: Request, service_id: str):
login_check(request)
firewall.change_status(service_id,STATUS.PAUSE)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/start')
async def get_service_start(request: Request, service_id: str):
login_check(request)
firewall.change_status(service_id,STATUS.ACTIVE)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/delete')
async def get_service_delete(request: Request, service_id: str):
login_check(request)
db.query('DELETE FROM services WHERE service_id = ?;', service_id)
db.query('DELETE FROM regexes WHERE service_id = ?;', service_id)
firewall.fire_update(service_id)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/regen-port')
async def get_regen_port(request: Request, service_id: str):
login_check(request)
db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', gen_internal_port(db), service_id)
firewall.fire_update(service_id)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/regexes')
async def get_service_regexes(request: Request, service_id: str):
login_check(request)
return db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
blocked_packets n_packets, is_case_sensitive
FROM regexes WHERE service_id = ?;
""", service_id)
@app.get('/api/regex/{regex_id}')
async def get_regex_id(request: Request, regex_id: int):
login_check(request)
res = db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
blocked_packets n_packets, is_case_sensitive
FROM regexes WHERE `id` = ?;
""", regex_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This regex does not exists!")
return res[0]
@app.get('/api/regex/{regex_id}/delete')
async def get_regex_delete(request: Request, regex_id: int):
login_check(request)
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
db.query('DELETE FROM regexes WHERE regex_id = ?;', regex_id)
firewall.fire_update(res[0]["service_id"])
return {'status': 'ok'}
@app.route('/api/regexes/add', methods = ['POST'])
@login_required
def post_regexes_add():
req = request.get_json(force = True)
class RegexAddForm(BaseModel):
service_id: str
regex: str
mode: str
is_blacklist: bool
is_case_sensitive: bool
@app.post('/api/regexes/add')
async def post_regexes_add(request: Request, form: RegexAddForm):
login_check(request)
try:
validate(
instance=req,
schema={
"type" : "object",
"properties" : {
"service_id" : {"type" : "string"},
"regex" : {"type" : "string"},
"is_blacklist" : {"type" : "boolean"},
"mode" : {"type" : "string"},
"is_case_sensitive" : {"type" : "boolean"}
},
})
except Exception:
return abort(400)
try:
re.compile(b64decode(req["regex"]))
re.compile(b64decode(form.regex))
except Exception:
return {"status":"Invalid regex"}
try:
db.query("INSERT INTO regexes (service_id, regex, is_blacklist, mode, is_case_sensitive ) VALUES (?, ?, ?, ?, ?);",
(req['service_id'], req['regex'], req['is_blacklist'], req['mode'], req['is_case_sensitive']))
form.service_id, form.regex, form.is_blacklist, form.mode, form.is_case_sensitive)
except sqlite3.IntegrityError:
return {'status': 'An identical regex already exists'}
firewall.fire_update(req['service_id'])
firewall.fire_update(form.service_id)
return {'status': 'ok'}
class ServiceAddForm(BaseModel):
name: str
port: int
@app.route('/api/services/add', methods = ['POST'])
@login_required
def post_services_add():
req = request.get_json(force = True)
try:
validate(
instance=req,
schema={
"type" : "object",
"properties" : {
"name" : {"type" : "string"},
"port" : {"type" : "number"}
},
})
except Exception:
return abort(400)
serv_id = from_name_get_id(req['name'])
@app.post('/api/services/add')
async def post_services_add(request: Request, form: ServiceAddForm):
login_check(request)
serv_id = from_name_get_id(form.name)
try:
db.query("INSERT INTO services (name, service_id, internal_port, public_port, status) VALUES (?, ?, ?, ?, ?)",
(req['name'], serv_id, gen_internal_port(db), req['port'], 'stop'))
form.name, serv_id, gen_internal_port(db), form.port, 'stop')
firewall.reload()
except sqlite3.IntegrityError:
return {'status': 'Name or/and port of the service has been already assigned to another service'}
return {'status': 'ok'}
async def frontend_debug_proxy(path):
httpc = httpx.AsyncClient()
req = httpc.build_request("GET",urllib.parse.urljoin(f"http://0.0.0.0:{os.getenv('F_PORT','3000')}", path))
resp = await httpc.send(req, stream=True)
return StreamingResponse(resp.aiter_bytes(),status_code=resp.status_code)
async def react_deploy(path):
file_request = os.path.join(REACT_BUILD_DIR, path)
if not os.path.isfile(file_request):
return FileResponse(REACT_HTML_PATH, media_type='text/html')
else:
return FileResponse(file_request)
if DEBUG:
CORS(app, resources={r"/api/*": {"origins": "*"}}, supports_credentials=True )
async def forward_websocket(ws_a: WebSocket, ws_b: websockets.WebSocketClientProtocol):
while True:
data = await ws_a.receive_bytes()
await ws_b.send(data)
async def reverse_websocket(ws_a: WebSocket, ws_b: websockets.WebSocketClientProtocol):
while True:
data = await ws_b.recv()
await ws_a.send_text(data)
@app.websocket("/ws")
async def websocket_debug_proxy(ws: WebSocket):
await ws.accept()
async with websockets.connect(f"ws://0.0.0.0:{os.getenv('F_PORT','3000')}/ws") as ws_b_client:
fwd_task = asyncio.create_task(forward_websocket(ws, ws_b_client))
rev_task = asyncio.create_task(reverse_websocket(ws, ws_b_client))
await asyncio.gather(fwd_task, rev_task)
@app.get("/{full_path:path}")
async def catch_all(request: Request, full_path:str):
if DEBUG:
try:
return await frontend_debug_proxy(full_path)
except Exception:
return {"details":"Frontend not started at "+f"http://0.0.0.0:{os.getenv('F_PORT','3000')}"}
else: return await react_deploy(full_path)
if __name__ == '__main__':
db.check_integrity({
@@ -360,10 +306,10 @@ if __name__ == '__main__':
'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL',
'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'VARCHAR(1) NOT NULL',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL async defAULT 0',
'regex_id': 'INTEGER PRIMARY KEY',
'is_case_sensitive' : 'VARCHAR(1) NOT NULL',
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
},
'keys_values': {
@@ -372,7 +318,13 @@ if __name__ == '__main__':
},
})
db.query("CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);")
if DEBUG:
app.run(host="0.0.0.0", port=8080 ,debug=True)
else:
subprocess.run(["uwsgi","--socket","./uwsgi.sock","--master","--module","app:app", "--enable-threads"])
firewall.reload()
# os.environ {PORT = Backend Port (Main Port), F_PORT = Frontend Port}
uvicorn.run(
"app:app",
host="0.0.0.0",
port=int(os.getenv("PORT","4444")),
reload=DEBUG,
access_log=DEBUG,
)