From Threads to Multiprocess

This commit is contained in:
DomySh
2022-07-08 11:35:12 +02:00
parent 6e317defbf
commit 74e3e832b3
5 changed files with 76 additions and 62 deletions

View File

@@ -3,12 +3,6 @@ FROM python:slim-buster
RUN apt-get update && apt-get -y install build-essential libpcre2-dev python-dev git iptables libnetfilter-queue-dev
WORKDIR /tmp/
RUN git clone https://github.com/gpfei/python-pcre2.git
WORKDIR /tmp/python-pcre2/
RUN python3 setup.py install
WORKDIR /
RUN mkdir /execute
WORKDIR /execute

View File

@@ -37,6 +37,7 @@ def JWT_SECRET(): return conf.get("secret")
@app.on_event("startup")
async def startup_event():
db.init()
firewall.init_updater()
if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32))
await firewall.reload()

View File

@@ -1,9 +1,11 @@
from typing import List, Set
import multiprocessing
from threading import Thread
from typing import List
from netfilterqueue import NetfilterQueue
from multiprocessing import Manager, Process
from scapy.all import IP, TCP, UDP
from subprocess import Popen, PIPE
import os, pcre2, traceback
from kthread import KThread
import os, traceback, pcre, re
QUEUE_BASE_NUM = 1000
@@ -146,24 +148,21 @@ class FiregexFilterManager:
def get(self) -> List[FiregexFilter]:
res = []
balanced_mode = pcre2.PCRE2(b"NFQUEUE balance ([0-9]+):([0-9]+)")
num_mode = pcre2.PCRE2(b"NFQUEUE num ([0-9]+)")
port_selected = pcre2.PCRE2(b"[sd]pt:([0-9]+)")
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
for filter in IPTables.list_filters(filter_type):
queue_num = None
balanced = balanced_mode.search(filter["details"].encode())
numbered = num_mode.search(filter["details"].encode())
port = port_selected.search(filter["details"].encode())
if balanced: queue_num = (int(balanced.group(1).decode()), int(balanced.group(2).decode()))
if numbered: queue_num = (int(numbered.group(1).decode()), int(numbered.group(1).decode()))
balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", filter["details"])
numbered = re.findall(r"NFQUEUE num ([0-9]+)", filter["details"])
port = re.findall(r"[sd]pt:([0-9]+)", filter["details"])
if balanced: queue_num = (int(balanced[0]), int(balanced[1]))
if numbered: queue_num = (int(numbered[0]), int(numbered[0]))
if queue_num and port:
res.append(FiregexFilter(
type=filter_type,
number=filter["id"],
queue=queue_num,
proto=filter["prot"],
port=port.group(1).decode()
port=int(port[0])
))
return res
@@ -204,24 +203,31 @@ class Filter:
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
return pcre2.PCRE2(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
def check(self, data):
return True if self.compiled_regex.search(data) else False
def inc_block(self):
print("INC", self.blocked)
self.blocked+=1
class Proxy:
def __init__(self, public_port = 0, callback_blocked_update=None, filters=None):
def __init__(self, port, filters=None):
self.manager = FiregexFilterManager()
self.port = public_port
self.filters: Set[Filter] = set(filters) if filters else set([])
self.use_filters = True
self.callback_blocked_update = callback_blocked_update
self.threads = []
self.queue_list = []
self.port = port
self.filters = Manager().list(filters) if filters else Manager().list([])
self.process = None
def set_filters(self, filters):
elements_to_pop = len(self.filters)
for ele in filters:
self.filters.append(ele)
for _ in range(elements_to_pop):
self.filters.pop(0)
def start(self):
def _starter(self):
self.manager.delete_by_port(self.port)
def regex_filter(pkt, data, by_client):
packet = bytes(data[TCP if TCP in data else UDP].payload)
try:
@@ -229,32 +235,31 @@ class Proxy:
if (by_client and filter.c_to_s) or (not by_client and filter.s_to_c):
match = filter.check(packet)
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
filter.blocked+=1
self.callback_blocked_update(filter)
filter.inc_block()
pkt.drop()
return
except IndexError:
pass
except IndexError: pass
pkt.accept()
queue_list = self.manager.add(ProtoTypes.TCP, self.port, regex_filter)
threads = []
for ele in queue_list:
threads.append(Thread(target=ele.run))
threads[-1].daemon = True
threads[-1].start()
for ele in threads: ele.join()
for ele in queue_list: ele.unbind()
self.queue_list = self.manager.add(ProtoTypes.TCP, self.port, regex_filter)
for ele in self.queue_list:
self.threads.append(KThread(target=ele.run))
self.threads[-1].daemon = True
self.threads[-1].start()
def start(self):
self.process = Process(target=self._starter)
self.process.start()
def stop(self):
self.manager.delete_by_port(self.port)
for ele in self.threads:
ele.kill()
if ele.is_alive():
print("Not killed succesffully") #TODO
self.threads = []
for ele in self.queue_list:
ele.unbind()
self.queue_list = []
if self.process:
self.process.kill()
self.process = None
def restart(self):
self.stop()
self.start()
self.start()

View File

@@ -4,5 +4,5 @@ uvicorn[standard]
passlib[bcrypt]
python-jose[cryptography]
NetfilterQueue
kthread
scapy
scapy
python-pcre

View File

@@ -1,4 +1,5 @@
import threading
import traceback
from typing import Dict
from proxy import Filter, Proxy
import os, sqlite3, socket, asyncio
from base64 import b64decode
@@ -10,7 +11,6 @@ class SQLite():
self.conn = None
self.cur = None
self.db_name = db_name
self.lock = threading.Lock()
def connect(self) -> None:
try:
@@ -27,8 +27,7 @@ class SQLite():
self.conn.row_factory = dict_factory
def disconnect(self) -> None:
with self.lock:
self.conn.close()
self.conn.close()
def create_schema(self, tables = {}) -> None:
cur = self.conn.cursor()
@@ -39,9 +38,8 @@ class SQLite():
def query(self, query, *values):
cur = self.conn.cursor()
try:
with self.lock:
cur.execute(query, values)
return cur.fetchall()
cur.execute(query, values)
return cur.fetchall()
finally:
cur.close()
try: self.conn.commit()
@@ -100,10 +98,7 @@ class ServiceManager:
def __init__(self, port, db):
self.port = port
self.db = db
self.proxy = Proxy(
callback_blocked_update=self._stats_updater,
public_port=port
)
self.proxy = Proxy(port)
self.status = STATUS.STOP
self.filters = {}
self._update_filters_from_db()
@@ -139,7 +134,7 @@ class ServiceManager:
blocked_packets=filter_info["n_packets"],
code=f
)
self.proxy.filters = list(self.filters.values())
self.proxy.set_filters(self.filters.values())
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE port = ?;", status, self.port)
@@ -161,7 +156,12 @@ class ServiceManager:
def _stats_updater(self,filter:Filter):
print(filter, filter.blocked, filter.code)
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
def update_stats(self):
for ele in self.proxy.filters:
self._stats_updater(ele)
def _set_status(self,status):
self.status = status
@@ -174,8 +174,11 @@ class ServiceManager:
class ProxyManager:
def __init__(self, db:SQLite):
self.db = db
self.proxy_table = {}
self.proxy_table: Dict[ServiceManager] = {}
self.lock = asyncio.Lock()
def init_updater(self):
asyncio.create_task(self._stats_updater())
async def close(self):
for key in list(self.proxy_table.keys()):
@@ -197,6 +200,17 @@ class ProxyManager:
self.proxy_table[srv_port] = ServiceManager(srv_port,self.db)
await self.proxy_table[srv_port].next(req_status)
async def _stats_updater(self):
while True:
print("ALIVE!")
try:
for key in list(self.proxy_table.keys()):
self.proxy_table[key].update_stats()
except Exception:
traceback.print_exc()
await asyncio.sleep(1)
def get(self,port):
if port in self.proxy_table:
return self.proxy_table[port]