removed fastapi_socketio + general improves
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,6 +29,7 @@
|
||||
/firegex-compose-tmp-file.yml
|
||||
/firegex.py
|
||||
/tests/benchmark.csv
|
||||
/backend/modules/nfproxy/socks/
|
||||
# misc
|
||||
**/.DS_Store
|
||||
**/.env.local
|
||||
|
||||
@@ -27,7 +27,7 @@ RUN pip3 install --no-cache-dir --break-system-packages -r /execute/requirements
|
||||
|
||||
COPY ./backend/binsrc /execute/binsrc
|
||||
RUN g++ binsrc/nfqueue.cpp -o modules/cppqueue -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libhs libmnl)
|
||||
#RUN g++ binsrc/nfproxy-tun.cpp -o modules/cppnfproxy -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libmnl)
|
||||
#RUN g++ binsrc/nfproxy-tun.cpp -o modules/cppproxy -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libmnl)
|
||||
|
||||
COPY ./backend/ /execute/
|
||||
COPY --from=frontend /app/dist/ ./frontend/
|
||||
|
||||
@@ -8,13 +8,13 @@ from fastapi import FastAPI, HTTPException, Depends, APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from fastapi_socketio import SocketManager
|
||||
from utils.sqlite import SQLite
|
||||
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager
|
||||
from utils.loader import frontend_deploy, load_routers
|
||||
from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import socketio
|
||||
|
||||
# DB init
|
||||
db = SQLite('db/firegex.db')
|
||||
@@ -42,7 +42,14 @@ app = FastAPI(
|
||||
title="Firegex API",
|
||||
version=API_VERSION,
|
||||
)
|
||||
utils.socketio = SocketManager(app, "/sock", socketio_path="")
|
||||
utils.socketio = socketio.AsyncServer(
|
||||
async_mode="asgi",
|
||||
cors_allowed_origins=[],
|
||||
transports=["websocket"]
|
||||
)
|
||||
|
||||
sio_app = socketio.ASGIApp(utils.socketio, socketio_path="/sock/socket.io", other_asgi_app=app)
|
||||
app.mount("/sock", sio_app)
|
||||
|
||||
if DEBUG:
|
||||
app.add_middleware(
|
||||
|
||||
@@ -1,97 +1,61 @@
|
||||
from modules.nfregex.nftables import FiregexTables
|
||||
from modules.nfproxy.nftables import FiregexTables
|
||||
from utils import run_func
|
||||
from modules.nfregex.models import Service, Regex
|
||||
import re
|
||||
from modules.nfproxy.models import Service, PyFilter
|
||||
import os
|
||||
import asyncio
|
||||
import traceback
|
||||
from utils import DEBUG
|
||||
from fastapi import HTTPException
|
||||
|
||||
#TODO copied file, review
|
||||
import socket
|
||||
import shutil
|
||||
|
||||
nft = FiregexTables()
|
||||
|
||||
class RegexFilter:
|
||||
def __init__(
|
||||
self, regex,
|
||||
is_case_sensitive=True,
|
||||
input_mode=False,
|
||||
output_mode=False,
|
||||
blocked_packets=0,
|
||||
id=None,
|
||||
update_func = None
|
||||
):
|
||||
self.regex = regex
|
||||
self.is_case_sensitive = is_case_sensitive
|
||||
if input_mode == output_mode:
|
||||
input_mode = output_mode = True # (False, False) == (True, True)
|
||||
self.input_mode = input_mode
|
||||
self.output_mode = output_mode
|
||||
self.blocked = blocked_packets
|
||||
self.id = id
|
||||
self.update_func = update_func
|
||||
self.compiled_regex = self.compile()
|
||||
|
||||
@classmethod
|
||||
def from_regex(cls, regex:Regex, update_func = None):
|
||||
return cls(
|
||||
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
|
||||
blocked_packets=regex.blocked_packets,
|
||||
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"],
|
||||
update_func = update_func
|
||||
)
|
||||
def compile(self):
|
||||
if isinstance(self.regex, str):
|
||||
self.regex = self.regex.encode()
|
||||
if not isinstance(self.regex, bytes):
|
||||
raise Exception("Invalid Regex Paramether")
|
||||
re.compile(self.regex) # raise re.error if it's invalid!
|
||||
case_sensitive = "1" if self.is_case_sensitive else "0"
|
||||
if self.input_mode:
|
||||
yield case_sensitive + "C" + self.regex.hex()
|
||||
if self.output_mode:
|
||||
yield case_sensitive + "S" + self.regex.hex()
|
||||
|
||||
async def update(self):
|
||||
if self.update_func:
|
||||
await run_func(self.update_func, self)
|
||||
|
||||
class FiregexInterceptor:
|
||||
|
||||
def __init__(self):
|
||||
self.srv:Service
|
||||
self.filter_map_lock:asyncio.Lock
|
||||
self.filter_map: dict[str, RegexFilter]
|
||||
self.regex_filters: set[RegexFilter]
|
||||
self.update_config_lock:asyncio.Lock
|
||||
self._stats_updater_cb:callable
|
||||
self.process:asyncio.subprocess.Process
|
||||
self.update_task: asyncio.Task
|
||||
self.ack_arrived = False
|
||||
self.ack_status = None
|
||||
self.ack_fail_what = ""
|
||||
self.ack_lock = asyncio.Lock()
|
||||
self.base_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"socks", self.srv.id
|
||||
)
|
||||
self.n_threads = int(os.getenv("NTHREADS","1"))
|
||||
|
||||
self.connection_socket = os.path.join(self.base_dir, "connection.sock")
|
||||
self.vedict_sockets = [os.path.join(self.base_dir, f"vedict{i}.sock") for i in range(self.n_threads)]
|
||||
self.socks = []
|
||||
|
||||
def add_sock(self, path):
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
||||
sock.bind(path)
|
||||
self.socks.append(sock)
|
||||
return sock
|
||||
|
||||
async def _call_stats_updater_callback(self, filter: PyFilter):
|
||||
if self._stats_updater_cb:
|
||||
await run_func(self._stats_updater_cb(filter))
|
||||
|
||||
@classmethod
|
||||
async def start(cls, srv: Service):
|
||||
async def start(cls, srv: Service, stats_updater_cb:callable):
|
||||
self = cls()
|
||||
self.srv = srv
|
||||
self.filter_map_lock = asyncio.Lock()
|
||||
self.update_config_lock = asyncio.Lock()
|
||||
self._stats_updater_cb = stats_updater_cb
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
self.add_sock(self.connection_socket)
|
||||
for path in self.vedict_sockets:
|
||||
self.add_sock(path)
|
||||
queue_range = await self._start_binary()
|
||||
self.update_task = asyncio.create_task(self.update_blocked())
|
||||
# TODO starts python workers
|
||||
nft.add(self.srv, queue_range)
|
||||
if not self.ack_lock.locked():
|
||||
await self.ack_lock.acquire()
|
||||
return self
|
||||
|
||||
async def _start_binary(self):
|
||||
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppqueue")
|
||||
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppproxy")
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
proxy_binary_path,
|
||||
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
|
||||
env={"MATCH_MODE": "stream" if self.srv.proto == "tcp" else "block", "NTHREADS": os.getenv("NTHREADS","1")},
|
||||
)
|
||||
self.process.stdin.write(self.base_dir.encode().hex().encode()+b" 3\n")
|
||||
await self.process.stdin.drain()
|
||||
line_fut = self.process.stdout.readuntil()
|
||||
try:
|
||||
line_fut = await asyncio.wait_for(line_fut, timeout=3)
|
||||
@@ -106,68 +70,14 @@ class FiregexInterceptor:
|
||||
self.process.kill()
|
||||
raise Exception("Invalid binary output")
|
||||
|
||||
async def update_blocked(self):
|
||||
try:
|
||||
while True:
|
||||
line = (await self.process.stdout.readuntil()).decode()
|
||||
if DEBUG:
|
||||
print(line)
|
||||
if line.startswith("BLOCKED "):
|
||||
regex_id = line.split()[1]
|
||||
async with self.filter_map_lock:
|
||||
if regex_id in self.filter_map:
|
||||
self.filter_map[regex_id].blocked+=1
|
||||
await self.filter_map[regex_id].update()
|
||||
if line.startswith("ACK "):
|
||||
self.ack_arrived = True
|
||||
self.ack_status = line.split()[1].upper() == "OK"
|
||||
if not self.ack_status:
|
||||
self.ack_fail_what = " ".join(line.split()[2:])
|
||||
self.ack_lock.release()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.IncompleteReadError:
|
||||
pass
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
async def stop(self):
|
||||
self.update_task.cancel()
|
||||
if self.process and self.process.returncode is None:
|
||||
self.process.kill()
|
||||
for sock in self.socks:
|
||||
sock.close()
|
||||
shutil.rmtree(self.base_dir)
|
||||
|
||||
async def _update_config(self, filters_codes):
|
||||
async with self.update_config_lock:
|
||||
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
|
||||
await self.process.stdin.drain()
|
||||
try:
|
||||
async with asyncio.timeout(3):
|
||||
await self.ack_lock.acquire()
|
||||
except TimeoutError:
|
||||
pass
|
||||
if not self.ack_arrived or not self.ack_status:
|
||||
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")
|
||||
|
||||
|
||||
async def reload(self, filters:list[RegexFilter]):
|
||||
async with self.filter_map_lock:
|
||||
self.filter_map = self.compile_filters(filters)
|
||||
filters_codes = self.get_filter_codes()
|
||||
await self._update_config(filters_codes)
|
||||
|
||||
def get_filter_codes(self):
|
||||
filters_codes = list(self.filter_map.keys())
|
||||
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
|
||||
return filters_codes
|
||||
|
||||
def compile_filters(self, filters:list[RegexFilter]):
|
||||
res = {}
|
||||
for filter_obj in filters:
|
||||
try:
|
||||
raw_filters = filter_obj.compile()
|
||||
for filter in raw_filters:
|
||||
res[filter] = filter_obj
|
||||
except Exception:
|
||||
pass
|
||||
return res
|
||||
|
||||
async def reload(self, filters:list[PyFilter]):
|
||||
# filters are the functions to use in the workers (other functions are disabled or not flagged as filters)
|
||||
# TODO update filters in python workers (prob for new filters added) (reading from file????)
|
||||
pass
|
||||
@@ -1,18 +1,15 @@
|
||||
import asyncio
|
||||
from modules.nfregex.firegex import FiregexInterceptor, RegexFilter
|
||||
from modules.nfregex.nftables import FiregexTables, FiregexFilter
|
||||
from modules.nfregex.models import Regex, Service
|
||||
from modules.nfproxy.firegex import FiregexInterceptor
|
||||
from modules.nfproxy.nftables import FiregexTables, FiregexFilter
|
||||
from modules.nfproxy.models import Service, PyFilter
|
||||
from utils.sqlite import SQLite
|
||||
|
||||
#TODO copied file, review
|
||||
|
||||
class STATUS:
|
||||
STOP = "stop"
|
||||
ACTIVE = "active"
|
||||
|
||||
nft = FiregexTables()
|
||||
|
||||
|
||||
class ServiceManager:
|
||||
def __init__(self, srv: Service, db):
|
||||
self.srv = srv
|
||||
@@ -23,13 +20,13 @@ class ServiceManager:
|
||||
self.interceptor = None
|
||||
|
||||
async def _update_filters_from_db(self):
|
||||
regexes = [
|
||||
Regex.from_dict(ele) for ele in
|
||||
self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id)
|
||||
pyfilters = [
|
||||
PyFilter.from_dict(ele) for ele in
|
||||
self.db.query("SELECT * FROM pyfilter WHERE service_id = ? AND active=1;", self.srv.id)
|
||||
]
|
||||
#Filter check
|
||||
old_filters = set(self.filters.keys())
|
||||
new_filters = set([f.id for f in regexes])
|
||||
new_filters = set([f.id for f in pyfilters])
|
||||
#remove old filters
|
||||
for f in old_filters:
|
||||
if f not in new_filters:
|
||||
@@ -37,8 +34,7 @@ class ServiceManager:
|
||||
#add new filters
|
||||
for f in new_filters:
|
||||
if f not in old_filters:
|
||||
filter = [ele for ele in regexes if ele.id == f][0]
|
||||
self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater)
|
||||
self.filters[f] = [ele for ele in pyfilters if ele.id == f][0]
|
||||
if self.interceptor:
|
||||
await self.interceptor.reload(self.filters.values())
|
||||
|
||||
@@ -54,8 +50,8 @@ class ServiceManager:
|
||||
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
|
||||
await self.restart()
|
||||
|
||||
def _stats_updater(self,filter:RegexFilter):
|
||||
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id)
|
||||
def _stats_updater(self,filter:PyFilter):
|
||||
self.db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE filter_id = ?;", filter.blocked_packets, filter.edited_packets, filter.id)
|
||||
|
||||
def _set_status(self,status):
|
||||
self.status = status
|
||||
@@ -64,7 +60,7 @@ class ServiceManager:
|
||||
async def start(self):
|
||||
if not self.interceptor:
|
||||
nft.delete(self.srv)
|
||||
self.interceptor = await FiregexInterceptor.start(self.srv)
|
||||
self.interceptor = await FiregexInterceptor.start(self.srv, self._stats_updater)
|
||||
await self._update_filters_from_db()
|
||||
self._set_status(STATUS.ACTIVE)
|
||||
|
||||
@@ -119,3 +115,5 @@ class FirewallManager:
|
||||
|
||||
class ServiceNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
class Service:
|
||||
def __init__(self, service_id: str, status: str, port: int, name: str, proto: str, ip_int: str, **other):
|
||||
self.id = service_id
|
||||
@@ -14,7 +15,7 @@ class Service:
|
||||
|
||||
class PyFilter:
|
||||
def __init__(self, filter_id:int, name: str, blocked_packets: int, edited_packets: int, active: bool, **other):
|
||||
self.filter_id = filter_id
|
||||
self.id = filter_id
|
||||
self.name = name
|
||||
self.blocked_packets = blocked_packets
|
||||
self.edited_packets = edited_packets
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from modules.nfregex.models import Service
|
||||
from modules.nfproxy.models import Service
|
||||
from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json
|
||||
|
||||
class FiregexFilter:
|
||||
@@ -48,10 +48,12 @@ class FiregexTables(NFTableManager):
|
||||
def add(self, srv:Service, queue_range):
|
||||
|
||||
for ele in self.get():
|
||||
if ele.__eq__(srv): return
|
||||
if ele.__eq__(srv):
|
||||
return
|
||||
|
||||
init, end = queue_range
|
||||
if init > end: init, end = end, init
|
||||
if init > end:
|
||||
init, end = end, init
|
||||
self.cmd(
|
||||
{ "insert":{ "rule": {
|
||||
"family": "inet",
|
||||
|
||||
@@ -4,5 +4,5 @@ uvicorn[standard]
|
||||
passlib[bcrypt]
|
||||
psutil
|
||||
python-jose[cryptography]
|
||||
fastapi-socketio
|
||||
python-socketio
|
||||
#git+https://salsa.debian.org/pkg-netfilter-team/pkg-nftables#egg=nftables&subdirectory=py
|
||||
|
||||
@@ -5,13 +5,13 @@ import socket
|
||||
import psutil
|
||||
import sys
|
||||
import nftables
|
||||
from fastapi_socketio import SocketManager
|
||||
from socketio import AsyncServer
|
||||
from fastapi import Path
|
||||
from typing import Annotated
|
||||
|
||||
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
|
||||
|
||||
socketio:SocketManager = None
|
||||
socketio:AsyncServer = None
|
||||
|
||||
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
ROUTERS_DIR = os.path.join(ROOT_DIR,"routers")
|
||||
|
||||
@@ -14,7 +14,7 @@ import { Firewall } from './pages/Firewall';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
|
||||
|
||||
const socket = IS_DEV?io("ws://"+DEV_IP_BACKEND, {transports: ["websocket", "polling"], path:"/sock" }):io({transports: ["websocket", "polling"], path:"/sock"});
|
||||
const socket = IS_DEV?io("ws://"+DEV_IP_BACKEND, {transports: ["websocket"], path:"/sock/socket.io" }):io({transports: ["websocket"], path:"/sock/socket.io"});
|
||||
|
||||
function App() {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user