removed fastapi_socketio + general improves

This commit is contained in:
Domingo Dirutigliano
2025-02-12 01:16:10 +01:00
parent f3ba6dc716
commit 2fb77a348f
10 changed files with 76 additions and 157 deletions

1
.gitignore vendored
View File

@@ -29,6 +29,7 @@
/firegex-compose-tmp-file.yml /firegex-compose-tmp-file.yml
/firegex.py /firegex.py
/tests/benchmark.csv /tests/benchmark.csv
/backend/modules/nfproxy/socks/
# misc # misc
**/.DS_Store **/.DS_Store
**/.env.local **/.env.local

View File

@@ -27,7 +27,7 @@ RUN pip3 install --no-cache-dir --break-system-packages -r /execute/requirements
COPY ./backend/binsrc /execute/binsrc COPY ./backend/binsrc /execute/binsrc
RUN g++ binsrc/nfqueue.cpp -o modules/cppqueue -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libhs libmnl) RUN g++ binsrc/nfqueue.cpp -o modules/cppqueue -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libhs libmnl)
#RUN g++ binsrc/nfproxy-tun.cpp -o modules/cppnfproxy -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libmnl) #RUN g++ binsrc/nfproxy-tun.cpp -o modules/cppproxy -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libmnl)
COPY ./backend/ /execute/ COPY ./backend/ /execute/
COPY --from=frontend /app/dist/ ./frontend/ COPY --from=frontend /app/dist/ ./frontend/

View File

