oauth2 on fastapi

This commit is contained in:
nik012003
2022-06-28 21:49:03 +02:00
parent 4971281f5a
commit e0a881abdb
14 changed files with 159 additions and 189 deletions

View File

@@ -1,10 +1,15 @@
from base64 import b64decode
import sqlite3, uvicorn, sys, bcrypt, secrets, re, os, asyncio, httpx, urllib, websockets
from fastapi import FastAPI, Request, HTTPException, WebSocket
from starlette.middleware.sessions import SessionMiddleware
from datetime import datetime, timedelta
import sqlite3, uvicorn, sys, secrets, re, os, asyncio, httpx, urllib, websockets
from tabnanny import check
from typing import Union
from fastapi import FastAPI, Request, HTTPException, WebSocket, Depends
from pydantic import BaseModel
from fastapi.responses import FileResponse, StreamingResponse
from utils import SQLite, KeyValueStorage, gen_internal_port, ProxyManager, from_name_get_id, STATUS
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
ON_DOCKER = len(sys.argv) > 1 and sys.argv[1] == "DOCKER"
DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG"
@@ -15,6 +20,14 @@ db = SQLite('db/firegex.db')
conf = KeyValueStorage(db)
firewall = ProxyManager(db)
JWT_ALGORITHM="HS256"
JWT_SECRET = secrets.token_hex(32)
APP_STATUS = "init"
REACT_BUILD_DIR = "../frontend/build/" if not ON_DOCKER else "frontend/"
REACT_HTML_PATH = os.path.join(REACT_BUILD_DIR,"index.html")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/login", auto_error=False)
crypto = CryptContext(schemes=["bcrypt"], deprecated="auto")
app = FastAPI(debug=DEBUG)
@app.on_event("shutdown")
@@ -56,28 +69,38 @@ async def startup_event():
await firewall.reload()
app.add_middleware(SessionMiddleware, secret_key=os.urandom(32))
SESSION_TOKEN = secrets.token_hex(8)
APP_STATUS = "init"
REACT_BUILD_DIR = "../frontend/build/" if not ON_DOCKER else "frontend/"
REACT_HTML_PATH = os.path.join(REACT_BUILD_DIR,"index.html")
def create_access_token(data: dict):
global JWT_SECRET
to_encode = data.copy()
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
return encoded_jwt
async def check_login(token: str = Depends(oauth2_scheme)):
global JWT_SECRET
if not token:
return False
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
logged_in: bool = payload.get("logged_in")
except JWTError:
return False
return logged_in
def is_loggined(request: Request):
global SESSION_TOKEN
return request.session.get("token", "") == SESSION_TOKEN
def login_check(request: Request):
if is_loggined(request): return True
raise HTTPException(status_code=401, detail="Invalid login session!")
async def is_loggined(auth: bool = Depends(check_login)):
if not auth:
raise HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return True
@app.get("/api/status")
async def get_status(request: Request):
async def get_status(auth: bool = Depends(check_login)):
global APP_STATUS
return {
"status":APP_STATUS,
"loggined": is_loggined(request)
"loggined": auth
}
class PasswordForm(BaseModel):
@@ -88,58 +111,51 @@ class PasswordChangeForm(BaseModel):
expire: bool
@app.post("/api/login")
async def login_api(request: Request, form: PasswordForm):
global APP_STATUS, SESSION_TOKEN
async def login_api(form: OAuth2PasswordRequestForm = Depends()):
global APP_STATUS, JWT_SECRET
if APP_STATUS != "run": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
await asyncio.sleep(0.3) # No bruteforce :)
if crypto.verify(form.password, conf.get("password")):
print("access granted, good job")
return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"}
raise HTTPException(406,"Wrong password!")
if bcrypt.checkpw(form.password.encode(), conf.get("password").encode()):
request.session["token"] = SESSION_TOKEN
return { "status":"ok" }
return {"status":"Wrong password!"}
@app.get("/api/logout")
async def logout(request: Request):
request.session["token"] = False
return { "status":"ok" }
@app.post('/api/change-password')
async def change_password(request: Request, form: PasswordChangeForm):
login_check(request)
global APP_STATUS, SESSION_TOKEN
async def change_password(form: PasswordChangeForm, auth: bool = Depends(is_loggined)):
global APP_STATUS, JWT_SECRET
if APP_STATUS != "run": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
if form.expire:
SESSION_TOKEN = secrets.token_hex(8)
request.session["token"] = SESSION_TOKEN
JWT_SECRET = secrets.token_hex(32)
hash_psw = bcrypt.hashpw(form.password.encode(), bcrypt.gensalt())
conf.put("password",hash_psw.decode())
return {"status":"ok"}
hash_psw = crypto.hash(form.password)
conf.put("password",hash_psw)
return {"status":"ok", "access_token": create_access_token({"logged_in": True})}
@app.post('/api/set-password')
async def set_password(request: Request, form: PasswordForm):
global APP_STATUS, SESSION_TOKEN
async def set_password(form: PasswordForm):
global APP_STATUS, JWT_SECRET
if APP_STATUS != "init": raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
hash_psw = bcrypt.hashpw(form.password.encode(), bcrypt.gensalt())
conf.put("password",hash_psw.decode())
hash_psw = crypto.hash(form.password)
conf.put("password",hash_psw)
APP_STATUS = "run"
request.session["token"] = SESSION_TOKEN
return {"status":"ok"}
return {"status":"ok", "access_token": create_access_token({"logged_in": True})}
@app.get('/api/general-stats')
async def get_general_stats(request: Request):
login_check(request)
async def get_general_stats(auth: bool = Depends(is_loggined)):
return db.query("""
SELECT
(SELECT COALESCE(SUM(blocked_packets),0) FROM regexes) closed,
@@ -148,8 +164,8 @@ async def get_general_stats(request: Request):
""")[0]
@app.get('/api/services')
async def get_services(request: Request):
login_check(request)
async def get_services(auth: bool = Depends(is_loggined)):
return db.query("""
SELECT
s.service_id `id`,
@@ -164,8 +180,8 @@ async def get_services(request: Request):
""")
@app.get('/api/service/{service_id}')
async def get_service(request: Request, service_id: str):
login_check(request)
async def get_service(service_id: str, auth: bool = Depends(is_loggined)):
res = db.query("""
SELECT
s.service_id `id`,
@@ -182,26 +198,26 @@ async def get_service(request: Request, service_id: str):
return res[0]
@app.get('/api/service/{service_id}/stop')
async def get_service_stop(request: Request, service_id: str):
login_check(request)
async def get_service_stop(service_id: str, auth: bool = Depends(is_loggined)):
await firewall.get(service_id).next(STATUS.STOP)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/pause')
async def get_service_pause(request: Request, service_id: str):
login_check(request)
async def get_service_pause(service_id: str, auth: bool = Depends(is_loggined)):
await firewall.get(service_id).next(STATUS.PAUSE)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/start')
async def get_service_start(request: Request, service_id: str):
login_check(request)
async def get_service_start(service_id: str, auth: bool = Depends(is_loggined)):
await firewall.get(service_id).next(STATUS.ACTIVE)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/delete')
async def get_service_delete(request: Request, service_id: str):
login_check(request)
async def get_service_delete(service_id: str, auth: bool = Depends(is_loggined)):
db.query('DELETE FROM services WHERE service_id = ?;', service_id)
db.query('DELETE FROM regexes WHERE service_id = ?;', service_id)
await firewall.remove(service_id)
@@ -209,16 +225,16 @@ async def get_service_delete(request: Request, service_id: str):
@app.get('/api/service/{service_id}/regen-port')
async def get_regen_port(request: Request, service_id: str):
login_check(request)
async def get_regen_port(service_id: str, auth: bool = Depends(is_loggined)):
db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', gen_internal_port(db), service_id)
await firewall.get(service_id).update_port()
return {'status': 'ok'}
@app.get('/api/service/{service_id}/regexes')
async def get_service_regexes(request: Request, service_id: str):
login_check(request)
async def get_service_regexes(service_id: str, auth: bool = Depends(is_loggined)):
return db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
@@ -227,8 +243,8 @@ async def get_service_regexes(request: Request, service_id: str):
""", service_id)
@app.get('/api/regex/{regex_id}')
async def get_regex_id(request: Request, regex_id: int):
login_check(request)
async def get_regex_id(regex_id: int, auth: bool = Depends(is_loggined)):
res = db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
@@ -239,8 +255,8 @@ async def get_regex_id(request: Request, regex_id: int):
return res[0]
@app.get('/api/regex/{regex_id}/delete')
async def get_regex_delete(request: Request, regex_id: int):
login_check(request)
async def get_regex_delete(regex_id: int, auth: bool = Depends(is_loggined)):
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
@@ -257,8 +273,8 @@ class RegexAddForm(BaseModel):
is_case_sensitive: bool
@app.post('/api/regexes/add')
async def post_regexes_add(request: Request, form: RegexAddForm):
login_check(request)
async def post_regexes_add(form: RegexAddForm, auth: bool = Depends(is_loggined)):
try:
re.compile(b64decode(form.regex))
except Exception:
@@ -277,8 +293,8 @@ class ServiceAddForm(BaseModel):
port: int
@app.post('/api/services/add')
async def post_services_add(request: Request, form: ServiceAddForm):
login_check(request)
async def post_services_add(form: ServiceAddForm, auth: bool = Depends(is_loggined)):
serv_id = from_name_get_id(form.name)
try:
db.query("INSERT INTO services (name, service_id, internal_port, public_port, status) VALUES (?, ?, ?, ?, ?)",
@@ -323,7 +339,7 @@ if DEBUG:
await asyncio.gather(fwd_task, rev_task)
@app.get("/{full_path:path}")
async def catch_all(request: Request, full_path:str):
async def catch_all(full_path:str):
if DEBUG:
try:
return await frontend_debug_proxy(full_path)
@@ -339,5 +355,6 @@ if __name__ == '__main__':
host="0.0.0.0",
port=int(os.getenv("PORT","4444")),
reload=DEBUG,
access_log=DEBUG
access_log=DEBUG,
workers=2
)

View File

@@ -1,6 +1,4 @@
import subprocess, re, os, asyncio
#c++ -o proxy proxy.cpp
import re, os, asyncio
class Filter:
def __init__(self, regex, is_case_sensitive=True, is_blacklist=True, c_to_s=False, s_to_c=False, blocked_packets=0, code=None):

View File

@@ -1,5 +1,5 @@
fastapi[all]
httpx
uvicorn[standard]
bcrypt
kthread
passlib[bcrypt]
python-jose[cryptography]

View File

@@ -182,14 +182,17 @@ class ServiceManager:
def __proxy_starter(self,to):
async def func():
while True:
if check_port_is_open(self.proxy.public_port):
self._set_status(to)
await self.proxy.start(in_pause=(to==STATUS.PAUSE))
self._set_status(STATUS.STOP)
return
else:
await asyncio.sleep(.5)
try:
while True:
if check_port_is_open(self.proxy.public_port):
self._set_status(to)
await self.proxy.start(in_pause=(to==STATUS.PAUSE))
self._set_status(STATUS.STOP)
return
else:
await asyncio.sleep(.5)
except Exception:
await self.proxy.stop()
self.starter = asyncio.create_task(func())
class ProxyManager: