fixed managment of http message queue on nfproxy

This commit is contained in:
Domingo Dirutigliano
2025-05-17 13:58:13 +02:00
parent d6a4fc1953
commit 5c46a9bf34

View File

@@ -1,7 +1,12 @@
import pyllhttp import pyllhttp
from firegex.nfproxy.internals.exceptions import NotReadyToRun from firegex.nfproxy.internals.exceptions import NotReadyToRun
from firegex.nfproxy.internals.data import DataStreamCtx from firegex.nfproxy.internals.data import DataStreamCtx
from firegex.nfproxy.internals.exceptions import StreamFullDrop, StreamFullReject, RejectConnection, DropPacket from firegex.nfproxy.internals.exceptions import (
StreamFullDrop,
StreamFullReject,
RejectConnection,
DropPacket,
)
from firegex.nfproxy.internals.models import FullStreamAction, ExceptionAction from firegex.nfproxy.internals.models import FullStreamAction, ExceptionAction
from dataclasses import dataclass, field from dataclasses import dataclass, field
from collections import deque from collections import deque
@@ -15,12 +20,16 @@ from websockets.frames import Frame
from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.extensions.permessage_deflate import PerMessageDeflate
from pyllhttp import PAUSED_H2_UPGRADE, PAUSED_UPGRADE from pyllhttp import PAUSED_H2_UPGRADE, PAUSED_UPGRADE
@dataclass @dataclass
class InternalHTTPMessage: class InternalHTTPMessage:
"""Internal class to handle HTTP messages""" """Internal class to handle HTTP messages"""
url: str | None = field(default=None) url: str | None = field(default=None)
headers: dict[str, str] = field(default_factory=dict) headers: dict[str, str] = field(default_factory=dict)
lheaders: dict[str, str] = field(default_factory=dict) # lowercase copy of the headers lheaders: dict[str, str] = field(
default_factory=dict
) # lowercase copy of the headers
body: bytes | None = field(default=None) body: bytes | None = field(default=None)
body_decoded: bool = field(default=False) body_decoded: bool = field(default=False)
headers_complete: bool = field(default=False) headers_complete: bool = field(default=False)
@@ -40,9 +49,11 @@ class InternalHTTPMessage:
upgrading_to_h2: bool = field(default=False) upgrading_to_h2: bool = field(default=False)
upgrading_to_ws: bool = field(default=False) upgrading_to_ws: bool = field(default=False)
@dataclass @dataclass
class InternalHttpBuffer: class InternalHttpBuffer:
"""Internal class to handle HTTP messages""" """Internal class to handle HTTP messages"""
_url_buffer: bytes = field(default_factory=bytes) _url_buffer: bytes = field(default_factory=bytes)
_raw_header_fields: dict[str, str | list[str]] = field(default_factory=dict) _raw_header_fields: dict[str, str | list[str]] = field(default_factory=dict)
_header_fields: dict[str, str] = field(default_factory=dict) _header_fields: dict[str, str] = field(default_factory=dict)
@@ -52,8 +63,8 @@ class InternalHttpBuffer:
_current_header_value: bytes = field(default_factory=bytes) _current_header_value: bytes = field(default_factory=bytes)
_ws_packet_stream: bytes = field(default_factory=bytes) _ws_packet_stream: bytes = field(default_factory=bytes)
class InternalCallbackHandler():
class InternalCallbackHandler:
buffers = InternalHttpBuffer() buffers = InternalHttpBuffer()
msg = InternalHTTPMessage() msg = InternalHTTPMessage()
save_body = True save_body = True
@@ -102,7 +113,10 @@ class InternalCallbackHandler():
def on_header_value_complete(self): def on_header_value_complete(self):
if self.buffers._current_header_field: if self.buffers._current_header_field:
k, v = self.buffers._current_header_field.decode(errors="ignore"), self.buffers._current_header_value.decode(errors="ignore") k, v = (
self.buffers._current_header_field.decode(errors="ignore"),
self.buffers._current_header_value.decode(errors="ignore"),
)
old_value = self.buffers._raw_header_fields.get(k, None) old_value = self.buffers._raw_header_fields.get(k, None)
# raw headers are stored as thay were, considering to check changes between headers encoding # raw headers are stored as thay were, considering to check changes between headers encoding
@@ -116,7 +130,9 @@ class InternalCallbackHandler():
# Decoding headers normally # Decoding headers normally
kl = k.lower() kl = k.lower()
if kl in self.buffers._header_fields: if kl in self.buffers._header_fields:
self.buffers._header_fields[kl] += f", {v}" # Should be considered as a single list separated by commas as said in the RFC self.buffers._header_fields[kl] += (
f", {v}" # Should be considered as a single list separated by commas as said in the RFC
)
else: else:
self.buffers._header_fields[kl] = v self.buffers._header_fields[kl] = v
@@ -170,7 +186,9 @@ class InternalCallbackHandler():
print(f"Error decompressing brotli: {e}: skipping", flush=True) print(f"Error decompressing brotli: {e}: skipping", flush=True)
decode_success = False decode_success = False
break break
elif enc == "gzip" or enc == "x-gzip": #https://datatracker.ietf.org/doc/html/rfc2616#section-3.5 elif (
enc == "gzip" or enc == "x-gzip"
): # https://datatracker.ietf.org/doc/html/rfc2616#section-3.5
try: try:
if "gzip" in self.content_encoding.lower(): if "gzip" in self.content_encoding.lower():
with gzip.GzipFile(fileobj=io.BytesIO(decoding_body)) as f: with gzip.GzipFile(fileobj=io.BytesIO(decoding_body)) as f:
@@ -259,9 +277,14 @@ class InternalCallbackHandler():
self.buffers._ws_packet_stream += data self.buffers._ws_packet_stream += data
while True: while True:
try: try:
new_frame, self.buffers._ws_packet_stream = self._parse_websocket_frame(self.buffers._ws_packet_stream) new_frame, self.buffers._ws_packet_stream = (
self._parse_websocket_frame(self.buffers._ws_packet_stream)
)
except Exception: except Exception:
print("[WARNING] Websocket parsing failed, passing data to stream...", flush=True) print(
"[WARNING] Websocket parsing failed, passing data to stream...",
flush=True,
)
traceback.print_exc() traceback.print_exc()
self._ws_raised_error = True self._ws_raised_error = True
self.msg.stream += self.buffers._ws_packet_stream self.msg.stream += self.buffers._ws_packet_stream
@@ -293,8 +316,11 @@ class InternalCallbackHandler():
if self._is_input(): if self._is_input():
self._ws_extentions = [] # Fallback to no options self._ws_extentions = [] # Fallback to no options
else: else:
self._ws_extentions = self._parse_websocket_ext() # Extentions used are choosen by the server response self._ws_extentions = (
self._parse_websocket_ext()
) # Extentions used are choosen by the server response
read_buffering = bytearray() read_buffering = bytearray()
def read_exact(n: int): def read_exact(n: int):
nonlocal read_buffering nonlocal read_buffering
buffer = bytearray(read_buffering) buffer = bytearray(read_buffering)
@@ -307,7 +333,9 @@ class InternalCallbackHandler():
read_buffering = buffer[n:] read_buffering = buffer[n:]
return new_data return new_data
parsing = Frame.parse(read_exact, extensions=self._ws_extentions, mask=self._is_input()) parsing = Frame.parse(
read_exact, extensions=self._ws_extentions, mask=self._is_input()
)
parsing.send(None) parsing.send(None)
try: try:
parsing.send(bytearray(data)) parsing.send(bytearray(data))
@@ -337,6 +365,11 @@ class InternalCallbackHandler():
def pop_message(self): def pop_message(self):
return self.messages.popleft() return self.messages.popleft()
def pop_all_messages(self):
tmp = self.messages
self.messages = deque()
return tmp
def __repr__(self): def __repr__(self):
return f"<InternalCallbackHandler msg={self.msg} buffers={self.buffers} save_body={self.save_body} raised_error={self.raised_error} has_begun={self.has_begun} messages={self.messages}>" return f"<InternalCallbackHandler msg={self.msg} buffers={self.buffers} save_body={self.save_body} raised_error={self.raised_error} has_begun={self.has_begun} messages={self.messages}>"
@@ -349,6 +382,7 @@ class InternalHttpRequest(InternalCallbackHandler, pyllhttp.Request):
def _is_input(self): def _is_input(self):
return True return True
class InternalHttpResponse(InternalCallbackHandler, pyllhttp.Response): class InternalHttpResponse(InternalCallbackHandler, pyllhttp.Response):
def __init__(self): def __init__(self):
super(InternalCallbackHandler, self).__init__() super(InternalCallbackHandler, self).__init__()
@@ -357,10 +391,15 @@ class InternalHttpResponse(InternalCallbackHandler, pyllhttp.Response):
def _is_input(self): def _is_input(self):
return False return False
class InternalBasicHttpMetaClass: class InternalBasicHttpMetaClass:
"""Internal class to handle HTTP requests and responses""" """Internal class to handle HTTP requests and responses"""
def __init__(self, parser: InternalHttpRequest|InternalHttpResponse, msg: InternalHTTPMessage): def __init__(
self,
parser: InternalHttpRequest | InternalHttpResponse,
msg: InternalHTTPMessage,
):
self._parser = parser self._parser = parser
self.raised_error = False self.raised_error = False
self._message: InternalHTTPMessage | None = msg self._message: InternalHTTPMessage | None = msg
@@ -463,10 +502,17 @@ class InternalBasicHttpMetaClass:
@classmethod @classmethod
def _fetch_packet(cls, internal_data: DataStreamCtx): def _fetch_packet(cls, internal_data: DataStreamCtx):
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False: if (
internal_data.current_pkt is None
or internal_data.current_pkt.is_tcp is False
):
raise NotReadyToRun() raise NotReadyToRun()
ParserType = InternalHttpRequest if internal_data.current_pkt.is_input else InternalHttpResponse ParserType = (
InternalHttpRequest
if internal_data.current_pkt.is_input
else InternalHttpResponse
)
parser_key = f"{cls._parser_class()}_{'in' if internal_data.current_pkt.is_input else 'out'}" parser_key = f"{cls._parser_class()}_{'in' if internal_data.current_pkt.is_input else 'out'}"
parser = internal_data.data_handler_context.get(parser_key, None) parser = internal_data.data_handler_context.get(parser_key, None)
@@ -474,17 +520,25 @@ class InternalBasicHttpMetaClass:
parser: InternalHttpRequest | InternalHttpResponse = ParserType() parser: InternalHttpRequest | InternalHttpResponse = ParserType()
internal_data.data_handler_context[parser_key] = parser internal_data.data_handler_context[parser_key] = parser
if not internal_data.call_mem.get(cls._parser_class(), False): #Need to parse HTTP if not internal_data.call_mem.get(
cls._parser_class(), False
): # Need to parse HTTP
internal_data.call_mem[cls._parser_class()] = True internal_data.call_mem[cls._parser_class()] = True
parser.pop_all_messages() # Delete content on message deque
# Setting websocket options if needed to the client parser # Setting websocket options if needed to the client parser
if internal_data.current_pkt.is_input: if internal_data.current_pkt.is_input:
ext_opt = internal_data.data_handler_context.get(f"{cls._parser_class()}_ws_options_client") ext_opt = internal_data.data_handler_context.get(
f"{cls._parser_class()}_ws_options_client"
)
if ext_opt is not None and parser._ws_extentions != ext_opt: if ext_opt is not None and parser._ws_extentions != ext_opt:
parser._ws_extentions = ext_opt parser._ws_extentions = ext_opt
# Memory size managment # Memory size managment
if parser.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size: if (
parser.total_size + len(internal_data.current_pkt.data)
> internal_data.stream_max_size
):
match internal_data.full_stream_action: match internal_data.full_stream_action:
case FullStreamAction.FLUSH: case FullStreamAction.FLUSH:
# Deleting parser and re-creating it # Deleting parser and re-creating it
@@ -494,7 +548,10 @@ class InternalBasicHttpMetaClass:
parser.msg.total_size -= len(parser.msg.body) parser.msg.total_size -= len(parser.msg.body)
parser.msg.body = b"" parser.msg.body = b""
print("[WARNING] Flushing stream", flush=True) print("[WARNING] Flushing stream", flush=True)
if parser.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size: if (
parser.total_size + len(internal_data.current_pkt.data)
> internal_data.stream_max_size
):
parser.reset_data() parser.reset_data()
case FullStreamAction.REJECT: case FullStreamAction.REJECT:
raise StreamFullReject() raise StreamFullReject()
@@ -503,7 +560,9 @@ class InternalBasicHttpMetaClass:
case FullStreamAction.ACCEPT: case FullStreamAction.ACCEPT:
raise NotReadyToRun() raise NotReadyToRun()
internal_data.call_mem["headers_were_set"] = parser.msg.headers_complete #This information is usefull for building the real object internal_data.call_mem["headers_were_set"] = (
parser.msg.headers_complete
) # This information is usefull for building the real object
try: try:
parser.parse_data(internal_data.current_pkt.data) parser.parse_data(internal_data.current_pkt.data)
@@ -521,9 +580,13 @@ class InternalBasicHttpMetaClass:
if parser.should_upgrade and not internal_data.current_pkt.is_input: if parser.should_upgrade and not internal_data.current_pkt.is_input:
# Creating ws_option for the client # Creating ws_option for the client
if not internal_data.data_handler_context.get(f"{cls._parser_class()}_ws_options_client"): if not internal_data.data_handler_context.get(
f"{cls._parser_class()}_ws_options_client"
):
ext = parser._parse_websocket_ext() ext = parser._parse_websocket_ext()
internal_data.data_handler_context[f"{cls._parser_class()}_ws_options_client"] = ext internal_data.data_handler_context[
f"{cls._parser_class()}_ws_options_client"
] = ext
# Once the parsers has been triggered, we can return the object if needed # Once the parsers has been triggered, we can return the object if needed
if not cls._before_fetch_callable_checks(internal_data): if not cls._before_fetch_callable_checks(internal_data):
@@ -534,13 +597,22 @@ class InternalBasicHttpMetaClass:
messages_tosend.append(parser.pop_message()) messages_tosend.append(parser.pop_message())
if len(messages_tosend) > 0: if len(messages_tosend) > 0:
internal_data.call_mem["headers_were_set"] = False # New messages completed so the current message headers were not set in this case internal_data.call_mem["headers_were_set"] = (
False # New messages completed so the current message headers were not set in this case
)
if not internal_data.call_mem["headers_were_set"] and parser.msg.headers_complete: if (
messages_tosend.append(parser.msg) # Also the current message needs to be sent due to complete headers not internal_data.call_mem["headers_were_set"]
and parser.msg.headers_complete
):
messages_tosend.append(
parser.msg
) # Also the current message needs to be sent due to complete headers
if parser._packet_to_stream(): if parser._packet_to_stream():
messages_tosend.append(parser.msg) # Also the current message needs to beacase a stream is going on messages_tosend.append(
parser.msg
) # Also the current message needs to beacase a stream is going on
messages_to_call = len(messages_tosend) messages_to_call = len(messages_tosend)
@@ -551,6 +623,7 @@ class InternalBasicHttpMetaClass:
return [cls(parser, ele) for ele in messages_tosend] return [cls(parser, ele) for ele in messages_tosend]
class HttpRequest(InternalBasicHttpMetaClass): class HttpRequest(InternalBasicHttpMetaClass):
""" """
HTTP Request handler HTTP Request handler
@@ -573,6 +646,7 @@ class HttpRequest(InternalBasicHttpMetaClass):
def __repr__(self): def __repr__(self):
return f"<HttpRequest method={self.method} url={self.url} headers={self.headers} body=[{0 if not self.body else len(self.body)} bytes] http_version={self.http_version} keep_alive={self.keep_alive} should_upgrade={self.should_upgrade} headers_complete={self.headers_complete} message_complete={self.message_complete} content_length={self.content_length} stream={self.stream} ws_stream={self.ws_stream}>" return f"<HttpRequest method={self.method} url={self.url} headers={self.headers} body=[{0 if not self.body else len(self.body)} bytes] http_version={self.http_version} keep_alive={self.keep_alive} should_upgrade={self.should_upgrade} headers_complete={self.headers_complete} message_complete={self.message_complete} content_length={self.content_length} stream={self.stream} ws_stream={self.ws_stream}>"
class HttpResponse(InternalBasicHttpMetaClass): class HttpResponse(InternalBasicHttpMetaClass):
""" """
HTTP Response handler HTTP Response handler
@@ -595,6 +669,7 @@ class HttpResponse(InternalBasicHttpMetaClass):
def __repr__(self): def __repr__(self):
return f"<HttpResponse status_code={self.status_code} url={self.url} headers={self.headers} body=[{0 if not self.body else len(self.body)} bytes] http_version={self.http_version} keep_alive={self.keep_alive} should_upgrade={self.should_upgrade} headers_complete={self.headers_complete} message_complete={self.message_complete} content_length={self.content_length} stream={self.stream} ws_stream={self.ws_stream}>" return f"<HttpResponse status_code={self.status_code} url={self.url} headers={self.headers} body=[{0 if not self.body else len(self.body)} bytes] http_version={self.http_version} keep_alive={self.keep_alive} should_upgrade={self.should_upgrade} headers_complete={self.headers_complete} message_complete={self.message_complete} content_length={self.content_length} stream={self.stream} ws_stream={self.ws_stream}>"
class HttpRequestHeader(HttpRequest): class HttpRequestHeader(HttpRequest):
""" """
HTTP Request Header handler HTTP Request Header handler
@@ -608,6 +683,7 @@ class HttpRequestHeader(HttpRequest):
def _parser_class() -> str: def _parser_class() -> str:
return "header_http" return "header_http"
class HttpResponseHeader(HttpResponse): class HttpResponseHeader(HttpResponse):
""" """
HTTP Response Header handler HTTP Response Header handler