removed fastapi_socketio + general improves

This commit is contained in:
Domingo Dirutigliano
2025-02-12 01:16:10 +01:00
parent f3ba6dc716
commit 2fb77a348f
10 changed files with 76 additions and 157 deletions

1
.gitignore vendored
View File

@@ -29,6 +29,7 @@
/firegex-compose-tmp-file.yml
/firegex.py
/tests/benchmark.csv
/backend/modules/nfproxy/socks/
# misc
**/.DS_Store
**/.env.local

View File

@@ -27,7 +27,7 @@ RUN pip3 install --no-cache-dir --break-system-packages -r /execute/requirements
COPY ./backend/binsrc /execute/binsrc
RUN g++ binsrc/nfqueue.cpp -o modules/cppqueue -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libhs libmnl)
#RUN g++ binsrc/nfproxy-tun.cpp -o modules/cppnfproxy -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libmnl)
#RUN g++ binsrc/nfproxy-tun.cpp -o modules/cppproxy -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libmnl)
COPY ./backend/ /execute/
COPY --from=frontend /app/dist/ ./frontend/

View File

@@ -8,13 +8,13 @@ from fastapi import FastAPI, HTTPException, Depends, APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt
from passlib.context import CryptContext
from fastapi_socketio import SocketManager
from utils.sqlite import SQLite
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager
from utils.loader import frontend_deploy, load_routers
from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel
from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
import socketio
# DB init
db = SQLite('db/firegex.db')
@@ -42,7 +42,14 @@ app = FastAPI(
title="Firegex API",
version=API_VERSION,
)
utils.socketio = SocketManager(app, "/sock", socketio_path="")
utils.socketio = socketio.AsyncServer(
async_mode="asgi",
cors_allowed_origins=[],
transports=["websocket"]
)
sio_app = socketio.ASGIApp(utils.socketio, socketio_path="/sock/socket.io", other_asgi_app=app)
app.mount("/sock", sio_app)
if DEBUG:
app.add_middleware(

View File

@@ -1,97 +1,61 @@
from modules.nfregex.nftables import FiregexTables
from modules.nfproxy.nftables import FiregexTables
from utils import run_func
from modules.nfregex.models import Service, Regex
import re
from modules.nfproxy.models import Service, PyFilter
import os
import asyncio
import traceback
from utils import DEBUG
from fastapi import HTTPException
#TODO copied file, review
import socket
import shutil
nft = FiregexTables()
class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
id=None,
update_func = None
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
if input_mode == output_mode:
input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
self.id = id
self.update_func = update_func
self.compiled_regex = self.compile()
@classmethod
def from_regex(cls, regex:Regex, update_func = None):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"],
update_func = update_func
)
def compile(self):
if isinstance(self.regex, str):
self.regex = self.regex.encode()
if not isinstance(self.regex, bytes):
raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if it's invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.input_mode:
yield case_sensitive + "C" + self.regex.hex()
if self.output_mode:
yield case_sensitive + "S" + self.regex.hex()
async def update(self):
if self.update_func:
await run_func(self.update_func, self)
class FiregexInterceptor:
def __init__(self):
self.srv:Service
self.filter_map_lock:asyncio.Lock
self.filter_map: dict[str, RegexFilter]
self.regex_filters: set[RegexFilter]
self.update_config_lock:asyncio.Lock
self._stats_updater_cb:callable
self.process:asyncio.subprocess.Process
self.update_task: asyncio.Task
self.ack_arrived = False
self.ack_status = None
self.ack_fail_what = ""
self.ack_lock = asyncio.Lock()
self.base_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"socks", self.srv.id
)
self.n_threads = int(os.getenv("NTHREADS","1"))
self.connection_socket = os.path.join(self.base_dir, "connection.sock")
self.vedict_sockets = [os.path.join(self.base_dir, f"vedict{i}.sock") for i in range(self.n_threads)]
self.socks = []
def add_sock(self, path):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
sock.bind(path)
self.socks.append(sock)
return sock
async def _call_stats_updater_callback(self, filter: PyFilter):
if self._stats_updater_cb:
await run_func(self._stats_updater_cb(filter))
@classmethod
async def start(cls, srv: Service):
async def start(cls, srv: Service, stats_updater_cb:callable):
self = cls()
self.srv = srv
self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock()
self._stats_updater_cb = stats_updater_cb
os.makedirs(self.base_dir, exist_ok=True)
self.add_sock(self.connection_socket)
for path in self.vedict_sockets:
self.add_sock(path)
queue_range = await self._start_binary()
self.update_task = asyncio.create_task(self.update_blocked())
# TODO starts python workers
nft.add(self.srv, queue_range)
if not self.ack_lock.locked():
await self.ack_lock.acquire()
return self
async def _start_binary(self):
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppqueue")
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppproxy")
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path,
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
env={"MATCH_MODE": "stream" if self.srv.proto == "tcp" else "block", "NTHREADS": os.getenv("NTHREADS","1")},
)
self.process.stdin.write(self.base_dir.encode().hex().encode()+b" 3\n")
await self.process.stdin.drain()
line_fut = self.process.stdout.readuntil()
try:
line_fut = await asyncio.wait_for(line_fut, timeout=3)
@@ -106,68 +70,14 @@ class FiregexInterceptor:
self.process.kill()
raise Exception("Invalid binary output")
async def update_blocked(self):
try:
while True:
line = (await self.process.stdout.readuntil()).decode()
if DEBUG:
print(line)
if line.startswith("BLOCKED "):
regex_id = line.split()[1]
async with self.filter_map_lock:
if regex_id in self.filter_map:
self.filter_map[regex_id].blocked+=1
await self.filter_map[regex_id].update()
if line.startswith("ACK "):
self.ack_arrived = True
self.ack_status = line.split()[1].upper() == "OK"
if not self.ack_status:
self.ack_fail_what = " ".join(line.split()[2:])
self.ack_lock.release()
except asyncio.CancelledError:
pass
except asyncio.IncompleteReadError:
pass
except Exception:
traceback.print_exc()
async def stop(self):
self.update_task.cancel()
if self.process and self.process.returncode is None:
self.process.kill()
async def _update_config(self, filters_codes):
async with self.update_config_lock:
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain()
try:
async with asyncio.timeout(3):
await self.ack_lock.acquire()
except TimeoutError:
pass
if not self.ack_arrived or not self.ack_status:
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")
async def reload(self, filters:list[RegexFilter]):
async with self.filter_map_lock:
self.filter_map = self.compile_filters(filters)
filters_codes = self.get_filter_codes()
await self._update_config(filters_codes)
def get_filter_codes(self):
filters_codes = list(self.filter_map.keys())
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
return filters_codes
def compile_filters(self, filters:list[RegexFilter]):
res = {}
for filter_obj in filters:
try:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception:
pass
return res
for sock in self.socks:
sock.close()
shutil.rmtree(self.base_dir)
async def reload(self, filters:list[PyFilter]):
# filters are the functions to use in the workers (other functions are disabled or not flagged as filters)
# TODO update filters in python workers (prob for new filters added) (reading from file????)
pass

