crash and unexpected behaviours fix

This commit is contained in:
Domingo Dirutigliano
2025-03-03 23:55:24 +01:00
parent 832c6e1530
commit fde3ee57a5
7 changed files with 47 additions and 67 deletions

View File

@@ -73,7 +73,6 @@ struct pyfilter_ctx {
} }
~pyfilter_ctx(){ ~pyfilter_ctx(){
cerr << "[info] [pyfilter_ctx] Cleaning pyfilter_ctx" << endl;
Py_DECREF(glob); Py_DECREF(glob);
Py_DECREF(py_handle_packet); Py_DECREF(py_handle_packet);
PyGC_Collect(); PyGC_Collect();
@@ -120,14 +119,8 @@ struct pyfilter_ctx {
// Set packet info to the global context // Set packet info to the global context
set_item_to_glob("__firegex_packet_info", packet_info); set_item_to_glob("__firegex_packet_info", packet_info);
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Calling python with a data of " << data.size() << endl;
#endif
PyObject * result = PyEval_EvalCode(py_handle_packet, glob, glob); PyObject * result = PyEval_EvalCode(py_handle_packet, glob, glob);
PyGC_Collect(); PyGC_Collect();
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] End of python call" << endl;
#endif
del_item_from_glob("__firegex_packet_info"); del_item_from_glob("__firegex_packet_info");
if (PyErr_Occurred()){ if (PyErr_Occurred()){

View File

@@ -6,6 +6,7 @@ import traceback
from fastapi import HTTPException from fastapi import HTTPException
import time import time
from utils import run_func from utils import run_func
from utils import DEBUG
nft = FiregexTables() nft = FiregexTables()
@@ -62,20 +63,25 @@ class FiregexInterceptor:
async def _stream_handler(self): async def _stream_handler(self):
while True: while True:
try: try:
line = (await self.process.stdout.readuntil()).decode(errors="ignore") out_data = (await self.process.stdout.read(1024*10)).decode(errors="ignore")
print(line, end="") if DEBUG:
print(out_data, end="")
except asyncio.exceptions.LimitOverrunError:
self.outstrem_buffer = ""
continue
except Exception as e: except Exception as e:
self.ack_arrived = False self.ack_arrived = False
self.ack_status = False self.ack_status = False
self.ack_fail_what = "Can't read from nfq client" self.ack_fail_what = "Can't read from nfq client"
self.ack_lock.release() self.ack_lock.release()
await self.stop() await self.stop()
traceback.print_exc() # Python can't print it alone? nope it's python... wasted 1 day :)
raise HTTPException(status_code=500, detail="Can't read from nfq client") from e raise HTTPException(status_code=500, detail="Can't read from nfq client") from e
self.outstrem_buffer+=line self.outstrem_buffer+=out_data
if len(self.outstrem_buffer) > OUTSTREAM_BUFFER_SIZE: if len(self.outstrem_buffer) > OUTSTREAM_BUFFER_SIZE:
self.outstrem_buffer = self.outstrem_buffer[-OUTSTREAM_BUFFER_SIZE:]+"\n" self.outstrem_buffer = self.outstrem_buffer[-OUTSTREAM_BUFFER_SIZE:]+"\n"
if self.outstrem_function: if self.outstrem_function:
await run_func(self.outstrem_function, self.srv.id, line) await run_func(self.outstrem_function, self.srv.id, out_data)
async def _start_binary(self): async def _start_binary(self):
proxy_binary_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../cpproxy")) proxy_binary_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../cpproxy"))

View File

@@ -68,12 +68,9 @@ def get_filter_names(code:str, proto:str) -> list[str]:
def handle_packet(glob: dict) -> None: def handle_packet(glob: dict) -> None:
internal_data = DataStreamCtx(glob) internal_data = DataStreamCtx(glob)
print("I'm here", flush=True)
cache_call = {} # Cache of the data handler calls
pkt_info = RawPacket._fetch_packet(internal_data) cache_call = {} # Cache of the data handler calls
internal_data.current_pkt = pkt_info cache_call[RawPacket] = internal_data.current_pkt
cache_call[RawPacket] = pkt_info
final_result = Action.ACCEPT final_result = Action.ACCEPT
result = PacketHandlerResult(glob) result = PacketHandlerResult(glob)
@@ -108,8 +105,10 @@ def handle_packet(glob: dict) -> None:
result.matched_by = filter.name result.matched_by = filter.name
return result.set_result() return result.set_result()
final_params.append(cache_call[data_type]) final_params.append(cache_call[data_type])
if skip_call: if skip_call:
continue continue
res = context_call(glob, filter.func, *final_params) res = context_call(glob, filter.func, *final_params)
if res is None: if res is None:
@@ -117,7 +116,7 @@ def handle_packet(glob: dict) -> None:
if not isinstance(res, Action): if not isinstance(res, Action):
raise Exception(f"Invalid return type {type(res)} for function {filter.name}") raise Exception(f"Invalid return type {type(res)} for function {filter.name}")
if res == Action.MANGLE: if res == Action.MANGLE:
mangled_packet = pkt_info.raw_packet mangled_packet = internal_data.current_pkt.raw_packet
if res != Action.ACCEPT: if res != Action.ACCEPT:
func_name = filter.name func_name = filter.name
final_result = res final_result = res
@@ -131,7 +130,7 @@ def handle_packet(glob: dict) -> None:
def compile(glob:dict) -> None: def compile(glob:dict) -> None:
internal_data = DataStreamCtx(glob) internal_data = DataStreamCtx(glob, init_pkt=False)
glob["print"] = functools.partial(print, flush = True) glob["print"] = functools.partial(print, flush = True)

View File

@@ -75,8 +75,7 @@ class RawPacket:
self.__l4_size = len(v)-self.raw_packet_header_len self.__l4_size = len(v)-self.raw_packet_header_len
@classmethod @classmethod
def _fetch_packet(cls, internal_data): def _fetch_packet(cls, internal_data:"DataStreamCtx"):
from firegex.nfproxy.internals.data import DataStreamCtx
if not isinstance(internal_data, DataStreamCtx): if not isinstance(internal_data, DataStreamCtx):
if isinstance(internal_data, dict): if isinstance(internal_data, dict):
internal_data = DataStreamCtx(internal_data) internal_data = DataStreamCtx(internal_data)
@@ -93,11 +92,12 @@ class RawPacket:
class DataStreamCtx: class DataStreamCtx:
def __init__(self, glob: dict): def __init__(self, glob: dict, init_pkt: bool = True):
if "__firegex_pyfilter_ctx" not in glob.keys(): if "__firegex_pyfilter_ctx" not in glob.keys():
glob["__firegex_pyfilter_ctx"] = {} glob["__firegex_pyfilter_ctx"] = {}
self.__data = glob["__firegex_pyfilter_ctx"] self.__data = glob["__firegex_pyfilter_ctx"]
self.filter_glob = glob self.filter_glob = glob
self.current_pkt = RawPacket._fetch_packet(self) if init_pkt else None
@property @property
def filter_call_info(self) -> list[FilterHandler]: def filter_call_info(self) -> list[FilterHandler]:
@@ -128,14 +128,6 @@ class DataStreamCtx:
@full_stream_action.setter @full_stream_action.setter
def full_stream_action(self, v: FullStreamAction): def full_stream_action(self, v: FullStreamAction):
self.__data["full_stream_action"] = v self.__data["full_stream_action"] = v
@property
def current_pkt(self) -> RawPacket:
return self.__data.get("current_pkt", None)
@current_pkt.setter
def current_pkt(self, v: RawPacket):
self.__data["current_pkt"] = v
@property @property
def data_handler_context(self) -> dict: def data_handler_context(self) -> dict:
@@ -146,16 +138,4 @@ class DataStreamCtx:
@data_handler_context.setter @data_handler_context.setter
def data_handler_context(self, v: dict): def data_handler_context(self, v: dict):
self.__data["data_handler_context"] = v self.__data["data_handler_context"] = v
@property
def save_http_data_in_streams(self) -> bool:
if "save_http_data_in_streams" not in self.__data.keys():
self.__data["save_http_data_in_streams"] = False
return self.__data.get("save_http_data_in_streams")
@save_http_data_in_streams.setter
def save_http_data_in_streams(self, v: bool):
self.__data["save_http_data_in_streams"] = v

View File

@@ -12,4 +12,4 @@ class RejectConnection(Exception):
"raise this exception if you want to reject the connection" "raise this exception if you want to reject the connection"
class StreamFullReject(Exception): class StreamFullReject(Exception):
"raise this exception if you want to reject the connection due to full stream" "raise this exception if you want to reject the connection due to full stream"

View File

@@ -54,8 +54,8 @@ class InternalCallbackHandler():
self.headers_complete = True self.headers_complete = True
self.headers = self._header_fields self.headers = self._header_fields
self._header_fields = {} self._header_fields = {}
self._current_header_field = None self._current_header_field = b""
self._current_header_value = None self._current_header_value = b""
def on_body(self, body: bytes): def on_body(self, body: bytes):
if self._save_body: if self._save_body:
@@ -98,14 +98,14 @@ class InternalCallbackHandler():
class InternalHttpRequest(InternalCallbackHandler, pyllhttp.Request): class InternalHttpRequest(InternalCallbackHandler, pyllhttp.Request):
def __init__(self): def __init__(self):
super(pyllhttp.Request, self).__init__()
super(InternalCallbackHandler, self).__init__() super(InternalCallbackHandler, self).__init__()
super(pyllhttp.Request, self).__init__()
class InternalHttpResponse(InternalCallbackHandler, pyllhttp.Response): class InternalHttpResponse(InternalCallbackHandler, pyllhttp.Response):
def __init__(self): def __init__(self):
super(pyllhttp.Response, self).__init__()
super(InternalCallbackHandler, self).__init__() super(InternalCallbackHandler, self).__init__()
super(pyllhttp.Response, self).__init__()
class InternalBasicHttpMetaClass: class InternalBasicHttpMetaClass:
def __init__(self): def __init__(self):
@@ -162,9 +162,12 @@ class InternalBasicHttpMetaClass:
def method(self) -> str|None: def method(self) -> str|None:
return self._parser.method_parsed return self._parser.method_parsed
def _packet_to_stream(self, internal_data: DataStreamCtx):
return self.should_upgrade and self._parser._save_body
def _fetch_current_packet(self, internal_data: DataStreamCtx): def _fetch_current_packet(self, internal_data: DataStreamCtx):
# TODO: if an error is triggered should I reject the connection? if self._packet_to_stream(internal_data): # This is a websocket upgrade!
if internal_data.save_http_data_in_streams: # This is a websocket upgrade! self._parser.total_size += len(internal_data.current_pkt.data)
self.stream += internal_data.current_pkt.data self.stream += internal_data.current_pkt.data
else: else:
try: try:
@@ -173,20 +176,21 @@ class InternalBasicHttpMetaClass:
self._parser.on_message_complete() self._parser.on_message_complete()
except Exception as e: except Exception as e:
self.raised_error = True self.raised_error = True
print(f"Error parsing HTTP packet: {e} {internal_data.current_pkt}", self, flush=True)
raise e raise e
#It's called the first time if the headers are complete, and second time with body complete #It's called the first time if the headers are complete, and second time with body complete
def _callable_checks(self, internal_data: DataStreamCtx): def _after_fetch_callable_checks(self, internal_data: DataStreamCtx):
if self._parser.headers_complete and not self._headers_were_set: if self._parser.headers_complete and not self._headers_were_set:
self._headers_were_set = True self._headers_were_set = True
return True return True
return self._parser.message_complete or internal_data.save_http_data_in_streams return self._parser.message_complete or self.should_upgrade
def _before_fetch_callable_checks(self, internal_data: DataStreamCtx): def _before_fetch_callable_checks(self, internal_data: DataStreamCtx):
return True return True
def _trigger_remove_data(self, internal_data: DataStreamCtx): def _trigger_remove_data(self, internal_data: DataStreamCtx):
return self.message_complete return self.message_complete and not self.should_upgrade
@classmethod @classmethod
def _fetch_packet(cls, internal_data: DataStreamCtx): def _fetch_packet(cls, internal_data: DataStreamCtx):
@@ -216,12 +220,9 @@ class InternalBasicHttpMetaClass:
datahandler._fetch_current_packet(internal_data) datahandler._fetch_current_packet(internal_data)
if not datahandler._callable_checks(internal_data): if not datahandler._after_fetch_callable_checks(internal_data):
raise NotReadyToRun() raise NotReadyToRun()
if datahandler.should_upgrade:
internal_data.save_http_data_in_streams = True
if datahandler._trigger_remove_data(internal_data): if datahandler._trigger_remove_data(internal_data):
if internal_data.data_handler_context.get(cls): if internal_data.data_handler_context.get(cls):
del internal_data.data_handler_context[cls] del internal_data.data_handler_context[cls]
@@ -266,7 +267,10 @@ class HttpRequestHeader(HttpRequest):
super().__init__() super().__init__()
self._parser._save_body = False self._parser._save_body = False
def _callable_checks(self, internal_data: DataStreamCtx): def _before_fetch_callable_checks(self, internal_data: DataStreamCtx):
return internal_data.current_pkt.is_input and not self._headers_were_set
def _after_fetch_callable_checks(self, internal_data: DataStreamCtx):
if self._parser.headers_complete and not self._headers_were_set: if self._parser.headers_complete and not self._headers_were_set:
self._headers_were_set = True self._headers_were_set = True
return True return True
@@ -277,9 +281,11 @@ class HttpResponseHeader(HttpResponse):
super().__init__() super().__init__()
self._parser._save_body = False self._parser._save_body = False
def _callable_checks(self, internal_data: DataStreamCtx): def _before_fetch_callable_checks(self, internal_data: DataStreamCtx):
return not internal_data.current_pkt.is_input and not self._headers_were_set
def _after_fetch_callable_checks(self, internal_data: DataStreamCtx):
if self._parser.headers_complete and not self._headers_were_set: if self._parser.headers_complete and not self._headers_were_set:
self._headers_were_set = True self._headers_were_set = True
return True return True
return False return False

View File

@@ -7,13 +7,9 @@ class InternalTCPStream:
data: bytes, data: bytes,
is_ipv6: bool, is_ipv6: bool,
): ):
self.__data = bytes(data) self.data = bytes(data)
self.__is_ipv6 = bool(is_ipv6) self.__is_ipv6 = bool(is_ipv6)
self.__total_stream_size = len(data) self.__total_stream_size = len(data)
@property
def data(self) -> bool:
return self.__data
@property @property
def is_ipv6(self) -> bool: def is_ipv6(self) -> bool:
@@ -24,7 +20,7 @@ class InternalTCPStream:
return self.__total_stream_size return self.__total_stream_size
def _push_new_data(self, data: bytes): def _push_new_data(self, data: bytes):
self.__data += data self.data += data
self.__total_stream_size += len(data) self.__total_stream_size += len(data)
@classmethod @classmethod