changed datahandler max size managment

This commit is contained in:
Domingo Dirutigliano
2025-03-03 21:15:49 +01:00
parent 072745cc06
commit 832c6e1530
7 changed files with 135 additions and 291 deletions

View File

@@ -1,5 +1,5 @@
import functools import functools
from firegex.nfproxy.models import RawPacket, TCPInputStream, TCPOutputStream, TCPClientStream, TCPServerStream, TCPStreams from firegex.nfproxy.models import RawPacket, TCPInputStream, TCPOutputStream, TCPClientStream, TCPServerStream
from firegex.nfproxy.internals.models import Action, FullStreamAction from firegex.nfproxy.internals.models import Action, FullStreamAction
ACCEPT = Action.ACCEPT ACCEPT = Action.ACCEPT
@@ -35,5 +35,5 @@ def clear_pyfilter_registry():
__all__ = [ __all__ = [
"ACCEPT", "DROP", "REJECT", "UNSTABLE_MANGLE" "ACCEPT", "DROP", "REJECT", "UNSTABLE_MANGLE"
"Action", "FullStreamAction", "pyfilter", "Action", "FullStreamAction", "pyfilter",
"RawPacket", "TCPInputStream", "TCPOutputStream", "TCPClientStream", "TCPServerStream", "TCPStreams" "RawPacket", "TCPInputStream", "TCPOutputStream", "TCPClientStream", "TCPServerStream"
] ]

View File

@@ -3,7 +3,7 @@ from firegex.nfproxy.internals.models import Action, FullStreamAction
from firegex.nfproxy.internals.models import FilterHandler, PacketHandlerResult from firegex.nfproxy.internals.models import FilterHandler, PacketHandlerResult
import functools import functools
from firegex.nfproxy.internals.data import DataStreamCtx from firegex.nfproxy.internals.data import DataStreamCtx
from firegex.nfproxy.internals.exceptions import NotReadyToRun from firegex.nfproxy.internals.exceptions import NotReadyToRun, StreamFullReject, DropPacket, RejectConnection, StreamFullDrop
from firegex.nfproxy.internals.data import RawPacket from firegex.nfproxy.internals.data import RawPacket
def context_call(glob, func, *args, **kargs): def context_call(glob, func, *args, **kargs):
@@ -76,32 +76,8 @@ def handle_packet(glob: dict) -> None:
cache_call[RawPacket] = pkt_info cache_call[RawPacket] = pkt_info
final_result = Action.ACCEPT final_result = Action.ACCEPT
data_size = len(pkt_info.data)
result = PacketHandlerResult(glob) result = PacketHandlerResult(glob)
if internal_data.stream_size+data_size > internal_data.stream_max_size:
match internal_data.full_stream_action:
case FullStreamAction.FLUSH:
internal_data.stream = []
internal_data.stream_size = 0
for func in internal_data.flush_action_set:
func()
case FullStreamAction.ACCEPT:
result.action = Action.ACCEPT
return result.set_result()
case FullStreamAction.REJECT:
result.action = Action.REJECT
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
case FullStreamAction.REJECT:
result.action = Action.DROP
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
internal_data.stream.append(pkt_info)
internal_data.stream_size += data_size
func_name = None func_name = None
mangled_packet = None mangled_packet = None
for filter in internal_data.filter_call_info: for filter in internal_data.filter_call_info:
@@ -115,6 +91,22 @@ def handle_packet(glob: dict) -> None:
cache_call[data_type] = None cache_call[data_type] = None
skip_call = True skip_call = True
break break
except StreamFullDrop:
result.action = Action.DROP
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
except StreamFullReject:
result.action = Action.REJECT
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
except DropPacket:
result.action = Action.DROP
result.matched_by = filter.name
return result.set_result()
except RejectConnection:
result.action = Action.REJECT
result.matched_by = filter.name
return result.set_result()
final_params.append(cache_call[data_type]) final_params.append(cache_call[data_type])
if skip_call: if skip_call:
continue continue

View File

