Socket io implementation
This commit is contained in:
@@ -8,6 +8,7 @@ from utils import *
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from fastapi_socketio import SocketManager
|
||||
|
||||
ON_DOCKER = len(sys.argv) > 1 and sys.argv[1] == "DOCKER"
|
||||
DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG"
|
||||
@@ -30,14 +31,22 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/login", auto_error=False)
|
||||
crypto = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
app = FastAPI(debug=DEBUG, redoc_url=None)
|
||||
sio = SocketManager(app, "/sock", socketio_path="")
|
||||
|
||||
def APP_STATUS(): return "init" if conf.get("password") is None else "run"
|
||||
def JWT_SECRET(): return conf.get("secret")
|
||||
|
||||
async def refresh_frontend():
|
||||
await sio.emit("update","Refresh")
|
||||
|
||||
@sio.on("update")
|
||||
async def updater(): pass
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
db.init()
|
||||
await firewall.init()
|
||||
await firewall.init(refresh_frontend)
|
||||
await refresh_frontend()
|
||||
if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32))
|
||||
|
||||
@app.on_event("shutdown")
|
||||
@@ -117,6 +126,7 @@ async def change_password(form: PasswordChangeForm, auth: bool = Depends(is_logg
|
||||
|
||||
hash_psw = crypto.hash(form.password)
|
||||
conf.put("password",hash_psw)
|
||||
await refresh_frontend()
|
||||
return {"status":"ok", "access_token": create_access_token({"logged_in": True})}
|
||||
|
||||
|
||||
@@ -128,6 +138,7 @@ async def set_password(form: PasswordForm):
|
||||
return {"status":"Cannot insert an empty password!"}
|
||||
hash_psw = crypto.hash(form.password)
|
||||
conf.put("password",hash_psw)
|
||||
await refresh_frontend()
|
||||
return {"status":"ok", "access_token": create_access_token({"logged_in": True})}
|
||||
|
||||
class GeneralStatModel(BaseModel):
|
||||
@@ -189,12 +200,14 @@ class StatusMessageModel(BaseModel):
|
||||
async def service_stop(service_port: int, auth: bool = Depends(is_loggined)):
|
||||
"""Request the stop of a specific service"""
|
||||
await firewall.get(service_port).next(STATUS.STOP)
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.get('/api/service/{service_port}/start', response_model=StatusMessageModel)
|
||||
async def service_start(service_port: int, auth: bool = Depends(is_loggined)):
|
||||
"""Request the start of a specific service"""
|
||||
await firewall.get(service_port).next(STATUS.ACTIVE)
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.get('/api/service/{service_port}/delete', response_model=StatusMessageModel)
|
||||
@@ -203,6 +216,7 @@ async def service_delete(service_port: int, auth: bool = Depends(is_loggined)):
|
||||
db.query('DELETE FROM services WHERE port = ?;', service_port)
|
||||
db.query('DELETE FROM regexes WHERE service_port = ?;', service_port)
|
||||
await firewall.remove(service_port)
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
class RenameForm(BaseModel):
|
||||
@@ -213,6 +227,7 @@ async def service_rename(service_port: int, form: RenameForm, auth: bool = Depen
|
||||
"""Request to change the name of a specific service"""
|
||||
if not form.name: return {'status': 'The name cannot be empty!'}
|
||||
db.query('UPDATE services SET name=? WHERE port = ?;', form.name, service_port)
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
class RegexModel(BaseModel):
|
||||
@@ -254,6 +269,7 @@ async def regex_delete(regex_id: int, auth: bool = Depends(is_loggined)):
|
||||
if len(res) != 0:
|
||||
db.query('DELETE FROM regexes WHERE regex_id = ?;', regex_id)
|
||||
await firewall.get(res[0]["service_port"]).update_filters()
|
||||
await refresh_frontend()
|
||||
|
||||
return {'status': 'ok'}
|
||||
|
||||
@@ -264,6 +280,7 @@ async def regex_enable(regex_id: int, auth: bool = Depends(is_loggined)):
|
||||
if len(res) != 0:
|
||||
db.query('UPDATE regexes SET active=1 WHERE regex_id = ?;', regex_id)
|
||||
await firewall.get(res[0]["service_port"]).update_filters()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.get('/api/regex/{regex_id}/disable', response_model=StatusMessageModel)
|
||||
@@ -273,6 +290,7 @@ async def regex_disable(regex_id: int, auth: bool = Depends(is_loggined)):
|
||||
if len(res) != 0:
|
||||
db.query('UPDATE regexes SET active=0 WHERE regex_id = ?;', regex_id)
|
||||
await firewall.get(res[0]["service_port"]).update_filters()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
class RegexAddForm(BaseModel):
|
||||
@@ -297,6 +315,7 @@ async def add_new_regex(form: RegexAddForm, auth: bool = Depends(is_loggined)):
|
||||
return {'status': 'An identical regex already exists'}
|
||||
|
||||
await firewall.get(form.service_port).update_filters()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
class ServiceAddForm(BaseModel):
|
||||
@@ -312,6 +331,7 @@ async def add_new_service(form: ServiceAddForm, auth: bool = Depends(is_loggined
|
||||
except sqlite3.IntegrityError:
|
||||
return {'status': 'Name or/and ports of the service has been already assigned to another service'}
|
||||
await firewall.reload()
|
||||
await refresh_frontend()
|
||||
|
||||
return {'status': 'ok'}
|
||||
|
||||
|
||||
@@ -5,4 +5,5 @@ passlib[bcrypt]
|
||||
python-jose[cryptography]
|
||||
NetfilterQueue
|
||||
scapy
|
||||
python-pcre
|
||||
python-pcre
|
||||
fastapi-socketio
|
||||
@@ -177,9 +177,9 @@ class ProxyManager:
|
||||
self.lock = asyncio.Lock()
|
||||
self.updater_task = None
|
||||
|
||||
def init_updater(self):
|
||||
def init_updater(self, callback = None):
|
||||
if not self.updater_task:
|
||||
self.updater_task = asyncio.create_task(self._stats_updater())
|
||||
self.updater_task = asyncio.create_task(self._stats_updater(callback))
|
||||
|
||||
def close_updater(self):
|
||||
if self.updater_task: self.updater_task.cancel()
|
||||
@@ -196,11 +196,11 @@ class ProxyManager:
|
||||
await self.proxy_table[port].next(STATUS.STOP)
|
||||
del self.proxy_table[port]
|
||||
|
||||
async def init(self):
|
||||
async def init(self, callback = None):
|
||||
self.init_updater(callback)
|
||||
await self.reload()
|
||||
|
||||
async def reload(self):
|
||||
self.init_updater()
|
||||
async with self.lock:
|
||||
for srv in self.db.query('SELECT port, status FROM services;'):
|
||||
srv_port, req_status = srv["port"], srv["status"]
|
||||
@@ -210,7 +210,7 @@ class ProxyManager:
|
||||
self.proxy_table[srv_port] = ServiceManager(srv_port,self.db)
|
||||
await self.proxy_table[srv_port].next(req_status)
|
||||
|
||||
async def _stats_updater(self):
|
||||
async def _stats_updater(self, callback):
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
@@ -218,7 +218,10 @@ class ProxyManager:
|
||||
self.proxy_table[key].update_stats()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(1)
|
||||
if callback:
|
||||
if asyncio.iscoroutinefunction(callback): await callback()
|
||||
else: callback()
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
self.updater_task = None
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user