import pyllhttp from firegex.nfproxy.internals.exceptions import NotReadyToRun from firegex.nfproxy.internals.data import DataStreamCtx from firegex.nfproxy.internals.exceptions import StreamFullDrop, StreamFullReject, RejectConnection, DropPacket from firegex.nfproxy.internals.models import FullStreamAction, ExceptionAction from dataclasses import dataclass, field from collections import deque from typing import Type from zstd import ZSTD_uncompress import gzip import io import zlib import brotli from websockets.frames import Frame from websockets.extensions.permessage_deflate import PerMessageDeflate from pyllhttp import PAUSED_H2_UPGRADE, PAUSED_UPGRADE @dataclass class InternalHTTPMessage: """Internal class to handle HTTP messages""" url: str|None = field(default=None) headers: dict[str, str] = field(default_factory=dict) lheaders: dict[str, str] = field(default_factory=dict) # lowercase copy of the headers body: bytes|None = field(default=None) body_decoded: bool = field(default=False) headers_complete: bool = field(default=False) message_complete: bool = field(default=False) status: str|None = field(default=None) total_size: int = field(default=0) user_agent: str = field(default_factory=str) content_encoding: str = field(default=str) content_type: str = field(default=str) keep_alive: bool = field(default=False) should_upgrade: bool = field(default=False) http_version: str = field(default=str) method: str = field(default=str) content_length: int = field(default=0) stream: bytes = field(default_factory=bytes) ws_stream: list[Frame] = field(default_factory=list) # Decoded websocket stream upgrading_to_h2: bool = field(default=False) upgrading_to_ws: bool = field(default=False) @dataclass class InternalHttpBuffer: """Internal class to handle HTTP messages""" _url_buffer: bytes = field(default_factory=bytes) _raw_header_fields: dict[str, str|list[str]] = field(default_factory=dict) _header_fields: dict[str, str] = field(default_factory=dict) _body_buffer: bytes = field(default_factory=bytes) _status_buffer: bytes = field(default_factory=bytes) _current_header_field: bytes = field(default_factory=bytes) _current_header_value: bytes = field(default_factory=bytes) _ws_packet_stream: bytes = field(default_factory=bytes) class InternalCallbackHandler(): buffers = InternalHttpBuffer() msg = InternalHTTPMessage() save_body = True raised_error = False has_begun = False messages: deque[InternalHTTPMessage] = deque() _ws_extentions = None _ws_raised_error = False def reset_data(self): self.msg = InternalHTTPMessage() self.buffers = InternalHttpBuffer() self.messages.clear() def on_message_begin(self): self.buffers = InternalHttpBuffer() self.msg = InternalHTTPMessage() self.has_begun = True def on_url(self, url): self.buffers._url_buffer += url self.msg.total_size += len(url) def on_url_complete(self): self.msg.url = self.buffers._url_buffer.decode(errors="ignore") self.buffers._url_buffer = b"" def on_status(self, status: bytes): self.msg.total_size += len(status) self.buffers._status_buffer += status def on_status_complete(self): self.msg.status = self.buffers._status_buffer.decode(errors="ignore") self.buffers._status_buffer = b"" def on_header_field(self, field): self.msg.total_size += len(field) self.buffers._current_header_field += field def on_header_field_complete(self): pass # Nothing to do def on_header_value(self, value): self.msg.total_size += len(value) self.buffers._current_header_value += value def on_header_value_complete(self): if self.buffers._current_header_field: k, v = self.buffers._current_header_field.decode(errors="ignore"), self.buffers._current_header_value.decode(errors="ignore") old_value = self.buffers._raw_header_fields.get(k, None) # raw headers are stored as thay were, considering to check changes between headers encoding if isinstance(old_value, list): old_value.append(v) elif isinstance(old_value, str): self.buffers._raw_header_fields[k] = [old_value, v] else: self.buffers._raw_header_fields[k] = v # Decoding headers normally kl = k.lower() if kl in self.buffers._header_fields: self.buffers._header_fields[kl] += f", {v}" # Should be considered as a single list separated by commas as said in the RFC else: self.buffers._header_fields[kl] = v self.buffers._current_header_field = b"" self.buffers._current_header_value = b"" def on_headers_complete(self): self.msg.headers = self.buffers._raw_header_fields self.msg.lheaders = self.buffers._header_fields self.buffers._raw_header_fields = {} self.buffers._current_header_field = b"" self.buffers._current_header_value = b"" self.msg.headers_complete = True self.msg.method = self.method_parsed self.msg.content_length = self.content_length_parsed self.msg.should_upgrade = self.should_upgrade self.msg.keep_alive = self.keep_alive self.msg.http_version = self.http_version self.msg.content_type = self.content_type self.msg.content_encoding = self.content_encoding self.msg.user_agent = self.user_agent def on_body(self, body: bytes): if self.save_body: self.msg.total_size += len(body) self.buffers._body_buffer += body def on_message_complete(self): self.msg.body = self.buffers._body_buffer self.msg.should_upgrade = self.should_upgrade self.buffers._body_buffer = b"" encodings = [ele.strip() for ele in self.content_encoding.lower().split(",")] decode_success = True decoding_body = self.msg.body for enc in reversed(encodings): if not enc: continue if enc == "deflate": try: decompress = zlib.decompressobj(-zlib.MAX_WBITS) decoding_body = decompress.decompress(decoding_body) decoding_body += decompress.flush() except Exception as e: print(f"Error decompressing deflate: {e}: skipping", flush=True) decode_success = False break elif enc == "br": try: decoding_body = brotli.decompress(decoding_body) except Exception as e: print(f"Error decompressing brotli: {e}: skipping", flush=True) decode_success = False break elif enc == "gzip" or enc == "x-gzip": #https://datatracker.ietf.org/doc/html/rfc2616#section-3.5 try: if "gzip" in self.content_encoding.lower(): with gzip.GzipFile(fileobj=io.BytesIO(decoding_body)) as f: decoding_body = f.read() except Exception as e: print(f"Error decompressing gzip: {e}: skipping", flush=True) decode_success = False break elif enc == "zstd": try: decoding_body = ZSTD_uncompress(decoding_body) except Exception as e: print(f"Error decompressing zstd: {e}: skipping", flush=True) decode_success = False break elif enc == "identity": pass # No need to do anything https://datatracker.ietf.org/doc/html/rfc2616#section-3.5 (it's possible to be found also if it should't be used) else: decode_success = False break if decode_success: self.msg.body = decoding_body self.msg.body_decoded = True self.msg.message_complete = True self.has_begun = False if not self._packet_to_stream(): self.messages.append(self.msg) @property def user_agent(self) -> str: return self.msg.lheaders.get("user-agent", "") @property def content_encoding(self) -> str: return self.msg.lheaders.get("content-encoding", "") @property def content_type(self) -> str: return self.msg.lheaders.get("content-type", "") @property def keep_alive(self) -> bool: return self.should_keep_alive @property def should_upgrade(self) -> bool: return self.is_upgrading @property def http_version(self) -> str: if self.major and self.minor: return f"{self.major}.{self.minor}" else: return "" @property def method_parsed(self) -> str: return self.method @property def total_size(self) -> int: """Total size used by the parser""" tot = self.msg.total_size for msg in self.messages: tot += msg.total_size return tot @property def content_length_parsed(self) -> int: return self.content_length def _is_input(self) -> bool: raise NotImplementedError() def _packet_to_stream(self): return self.should_upgrade and self.save_body def _stream_parser(self, data: bytes): if self.msg.upgrading_to_ws: if self._ws_raised_error: self.msg.stream += data self.msg.total_size += len(data) return self.buffers._ws_packet_stream += data while True: try: new_frame, self.buffers._ws_packet_stream = self._parse_websocket_frame(self.buffers._ws_packet_stream) except Exception as e: self._ws_raised_error = True self.msg.stream += self.buffers._ws_packet_stream self.buffers._ws_packet_stream = b"" self.msg.total_size += len(data) return if new_frame is None: break self.msg.ws_stream.append(new_frame) self.msg.total_size += len(new_frame.data) if self.msg.upgrading_to_h2: self.msg.total_size += len(data) self.msg.stream += data def _parse_websocket_ext(self): ext_ws = [] req_ext = [] for ele in self.msg.lheaders.get("sec-websocket-extensions", "").split(","): for xt in ele.split(";"): req_ext.append(xt.strip().lower()) for ele in req_ext: if ele == "permessage-deflate": ext_ws.append(PerMessageDeflate(False, False, 15, 15)) return ext_ws def _parse_websocket_frame(self, data: bytes) -> tuple[Frame|None, bytes]: # mask = is_input if self._ws_extentions is None: self._ws_extentions = self._parse_websocket_ext() read_buffering = bytearray() def read_exact(n: int): nonlocal read_buffering buffer = bytearray(read_buffering) while len(buffer) < n: data = yield if data is None: raise RuntimeError("Should not send None to this generator") buffer.extend(data) new_data = bytes(buffer[:n]) read_buffering = buffer[n:] return new_data parsing = Frame.parse(read_exact, extensions=self._ws_extentions, mask=self._is_input()) parsing.send(None) try: parsing.send(bytearray(data)) except StopIteration as e: return e.value, read_buffering return None, read_buffering def parse_data(self, data: bytes): if self._packet_to_stream(): # This is a websocket upgrade! self._stream_parser(data) else: try: reason, consumed = self.execute(data) if reason == PAUSED_UPGRADE: self.msg.upgrading_to_ws = True self.msg.message_complete = True self._stream_parser(data[consumed:]) elif reason == PAUSED_H2_UPGRADE: self.msg.upgrading_to_h2 = True self.msg.message_complete = True self._stream_parser(data[consumed:]) except Exception as e: self.raised_error = True raise e def pop_message(self): return self.messages.popleft() def __repr__(self): return f"" class InternalHttpRequest(InternalCallbackHandler, pyllhttp.Request): def __init__(self): super(InternalCallbackHandler, self).__init__() super(pyllhttp.Request, self).__init__() def _is_input(self): return True class InternalHttpResponse(InternalCallbackHandler, pyllhttp.Response): def __init__(self): super(InternalCallbackHandler, self).__init__() super(pyllhttp.Response, self).__init__() def _is_input(self): return False class InternalBasicHttpMetaClass: """Internal class to handle HTTP requests and responses""" def __init__(self, parser: InternalHttpRequest|InternalHttpResponse, msg: InternalHTTPMessage): self._parser = parser self.raised_error = False self._message: InternalHTTPMessage|None = msg self._contructor_hook() def _contructor_hook(self): pass @property def total_size(self) -> int: """Total size of the stream""" return self._parser.total_size @property def url(self) -> str|None: """URL of the message""" return self._message.url @property def headers(self) -> dict[str, str]: """Headers of the message""" return self._message.headers @property def user_agent(self) -> str: """User agent of the message""" return self._message.user_agent @property def content_encoding(self) -> str: """Content encoding of the message""" return self._message.content_encoding @property def body(self) -> bytes: """Body of the message""" return self._message.body @property def headers_complete(self) -> bool: """If the headers are complete""" return self._message.headers_complete @property def message_complete(self) -> bool: """If the message is complete""" return self._message.message_complete @property def http_version(self) -> str: """HTTP version of the message""" return self._message.http_version @property def keep_alive(self) -> bool: """If the message should keep alive""" return self._message.keep_alive @property def should_upgrade(self) -> bool: """If the message should upgrade""" return self._parser.should_upgrade @property def content_length(self) -> int|None: """Content length of the message""" return self._message.content_length @property def upgrading_to_h2(self) -> bool: """If the message is upgrading to HTTP/2""" return self._message.upgrading_to_h2 @property def upgrading_to_ws(self) -> bool: """If the message is upgrading to Websocket""" return self._message.upgrading_to_ws @property def ws_stream(self) -> list[Frame]: """Websocket stream""" return self._message.ws_stream @property def stream(self) -> bytes: """Stream of the message""" return self._message.stream def get_header(self, header: str, default=None) -> str: """Get a header from the message without caring about the case""" return self._message.lheaders.get(header.lower(), default) @staticmethod def _associated_parser_class() -> Type[InternalHttpRequest]|Type[InternalHttpResponse]: raise NotImplementedError() @staticmethod def _before_fetch_callable_checks(internal_data: DataStreamCtx): return True @classmethod def _fetch_packet(cls, internal_data: DataStreamCtx): if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False: raise NotReadyToRun() ParserType = cls._associated_parser_class() parser = internal_data.data_handler_context.get(cls, None) if parser is None or parser.raised_error: parser: InternalHttpRequest|InternalHttpResponse = ParserType() internal_data.data_handler_context[cls] = parser if not cls._before_fetch_callable_checks(internal_data): raise NotReadyToRun() # Memory size managment if parser.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size: match internal_data.full_stream_action: case FullStreamAction.FLUSH: # Deleting parser and re-creating it parser.messages.clear() parser.msg.total_size -= len(parser.msg.stream) parser.msg.stream = b"" parser.msg.total_size -= len(parser.msg.body) parser.msg.body = b"" print("[WARNING] Flushing stream", flush=True) if parser.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size: parser.reset_data() case FullStreamAction.REJECT: raise StreamFullReject() case FullStreamAction.DROP: raise StreamFullDrop() case FullStreamAction.ACCEPT: raise NotReadyToRun() headers_were_set = parser.msg.headers_complete try: parser.parse_data(internal_data.current_pkt.data) except Exception as e: match internal_data.invalid_encoding_action: case ExceptionAction.REJECT: raise RejectConnection() case ExceptionAction.DROP: raise DropPacket() case ExceptionAction.NOACTION: raise e case ExceptionAction.ACCEPT: raise NotReadyToRun() messages_tosend:list[InternalHTTPMessage] = [] for i in range(len(parser.messages)): messages_tosend.append(parser.pop_message()) if len(messages_tosend) > 0: headers_were_set = False # New messages completed so the current message headers were not set in this case if not headers_were_set and parser.msg.headers_complete: messages_tosend.append(parser.msg) # Also the current message needs to be sent due to complete headers if parser._packet_to_stream(): messages_tosend.append(parser.msg) # Also the current message needs to beacase a stream is going on messages_to_call = len(messages_tosend) if messages_to_call == 0: raise NotReadyToRun() elif messages_to_call == 1: return cls(parser, messages_tosend[0]) return [cls(parser, ele) for ele in messages_tosend] class HttpRequest(InternalBasicHttpMetaClass): """ HTTP Request handler This data handler will be called twice, first with the headers complete, and second with the body complete """ @staticmethod def _associated_parser_class() -> Type[InternalHttpRequest]: return InternalHttpRequest @staticmethod def _before_fetch_callable_checks(internal_data: DataStreamCtx): return internal_data.current_pkt.is_input @property def method(self) -> bytes: """Method of the request""" return self._parser.msg.method def __repr__(self): return f"" class HttpResponse(InternalBasicHttpMetaClass): """ HTTP Response handler This data handler will be called twice, first with the headers complete, and second with the body complete """ @staticmethod def _associated_parser_class() -> Type[InternalHttpResponse]: return InternalHttpResponse @staticmethod def _before_fetch_callable_checks(internal_data: DataStreamCtx): return not internal_data.current_pkt.is_input @property def status_code(self) -> int: """Status code of the response""" return self._parser.msg.status def __repr__(self): return f"" class HttpRequestHeader(HttpRequest): """ HTTP Request Header handler This data handler will be called only once, the headers are complete, the body will be empty and not buffered """ def _contructor_hook(self): self._parser.save_body = False class HttpResponseHeader(HttpResponse): """ HTTP Response Header handler This data handler will be called only once, the headers are complete, the body will be empty and not buffered """ def _contructor_hook(self): self._parser.save_body = False