View File

@@ -1,18 +1,15 @@
import asyncio
from modules.nfregex.firegex import FiregexInterceptor, RegexFilter
from modules.nfregex.nftables import FiregexTables, FiregexFilter
from modules.nfregex.models import Regex, Service
from modules.nfproxy.firegex import FiregexInterceptor
from modules.nfproxy.nftables import FiregexTables, FiregexFilter
from modules.nfproxy.models import Service, PyFilter
from utils.sqlite import SQLite
#TODO copied file, review
class STATUS:
STOP = "stop"
ACTIVE = "active"
nft = FiregexTables()
class ServiceManager:
def __init__(self, srv: Service, db):
self.srv = srv
@@ -23,13 +20,13 @@ class ServiceManager:
self.interceptor = None
async def _update_filters_from_db(self):
regexes = [
Regex.from_dict(ele) for ele in
self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id)
pyfilters = [
PyFilter.from_dict(ele) for ele in
self.db.query("SELECT * FROM pyfilter WHERE service_id = ? AND active=1;", self.srv.id)
]
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f.id for f in regexes])
new_filters = set([f.id for f in pyfilters])
#remove old filters
for f in old_filters:
if f not in new_filters:
@@ -37,8 +34,7 @@ class ServiceManager:
#add new filters
for f in new_filters:
if f not in old_filters:
filter = [ele for ele in regexes if ele.id == f][0]
self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater)
self.filters[f] = [ele for ele in pyfilters if ele.id == f][0]
if self.interceptor:
await self.interceptor.reload(self.filters.values())
@@ -54,8 +50,8 @@ class ServiceManager:
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
await self.restart()
def _stats_updater(self,filter:RegexFilter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id)
def _stats_updater(self,filter:PyFilter):
self.db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE filter_id = ?;", filter.blocked_packets, filter.edited_packets, filter.id)
def _set_status(self,status):
self.status = status
@@ -64,7 +60,7 @@ class ServiceManager:
async def start(self):
if not self.interceptor:
nft.delete(self.srv)
self.interceptor = await FiregexInterceptor.start(self.srv)
self.interceptor = await FiregexInterceptor.start(self.srv, self._stats_updater)
await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE)
@@ -119,3 +115,5 @@ class FirewallManager:
class ServiceNotFoundException(Exception):
pass

