From 47496287d5582f6d4560e0a46e27cbe79aa6e01d Mon Sep 17 00:00:00 2001 From: Domingo Dirutigliano Date: Tue, 4 Mar 2025 16:19:30 +0100 Subject: [PATCH] asyncio simulator + fix close --- fgex-lib/firegex/nfproxy/proxysim/__init__.py | 266 ++++++++++-------- fgex-lib/requirements.txt | 2 +- 2 files changed, 157 insertions(+), 111 deletions(-) diff --git a/fgex-lib/firegex/nfproxy/proxysim/__init__.py b/fgex-lib/firegex/nfproxy/proxysim/__init__.py index fb7b6fc..31d1d57 100644 --- a/fgex-lib/firegex/nfproxy/proxysim/__init__.py +++ b/fgex-lib/firegex/nfproxy/proxysim/__init__.py @@ -1,14 +1,13 @@ import socket import os -import threading -from watchdog.observers import Observer -from watchdog.events import FileSystemEventHandler from firegex.nfproxy.internals import get_filter_names import traceback from multiprocessing import Process from firegex.nfproxy import ACCEPT, DROP, REJECT, UNSTABLE_MANGLE from rich.markup import escape from rich import print +import asyncio +from watchfiles import awatch, Change fake_ip_header = b"FAKE:IP:TCP:HEADERS:" fake_ip_header_len = len(fake_ip_header) @@ -37,47 +36,49 @@ def load_level_str(level:str): def log_print(module:str, *args, level:str = LogLevels.INFO, **kwargs): return print(f"{load_level_str(level)}[deep_pink4 bold]\\[nfproxy][/][medium_orchid3 bold]\\[{escape(module)}][/]", *args, **kwargs) -class ProxyFilterHandler(FileSystemEventHandler): - - def __init__(self, reload_action): - super().__init__() - self.__reload_action = reload_action - - def on_modified(self, event): - if self.__reload_action is not None: - self.__reload_action() - return super().on_modified(event) - - def on_deleted(self, event): - if self.__reload_action is not None: - self.__reload_action() - return super().on_deleted(event) - - -def _forward_and_filter(filter_ctx:dict, source:socket.socket, destination:socket.socket, is_input:bool, is_ipv6:bool, is_tcp:bool, has_to_filter:bool = True): - """Forward data from source to destination.""" +async def watch_filter_file(filter_file: str, reload_action): + abs_path = os.path.abspath(filter_file) + directory = os.path.dirname(abs_path) + # Immediately call the reload action on startup. + if reload_action is not None: + reload_action() + log_print("observer", f"Listening for changes on {escape(abs_path)}") + try: + # Monitor the directory; set recursive=False since we only care about the specific file. + async for changes in awatch(directory, recursive=False): + # Process events and filter for our file. + for change in changes: + event, path = change + if os.path.abspath(path) == abs_path: + # Optionally, you can check the event type: + if event in {Change.modified, Change.deleted}: + if reload_action is not None: + reload_action() + except asyncio.CancelledError: + log_print("observer", "Watcher cancelled, stopping.") + +async def forward_and_filter(filter_ctx: dict, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + is_input: bool, + is_ipv6: bool, + is_tcp: bool, + has_to_filter: bool = True): + """Asynchronously forward data from reader to writer applying filters.""" try: - def forward(data:bytes): - try: - destination.sendall(data) - except OSError: - return - def stop_filter_action(data:bytes): - nonlocal has_to_filter - has_to_filter = False - forward(data) while True: try: - data = source.recv(4096) - except OSError: - return + data = await reader.read(4096) + except Exception: + break if not data: break + if has_to_filter: filter_ctx["__firegex_packet_info"] = { "data": data, "l4_size": len(data), - "raw_packet": fake_ip_header+data, + "raw_packet": fake_ip_header + data, "is_input": is_input, "is_ipv6": is_ipv6, "is_tcp": is_tcp @@ -85,103 +86,151 @@ def _forward_and_filter(filter_ctx:dict, source:socket.socket, destination:socke try: exec("firegex.nfproxy.internals.handle_packet(globals())", filter_ctx, filter_ctx) except Exception as e: - log_print("packet-handling", f"Error while executing filter: {escape(str(e))}, forwarding normally from now", level=LogLevels.ERROR) + log_print("packet-handling", + f"Error while executing filter: {escape(str(e))}, forwarding normally from now", + level=LogLevels.ERROR) traceback.print_exc() - stop_filter_action(data) + # Stop filtering and forward the packet as is. + has_to_filter = False + writer.write(data) + await writer.drain() continue finally: - if "__firegex_packet_info" in filter_ctx.keys(): - del filter_ctx["__firegex_packet_info"] - - result = filter_ctx.get("__firegex_pyfilter_result", None) - - if result is not None: - del filter_ctx["__firegex_pyfilter_result"] - + filter_ctx.pop("__firegex_packet_info", None) + + result = filter_ctx.pop("__firegex_pyfilter_result", None) if result is None or not isinstance(result, dict): log_print("filter-parsing", "No result found", level=LogLevels.ERROR) - stop_filter_action(data) + has_to_filter = False + writer.write(data) + await writer.drain() continue - action = result.get("action", None) - + + action = result.get("action") if action is None or not isinstance(action, int): log_print("filter-parsing", "No action found", level=LogLevels.ERROR) - stop_filter_action(data) + has_to_filter = False + writer.write(data) + await writer.drain() continue - + if action == ACCEPT.value: - forward(data) + writer.write(data) + await writer.drain() continue - - filter_name = result.get("matched_by", None) + + filter_name = result.get("matched_by") if filter_name is None or not isinstance(filter_name, str): log_print("filter-parsing", "No matched_by found", level=LogLevels.ERROR) - stop_filter_action(data) + has_to_filter = False + writer.write(data) + await writer.drain() continue - + if action == DROP.value: log_print("drop-action", "Dropping packet can't be simulated, so the connection will be rejected", level=LogLevels.WARNING) action = REJECT.value - + if action == REJECT.value: log_print("reject-action", f"Rejecting connection caused by {escape(filter_name)} pyfilter") - source.close() - destination.close() + writer.close() + await writer.wait_closed() return + elif action == UNSTABLE_MANGLE.value: - mangled_packet = result.get("mangled_packet", None) + mangled_packet = result.get("mangled_packet") if mangled_packet is None or not isinstance(mangled_packet, bytes): log_print("filter-parsing", "No mangled_packet found", level=LogLevels.ERROR) - stop_filter_action(data) + has_to_filter = False + writer.write(data) + await writer.drain() continue log_print("mangle", f"Mangling packet caused by {escape(filter_name)} pyfilter") - log_print("mangle", "In the real execution mangling is not so stable as the simulation does, l4_data can be different by data", level=LogLevels.WARNING) - forward(mangled_packet[fake_ip_header_len:]) + log_print("mangle", + "In the real execution mangling is not so stable as the simulation does, l4_data can be different by data", + level=LogLevels.WARNING) + writer.write(mangled_packet[fake_ip_header_len:]) + await writer.drain() continue else: log_print("filter-parsing", f"Invalid action {action} found", level=LogLevels.ERROR) - stop_filter_action(data) + has_to_filter = False + writer.write(data) + await writer.drain() continue - forward(data) + else: + writer.write(data) + await writer.drain() + except Exception as exc: + log_print("forward_and_filter", f"Exception occurred: {escape(str(exc))}", level=LogLevels.ERROR) finally: - source.close() - destination.close() + writer.close() + try: + await writer.wait_closed() + except Exception: + pass -def _execute_proxy(filter_code:str, target_ip:str, target_port:int, local_ip:str = "127.0.0.1", local_port:int = 7474, ipv6:bool = False): - - addr_family = socket.AF_INET6 if ipv6 else socket.AF_INET - server = socket.socket(addr_family, socket.SOCK_STREAM) - server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server.bind((local_ip, local_port)) - server.listen(5) - - log_print("listener", f"TCP proxy listening on {escape(local_ip)}:{local_port} and forwarding to -> {escape(target_ip)}:{target_port}") +async def handle_connection( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter, filter_code: str, + target_ip: str, target_port: int, ipv6: bool): + """Handle a new incoming connection and create a remote connection.""" + addr = writer.get_extra_info('peername') + log_print("listener", f"Accepted connection from {escape(addr[0])}:{addr[1]}") try: - while True: - client_socket, addr = server.accept() - log_print("listener", f"Accepted connection from {escape(addr[0])}:{addr[1]}") - try: - remote_socket = socket.socket(addr_family, socket.SOCK_STREAM) - remote_socket.connect((target_ip, target_port)) - except Exception as e: - log_print("listener", f"Could not connect to remote {escape(target_ip)}:{target_port}: {escape(str(e))}", level=LogLevels.ERROR) - client_socket.close() - continue - try: - filter_ctx = {} - exec(filter_code, filter_ctx, filter_ctx) - # Start two threads to forward data in both directions. - threading.Thread(target=_forward_and_filter, args=(filter_ctx, client_socket, remote_socket, True, ipv6, True, True)).start() - threading.Thread(target=_forward_and_filter, args=(filter_ctx, remote_socket, client_socket, False, ipv6, True, True)).start() - except Exception as e: - log_print("listener", f"Error while compiling filter context: {escape(str(e))}, forwarding normally", level=LogLevels.ERROR) - traceback.print_exc() - threading.Thread(target=_forward_and_filter, args=(filter_ctx, client_socket, remote_socket, True, ipv6, True, False)).start() - threading.Thread(target=_forward_and_filter, args=(filter_ctx, remote_socket, client_socket, False, ipv6, True, False)).start() - except KeyboardInterrupt: - log_print("listener", "Proxy stopped by user") + remote_reader, remote_writer = await asyncio.open_connection( + target_ip, target_port, + family=socket.AF_INET6 if ipv6 else socket.AF_INET) + except Exception as e: + log_print("listener", + f"Could not connect to remote {escape(target_ip)}:{target_port}: {escape(str(e))}", + level=LogLevels.ERROR) + writer.close() + await writer.wait_closed() + return + + try: + filter_ctx = {} + exec(filter_code, filter_ctx, filter_ctx) + except Exception as e: + log_print("listener", + f"Error while compiling filter context: {escape(str(e))}, forwarding normally", + level=LogLevels.ERROR) + traceback.print_exc() + filter_ctx = {} + # Create asynchronous tasks for bidirectional forwarding. + task1 = asyncio.create_task(forward_and_filter(filter_ctx, reader, remote_writer, True, ipv6, True, True)) + task2 = asyncio.create_task(forward_and_filter(filter_ctx, remote_reader, writer, False, ipv6, True, True)) + try: + await asyncio.gather(task1, task2) + except (KeyboardInterrupt, asyncio.CancelledError): + task1.cancel() + task2.cancel() + await asyncio.gather(task1, task2) finally: - server.close() + remote_writer.close() + await remote_writer.wait_closed() + +async def _execute_proxy( + filter_code: str, + target_ip: str, target_port: int, + local_ip: str = "127.0.0.1", local_port: int = 7474, + ipv6: bool = False +): + """Start the asyncio-based TCP proxy server.""" + addr_family = socket.AF_INET6 if ipv6 else socket.AF_INET + server = await asyncio.start_server( + lambda r, w: handle_connection(r, w, filter_code, target_ip, target_port, ipv6), + local_ip, local_port, family=addr_family) + log_print("listener", f"TCP proxy listening on {escape(local_ip)}:{local_port} and forwarding to -> {escape(target_ip)}:{target_port}") + async with server: + await server.serve_forever() + + +def _proxy_asyncio_runner(filter_code: str, target_ip: str, target_port: int, local_ip: str, local_port: int, ipv6: bool): + try: + return asyncio.run(_execute_proxy(filter_code, target_ip, target_port, local_ip, local_port, ipv6)) + except KeyboardInterrupt: + log_print("listener", "Proxy server stopped", level=LogLevels.WARNING) def _build_filter(filepath:str, proto:str): if os.path.isfile(filepath) is False: @@ -219,7 +268,7 @@ def run_proxy_simulation(filter_file:str, proto:str, target_ip:str, target_port: def reload_proxy_proc(): nonlocal proxy_process if proxy_process is not None: - proxy_process.terminate() + proxy_process.kill() proxy_process.join() proxy_process = None @@ -230,19 +279,16 @@ def run_proxy_simulation(filter_file:str, proto:str, target_ip:str, target_port: log_print("reloader", f"Failed to build filter {escape(filter_file)}!", level=LogLevels.ERROR) traceback.print_exc() if compiled_filter is not None: - proxy_process = Process(target=_execute_proxy, args=(compiled_filter, target_ip, target_port, local_ip, local_port, ipv6)) + proxy_process = Process(target=_proxy_asyncio_runner, args=(compiled_filter, target_ip, target_port, local_ip, local_port, ipv6)) proxy_process.start() - - observer = Observer() - handler = ProxyFilterHandler(reload_proxy_proc) - observer.schedule(handler, os.path.abspath(filter_file), recursive=False) - observer.start() - reload_proxy_proc() - log_print("observer", f"Listening for changes on {escape(os.path.abspath(filter_file))}") try: - observer.join() + asyncio.run(watch_filter_file(filter_file, reload_proxy_proc)) except KeyboardInterrupt: - observer.stop() + pass + finally: + if proxy_process is not None: + proxy_process.kill() + proxy_process.join() diff --git a/fgex-lib/requirements.txt b/fgex-lib/requirements.txt index ae67857..17654e5 100644 --- a/fgex-lib/requirements.txt +++ b/fgex-lib/requirements.txt @@ -1,6 +1,6 @@ typer==0.15.1 pydantic>=2 typing-extensions>=4.7.1 -watchdog>=6.0.0 +watchfiles fgex pyllhttp