asyncio simulator + fix close
This commit is contained in:
@@ -1,14 +1,13 @@
|
|||||||
import socket
|
import socket
|
||||||
import os
|
import os
|
||||||
import threading
|
|
||||||
from watchdog.observers import Observer
|
|
||||||
from watchdog.events import FileSystemEventHandler
|
|
||||||
from firegex.nfproxy.internals import get_filter_names
|
from firegex.nfproxy.internals import get_filter_names
|
||||||
import traceback
|
import traceback
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from firegex.nfproxy import ACCEPT, DROP, REJECT, UNSTABLE_MANGLE
|
from firegex.nfproxy import ACCEPT, DROP, REJECT, UNSTABLE_MANGLE
|
||||||
from rich.markup import escape
|
from rich.markup import escape
|
||||||
from rich import print
|
from rich import print
|
||||||
|
import asyncio
|
||||||
|
from watchfiles import awatch, Change
|
||||||
|
|
||||||
fake_ip_header = b"FAKE:IP:TCP:HEADERS:"
|
fake_ip_header = b"FAKE:IP:TCP:HEADERS:"
|
||||||
fake_ip_header_len = len(fake_ip_header)
|
fake_ip_header_len = len(fake_ip_header)
|
||||||
@@ -37,47 +36,49 @@ def load_level_str(level:str):
|
|||||||
def log_print(module:str, *args, level:str = LogLevels.INFO, **kwargs):
|
def log_print(module:str, *args, level:str = LogLevels.INFO, **kwargs):
|
||||||
return print(f"{load_level_str(level)}[deep_pink4 bold]\\[nfproxy][/][medium_orchid3 bold]\\[{escape(module)}][/]", *args, **kwargs)
|
return print(f"{load_level_str(level)}[deep_pink4 bold]\\[nfproxy][/][medium_orchid3 bold]\\[{escape(module)}][/]", *args, **kwargs)
|
||||||
|
|
||||||
class ProxyFilterHandler(FileSystemEventHandler):
|
async def watch_filter_file(filter_file: str, reload_action):
|
||||||
|
abs_path = os.path.abspath(filter_file)
|
||||||
def __init__(self, reload_action):
|
directory = os.path.dirname(abs_path)
|
||||||
super().__init__()
|
# Immediately call the reload action on startup.
|
||||||
self.__reload_action = reload_action
|
if reload_action is not None:
|
||||||
|
reload_action()
|
||||||
def on_modified(self, event):
|
log_print("observer", f"Listening for changes on {escape(abs_path)}")
|
||||||
if self.__reload_action is not None:
|
|
||||||
self.__reload_action()
|
|
||||||
return super().on_modified(event)
|
|
||||||
|
|
||||||
def on_deleted(self, event):
|
|
||||||
if self.__reload_action is not None:
|
|
||||||
self.__reload_action()
|
|
||||||
return super().on_deleted(event)
|
|
||||||
|
|
||||||
|
|
||||||
def _forward_and_filter(filter_ctx:dict, source:socket.socket, destination:socket.socket, is_input:bool, is_ipv6:bool, is_tcp:bool, has_to_filter:bool = True):
|
|
||||||
"""Forward data from source to destination."""
|
|
||||||
try:
|
try:
|
||||||
def forward(data:bytes):
|
# Monitor the directory; set recursive=False since we only care about the specific file.
|
||||||
|
async for changes in awatch(directory, recursive=False):
|
||||||
|
# Process events and filter for our file.
|
||||||
|
for change in changes:
|
||||||
|
event, path = change
|
||||||
|
if os.path.abspath(path) == abs_path:
|
||||||
|
# Optionally, you can check the event type:
|
||||||
|
if event in {Change.modified, Change.deleted}:
|
||||||
|
if reload_action is not None:
|
||||||
|
reload_action()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
log_print("observer", "Watcher cancelled, stopping.")
|
||||||
|
|
||||||
|
async def forward_and_filter(filter_ctx: dict,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
writer: asyncio.StreamWriter,
|
||||||
|
is_input: bool,
|
||||||
|
is_ipv6: bool,
|
||||||
|
is_tcp: bool,
|
||||||
|
has_to_filter: bool = True):
|
||||||
|
"""Asynchronously forward data from reader to writer applying filters."""
|
||||||
try:
|
try:
|
||||||
destination.sendall(data)
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
def stop_filter_action(data:bytes):
|
|
||||||
nonlocal has_to_filter
|
|
||||||
has_to_filter = False
|
|
||||||
forward(data)
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data = source.recv(4096)
|
data = await reader.read(4096)
|
||||||
except OSError:
|
except Exception:
|
||||||
return
|
break
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
|
|
||||||
if has_to_filter:
|
if has_to_filter:
|
||||||
filter_ctx["__firegex_packet_info"] = {
|
filter_ctx["__firegex_packet_info"] = {
|
||||||
"data": data,
|
"data": data,
|
||||||
"l4_size": len(data),
|
"l4_size": len(data),
|
||||||
"raw_packet": fake_ip_header+data,
|
"raw_packet": fake_ip_header + data,
|
||||||
"is_input": is_input,
|
"is_input": is_input,
|
||||||
"is_ipv6": is_ipv6,
|
"is_ipv6": is_ipv6,
|
||||||
"is_tcp": is_tcp
|
"is_tcp": is_tcp
|
||||||
@@ -85,38 +86,45 @@ def _forward_and_filter(filter_ctx:dict, source:socket.socket, destination:socke
|
|||||||
try:
|
try:
|
||||||
exec("firegex.nfproxy.internals.handle_packet(globals())", filter_ctx, filter_ctx)
|
exec("firegex.nfproxy.internals.handle_packet(globals())", filter_ctx, filter_ctx)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_print("packet-handling", f"Error while executing filter: {escape(str(e))}, forwarding normally from now", level=LogLevels.ERROR)
|
log_print("packet-handling",
|
||||||
|
f"Error while executing filter: {escape(str(e))}, forwarding normally from now",
|
||||||
|
level=LogLevels.ERROR)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
stop_filter_action(data)
|
# Stop filtering and forward the packet as is.
|
||||||
|
has_to_filter = False
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
continue
|
continue
|
||||||
finally:
|
finally:
|
||||||
if "__firegex_packet_info" in filter_ctx.keys():
|
filter_ctx.pop("__firegex_packet_info", None)
|
||||||
del filter_ctx["__firegex_packet_info"]
|
|
||||||
|
|
||||||
result = filter_ctx.get("__firegex_pyfilter_result", None)
|
|
||||||
|
|
||||||
if result is not None:
|
|
||||||
del filter_ctx["__firegex_pyfilter_result"]
|
|
||||||
|
|
||||||
|
result = filter_ctx.pop("__firegex_pyfilter_result", None)
|
||||||
if result is None or not isinstance(result, dict):
|
if result is None or not isinstance(result, dict):
|
||||||
log_print("filter-parsing", "No result found", level=LogLevels.ERROR)
|
log_print("filter-parsing", "No result found", level=LogLevels.ERROR)
|
||||||
stop_filter_action(data)
|
has_to_filter = False
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
continue
|
continue
|
||||||
action = result.get("action", None)
|
|
||||||
|
|
||||||
|
action = result.get("action")
|
||||||
if action is None or not isinstance(action, int):
|
if action is None or not isinstance(action, int):
|
||||||
log_print("filter-parsing", "No action found", level=LogLevels.ERROR)
|
log_print("filter-parsing", "No action found", level=LogLevels.ERROR)
|
||||||
stop_filter_action(data)
|
has_to_filter = False
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if action == ACCEPT.value:
|
if action == ACCEPT.value:
|
||||||
forward(data)
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
filter_name = result.get("matched_by", None)
|
filter_name = result.get("matched_by")
|
||||||
if filter_name is None or not isinstance(filter_name, str):
|
if filter_name is None or not isinstance(filter_name, str):
|
||||||
log_print("filter-parsing", "No matched_by found", level=LogLevels.ERROR)
|
log_print("filter-parsing", "No matched_by found", level=LogLevels.ERROR)
|
||||||
stop_filter_action(data)
|
has_to_filter = False
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if action == DROP.value:
|
if action == DROP.value:
|
||||||
@@ -125,63 +133,104 @@ def _forward_and_filter(filter_ctx:dict, source:socket.socket, destination:socke
|
|||||||
|
|
||||||
if action == REJECT.value:
|
if action == REJECT.value:
|
||||||
log_print("reject-action", f"Rejecting connection caused by {escape(filter_name)} pyfilter")
|
log_print("reject-action", f"Rejecting connection caused by {escape(filter_name)} pyfilter")
|
||||||
source.close()
|
writer.close()
|
||||||
destination.close()
|
await writer.wait_closed()
|
||||||
return
|
return
|
||||||
|
|
||||||
elif action == UNSTABLE_MANGLE.value:
|
elif action == UNSTABLE_MANGLE.value:
|
||||||
mangled_packet = result.get("mangled_packet", None)
|
mangled_packet = result.get("mangled_packet")
|
||||||
if mangled_packet is None or not isinstance(mangled_packet, bytes):
|
if mangled_packet is None or not isinstance(mangled_packet, bytes):
|
||||||
log_print("filter-parsing", "No mangled_packet found", level=LogLevels.ERROR)
|
log_print("filter-parsing", "No mangled_packet found", level=LogLevels.ERROR)
|
||||||
stop_filter_action(data)
|
has_to_filter = False
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
continue
|
continue
|
||||||
log_print("mangle", f"Mangling packet caused by {escape(filter_name)} pyfilter")
|
log_print("mangle", f"Mangling packet caused by {escape(filter_name)} pyfilter")
|
||||||
log_print("mangle", "In the real execution mangling is not so stable as the simulation does, l4_data can be different by data", level=LogLevels.WARNING)
|
log_print("mangle",
|
||||||
forward(mangled_packet[fake_ip_header_len:])
|
"In the real execution mangling is not so stable as the simulation does, l4_data can be different by data",
|
||||||
|
level=LogLevels.WARNING)
|
||||||
|
writer.write(mangled_packet[fake_ip_header_len:])
|
||||||
|
await writer.drain()
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
log_print("filter-parsing", f"Invalid action {action} found", level=LogLevels.ERROR)
|
log_print("filter-parsing", f"Invalid action {action} found", level=LogLevels.ERROR)
|
||||||
stop_filter_action(data)
|
has_to_filter = False
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
continue
|
continue
|
||||||
forward(data)
|
else:
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
|
except Exception as exc:
|
||||||
|
log_print("forward_and_filter", f"Exception occurred: {escape(str(exc))}", level=LogLevels.ERROR)
|
||||||
finally:
|
finally:
|
||||||
source.close()
|
writer.close()
|
||||||
destination.close()
|
|
||||||
|
|
||||||
def _execute_proxy(filter_code:str, target_ip:str, target_port:int, local_ip:str = "127.0.0.1", local_port:int = 7474, ipv6:bool = False):
|
|
||||||
|
|
||||||
addr_family = socket.AF_INET6 if ipv6 else socket.AF_INET
|
|
||||||
server = socket.socket(addr_family, socket.SOCK_STREAM)
|
|
||||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
server.bind((local_ip, local_port))
|
|
||||||
server.listen(5)
|
|
||||||
|
|
||||||
log_print("listener", f"TCP proxy listening on {escape(local_ip)}:{local_port} and forwarding to -> {escape(target_ip)}:{target_port}")
|
|
||||||
try:
|
try:
|
||||||
while True:
|
await writer.wait_closed()
|
||||||
client_socket, addr = server.accept()
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def handle_connection(
|
||||||
|
reader: asyncio.StreamReader, writer: asyncio.StreamWriter, filter_code: str,
|
||||||
|
target_ip: str, target_port: int, ipv6: bool):
|
||||||
|
"""Handle a new incoming connection and create a remote connection."""
|
||||||
|
addr = writer.get_extra_info('peername')
|
||||||
log_print("listener", f"Accepted connection from {escape(addr[0])}:{addr[1]}")
|
log_print("listener", f"Accepted connection from {escape(addr[0])}:{addr[1]}")
|
||||||
try:
|
try:
|
||||||
remote_socket = socket.socket(addr_family, socket.SOCK_STREAM)
|
remote_reader, remote_writer = await asyncio.open_connection(
|
||||||
remote_socket.connect((target_ip, target_port))
|
target_ip, target_port,
|
||||||
|
family=socket.AF_INET6 if ipv6 else socket.AF_INET)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_print("listener", f"Could not connect to remote {escape(target_ip)}:{target_port}: {escape(str(e))}", level=LogLevels.ERROR)
|
log_print("listener",
|
||||||
client_socket.close()
|
f"Could not connect to remote {escape(target_ip)}:{target_port}: {escape(str(e))}",
|
||||||
continue
|
level=LogLevels.ERROR)
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filter_ctx = {}
|
filter_ctx = {}
|
||||||
exec(filter_code, filter_ctx, filter_ctx)
|
exec(filter_code, filter_ctx, filter_ctx)
|
||||||
# Start two threads to forward data in both directions.
|
|
||||||
threading.Thread(target=_forward_and_filter, args=(filter_ctx, client_socket, remote_socket, True, ipv6, True, True)).start()
|
|
||||||
threading.Thread(target=_forward_and_filter, args=(filter_ctx, remote_socket, client_socket, False, ipv6, True, True)).start()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_print("listener", f"Error while compiling filter context: {escape(str(e))}, forwarding normally", level=LogLevels.ERROR)
|
log_print("listener",
|
||||||
|
f"Error while compiling filter context: {escape(str(e))}, forwarding normally",
|
||||||
|
level=LogLevels.ERROR)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
threading.Thread(target=_forward_and_filter, args=(filter_ctx, client_socket, remote_socket, True, ipv6, True, False)).start()
|
filter_ctx = {}
|
||||||
threading.Thread(target=_forward_and_filter, args=(filter_ctx, remote_socket, client_socket, False, ipv6, True, False)).start()
|
# Create asynchronous tasks for bidirectional forwarding.
|
||||||
except KeyboardInterrupt:
|
task1 = asyncio.create_task(forward_and_filter(filter_ctx, reader, remote_writer, True, ipv6, True, True))
|
||||||
log_print("listener", "Proxy stopped by user")
|
task2 = asyncio.create_task(forward_and_filter(filter_ctx, remote_reader, writer, False, ipv6, True, True))
|
||||||
|
try:
|
||||||
|
await asyncio.gather(task1, task2)
|
||||||
|
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||||
|
task1.cancel()
|
||||||
|
task2.cancel()
|
||||||
|
await asyncio.gather(task1, task2)
|
||||||
finally:
|
finally:
|
||||||
server.close()
|
remote_writer.close()
|
||||||
|
await remote_writer.wait_closed()
|
||||||
|
|
||||||
|
async def _execute_proxy(
|
||||||
|
filter_code: str,
|
||||||
|
target_ip: str, target_port: int,
|
||||||
|
local_ip: str = "127.0.0.1", local_port: int = 7474,
|
||||||
|
ipv6: bool = False
|
||||||
|
):
|
||||||
|
"""Start the asyncio-based TCP proxy server."""
|
||||||
|
addr_family = socket.AF_INET6 if ipv6 else socket.AF_INET
|
||||||
|
server = await asyncio.start_server(
|
||||||
|
lambda r, w: handle_connection(r, w, filter_code, target_ip, target_port, ipv6),
|
||||||
|
local_ip, local_port, family=addr_family)
|
||||||
|
log_print("listener", f"TCP proxy listening on {escape(local_ip)}:{local_port} and forwarding to -> {escape(target_ip)}:{target_port}")
|
||||||
|
async with server:
|
||||||
|
await server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
|
def _proxy_asyncio_runner(filter_code: str, target_ip: str, target_port: int, local_ip: str, local_port: int, ipv6: bool):
|
||||||
|
try:
|
||||||
|
return asyncio.run(_execute_proxy(filter_code, target_ip, target_port, local_ip, local_port, ipv6))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
log_print("listener", "Proxy server stopped", level=LogLevels.WARNING)
|
||||||
|
|
||||||
def _build_filter(filepath:str, proto:str):
|
def _build_filter(filepath:str, proto:str):
|
||||||
if os.path.isfile(filepath) is False:
|
if os.path.isfile(filepath) is False:
|
||||||
@@ -219,7 +268,7 @@ def run_proxy_simulation(filter_file:str, proto:str, target_ip:str, target_port:
|
|||||||
def reload_proxy_proc():
|
def reload_proxy_proc():
|
||||||
nonlocal proxy_process
|
nonlocal proxy_process
|
||||||
if proxy_process is not None:
|
if proxy_process is not None:
|
||||||
proxy_process.terminate()
|
proxy_process.kill()
|
||||||
proxy_process.join()
|
proxy_process.join()
|
||||||
proxy_process = None
|
proxy_process = None
|
||||||
|
|
||||||
@@ -230,19 +279,16 @@ def run_proxy_simulation(filter_file:str, proto:str, target_ip:str, target_port:
|
|||||||
log_print("reloader", f"Failed to build filter {escape(filter_file)}!", level=LogLevels.ERROR)
|
log_print("reloader", f"Failed to build filter {escape(filter_file)}!", level=LogLevels.ERROR)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if compiled_filter is not None:
|
if compiled_filter is not None:
|
||||||
proxy_process = Process(target=_execute_proxy, args=(compiled_filter, target_ip, target_port, local_ip, local_port, ipv6))
|
proxy_process = Process(target=_proxy_asyncio_runner, args=(compiled_filter, target_ip, target_port, local_ip, local_port, ipv6))
|
||||||
proxy_process.start()
|
proxy_process.start()
|
||||||
|
|
||||||
|
|
||||||
observer = Observer()
|
|
||||||
handler = ProxyFilterHandler(reload_proxy_proc)
|
|
||||||
observer.schedule(handler, os.path.abspath(filter_file), recursive=False)
|
|
||||||
observer.start()
|
|
||||||
reload_proxy_proc()
|
|
||||||
log_print("observer", f"Listening for changes on {escape(os.path.abspath(filter_file))}")
|
|
||||||
try:
|
try:
|
||||||
observer.join()
|
asyncio.run(watch_filter_file(filter_file, reload_proxy_proc))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
observer.stop()
|
pass
|
||||||
|
finally:
|
||||||
|
if proxy_process is not None:
|
||||||
|
proxy_process.kill()
|
||||||
|
proxy_process.join()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
typer==0.15.1
|
typer==0.15.1
|
||||||
pydantic>=2
|
pydantic>=2
|
||||||
typing-extensions>=4.7.1
|
typing-extensions>=4.7.1
|
||||||
watchdog>=6.0.0
|
watchfiles
|
||||||
fgex
|
fgex
|
||||||
pyllhttp
|
pyllhttp
|
||||||
|
|||||||
Reference in New Issue
Block a user