diff --git a/Dockerfile b/Dockerfile index f829356..9b8f3cc 100755 --- a/Dockerfile +++ b/Dockerfile @@ -3,12 +3,6 @@ FROM python:slim-buster RUN apt-get update && apt-get -y install build-essential libpcre2-dev python-dev git iptables libnetfilter-queue-dev -WORKDIR /tmp/ -RUN git clone https://github.com/gpfei/python-pcre2.git -WORKDIR /tmp/python-pcre2/ -RUN python3 setup.py install -WORKDIR / - RUN mkdir /execute WORKDIR /execute diff --git a/backend/app.py b/backend/app.py index 23b9801..e8b99fe 100644 --- a/backend/app.py +++ b/backend/app.py @@ -37,6 +37,7 @@ def JWT_SECRET(): return conf.get("secret") @app.on_event("startup") async def startup_event(): db.init() + firewall.init_updater() if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32)) await firewall.reload() diff --git a/backend/proxy.py b/backend/proxy.py index afecce4..4320664 100755 --- a/backend/proxy.py +++ b/backend/proxy.py @@ -1,9 +1,11 @@ -from typing import List, Set +import multiprocessing +from threading import Thread +from typing import List from netfilterqueue import NetfilterQueue +from multiprocessing import Manager, Process from scapy.all import IP, TCP, UDP from subprocess import Popen, PIPE -import os, pcre2, traceback -from kthread import KThread +import os, traceback, pcre, re QUEUE_BASE_NUM = 1000 @@ -146,24 +148,21 @@ class FiregexFilterManager: def get(self) -> List[FiregexFilter]: res = [] - balanced_mode = pcre2.PCRE2(b"NFQUEUE balance ([0-9]+):([0-9]+)") - num_mode = pcre2.PCRE2(b"NFQUEUE num ([0-9]+)") - port_selected = pcre2.PCRE2(b"[sd]pt:([0-9]+)") for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]: for filter in IPTables.list_filters(filter_type): queue_num = None - balanced = balanced_mode.search(filter["details"].encode()) - numbered = num_mode.search(filter["details"].encode()) - port = port_selected.search(filter["details"].encode()) - if balanced: queue_num = (int(balanced.group(1).decode()), int(balanced.group(2).decode())) - if numbered: queue_num = (int(numbered.group(1).decode()), int(numbered.group(1).decode())) + balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", filter["details"]) + numbered = re.findall(r"NFQUEUE num ([0-9]+)", filter["details"]) + port = re.findall(r"[sd]pt:([0-9]+)", filter["details"]) + if balanced: queue_num = (int(balanced[0]), int(balanced[1])) + if numbered: queue_num = (int(numbered[0]), int(numbered[0])) if queue_num and port: res.append(FiregexFilter( type=filter_type, number=filter["id"], queue=queue_num, proto=filter["prot"], - port=port.group(1).decode() + port=int(port[0]) )) return res @@ -204,24 +203,31 @@ class Filter: def compile(self): if isinstance(self.regex, str): self.regex = self.regex.encode() if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether") - return pcre2.PCRE2(self.regex if self.is_case_sensitive else b"(?i)"+self.regex) + return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex) def check(self, data): return True if self.compiled_regex.search(data) else False + + def inc_block(self): + print("INC", self.blocked) + self.blocked+=1 class Proxy: - def __init__(self, public_port = 0, callback_blocked_update=None, filters=None): + def __init__(self, port, filters=None): self.manager = FiregexFilterManager() - self.port = public_port - self.filters: Set[Filter] = set(filters) if filters else set([]) - self.use_filters = True - self.callback_blocked_update = callback_blocked_update - self.threads = [] - self.queue_list = [] + self.port = port + self.filters = Manager().list(filters) if filters else Manager().list([]) + self.process = None + + def set_filters(self, filters): + elements_to_pop = len(self.filters) + for ele in filters: + self.filters.append(ele) + for _ in range(elements_to_pop): + self.filters.pop(0) - def start(self): + def _starter(self): self.manager.delete_by_port(self.port) - def regex_filter(pkt, data, by_client): packet = bytes(data[TCP if TCP in data else UDP].payload) try: @@ -229,32 +235,31 @@ class Proxy: if (by_client and filter.c_to_s) or (not by_client and filter.s_to_c): match = filter.check(packet) if (filter.is_blacklist and match) or (not filter.is_blacklist and not match): - filter.blocked+=1 - self.callback_blocked_update(filter) + filter.inc_block() pkt.drop() return - except IndexError: - pass + except IndexError: pass pkt.accept() + queue_list = self.manager.add(ProtoTypes.TCP, self.port, regex_filter) + threads = [] + for ele in queue_list: + threads.append(Thread(target=ele.run)) + threads[-1].daemon = True + threads[-1].start() + for ele in threads: ele.join() + for ele in queue_list: ele.unbind() - self.queue_list = self.manager.add(ProtoTypes.TCP, self.port, regex_filter) - for ele in self.queue_list: - self.threads.append(KThread(target=ele.run)) - self.threads[-1].daemon = True - self.threads[-1].start() + def start(self): + self.process = Process(target=self._starter) + self.process.start() def stop(self): self.manager.delete_by_port(self.port) - for ele in self.threads: - ele.kill() - if ele.is_alive(): - print("Not killed succesffully") #TODO - self.threads = [] - for ele in self.queue_list: - ele.unbind() - self.queue_list = [] - + if self.process: + self.process.kill() + self.process = None def restart(self): self.stop() - self.start() \ No newline at end of file + self.start() + diff --git a/backend/requirements.txt b/backend/requirements.txt index 00962b5..d61d54b 100755 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,5 +4,5 @@ uvicorn[standard] passlib[bcrypt] python-jose[cryptography] NetfilterQueue -kthread -scapy \ No newline at end of file +scapy +python-pcre \ No newline at end of file diff --git a/backend/utils.py b/backend/utils.py index 892fae5..bac2034 100755 --- a/backend/utils.py +++ b/backend/utils.py @@ -1,4 +1,5 @@ -import threading +import traceback +from typing import Dict from proxy import Filter, Proxy import os, sqlite3, socket, asyncio from base64 import b64decode @@ -10,7 +11,6 @@ class SQLite(): self.conn = None self.cur = None self.db_name = db_name - self.lock = threading.Lock() def connect(self) -> None: try: @@ -27,8 +27,7 @@ class SQLite(): self.conn.row_factory = dict_factory def disconnect(self) -> None: - with self.lock: - self.conn.close() + self.conn.close() def create_schema(self, tables = {}) -> None: cur = self.conn.cursor() @@ -39,9 +38,8 @@ class SQLite(): def query(self, query, *values): cur = self.conn.cursor() try: - with self.lock: - cur.execute(query, values) - return cur.fetchall() + cur.execute(query, values) + return cur.fetchall() finally: cur.close() try: self.conn.commit() @@ -100,10 +98,7 @@ class ServiceManager: def __init__(self, port, db): self.port = port self.db = db - self.proxy = Proxy( - callback_blocked_update=self._stats_updater, - public_port=port - ) + self.proxy = Proxy(port) self.status = STATUS.STOP self.filters = {} self._update_filters_from_db() @@ -139,7 +134,7 @@ class ServiceManager: blocked_packets=filter_info["n_packets"], code=f ) - self.proxy.filters = list(self.filters.values()) + self.proxy.set_filters(self.filters.values()) def __update_status_db(self, status): self.db.query("UPDATE services SET status = ? WHERE port = ?;", status, self.port) @@ -161,7 +156,12 @@ class ServiceManager: def _stats_updater(self,filter:Filter): + print(filter, filter.blocked, filter.code) self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code) + + def update_stats(self): + for ele in self.proxy.filters: + self._stats_updater(ele) def _set_status(self,status): self.status = status @@ -174,8 +174,11 @@ class ServiceManager: class ProxyManager: def __init__(self, db:SQLite): self.db = db - self.proxy_table = {} + self.proxy_table: Dict[ServiceManager] = {} self.lock = asyncio.Lock() + + def init_updater(self): + asyncio.create_task(self._stats_updater()) async def close(self): for key in list(self.proxy_table.keys()): @@ -197,6 +200,17 @@ class ProxyManager: self.proxy_table[srv_port] = ServiceManager(srv_port,self.db) await self.proxy_table[srv_port].next(req_status) + async def _stats_updater(self): + while True: + print("ALIVE!") + try: + for key in list(self.proxy_table.keys()): + self.proxy_table[key].update_stats() + except Exception: + traceback.print_exc() + await asyncio.sleep(1) + + def get(self,port): if port in self.proxy_table: return self.proxy_table[port]