View File

@@ -1,3 +1,4 @@
class Service:
def __init__(self, service_id: str, status: str, port: int, name: str, proto: str, ip_int: str, **other):
self.id = service_id
@@ -14,7 +15,7 @@ class Service:
class PyFilter:
def __init__(self, filter_id:int, name: str, blocked_packets: int, edited_packets: int, active: bool, **other):
self.filter_id = filter_id
self.id = filter_id
self.name = name
self.blocked_packets = blocked_packets
self.edited_packets = edited_packets

View File

@@ -1,4 +1,4 @@
from modules.nfregex.models import Service
from modules.nfproxy.models import Service
from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json
class FiregexFilter:
@@ -48,10 +48,12 @@ class FiregexTables(NFTableManager):
def add(self, srv:Service, queue_range):
for ele in self.get():
if ele.__eq__(srv): return
if ele.__eq__(srv):
return
init, end = queue_range
if init > end: init, end = end, init
if init > end:
init, end = end, init
self.cmd(
{ "insert":{ "rule": {
"family": "inet",

View File

@@ -4,5 +4,5 @@ uvicorn[standard]
passlib[bcrypt]
psutil
python-jose[cryptography]
fastapi-socketio
python-socketio
#git+https://salsa.debian.org/pkg-netfilter-team/pkg-nftables#egg=nftables&subdirectory=py

View File

@@ -5,13 +5,13 @@ import socket
import psutil
import sys
import nftables
from fastapi_socketio import SocketManager
from socketio import AsyncServer
from fastapi import Path
from typing import Annotated
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
socketio:SocketManager = None
socketio:AsyncServer = None
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
ROUTERS_DIR = os.path.join(ROOT_DIR,"routers")

View File

@@ -14,7 +14,7 @@ import { Firewall } from './pages/Firewall';
import { useQueryClient } from '@tanstack/react-query';
const socket = IS_DEV?io("ws://"+DEV_IP_BACKEND, {transports: ["websocket", "polling"], path:"/sock" }):io({transports: ["websocket", "polling"], path:"/sock"});
const socket = IS_DEV?io("ws://"+DEV_IP_BACKEND, {transports: ["websocket"], path:"/sock/socket.io" }):io({transports: ["websocket"], path:"/sock/socket.io"});
function App() {