changed datahandler max size managment

This commit is contained in:
Domingo Dirutigliano
2025-03-03 21:15:49 +01:00
parent 072745cc06
commit 832c6e1530
7 changed files with 135 additions and 291 deletions

View File

@@ -1,5 +1,5 @@
import functools
from firegex.nfproxy.models import RawPacket, TCPInputStream, TCPOutputStream, TCPClientStream, TCPServerStream, TCPStreams
from firegex.nfproxy.models import RawPacket, TCPInputStream, TCPOutputStream, TCPClientStream, TCPServerStream
from firegex.nfproxy.internals.models import Action, FullStreamAction
ACCEPT = Action.ACCEPT
@@ -35,5 +35,5 @@ def clear_pyfilter_registry():
__all__ = [
"ACCEPT", "DROP", "REJECT", "UNSTABLE_MANGLE"
"Action", "FullStreamAction", "pyfilter",
"RawPacket", "TCPInputStream", "TCPOutputStream", "TCPClientStream", "TCPServerStream", "TCPStreams"
"RawPacket", "TCPInputStream", "TCPOutputStream", "TCPClientStream", "TCPServerStream"
]

View File

@@ -3,7 +3,7 @@ from firegex.nfproxy.internals.models import Action, FullStreamAction
from firegex.nfproxy.internals.models import FilterHandler, PacketHandlerResult
import functools
from firegex.nfproxy.internals.data import DataStreamCtx
from firegex.nfproxy.internals.exceptions import NotReadyToRun
from firegex.nfproxy.internals.exceptions import NotReadyToRun, StreamFullReject, DropPacket, RejectConnection, StreamFullDrop
from firegex.nfproxy.internals.data import RawPacket
def context_call(glob, func, *args, **kargs):
@@ -76,32 +76,8 @@ def handle_packet(glob: dict) -> None:
cache_call[RawPacket] = pkt_info
final_result = Action.ACCEPT
data_size = len(pkt_info.data)
result = PacketHandlerResult(glob)
if internal_data.stream_size+data_size > internal_data.stream_max_size:
match internal_data.full_stream_action:
case FullStreamAction.FLUSH:
internal_data.stream = []
internal_data.stream_size = 0
for func in internal_data.flush_action_set:
func()
case FullStreamAction.ACCEPT:
result.action = Action.ACCEPT
return result.set_result()
case FullStreamAction.REJECT:
result.action = Action.REJECT
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
case FullStreamAction.REJECT:
result.action = Action.DROP
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
internal_data.stream.append(pkt_info)
internal_data.stream_size += data_size
func_name = None
mangled_packet = None
for filter in internal_data.filter_call_info:
@@ -115,6 +91,22 @@ def handle_packet(glob: dict) -> None:
cache_call[data_type] = None
skip_call = True
break
except StreamFullDrop:
result.action = Action.DROP
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
except StreamFullReject:
result.action = Action.REJECT
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
except DropPacket:
result.action = Action.DROP
result.matched_by = filter.name
return result.set_result()
except RejectConnection:
result.action = Action.REJECT
result.matched_by = filter.name
return result.set_result()
final_params.append(cache_call[data_type])
if skip_call:
continue

View File

@@ -1,5 +1,5 @@
from firegex.nfproxy.internals.models import FilterHandler
from typing import Callable
from firegex.nfproxy.internals.models import FullStreamAction
class RawPacket:
"""
@@ -109,26 +109,6 @@ class DataStreamCtx:
def filter_call_info(self, v: list[FilterHandler]):
self.__data["filter_call_info"] = v
@property
def stream(self) -> list[RawPacket]:
if "stream" not in self.__data.keys():
self.__data["stream"] = []
return self.__data.get("stream")
@stream.setter
def stream(self, v: list[RawPacket]):
self.__data["stream"] = v
@property
def stream_size(self) -> int:
if "stream_size" not in self.__data.keys():
self.__data["stream_size"] = 0
return self.__data.get("stream_size")
@stream_size.setter
def stream_size(self, v: int):
self.__data["stream_size"] = v
@property
def stream_max_size(self) -> int:
if "stream_max_size" not in self.__data.keys():
@@ -140,13 +120,13 @@ class DataStreamCtx:
self.__data["stream_max_size"] = v
@property
def full_stream_action(self) -> str:
def full_stream_action(self) -> FullStreamAction:
if "full_stream_action" not in self.__data.keys():
self.__data["full_stream_action"] = "flush"
return self.__data.get("full_stream_action")
@full_stream_action.setter
def full_stream_action(self, v: str):
def full_stream_action(self, v: FullStreamAction):
self.__data["full_stream_action"] = v
@property
@@ -158,14 +138,14 @@ class DataStreamCtx:
self.__data["current_pkt"] = v
@property
def http_data_objects(self) -> dict:
if "http_data_objects" not in self.__data.keys():
self.__data["http_data_objects"] = {}
return self.__data.get("http_data_objects")
def data_handler_context(self) -> dict:
if "data_handler_context" not in self.__data.keys():
self.__data["data_handler_context"] = {}
return self.__data.get("data_handler_context")
@http_data_objects.setter
def http_data_objects(self, v: dict):
self.__data["http_data_objects"] = v
@data_handler_context.setter
def data_handler_context(self, v: dict):
self.__data["data_handler_context"] = v
@property
def save_http_data_in_streams(self) -> bool:
@@ -177,14 +157,5 @@ class DataStreamCtx:
def save_http_data_in_streams(self, v: bool):
self.__data["save_http_data_in_streams"] = v
@property
def flush_action_set(self) -> set[Callable]:
if "flush_action_set" not in self.__data.keys():
self.__data["flush_action_set"] = set()
return self.__data.get("flush_action_set")
@flush_action_set.setter
def flush_action_set(self, v: set[Callable]):
self.__data["flush_action_set"] = v

View File

@@ -1,3 +1,15 @@
class NotReadyToRun(Exception):
"raise this exception if the stream state is not ready to parse this object, the call will be skipped"
class DropPacket(Exception):
"raise this exception if you want to drop the packet"
class StreamFullDrop(Exception):
"raise this exception if you want to drop the packet due to full stream"
class RejectConnection(Exception):
"raise this exception if you want to reject the connection"
class StreamFullReject(Exception):
"raise this exception if you want to reject the connection due to full stream"

View File

@@ -1,4 +1,4 @@
from firegex.nfproxy.models.tcp import TCPInputStream, TCPOutputStream, TCPClientStream, TCPServerStream, TCPStreams
from firegex.nfproxy.models.tcp import TCPInputStream, TCPOutputStream, TCPClientStream, TCPServerStream
from firegex.nfproxy.models.http import HttpRequest, HttpResponse, HttpRequestHeader, HttpResponseHeader
from firegex.nfproxy.internals.data import RawPacket
@@ -7,13 +7,11 @@ type_annotations_associations = {
RawPacket: RawPacket._fetch_packet,
TCPInputStream: TCPInputStream._fetch_packet,
TCPOutputStream: TCPOutputStream._fetch_packet,
TCPStreams: TCPStreams._fetch_packet,
},
"http": {
RawPacket: RawPacket._fetch_packet,
TCPInputStream: TCPInputStream._fetch_packet,
TCPOutputStream: TCPOutputStream._fetch_packet,
TCPStreams: TCPStreams._fetch_packet,
HttpRequest: HttpRequest._fetch_packet,
HttpResponse: HttpResponse._fetch_packet,
HttpRequestHeader: HttpRequestHeader._fetch_packet,
@@ -23,6 +21,6 @@ type_annotations_associations = {
__all__ = [
"RawPacket",
"TCPInputStream", "TCPOutputStream", "TCPClientStream", "TCPServerStream", "TCPStreams",
"TCPInputStream", "TCPOutputStream", "TCPClientStream", "TCPServerStream",
"HttpRequest", "HttpResponse", "HttpRequestHeader", "HttpResponseHeader",
]

View File

@@ -1,6 +1,8 @@
import pyllhttp
from firegex.nfproxy.internals.exceptions import NotReadyToRun
from firegex.nfproxy.internals.data import DataStreamCtx
from firegex.nfproxy.internals.exceptions import StreamFullDrop, StreamFullReject
from firegex.nfproxy.internals.models import FullStreamAction
class InternalCallbackHandler():
@@ -15,14 +17,16 @@ class InternalCallbackHandler():
message_complete: bool = False
status: str|None = None
_status_buffer: bytes = b""
current_header_field = None
current_header_value = None
_current_header_field = b""
_current_header_value = b""
_save_body = True
total_size = 0
def on_message_begin(self):
self.has_begun = True
def on_url(self, url):
self.total_size += len(url)
self._url_buffer += url
def on_url_complete(self):
@@ -30,35 +34,32 @@ class InternalCallbackHandler():
self._url_buffer = None
def on_header_field(self, field):
if self.current_header_field is None:
self.current_header_field = bytearray(field)
else:
self.current_header_field += field
self.total_size += len(field)
self._current_header_field += field
def on_header_field_complete(self):
self.current_header_field = self.current_header_field
self._current_header_field = self._current_header_field
def on_header_value(self, value):
if self.current_header_value is None:
self.current_header_value = bytearray(value)
else:
self.current_header_value += value
self.total_size += len(value)
self._current_header_value += value
def on_header_value_complete(self):
if self.current_header_value is not None and self.current_header_field is not None:
self._header_fields[self.current_header_field.decode(errors="ignore")] = self.current_header_value.decode(errors="ignore")
self.current_header_field = None
self.current_header_value = None
if self._current_header_value is not None and self._current_header_field is not None:
self._header_fields[self._current_header_field.decode(errors="ignore")] = self._current_header_value.decode(errors="ignore")
self._current_header_field = b""
self._current_header_value = b""
def on_headers_complete(self):
self.headers_complete = True
self.headers = self._header_fields
self._header_fields = {}
self.current_header_field = None
self.current_header_value = None
self._current_header_field = None
self._current_header_value = None
def on_body(self, body: bytes):
if self._save_body:
self.total_size += len(body)
self._body_buffer += body
def on_message_complete(self):
@@ -67,6 +68,7 @@ class InternalCallbackHandler():
self.message_complete = True
def on_status(self, status: bytes):
self.total_size += len(status)
self._status_buffer += status
def on_status_complete(self):
@@ -112,6 +114,10 @@ class InternalBasicHttpMetaClass:
self.stream = b""
self.raised_error = False
@property
def total_size(self) -> int:
return self._parser.total_size
@property
def url(self) -> str|None:
return self._parser.url
@@ -187,14 +193,29 @@ class InternalBasicHttpMetaClass:
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False:
raise NotReadyToRun()
datahandler:InternalBasicHttpMetaClass = internal_data.http_data_objects.get(cls, None)
datahandler:InternalBasicHttpMetaClass = internal_data.data_handler_context.get(cls, None)
if datahandler is None or datahandler.raised_error:
datahandler = cls()
internal_data.http_data_objects[cls] = datahandler
internal_data.data_handler_context[cls] = datahandler
if not datahandler._before_fetch_callable_checks(internal_data):
raise NotReadyToRun()
# Memory size managment
if datahandler.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size:
match internal_data.full_stream_action:
case FullStreamAction.FLUSH:
datahandler = cls()
internal_data.data_handler_context[cls] = datahandler
case FullStreamAction.REJECT:
raise StreamFullReject()
case FullStreamAction.DROP:
raise StreamFullDrop()
case FullStreamAction.ACCEPT:
raise NotReadyToRun()
datahandler._fetch_current_packet(internal_data)
if not datahandler._callable_checks(internal_data):
raise NotReadyToRun()
@@ -202,8 +223,8 @@ class InternalBasicHttpMetaClass:
internal_data.save_http_data_in_streams = True
if datahandler._trigger_remove_data(internal_data):
if internal_data.http_data_objects.get(cls):
del internal_data.http_data_objects[cls]
if internal_data.data_handler_context.get(cls):
del internal_data.data_handler_context[cls]
return datahandler
@@ -262,124 +283,3 @@ class HttpResponseHeader(HttpResponse):
return True
return False
"""
#TODO include this?
import codecs
# Null bytes; no need to recreate these on each call to guess_json_utf
_null = "\x00".encode("ascii") # encoding to ASCII for Python 3
_null2 = _null * 2
_null3 = _null * 3
def guess_json_utf(data):
""
:rtype: str
""
# JSON always starts with two ASCII characters, so detection is as
# easy as counting the nulls and from their location and count
# determine the encoding. Also detect a BOM, if present.
sample = data[:4]
if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE):
return "utf-32" # BOM included
if sample[:3] == codecs.BOM_UTF8:
return "utf-8-sig" # BOM included, MS style (discouraged)
if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
return "utf-16" # BOM included
nullcount = sample.count(_null)
if nullcount == 0:
return "utf-8"
if nullcount == 2:
if sample[::2] == _null2: # 1st and 3rd are null
return "utf-16-be"
if sample[1::2] == _null2: # 2nd and 4th are null
return "utf-16-le"
# Did not detect 2 valid UTF-16 ascii-range characters
if nullcount == 3:
if sample[:3] == _null3:
return "utf-32-be"
if sample[1:] == _null3:
return "utf-32-le"
# Did not detect a valid UTF-32 ascii-range character
return None
from http_parser.pyparser import HttpParser
import json
from urllib.parse import parse_qsl
from dataclasses import dataclass
@dataclass
class HttpMessage():
fragment: str
headers: dict
method: str
parameters: dict
path: str
query_string: str
raw_body: bytes
status_code: int
url: str
version: str
class HttpMessageParser(HttpParser):
def __init__(self, data:bytes, decompress_body=True):
super().__init__(decompress = decompress_body)
self.execute(data, len(data))
self._parameters = {}
try:
self._parse_parameters()
except Exception as e:
print("Error in parameters parsing:", data)
print("Exception:", str(e))
def get_raw_body(self):
return b"\r\n".join(self._body)
def _parse_query_string(self, raw_string):
parameters = parse_qsl(raw_string)
for key,value in parameters:
try:
key = key.decode()
value = value.decode()
except:
pass
if self._parameters.get(key):
if isinstance(self._parameters[key], list):
self._parameters[key].append(value)
else:
self._parameters[key] = [self._parameters[key], value]
else:
self._parameters[key] = value
def _parse_parameters(self):
if self._method == "POST":
body = self.get_raw_body()
if len(body) == 0:
return
content_type = self.get_headers().get("Content-Type")
if not content_type or "x-www-form-urlencoded" in content_type:
try:
self._parse_query_string(body.decode())
except:
pass
elif "json" in content_type:
self._parameters = json.loads(body)
elif self._method == "GET":
self._parse_query_string(self._query_string)
def get_parameters(self):
""returns parameters parsed from query string or body""
return self._parameters
def get_version(self):
if self._version:
return ".".join([str(x) for x in self._version])
return None
def to_message(self):
return HttpMessage(self._fragment, self._headers, self._method,
self._parameters, self._path, self._query_string,
self.get_raw_body(), self._status_code,
self._url, self.get_version()
)
"""

View File

@@ -1,61 +1,15 @@
from firegex.nfproxy.internals.data import DataStreamCtx
from firegex.nfproxy.internals.exceptions import NotReadyToRun
from firegex.nfproxy.internals.exceptions import NotReadyToRun, StreamFullDrop, StreamFullReject
from firegex.nfproxy.internals.models import FullStreamAction
class TCPStreams:
"""
This datamodel will assemble the TCP streams from the input and output data.
The function that use this data model will be handled when:
- The packet is TCP
- At least 1 packet has been sent
"""
def __init__(self,
input_data: bytes,
output_data: bytes,
is_ipv6: bool,
):
self.__input_data = bytes(input_data)
self.__output_data = bytes(output_data)
self.__is_ipv6 = bool(is_ipv6)
@property
def input_data(self) -> bytes:
return self.__input_data
@property
def output_data(self) -> bytes:
return self.__output_data
@property
def is_ipv6(self) -> bool:
return self.__is_ipv6
@classmethod
def _fetch_packet(cls, internal_data:DataStreamCtx):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False:
raise NotReadyToRun()
return cls(
input_data=b"".join([ele.data for ele in internal_data.stream if ele.is_input]),
output_data=b"".join([ele.data for ele in internal_data.stream if not ele.is_input]),
is_ipv6=internal_data.current_pkt.is_ipv6,
)
class TCPInputStream:
"""
This datamodel will assemble the TCP input stream from the client sent data.
The function that use this data model will be handled when:
- The packet is TCP
- At least 1 packet has been sent
- A new client packet has been received
"""
class InternalTCPStream:
def __init__(self,
data: bytes,
is_ipv6: bool,
):
self.__data = bytes(data)
self.__is_ipv6 = bool(is_ipv6)
self.__total_stream_size = len(data)
@property
def data(self) -> bool:
@@ -65,14 +19,52 @@ class TCPInputStream:
def is_ipv6(self) -> bool:
return self.__is_ipv6
@property
def total_stream_size(self) -> int:
return self.__total_stream_size
def _push_new_data(self, data: bytes):
self.__data += data
self.__total_stream_size += len(data)
@classmethod
def _fetch_packet(cls, internal_data:DataStreamCtx, is_input:bool=False):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False:
raise NotReadyToRun()
if internal_data.current_pkt.is_input != is_input:
raise NotReadyToRun()
datahandler: TCPInputStream = internal_data.data_handler_context.get(cls, None)
if datahandler is None:
datahandler = cls(internal_data.current_pkt.data, internal_data.current_pkt.is_ipv6)
internal_data.data_handler_context[cls] = datahandler
else:
if datahandler.total_stream_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size:
match internal_data.full_stream_action:
case FullStreamAction.FLUSH:
datahandler = cls(internal_data.current_pkt.data, internal_data.current_pkt.is_ipv6)
internal_data.data_handler_context[cls] = datahandler
case FullStreamAction.REJECT:
raise StreamFullReject()
case FullStreamAction.DROP:
raise StreamFullDrop()
case FullStreamAction.ACCEPT:
raise NotReadyToRun()
else:
datahandler._push_new_data(internal_data.current_pkt.data)
return datahandler
class TCPInputStream(InternalTCPStream):
"""
This datamodel will assemble the TCP input stream from the client sent data.
The function that use this data model will be handled when:
- The packet is TCP
- At least 1 packet has been sent
- A new client packet has been received
"""
@classmethod
def _fetch_packet(cls, internal_data:DataStreamCtx):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False or internal_data.current_pkt.is_input is False:
raise NotReadyToRun()
return cls(
data=internal_data.current_pkt.get_related_raw_stream(),
is_ipv6=internal_data.current_pkt.is_ipv6,
)
return super()._fetch_packet(internal_data, is_input=True)
TCPClientStream = TCPInputStream
@@ -85,29 +77,8 @@ class TCPOutputStream:
- A new server packet has been sent
"""
def __init__(self,
data: bytes,
is_ipv6: bool,
):
self.__data = bytes(data)
self.__is_ipv6 = bool(is_ipv6)
@property
def data(self) -> bool:
return self.__data
@property
def is_ipv6(self) -> bool:
return self.__is_ipv6
@classmethod
def _fetch_packet(cls, internal_data:DataStreamCtx):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False or internal_data.current_pkt.is_input is True:
raise NotReadyToRun()
return cls(
data=internal_data.current_pkt.get_related_raw_stream(),
is_ipv6=internal_data.current_pkt.is_ipv6,
)
return super()._fetch_packet(internal_data, is_input=False)
TCPServerStream = TCPOutputStream