This commit is contained in:
Your Name
2025-12-08 01:41:08 +03:00
parent 16f96aa6f6
commit 9af3023a37
49 changed files with 4609 additions and 4020 deletions

View File

@@ -5,6 +5,7 @@ import asyncio
import traceback
from fastapi import HTTPException
import time
import json
from utils import run_func
from utils import DEBUG
from utils import nicenessify
@@ -35,11 +36,12 @@ class FiregexInterceptor:
self.last_time_exception = 0
self.outstrem_function = None
self.expection_function = None
self.traffic_function = None
self.outstrem_task: asyncio.Task
self.outstrem_buffer = ""
@classmethod
async def start(cls, srv: Service, outstream_func=None, exception_func=None):
async def start(cls, srv: Service, outstream_func=None, exception_func=None, traffic_func=None):
self = cls()
self.srv = srv
self.filter_map_lock = asyncio.Lock()
@@ -47,6 +49,7 @@ class FiregexInterceptor:
self.sock_conn_lock = asyncio.Lock()
self.outstrem_function = outstream_func
self.expection_function = exception_func
self.traffic_function = traffic_func
if not self.sock_conn_lock.locked():
await self.sock_conn_lock.acquire()
self.sock_path = f"/tmp/firegex_nfproxy_{srv.id}.sock"
@@ -83,6 +86,16 @@ class FiregexInterceptor:
self.outstrem_buffer = self.outstrem_buffer[-OUTSTREAM_BUFFER_SIZE:]+"\n"
if self.outstrem_function:
await run_func(self.outstrem_function, self.srv.id, out_data)
# Parse JSON traffic events (if binary emits them)
if self.traffic_function:
for line in out_data.splitlines():
if line.startswith("{"): # JSON event from binary
try:
event = json.loads(line)
if "ts" in event and "verdict" in event: # Basic validation
await run_func(self.traffic_function, self.srv.id, event)
except (json.JSONDecodeError, KeyError):
pass # Ignore malformed JSON, keep backward compat with raw logs
async def _start_binary(self):
proxy_binary_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../cpproxy"))

View File

@@ -1,4 +1,5 @@
import asyncio
from collections import deque
from modules.nfproxy.firegex import FiregexInterceptor
from modules.nfproxy.nftables import FiregexTables, FiregexFilter
from modules.nfproxy.models import Service, PyFilter
@@ -12,7 +13,7 @@ class STATUS:
nft = FiregexTables()
class ServiceManager:
def __init__(self, srv: Service, db, outstream_func=None, exception_func=None):
def __init__(self, srv: Service, db, outstream_func=None, exception_func=None, traffic_func=None):
self.srv = srv
self.db = db
self.status = STATUS.STOP
@@ -21,11 +22,17 @@ class ServiceManager:
self.interceptor = None
self.outstream_function = outstream_func
self.last_exception_time = 0
self.traffic_events = deque(maxlen=500) # Ring buffer for traffic viewer
async def excep_internal_handler(srv, exc_time):
self.last_exception_time = exc_time
if exception_func:
await run_func(exception_func, srv, exc_time)
self.exception_function = excep_internal_handler
async def traffic_internal_handler(srv, event):
self.traffic_events.append(event)
if traffic_func:
await run_func(traffic_func, srv, event)
self.traffic_function = traffic_internal_handler
async def _update_filters_from_db(self):
pyfilters = [
@@ -69,7 +76,7 @@ class ServiceManager:
async def start(self):
if not self.interceptor:
nft.delete(self.srv)
self.interceptor = await FiregexInterceptor.start(self.srv, outstream_func=self.outstream_function, exception_func=self.exception_function)
self.interceptor = await FiregexInterceptor.start(self.srv, outstream_func=self.outstream_function, exception_func=self.exception_function, traffic_func=self.traffic_function)
await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE)
@@ -87,14 +94,24 @@ class ServiceManager:
async def update_filters(self):
async with self.lock:
await self._update_filters_from_db()
def get_traffic_events(self, limit: int = 500):
"""Return recent traffic events from ring buffer"""
events_list = list(self.traffic_events)
return events_list[-limit:] if limit < len(events_list) else events_list
def clear_traffic_events(self):
"""Clear traffic event history"""
self.traffic_events.clear()
class FirewallManager:
def __init__(self, db:SQLite, outstream_func=None, exception_func=None):
def __init__(self, db:SQLite, outstream_func=None, exception_func=None, traffic_func=None):
self.db = db
self.service_table: dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
self.outstream_function = outstream_func
self.exception_function = exception_func
self.traffic_function = traffic_func
async def close(self):
for key in list(self.service_table.keys()):
@@ -116,7 +133,7 @@ class FirewallManager:
srv = Service.from_dict(srv)
if srv.id in self.service_table:
continue
self.service_table[srv.id] = ServiceManager(srv, self.db, outstream_func=self.outstream_function, exception_func=self.exception_function)
self.service_table[srv.id] = ServiceManager(srv, self.db, outstream_func=self.outstream_function, exception_func=self.exception_function, traffic_func=self.traffic_function)
await self.service_table[srv.id].next(srv.status)
def get(self,srv_id) -> ServiceManager:

View File

@@ -1,8 +1,8 @@
fastapi[all]
httpx
uvicorn[standard]
psutil
python-jose[cryptography]
python-socketio
brotli
#git+https://salsa.debian.org/pkg-netfilter-team/pkg-nftables#egg=nftables&subdirectory=py
fastapi[all]
httpx
uvicorn[standard]
psutil
python-jose[cryptography]
python-socketio
brotli
#git+https://salsa.debian.org/pkg-netfilter-team/pkg-nftables#egg=nftables&subdirectory=py

View File

@@ -113,6 +113,8 @@ async def startup():
utils.socketio.on("nfproxy-outstream-leave", leave_outstream)
utils.socketio.on("nfproxy-exception-join", join_exception)
utils.socketio.on("nfproxy-exception-leave", leave_exception)
utils.socketio.on("nfproxy-traffic-join", join_traffic)
utils.socketio.on("nfproxy-traffic-leave", leave_traffic)
async def shutdown():
db.backup()
@@ -133,7 +135,10 @@ async def outstream_func(service_id, data):
async def exception_func(service_id, timestamp):
await utils.socketio.emit(f"nfproxy-exception-{service_id}", timestamp, room=f"nfproxy-exception-{service_id}")
firewall = FirewallManager(db, outstream_func=outstream_func, exception_func=exception_func)
async def traffic_func(service_id, event):
await utils.socketio.emit(f"nfproxy-traffic-{service_id}", event, room=f"nfproxy-traffic-{service_id}")
firewall = FirewallManager(db, outstream_func=outstream_func, exception_func=exception_func, traffic_func=traffic_func)
@app.get('/services', response_model=list[ServiceModel])
async def get_service_list():
@@ -368,6 +373,28 @@ async def get_pyfilters_code(service_id: str):
except FileNotFoundError:
return ""
@app.get('/services/{service_id}/traffic')
async def get_traffic_events(service_id: str, limit: int = 500):
"""Get recent traffic events from the service ring buffer"""
if not db.query("SELECT 1 FROM services WHERE service_id = ?;", service_id):
raise HTTPException(status_code=400, detail="This service does not exists!")
try:
events = firewall.get(service_id).get_traffic_events(limit)
return {"events": events, "count": len(events)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post('/services/{service_id}/traffic/clear', response_model=StatusMessageModel)
async def clear_traffic_events(service_id: str):
"""Clear traffic event history for a service"""
if not db.query("SELECT 1 FROM services WHERE service_id = ?;", service_id):
raise HTTPException(status_code=400, detail="This service does not exists!")
try:
firewall.get(service_id).clear_traffic_events()
return {"status": "ok"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
#Socket io events
async def join_outstream(sid, data):
"""Client joins a room."""
@@ -397,3 +424,20 @@ async def leave_exception(sid, data):
if srv:
await utils.socketio.leave_room(sid, f"nfproxy-exception-{srv}")
async def join_traffic(sid, data):
"""Client joins traffic viewer room and gets initial event history."""
srv = data.get("service")
if srv:
room = f"nfproxy-traffic-{srv}"
await utils.socketio.enter_room(sid, room)
try:
events = firewall.get(srv).get_traffic_events(500)
await utils.socketio.emit("nfproxy-traffic-history", {"events": events}, room=sid)
except Exception:
pass # Service may not exist or not started
async def leave_traffic(sid, data):
"""Client leaves traffic viewer room."""
srv = data.get("service")
if srv:
await utils.socketio.leave_room(sid, f"nfproxy-traffic-{srv}")

View File

@@ -1,221 +1,221 @@
import asyncio
from ipaddress import ip_address, ip_interface
import os
import socket
import psutil
import sys
import nftables
from socketio import AsyncServer
from fastapi import Path
from typing import Annotated
from functools import wraps
from pydantic import BaseModel, ValidationError
import traceback
from utils.models import StatusMessageModel
from typing import List
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
socketio:AsyncServer = None
sid_list:set = set()
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
ROUTERS_DIR = os.path.join(ROOT_DIR,"routers")
ON_DOCKER = "DOCKER" in sys.argv
DEBUG = "DEBUG" in sys.argv
NORELOAD = "NORELOAD" in sys.argv
FIREGEX_PORT = int(os.getenv("PORT","4444"))
FIREGEX_HOST = os.getenv("HOST","0.0.0.0")
FIREGEX_SOCKET_DIR = os.getenv("SOCKET_DIR", None)
FIREGEX_SOCKET = os.path.join(FIREGEX_SOCKET_DIR, "firegex.sock") if FIREGEX_SOCKET_DIR else None
JWT_ALGORITHM: str = "HS256"
API_VERSION = "{{VERSION_PLACEHOLDER}}" if "{" not in "{{VERSION_PLACEHOLDER}}" else "0.0.0"
PortType = Annotated[int, Path(gt=0, lt=65536)]
async def run_func(func, *args, **kwargs):
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
async def socketio_emit(elements:list[str]):
await socketio.emit("update",elements)
def refactor_name(name:str):
name = name.strip()
while " " in name:
name = name.replace(" "," ")
return name
class SysctlManager:
def __init__(self, ctl_table):
self.old_table = {}
self.new_table = {}
if os.path.isdir("/sys_host/"):
self.old_table = dict()
self.new_table = dict(ctl_table)
for name in ctl_table.keys():
self.old_table[name] = read_sysctl(name)
def write_table(self, table) -> bool:
for name, value in table.items():
if read_sysctl(name) != value:
write_sysctl(name, value)
def set(self):
self.write_table(self.new_table)
def reset(self):
self.write_table(self.old_table)
def read_sysctl(name:str):
with open(f"/sys_host/{name}", "rt") as f:
return "1" in f.read()
def write_sysctl(name:str, value:bool):
with open(f"/sys_host/{name}", "wt") as f:
f.write("1" if value else "0")
def list_files(mypath):
from os import listdir
from os.path import isfile, join
return [f for f in listdir(mypath) if isfile(join(mypath, f))]
def ip_parse(ip:str):
return str(ip_interface(ip).network)
def is_ip_parse(ip:str):
try:
ip_parse(ip)
return True
except Exception:
return False
def addr_parse(ip:str):
return str(ip_address(ip))
def ip_family(ip:str):
return "ip6" if ip_interface(ip).version == 6 else "ip"
def get_interfaces():
def _get_interfaces():
for int_name, interfs in psutil.net_if_addrs().items():
for interf in interfs:
if interf.family in [socket.AF_INET, socket.AF_INET6]:
yield {"name": int_name, "addr":interf.address}
return list(_get_interfaces())
def nftables_int_to_json(ip_int):
ip_int = ip_parse(ip_int)
ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1])
return {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}
def nftables_json_to_int(ip_json_int):
if isinstance(ip_json_int,str):
return str(ip_parse(ip_json_int))
else:
return f'{ip_json_int["prefix"]["addr"]}/{ip_json_int["prefix"]["len"]}'
class Singleton(object):
__instance = None
def __new__(class_, *args, **kwargs):
if not isinstance(class_.__instance, class_):
class_.__instance = object.__new__(class_, *args, **kwargs)
return class_.__instance
class NFTableManager(Singleton):
table_name = "firegex"
def __init__(self, init_cmd, reset_cmd):
self.__init_cmds = init_cmd
self.__reset_cmds = reset_cmd
self.nft = nftables.Nftables()
def raw_cmd(self, *cmds):
return self.nft.json_cmd({"nftables": list(cmds)})
def cmd(self, *cmds):
code, out, err = self.raw_cmd(*cmds)
if code == 0:
return out
else:
raise Exception(err)
def init(self):
self.reset()
self.raw_cmd({"add":{"table":{"name":self.table_name,"family":"inet"}}})
self.cmd(*self.__init_cmds)
def reset(self):
self.raw_cmd(*self.__reset_cmds)
def list_rules(self, tables = None, chains = None):
for filter in [ele["rule"] for ele in self.raw_list() if "rule" in ele ]:
if tables and filter["table"] not in tables:
continue
if chains and filter["chain"] not in chains:
continue
yield filter
def raw_list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"]
def _json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json"):
res = obj.model_dump(mode=mode, exclude_unset=not unset)
if convert_keys:
for from_k, to_k in convert_keys.items():
if from_k in res:
res[to_k] = res.pop(from_k)
if exclude:
for ele in exclude:
if ele in res:
del res[ele]
return res
def json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json") -> dict:
if isinstance(obj, list):
return [_json_like(ele, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode) for ele in obj]
return _json_like(obj, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode)
def register_event(sio_server: AsyncServer, event_name: str, model: BaseModel, response_model: BaseModel|None = None):
def decorator(func):
@sio_server.on(event_name) # Automatically registers the event
@wraps(func)
async def wrapper(sid, data):
try:
# Parse and validate incoming data
parsed_data = model.model_validate(data)
except ValidationError:
return json_like(StatusMessageModel(status=f"Invalid {event_name} request"))
# Call the original function with the parsed data
result = await func(sid, parsed_data)
# If a response model is provided, validate the output
if response_model:
try:
parsed_result = response_model.model_validate(result)
except ValidationError:
traceback.print_exc()
return json_like(StatusMessageModel(status=f"SERVER ERROR: Invalid {event_name} response"))
else:
parsed_result = result
# Emit the validated result
if parsed_result:
if isinstance(parsed_result, BaseModel):
return json_like(parsed_result)
return parsed_result
return wrapper
return decorator
def nicenessify(priority:int, pid:int|None=None):
try:
pid = os.getpid() if pid is None else pid
ps = psutil.Process(pid)
if os.name == 'posix':
ps.nice(priority)
except Exception as e:
print(f"Error setting priority: {e} {traceback.format_exc()}")
pass
import asyncio
from ipaddress import ip_address, ip_interface
import os
import socket
import psutil
import sys
import nftables
from socketio import AsyncServer
from fastapi import Path
from typing import Annotated
from functools import wraps
from pydantic import BaseModel, ValidationError
import traceback
from utils.models import StatusMessageModel
from typing import List
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
socketio:AsyncServer = None
sid_list:set = set()
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
ROUTERS_DIR = os.path.join(ROOT_DIR,"routers")
ON_DOCKER = "DOCKER" in sys.argv
DEBUG = "DEBUG" in sys.argv
NORELOAD = "NORELOAD" in sys.argv
FIREGEX_PORT = int(os.getenv("PORT","4444"))
FIREGEX_HOST = os.getenv("HOST","0.0.0.0")
FIREGEX_SOCKET_DIR = os.getenv("SOCKET_DIR", None)
FIREGEX_SOCKET = os.path.join(FIREGEX_SOCKET_DIR, "firegex.sock") if FIREGEX_SOCKET_DIR else None
JWT_ALGORITHM: str = "HS256"
API_VERSION = "{{VERSION_PLACEHOLDER}}" if "{" not in "{{VERSION_PLACEHOLDER}}" else "0.0.0"
PortType = Annotated[int, Path(gt=0, lt=65536)]
async def run_func(func, *args, **kwargs):
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
async def socketio_emit(elements:list[str]):
await socketio.emit("update",elements)
def refactor_name(name:str):
name = name.strip()
while " " in name:
name = name.replace(" "," ")
return name
class SysctlManager:
def __init__(self, ctl_table):
self.old_table = {}
self.new_table = {}
if os.path.isdir("/sys_host/"):
self.old_table = dict()
self.new_table = dict(ctl_table)
for name in ctl_table.keys():
self.old_table[name] = read_sysctl(name)
def write_table(self, table) -> bool:
for name, value in table.items():
if read_sysctl(name) != value:
write_sysctl(name, value)
def set(self):
self.write_table(self.new_table)
def reset(self):
self.write_table(self.old_table)
def read_sysctl(name:str):
with open(f"/sys_host/{name}", "rt") as f:
return "1" in f.read()
def write_sysctl(name:str, value:bool):
with open(f"/sys_host/{name}", "wt") as f:
f.write("1" if value else "0")
def list_files(mypath):
from os import listdir
from os.path import isfile, join
return [f for f in listdir(mypath) if isfile(join(mypath, f))]
def ip_parse(ip:str):
return str(ip_interface(ip).network)
def is_ip_parse(ip:str):
try:
ip_parse(ip)
return True
except Exception:
return False
def addr_parse(ip:str):
return str(ip_address(ip))
def ip_family(ip:str):
return "ip6" if ip_interface(ip).version == 6 else "ip"
def get_interfaces():
def _get_interfaces():
for int_name, interfs in psutil.net_if_addrs().items():
for interf in interfs:
if interf.family in [socket.AF_INET, socket.AF_INET6]:
yield {"name": int_name, "addr":interf.address}
return list(_get_interfaces())
def nftables_int_to_json(ip_int):
ip_int = ip_parse(ip_int)
ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1])
return {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}
def nftables_json_to_int(ip_json_int):
if isinstance(ip_json_int,str):
return str(ip_parse(ip_json_int))
else:
return f'{ip_json_int["prefix"]["addr"]}/{ip_json_int["prefix"]["len"]}'
class Singleton(object):
__instance = None
def __new__(class_, *args, **kwargs):
if not isinstance(class_.__instance, class_):
class_.__instance = object.__new__(class_, *args, **kwargs)
return class_.__instance
class NFTableManager(Singleton):
table_name = "firegex"
def __init__(self, init_cmd, reset_cmd):
self.__init_cmds = init_cmd
self.__reset_cmds = reset_cmd
self.nft = nftables.Nftables()
def raw_cmd(self, *cmds):
return self.nft.json_cmd({"nftables": list(cmds)})
def cmd(self, *cmds):
code, out, err = self.raw_cmd(*cmds)
if code == 0:
return out
else:
raise Exception(err)
def init(self):
self.reset()
self.raw_cmd({"add":{"table":{"name":self.table_name,"family":"inet"}}})
self.cmd(*self.__init_cmds)
def reset(self):
self.raw_cmd(*self.__reset_cmds)
def list_rules(self, tables = None, chains = None):
for filter in [ele["rule"] for ele in self.raw_list() if "rule" in ele ]:
if tables and filter["table"] not in tables:
continue
if chains and filter["chain"] not in chains:
continue
yield filter
def raw_list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"]
def _json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json"):
res = obj.model_dump(mode=mode, exclude_unset=not unset)
if convert_keys:
for from_k, to_k in convert_keys.items():
if from_k in res:
res[to_k] = res.pop(from_k)
if exclude:
for ele in exclude:
if ele in res:
del res[ele]
return res
def json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json") -> dict:
if isinstance(obj, list):
return [_json_like(ele, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode) for ele in obj]
return _json_like(obj, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode)
def register_event(sio_server: AsyncServer, event_name: str, model: BaseModel, response_model: BaseModel|None = None):
def decorator(func):
@sio_server.on(event_name) # Automatically registers the event
@wraps(func)
async def wrapper(sid, data):
try:
# Parse and validate incoming data
parsed_data = model.model_validate(data)
except ValidationError:
return json_like(StatusMessageModel(status=f"Invalid {event_name} request"))
# Call the original function with the parsed data
result = await func(sid, parsed_data)
# If a response model is provided, validate the output
if response_model:
try:
parsed_result = response_model.model_validate(result)
except ValidationError:
traceback.print_exc()
return json_like(StatusMessageModel(status=f"SERVER ERROR: Invalid {event_name} response"))
else:
parsed_result = result
# Emit the validated result
if parsed_result:
if isinstance(parsed_result, BaseModel):
return json_like(parsed_result)
return parsed_result
return wrapper
return decorator
def nicenessify(priority:int, pid:int|None=None):
try:
pid = os.getpid() if pid is None else pid
ps = psutil.Process(pid)
if os.name == 'posix':
ps.nice(priority)
except Exception as e:
print(f"Error setting priority: {e} {traceback.format_exc()}")
pass