@@ -8,13 +8,13 @@ from fastapi import FastAPI, HTTPException, Depends, APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt from jose import jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from fastapi_socketio import SocketManager
from utils.sqlite import SQLite from utils.sqlite import SQLite
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager
from utils.loader import frontend_deploy, load_routers from utils.loader import frontend_deploy, load_routers
from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import socketio
# DB init # DB init
db = SQLite('db/firegex.db') db = SQLite('db/firegex.db')
@@ -42,7 +42,14 @@ app = FastAPI(
title="Firegex API", title="Firegex API",
version=API_VERSION, version=API_VERSION,
) )
utils.socketio = SocketManager(app, "/sock", socketio_path="") utils.socketio = socketio.AsyncServer(
async_mode="asgi",
cors_allowed_origins=[],
transports=["websocket"]
)
sio_app = socketio.ASGIApp(utils.socketio, socketio_path="/sock/socket.io", other_asgi_app=app)
app.mount("/sock", sio_app)
if DEBUG: if DEBUG:
app.add_middleware( app.add_middleware(

View File

@@ -1,97 +1,61 @@
from modules.nfregex.nftables import FiregexTables from modules.nfproxy.nftables import FiregexTables
from utils import run_func from utils import run_func
from modules.nfregex.models import Service, Regex from modules.nfproxy.models import Service, PyFilter
import re
import os import os
import asyncio import asyncio
import traceback import socket
from utils import DEBUG import shutil
from fastapi import HTTPException
#TODO copied file, review
nft = FiregexTables() nft = FiregexTables()
class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
id=None,
update_func = None
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
if input_mode == output_mode:
input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
self.id = id
self.update_func = update_func
self.compiled_regex = self.compile()
@classmethod
def from_regex(cls, regex:Regex, update_func = None):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"],
update_func = update_func
)
def compile(self):
if isinstance(self.regex, str):
self.regex = self.regex.encode()
if not isinstance(self.regex, bytes):
raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if it's invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.input_mode:
yield case_sensitive + "C" + self.regex.hex()
if self.output_mode:
yield case_sensitive + "S" + self.regex.hex()
async def update(self):
if self.update_func:
await run_func(self.update_func, self)
class FiregexInterceptor: class FiregexInterceptor:
def __init__(self): def __init__(self):
self.srv:Service self.srv:Service
self.filter_map_lock:asyncio.Lock self._stats_updater_cb:callable
self.filter_map: dict[str, RegexFilter]
self.regex_filters: set[RegexFilter]
self.update_config_lock:asyncio.Lock
self.process:asyncio.subprocess.Process self.process:asyncio.subprocess.Process
self.update_task: asyncio.Task self.base_dir = os.path.join(
self.ack_arrived = False os.path.dirname(os.path.abspath(__file__)),
self.ack_status = None "socks", self.srv.id
self.ack_fail_what = "" )
self.ack_lock = asyncio.Lock() self.n_threads = int(os.getenv("NTHREADS","1"))
self.connection_socket = os.path.join(self.base_dir, "connection.sock")
self.vedict_sockets = [os.path.join(self.base_dir, f"vedict{i}.sock") for i in range(self.n_threads)]
self.socks = []
def add_sock(self, path):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
sock.bind(path)
self.socks.append(sock)
return sock
async def _call_stats_updater_callback(self, filter: PyFilter):
if self._stats_updater_cb:
await run_func(self._stats_updater_cb(filter))
@classmethod @classmethod
async def start(cls, srv: Service): async def start(cls, srv: Service, stats_updater_cb:callable):
self = cls() self = cls()
self.srv = srv self.srv = srv
self.filter_map_lock = asyncio.Lock() self._stats_updater_cb = stats_updater_cb
self.update_config_lock = asyncio.Lock() os.makedirs(self.base_dir, exist_ok=True)
self.add_sock(self.connection_socket)
for path in self.vedict_sockets:
self.add_sock(path)
queue_range = await self._start_binary() queue_range = await self._start_binary()
self.update_task = asyncio.create_task(self.update_blocked()) # TODO starts python workers
nft.add(self.srv, queue_range) nft.add(self.srv, queue_range)
if not self.ack_lock.locked():
await self.ack_lock.acquire()
return self return self
async def _start_binary(self): async def _start_binary(self):
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppqueue") proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppproxy")
self.process = await asyncio.create_subprocess_exec( self.process = await asyncio.create_subprocess_exec(
proxy_binary_path, proxy_binary_path,
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
env={"MATCH_MODE": "stream" if self.srv.proto == "tcp" else "block", "NTHREADS": os.getenv("NTHREADS","1")},
) )
self.process.stdin.write(self.base_dir.encode().hex().encode()+b" 3\n")
await self.process.stdin.drain()
line_fut = self.process.stdout.readuntil() line_fut = self.process.stdout.readuntil()
try: try:
line_fut = await asyncio.wait_for(line_fut, timeout=3) line_fut = await asyncio.wait_for(line_fut, timeout=3)
@@ -106,68 +70,14 @@ class FiregexInterceptor:
self.process.kill() self.process.kill()
raise Exception("Invalid binary output") raise Exception("Invalid binary output")
async def update_blocked(self):
try:
while True:
line = (await self.process.stdout.readuntil()).decode()
if DEBUG:
print(line)
if line.startswith("BLOCKED "):
regex_id = line.split()[1]
async with self.filter_map_lock:
if regex_id in self.filter_map:
self.filter_map[regex_id].blocked+=1
await self.filter_map[regex_id].update()
if line.startswith("ACK "):
self.ack_arrived = True
self.ack_status = line.split()[1].upper() == "OK"
if not self.ack_status:
self.ack_fail_what = " ".join(line.split()[2:])
self.ack_lock.release()
except asyncio.CancelledError:
pass
except asyncio.IncompleteReadError:
pass
except Exception:
traceback.print_exc()
async def stop(self): async def stop(self):
self.update_task.cancel()
if self.process and self.process.returncode is None: if self.process and self.process.returncode is None:
self.process.kill() self.process.kill()
for sock in self.socks:
async def _update_config(self, filters_codes): sock.close()
async with self.update_config_lock: shutil.rmtree(self.base_dir)
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain()
try:
async with asyncio.timeout(3):
await self.ack_lock.acquire()
except TimeoutError:
pass
if not self.ack_arrived or not self.ack_status:
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")
async def reload(self, filters:list[RegexFilter]):
async with self.filter_map_lock:
self.filter_map = self.compile_filters(filters)
filters_codes = self.get_filter_codes()
await self._update_config(filters_codes)
def get_filter_codes(self):
filters_codes = list(self.filter_map.keys())
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
return filters_codes
def compile_filters(self, filters:list[RegexFilter]):
res = {}
for filter_obj in filters:
try:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception:
pass
return res
async def reload(self, filters:list[PyFilter]):
# filters are the functions to use in the workers (other functions are disabled or not flagged as filters)
# TODO update filters in python workers (prob for new filters added) (reading from file????)
pass

View File

