push: code changes
This commit is contained in:
@@ -9,12 +9,13 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from utils.sqlite import SQLite
|
||||
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager
|
||||
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager, NORELOAD
|
||||
from utils.loader import frontend_deploy, load_routers
|
||||
from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import socketio
|
||||
from socketio.exceptions import ConnectionRefusedError
|
||||
|
||||
# DB init
|
||||
db = SQLite('db/firegex.db')
|
||||
@@ -52,7 +53,6 @@ if DEBUG:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
utils.socketio = socketio.AsyncServer(
|
||||
async_mode="asgi",
|
||||
cors_allowed_origins=[],
|
||||
@@ -69,9 +69,6 @@ def set_psw(psw: str):
|
||||
hash_psw = crypto.hash(psw)
|
||||
db.put("password",hash_psw)
|
||||
|
||||
@utils.socketio.on("update")
|
||||
async def updater(): pass
|
||||
|
||||
def create_access_token(data: dict):
|
||||
to_encode = data.copy()
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET(), algorithm=JWT_ALGORITHM)
|
||||
@@ -90,6 +87,28 @@ async def check_login(token: str = Depends(oauth2_scheme)):
|
||||
return False
|
||||
return logged_in
|
||||
|
||||
@utils.socketio.on("connect")
|
||||
async def sio_connect(sid, environ, auth):
|
||||
if not auth or not await check_login(auth.get("token")):
|
||||
raise ConnectionRefusedError("Unauthorized")
|
||||
utils.sid_list.add(sid)
|
||||
|
||||
@utils.socketio.on("disconnect")
|
||||
async def sio_disconnect(sid):
|
||||
try:
|
||||
utils.sid_list.remove(sid)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
async def disconnect_all():
|
||||
while True:
|
||||
if len(utils.sid_list) == 0:
|
||||
break
|
||||
await utils.socketio.disconnect(utils.sid_list.pop())
|
||||
|
||||
@utils.socketio.on("update")
|
||||
async def updater(): pass
|
||||
|
||||
async def is_loggined(auth: bool = Depends(check_login)):
|
||||
if not auth:
|
||||
raise HTTPException(
|
||||
@@ -122,6 +141,7 @@ async def login_api(form: OAuth2PasswordRequestForm = Depends()):
|
||||
return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"}
|
||||
raise HTTPException(406,"Wrong password!")
|
||||
|
||||
|
||||
@app.post('/api/set-password', response_model=ChangePasswordModel)
|
||||
async def set_password(form: PasswordForm):
|
||||
"""Set the password of firegex"""
|
||||
@@ -143,6 +163,7 @@ async def change_password(form: PasswordChangeForm):
|
||||
return {"status":"Cannot insert an empty password!"}
|
||||
if form.expire:
|
||||
db.put("secret", secrets.token_hex(32))
|
||||
await disconnect_all()
|
||||
|
||||
set_psw(form.password)
|
||||
await refresh_frontend()
|
||||
@@ -200,7 +221,7 @@ if __name__ == '__main__':
|
||||
"app:app",
|
||||
host="::" if DEBUG else None,
|
||||
port=FIREGEX_PORT,
|
||||
reload=DEBUG,
|
||||
reload=DEBUG and not NORELOAD,
|
||||
access_log=True,
|
||||
workers=1, # Firewall module can't be replicated in multiple workers
|
||||
# Later the firewall module will be moved to a separate process
|
||||
|
||||
Reference in New Issue
Block a user