@@ -1,5 +1,5 @@
from firegex.nfproxy.internals.models import FilterHandler from firegex.nfproxy.internals.models import FilterHandler
from typing import Callable from firegex.nfproxy.internals.models import FullStreamAction
class RawPacket: class RawPacket:
""" """
@@ -109,26 +109,6 @@ class DataStreamCtx:
def filter_call_info(self, v: list[FilterHandler]): def filter_call_info(self, v: list[FilterHandler]):
self.__data["filter_call_info"] = v self.__data["filter_call_info"] = v
@property
def stream(self) -> list[RawPacket]:
if "stream" not in self.__data.keys():
self.__data["stream"] = []
return self.__data.get("stream")
@stream.setter
def stream(self, v: list[RawPacket]):
self.__data["stream"] = v
@property
def stream_size(self) -> int:
if "stream_size" not in self.__data.keys():
self.__data["stream_size"] = 0
return self.__data.get("stream_size")
@stream_size.setter
def stream_size(self, v: int):
self.__data["stream_size"] = v
@property @property
def stream_max_size(self) -> int: def stream_max_size(self) -> int:
if "stream_max_size" not in self.__data.keys(): if "stream_max_size" not in self.__data.keys():
@@ -140,13 +120,13 @@ class DataStreamCtx:
self.__data["stream_max_size"] = v self.__data["stream_max_size"] = v
@property @property
def full_stream_action(self) -> str: def full_stream_action(self) -> FullStreamAction:
if "full_stream_action" not in self.__data.keys(): if "full_stream_action" not in self.__data.keys():
self.__data["full_stream_action"] = "flush" self.__data["full_stream_action"] = "flush"
return self.__data.get("full_stream_action") return self.__data.get("full_stream_action")
@full_stream_action.setter @full_stream_action.setter
def full_stream_action(self, v: str): def full_stream_action(self, v: FullStreamAction):
self.__data["full_stream_action"] = v self.__data["full_stream_action"] = v
@property @property
@@ -158,14 +138,14 @@ class DataStreamCtx:
self.__data["current_pkt"] = v self.__data["current_pkt"] = v
@property @property
def http_data_objects(self) -> dict: def data_handler_context(self) -> dict:
if "http_data_objects" not in self.__data.keys(): if "data_handler_context" not in self.__data.keys():
self.__data["http_data_objects"] = {} self.__data["data_handler_context"] = {}
return self.__data.get("http_data_objects") return self.__data.get("data_handler_context")
@http_data_objects.setter @data_handler_context.setter
def http_data_objects(self, v: dict): def data_handler_context(self, v: dict):
self.__data["http_data_objects"] = v self.__data["data_handler_context"] = v
@property @property
def save_http_data_in_streams(self) -> bool: def save_http_data_in_streams(self) -> bool:
@@ -177,14 +157,5 @@ class DataStreamCtx:
def save_http_data_in_streams(self, v: bool): def save_http_data_in_streams(self, v: bool):
self.__data["save_http_data_in_streams"] = v self.__data["save_http_data_in_streams"] = v
@property
def flush_action_set(self) -> set[Callable]:
if "flush_action_set" not in self.__data.keys():
self.__data["flush_action_set"] = set()
return self.__data.get("flush_action_set")
@flush_action_set.setter
def flush_action_set(self, v: set[Callable]):
self.__data["flush_action_set"] = v

View File

@@ -1,3 +1,15 @@
class NotReadyToRun(Exception): class NotReadyToRun(Exception):
"raise this exception if the stream state is not ready to parse this object, the call will be skipped" "raise this exception if the stream state is not ready to parse this object, the call will be skipped"
class DropPacket(Exception):
"raise this exception if you want to drop the packet"
class StreamFullDrop(Exception):
"raise this exception if you want to drop the packet due to full stream"
class RejectConnection(Exception):
"raise this exception if you want to reject the connection"
class StreamFullReject(Exception):
"raise this exception if you want to reject the connection due to full stream"

View File

@@ -1,4 +1,4 @@
from firegex.nfproxy.models.tcp import TCPInputStream, TCPOutputStream, TCPClientStream, TCPServerStream, TCPStreams from firegex.nfproxy.models.tcp import TCPInputStream, TCPOutputStream, TCPClientStream, TCPServerStream
from firegex.nfproxy.models.http import HttpRequest, HttpResponse, HttpRequestHeader, HttpResponseHeader from firegex.nfproxy.models.http import HttpRequest, HttpResponse, HttpRequestHeader, HttpResponseHeader
from firegex.nfproxy.internals.data import RawPacket from firegex.nfproxy.internals.data import RawPacket
@@ -7,13 +7,11 @@ type_annotations_associations = {
RawPacket: RawPacket._fetch_packet, RawPacket: RawPacket._fetch_packet,
TCPInputStream: TCPInputStream._fetch_packet, TCPInputStream: TCPInputStream._fetch_packet,
TCPOutputStream: TCPOutputStream._fetch_packet, TCPOutputStream: TCPOutputStream._fetch_packet,
TCPStreams: TCPStreams._fetch_packet,
}, },
"http": { "http": {
RawPacket: RawPacket._fetch_packet, RawPacket: RawPacket._fetch_packet,
TCPInputStream: TCPInputStream._fetch_packet, TCPInputStream: TCPInputStream._fetch_packet,
TCPOutputStream: TCPOutputStream._fetch_packet, TCPOutputStream: TCPOutputStream._fetch_packet,
TCPStreams: TCPStreams._fetch_packet,
HttpRequest: HttpRequest._fetch_packet, HttpRequest: HttpRequest._fetch_packet,
HttpResponse: HttpResponse._fetch_packet, HttpResponse: HttpResponse._fetch_packet,
HttpRequestHeader: HttpRequestHeader._fetch_packet, HttpRequestHeader: HttpRequestHeader._fetch_packet,
@@ -23,6 +21,6 @@ type_annotations_associations = {
__all__ = [ __all__ = [
"RawPacket", "RawPacket",
"TCPInputStream", "TCPOutputStream", "TCPClientStream", "TCPServerStream", "TCPStreams", "TCPInputStream", "TCPOutputStream", "TCPClientStream", "TCPServerStream",
"HttpRequest", "HttpResponse", "HttpRequestHeader", "HttpResponseHeader", "HttpRequest", "HttpResponse", "HttpRequestHeader", "HttpResponseHeader",
] ]

View File

@@ -1,6 +1,8 @@
import pyllhttp import pyllhttp
from firegex.nfproxy.internals.exceptions import NotReadyToRun from firegex.nfproxy.internals.exceptions import NotReadyToRun
from firegex.nfproxy.internals.data import DataStreamCtx from firegex.nfproxy.internals.data import DataStreamCtx
from firegex.nfproxy.internals.exceptions import StreamFullDrop, StreamFullReject
from firegex.nfproxy.internals.models import FullStreamAction
class InternalCallbackHandler(): class InternalCallbackHandler():
@@ -15,14 +17,16 @@ class InternalCallbackHandler():
message_complete: bool = False message_complete: bool = False
status: str|None = None status: str|None = None
_status_buffer: bytes = b"" _status_buffer: bytes = b""
current_header_field = None _current_header_field = b""
current_header_value = None _current_header_value = b""
_save_body = True _save_body = True
total_size = 0
def on_message_begin(self): def on_message_begin(self):
self.has_begun = True self.has_begun = True
def on_url(self, url): def on_url(self, url):
self.total_size += len(url)
self._url_buffer += url self._url_buffer += url
def on_url_complete(self): def on_url_complete(self):
@@ -30,35 +34,32 @@ class InternalCallbackHandler():
self._url_buffer = None self._url_buffer = None
def on_header_field(self, field): def on_header_field(self, field):
if self.current_header_field is None: self.total_size += len(field)
self.current_header_field = bytearray(field) self._current_header_field += field
else:
self.current_header_field += field
def on_header_field_complete(self): def on_header_field_complete(self):
self.current_header_field = self.current_header_field self._current_header_field = self._current_header_field
def on_header_value(self, value): def on_header_value(self, value):
if self.current_header_value is None: self.total_size += len(value)
self.current_header_value = bytearray(value) self._current_header_value += value
else:
self.current_header_value += value
def on_header_value_complete(self): def on_header_value_complete(self):
if self.current_header_value is not None and self.current_header_field is not None: if self._current_header_value is not None and self._current_header_field is not None:
self._header_fields[self.current_header_field.decode(errors="ignore")] = self.current_header_value.decode(errors="ignore") self._header_fields[self._current_header_field.decode(errors="ignore")] = self._current_header_value.decode(errors="ignore")
self.current_header_field = None self._current_header_field = b""
self.current_header_value = None self._current_header_value = b""
def on_headers_complete(self): def on_headers_complete(self):
self.headers_complete = True self.headers_complete = True
self.headers = self._header_fields self.headers = self._header_fields
self._header_fields = {} self._header_fields = {}
self.current_header_field = None self._current_header_field = None
self.current_header_value = None self._current_header_value = None
def on_body(self, body: bytes): def on_body(self, body: bytes):
if self._save_body: if self._save_body:
self.total_size += len(body)
self._body_buffer += body self._body_buffer += body
def on_message_complete(self): def on_message_complete(self):
@@ -67,6 +68,7 @@ class InternalCallbackHandler():
self.message_complete = True self.message_complete = True
def on_status(self, status: bytes): def on_status(self, status: bytes):
self.total_size += len(status)
self._status_buffer += status self._status_buffer += status
def on_status_complete(self): def on_status_complete(self):
@@ -112,6 +114,10 @@ class InternalBasicHttpMetaClass:
self.stream = b"" self.stream = b""
self.raised_error = False self.raised_error = False
@property
def total_size(self) -> int:
return self._parser.total_size
@property @property
def url(self) -> str|None: def url(self) -> str|None:
return self._parser.url return self._parser.url
@@ -187,14 +193,29 @@ class InternalBasicHttpMetaClass:
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False: if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False:
raise NotReadyToRun() raise NotReadyToRun()
datahandler:InternalBasicHttpMetaClass = internal_data.http_data_objects.get(cls, None) datahandler:InternalBasicHttpMetaClass = internal_data.data_handler_context.get(cls, None)
if datahandler is None or datahandler.raised_error: if datahandler is None or datahandler.raised_error:
datahandler = cls() datahandler = cls()
internal_data.http_data_objects[cls] = datahandler internal_data.data_handler_context[cls] = datahandler
if not datahandler._before_fetch_callable_checks(internal_data): if not datahandler._before_fetch_callable_checks(internal_data):
raise NotReadyToRun() raise NotReadyToRun()
# Memory size managment
if datahandler.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size:
match internal_data.full_stream_action:
case FullStreamAction.FLUSH:
datahandler = cls()
internal_data.data_handler_context[cls] = datahandler
case FullStreamAction.REJECT:
raise StreamFullReject()
case FullStreamAction.DROP:
raise StreamFullDrop()
case FullStreamAction.ACCEPT:
raise NotReadyToRun()
datahandler._fetch_current_packet(internal_data) datahandler._fetch_current_packet(internal_data)
if not datahandler._callable_checks(internal_data): if not datahandler._callable_checks(internal_data):
raise NotReadyToRun() raise NotReadyToRun()
@@ -202,8 +223,8 @@ class InternalBasicHttpMetaClass:
internal_data.save_http_data_in_streams = True internal_data.save_http_data_in_streams = True
if datahandler._trigger_remove_data(internal_data): if datahandler._trigger_remove_data(internal_data):
if internal_data.http_data_objects.get(cls): if internal_data.data_handler_context.get(cls):
del internal_data.http_data_objects[cls] del internal_data.data_handler_context[cls]
return datahandler return datahandler
@@ -262,124 +283,3 @@ class HttpResponseHeader(HttpResponse):
return True return True
return False return False
"""
#TODO include this?
import codecs
# Null bytes; no need to recreate these on each call to guess_json_utf
_null = "\x00".encode("ascii") # encoding to ASCII for Python 3
_null2 = _null * 2
_null3 = _null * 3
def guess_json_utf(data):
""
:rtype: str
""
# JSON always starts with two ASCII characters, so detection is as
# easy as counting the nulls and from their location and count
# determine the encoding. Also detect a BOM, if present.
sample = data[:4]
if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE):
return "utf-32" # BOM included
if sample[:3] == codecs.BOM_UTF8:
return "utf-8-sig" # BOM included, MS style (discouraged)
if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
return "utf-16" # BOM included
nullcount = sample.count(_null)
if nullcount == 0:
return "utf-8"
if nullcount == 2:
if sample[::2] == _null2: # 1st and 3rd are null
return "utf-16-be"
if sample[1::2] == _null2: # 2nd and 4th are null
return "utf-16-le"
# Did not detect 2 valid UTF-16 ascii-range characters
if nullcount == 3:
if sample[:3] == _null3:
return "utf-32-be"
if sample[1:] == _null3:
return "utf-32-le"
# Did not detect a valid UTF-32 ascii-range character
return None
from http_parser.pyparser import HttpParser
import json
from urllib.parse import parse_qsl
from dataclasses import dataclass
@dataclass
class HttpMessage():
fragment: str
headers: dict
method: str
parameters: dict
path: str
query_string: str
raw_body: bytes
status_code: int
url: str
version: str
class HttpMessageParser(HttpParser):
def __init__(self, data:bytes, decompress_body=True):
super().__init__(decompress = decompress_body)
self.execute(data, len(data))
self._parameters = {}
try:
self._parse_parameters()
except Exception as e:
print("Error in parameters parsing:", data)
print("Exception:", str(e))
def get_raw_body(self):
return b"\r\n".join(self._body)
def _parse_query_string(self, raw_string):
parameters = parse_qsl(raw_string)
for key,value in parameters:
try:
key = key.decode()
value = value.decode()
except:
pass
if self._parameters.get(key):
if isinstance(self._parameters[key], list):
self._parameters[key].append(value)
else:
self._parameters[key] = [self._parameters[key], value]
else:
self._parameters[key] = value
def _parse_parameters(self):
if self._method == "POST":
body = self.get_raw_body()
if len(body) == 0:
return
content_type = self.get_headers().get("Content-Type")
if not content_type or "x-www-form-urlencoded" in content_type:
try:
self._parse_query_string(body.decode())
except:
pass
elif "json" in content_type:
self._parameters = json.loads(body)
elif self._method == "GET":
self._parse_query_string(self._query_string)
def get_parameters(self):
""returns parameters parsed from query string or body""
return self._parameters
def get_version(self):
if self._version:
return ".".join([str(x) for x in self._version])
return None
def to_message(self):
return HttpMessage(self._fragment, self._headers, self._method,
self._parameters, self._path, self._query_string,
self.get_raw_body(), self._status_code,
self._url, self.get_version()
)
"""

View File

@@ -1,61 +1,15 @@
from firegex.nfproxy.internals.data import DataStreamCtx from firegex.nfproxy.internals.data import DataStreamCtx
from firegex.nfproxy.internals.exceptions import NotReadyToRun from firegex.nfproxy.internals.exceptions import NotReadyToRun, StreamFullDrop, StreamFullReject
from firegex.nfproxy.internals.models import FullStreamAction
class TCPStreams: class InternalTCPStream:
"""
This datamodel will assemble the TCP streams from the input and output data.
The function that use this data model will be handled when:
- The packet is TCP
- At least 1 packet has been sent
"""
def __init__(self,
input_data: bytes,
output_data: bytes,
is_ipv6: bool,
):
self.__input_data = bytes(input_data)
self.__output_data = bytes(output_data)
self.__is_ipv6 = bool(is_ipv6)
@property
def input_data(self) -> bytes:
return self.__input_data
@property
def output_data(self) -> bytes:
return self.__output_data
@property
def is_ipv6(self) -> bool:
return self.__is_ipv6
@classmethod
def _fetch_packet(cls, internal_data:DataStreamCtx):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False:
raise NotReadyToRun()
return cls(
input_data=b"".join([ele.data for ele in internal_data.stream if ele.is_input]),
output_data=b"".join([ele.data for ele in internal_data.stream if not ele.is_input]),
is_ipv6=internal_data.current_pkt.is_ipv6,
)
class TCPInputStream:
"""
This datamodel will assemble the TCP input stream from the client sent data.
The function that use this data model will be handled when:
- The packet is TCP
- At least 1 packet has been sent
- A new client packet has been received
"""
def __init__(self, def __init__(self,
data: bytes, data: bytes,
is_ipv6: bool, is_ipv6: bool,
): ):
self.__data = bytes(data) self.__data = bytes(data)
self.__is_ipv6 = bool(is_ipv6) self.__is_ipv6 = bool(is_ipv6)
self.__total_stream_size = len(data)
@property @property
def data(self) -> bool: def data(self) -> bool:
@@ -65,14 +19,52 @@ class TCPInputStream:
def is_ipv6(self) -> bool: def is_ipv6(self) -> bool:
return self.__is_ipv6 return self.__is_ipv6
@property
def total_stream_size(self) -> int:
return self.__total_stream_size
def _push_new_data(self, data: bytes):
self.__data += data
self.__total_stream_size += len(data)
@classmethod
def _fetch_packet(cls, internal_data:DataStreamCtx, is_input:bool=False):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False:
raise NotReadyToRun()
if internal_data.current_pkt.is_input != is_input:
raise NotReadyToRun()
datahandler: TCPInputStream = internal_data.data_handler_context.get(cls, None)
if datahandler is None:
datahandler = cls(internal_data.current_pkt.data, internal_data.current_pkt.is_ipv6)
internal_data.data_handler_context[cls] = datahandler
else:
if datahandler.total_stream_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size:
match internal_data.full_stream_action:
case FullStreamAction.FLUSH:
datahandler = cls(internal_data.current_pkt.data, internal_data.current_pkt.is_ipv6)
internal_data.data_handler_context[cls] = datahandler
case FullStreamAction.REJECT:
raise StreamFullReject()
case FullStreamAction.DROP:
raise StreamFullDrop()
case FullStreamAction.ACCEPT:
raise NotReadyToRun()
else:
datahandler._push_new_data(internal_data.current_pkt.data)
return datahandler
class TCPInputStream(InternalTCPStream):
"""
This datamodel will assemble the TCP input stream from the client sent data.
The function that use this data model will be handled when:
- The packet is TCP
- At least 1 packet has been sent
- A new client packet has been received
"""
@classmethod @classmethod
def _fetch_packet(cls, internal_data:DataStreamCtx): def _fetch_packet(cls, internal_data:DataStreamCtx):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False or internal_data.current_pkt.is_input is False: return super()._fetch_packet(internal_data, is_input=True)
raise NotReadyToRun()
return cls(
data=internal_data.current_pkt.get_related_raw_stream(),
is_ipv6=internal_data.current_pkt.is_ipv6,
)
TCPClientStream = TCPInputStream TCPClientStream = TCPInputStream
@@ -85,29 +77,8 @@ class TCPOutputStream:
- A new server packet has been sent - A new server packet has been sent
""" """
def __init__(self,
data: bytes,
is_ipv6: bool,
):
self.__data = bytes(data)
self.__is_ipv6 = bool(is_ipv6)
@property
def data(self) -> bool:
return self.__data
@property
def is_ipv6(self) -> bool:
return self.__is_ipv6
@classmethod @classmethod
def _fetch_packet(cls, internal_data:DataStreamCtx): def _fetch_packet(cls, internal_data:DataStreamCtx):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False or internal_data.current_pkt.is_input is True: return super()._fetch_packet(internal_data, is_input=False)
raise NotReadyToRun()
return cls(
data=internal_data.current_pkt.get_related_raw_stream(),
is_ipv6=internal_data.current_pkt.is_ipv6,
)
TCPServerStream = TCPOutputStream TCPServerStream = TCPOutputStream