@@ -1,18 +1,15 @@
import asyncio import asyncio
from modules.nfregex.firegex import FiregexInterceptor, RegexFilter from modules.nfproxy.firegex import FiregexInterceptor
from modules.nfregex.nftables import FiregexTables, FiregexFilter from modules.nfproxy.nftables import FiregexTables, FiregexFilter
from modules.nfregex.models import Regex, Service from modules.nfproxy.models import Service, PyFilter
from utils.sqlite import SQLite from utils.sqlite import SQLite
#TODO copied file, review
class STATUS: class STATUS:
STOP = "stop" STOP = "stop"
ACTIVE = "active" ACTIVE = "active"
nft = FiregexTables() nft = FiregexTables()
class ServiceManager: class ServiceManager:
def __init__(self, srv: Service, db): def __init__(self, srv: Service, db):
self.srv = srv self.srv = srv
@@ -23,13 +20,13 @@ class ServiceManager:
self.interceptor = None self.interceptor = None
async def _update_filters_from_db(self): async def _update_filters_from_db(self):
regexes = [ pyfilters = [
Regex.from_dict(ele) for ele in PyFilter.from_dict(ele) for ele in
self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id) self.db.query("SELECT * FROM pyfilter WHERE service_id = ? AND active=1;", self.srv.id)
] ]
#Filter check #Filter check
old_filters = set(self.filters.keys()) old_filters = set(self.filters.keys())
new_filters = set([f.id for f in regexes]) new_filters = set([f.id for f in pyfilters])
#remove old filters #remove old filters
for f in old_filters: for f in old_filters:
if f not in new_filters: if f not in new_filters:
@@ -37,8 +34,7 @@ class ServiceManager:
#add new filters #add new filters
for f in new_filters: for f in new_filters:
if f not in old_filters: if f not in old_filters:
filter = [ele for ele in regexes if ele.id == f][0] self.filters[f] = [ele for ele in pyfilters if ele.id == f][0]
self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater)
if self.interceptor: if self.interceptor:
await self.interceptor.reload(self.filters.values()) await self.interceptor.reload(self.filters.values())
@@ -54,8 +50,8 @@ class ServiceManager:
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE): elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
await self.restart() await self.restart()
def _stats_updater(self,filter:RegexFilter): def _stats_updater(self,filter:PyFilter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id) self.db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE filter_id = ?;", filter.blocked_packets, filter.edited_packets, filter.id)
def _set_status(self,status): def _set_status(self,status):
self.status = status self.status = status
@@ -64,7 +60,7 @@ class ServiceManager:
async def start(self): async def start(self):
if not self.interceptor: if not self.interceptor:
nft.delete(self.srv) nft.delete(self.srv)
self.interceptor = await FiregexInterceptor.start(self.srv) self.interceptor = await FiregexInterceptor.start(self.srv, self._stats_updater)
await self._update_filters_from_db() await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE) self._set_status(STATUS.ACTIVE)
@@ -119,3 +115,5 @@ class FirewallManager:
class ServiceNotFoundException(Exception): class ServiceNotFoundException(Exception):
pass pass

View File

@@ -1,3 +1,4 @@
class Service: class Service:
def __init__(self, service_id: str, status: str, port: int, name: str, proto: str, ip_int: str, **other): def __init__(self, service_id: str, status: str, port: int, name: str, proto: str, ip_int: str, **other):
self.id = service_id self.id = service_id
@@ -14,7 +15,7 @@ class Service:
class PyFilter: class PyFilter:
def __init__(self, filter_id:int, name: str, blocked_packets: int, edited_packets: int, active: bool, **other): def __init__(self, filter_id:int, name: str, blocked_packets: int, edited_packets: int, active: bool, **other):
self.filter_id = filter_id self.id = filter_id
self.name = name self.name = name
self.blocked_packets = blocked_packets self.blocked_packets = blocked_packets
self.edited_packets = edited_packets self.edited_packets = edited_packets

View File

@@ -1,4 +1,4 @@
from modules.nfregex.models import Service from modules.nfproxy.models import Service
from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json
class FiregexFilter: class FiregexFilter:
@@ -48,10 +48,12 @@ class FiregexTables(NFTableManager):
def add(self, srv:Service, queue_range): def add(self, srv:Service, queue_range):
for ele in self.get(): for ele in self.get():
if ele.__eq__(srv): return if ele.__eq__(srv):
return
init, end = queue_range init, end = queue_range
if init > end: init, end = end, init if init > end:
init, end = end, init
self.cmd( self.cmd(
{ "insert":{ "rule": { { "insert":{ "rule": {
"family": "inet", "family": "inet",

View File

@@ -4,5 +4,5 @@ uvicorn[standard]
passlib[bcrypt] passlib[bcrypt]
psutil psutil
python-jose[cryptography] python-jose[cryptography]
fastapi-socketio python-socketio
#git+https://salsa.debian.org/pkg-netfilter-team/pkg-nftables#egg=nftables&subdirectory=py #git+https://salsa.debian.org/pkg-netfilter-team/pkg-nftables#egg=nftables&subdirectory=py

View File

@@ -5,13 +5,13 @@ import socket
import psutil import psutil
import sys import sys
import nftables import nftables
from fastapi_socketio import SocketManager from socketio import AsyncServer
from fastapi import Path from fastapi import Path
from typing import Annotated from typing import Annotated
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1")) LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
socketio:SocketManager = None socketio:AsyncServer = None
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
ROUTERS_DIR = os.path.join(ROOT_DIR,"routers") ROUTERS_DIR = os.path.join(ROOT_DIR,"routers")

View File

@@ -14,7 +14,7 @@ import { Firewall } from './pages/Firewall';
import { useQueryClient } from '@tanstack/react-query'; import { useQueryClient } from '@tanstack/react-query';
const socket = IS_DEV?io("ws://"+DEV_IP_BACKEND, {transports: ["websocket", "polling"], path:"/sock" }):io({transports: ["websocket", "polling"], path:"/sock"}); const socket = IS_DEV?io("ws://"+DEV_IP_BACKEND, {transports: ["websocket"], path:"/sock/socket.io" }):io({transports: ["websocket"], path:"/sock/socket.io"});
function App() { function App() {