From Threads to Multiprocess

This commit is contained in:
DomySh
2022-07-08 11:35:12 +02:00
parent 6e317defbf
commit 74e3e832b3
5 changed files with 76 additions and 62 deletions

View File

@@ -3,12 +3,6 @@ FROM python:slim-buster
RUN apt-get update && apt-get -y install build-essential libpcre2-dev python-dev git iptables libnetfilter-queue-dev RUN apt-get update && apt-get -y install build-essential libpcre2-dev python-dev git iptables libnetfilter-queue-dev
WORKDIR /tmp/
RUN git clone https://github.com/gpfei/python-pcre2.git
WORKDIR /tmp/python-pcre2/
RUN python3 setup.py install
WORKDIR /
RUN mkdir /execute RUN mkdir /execute
WORKDIR /execute WORKDIR /execute

View File

@@ -37,6 +37,7 @@ def JWT_SECRET(): return conf.get("secret")
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
db.init() db.init()
firewall.init_updater()
if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32)) if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32))
await firewall.reload() await firewall.reload()

View File

@@ -1,9 +1,11 @@
from typing import List, Set import multiprocessing
from threading import Thread
from typing import List
from netfilterqueue import NetfilterQueue from netfilterqueue import NetfilterQueue
from multiprocessing import Manager, Process
from scapy.all import IP, TCP, UDP from scapy.all import IP, TCP, UDP
from subprocess import Popen, PIPE from subprocess import Popen, PIPE
import os, pcre2, traceback import os, traceback, pcre, re
from kthread import KThread
QUEUE_BASE_NUM = 1000 QUEUE_BASE_NUM = 1000
@@ -146,24 +148,21 @@ class FiregexFilterManager:
def get(self) -> List[FiregexFilter]: def get(self) -> List[FiregexFilter]:
res = [] res = []
balanced_mode = pcre2.PCRE2(b"NFQUEUE balance ([0-9]+):([0-9]+)")
num_mode = pcre2.PCRE2(b"NFQUEUE num ([0-9]+)")
port_selected = pcre2.PCRE2(b"[sd]pt:([0-9]+)")
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]: for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
for filter in IPTables.list_filters(filter_type): for filter in IPTables.list_filters(filter_type):
queue_num = None queue_num = None
balanced = balanced_mode.search(filter["details"].encode()) balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", filter["details"])
numbered = num_mode.search(filter["details"].encode()) numbered = re.findall(r"NFQUEUE num ([0-9]+)", filter["details"])
port = port_selected.search(filter["details"].encode()) port = re.findall(r"[sd]pt:([0-9]+)", filter["details"])
if balanced: queue_num = (int(balanced.group(1).decode()), int(balanced.group(2).decode())) if balanced: queue_num = (int(balanced[0]), int(balanced[1]))
if numbered: queue_num = (int(numbered.group(1).decode()), int(numbered.group(1).decode())) if numbered: queue_num = (int(numbered[0]), int(numbered[0]))
if queue_num and port: if queue_num and port:
res.append(FiregexFilter( res.append(FiregexFilter(
type=filter_type, type=filter_type,
number=filter["id"], number=filter["id"],
queue=queue_num, queue=queue_num,
proto=filter["prot"], proto=filter["prot"],
port=port.group(1).decode() port=int(port[0])
)) ))
return res return res
@@ -204,24 +203,31 @@ class Filter:
def compile(self): def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode() if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether") if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
return pcre2.PCRE2(self.regex if self.is_case_sensitive else b"(?i)"+self.regex) return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
def check(self, data): def check(self, data):
return True if self.compiled_regex.search(data) else False return True if self.compiled_regex.search(data) else False
def inc_block(self):
print("INC", self.blocked)
self.blocked+=1
class Proxy: class Proxy:
def __init__(self, public_port = 0, callback_blocked_update=None, filters=None): def __init__(self, port, filters=None):
self.manager = FiregexFilterManager() self.manager = FiregexFilterManager()
self.port = public_port self.port = port
self.filters: Set[Filter] = set(filters) if filters else set([]) self.filters = Manager().list(filters) if filters else Manager().list([])
self.use_filters = True self.process = None
self.callback_blocked_update = callback_blocked_update
self.threads = []
self.queue_list = []
def start(self): def set_filters(self, filters):
elements_to_pop = len(self.filters)
for ele in filters:
self.filters.append(ele)
for _ in range(elements_to_pop):
self.filters.pop(0)
def _starter(self):
self.manager.delete_by_port(self.port) self.manager.delete_by_port(self.port)
def regex_filter(pkt, data, by_client): def regex_filter(pkt, data, by_client):
packet = bytes(data[TCP if TCP in data else UDP].payload) packet = bytes(data[TCP if TCP in data else UDP].payload)
try: try:
@@ -229,32 +235,31 @@ class Proxy:
if (by_client and filter.c_to_s) or (not by_client and filter.s_to_c): if (by_client and filter.c_to_s) or (not by_client and filter.s_to_c):
match = filter.check(packet) match = filter.check(packet)
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match): if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
filter.blocked+=1 filter.inc_block()
self.callback_blocked_update(filter)
pkt.drop() pkt.drop()
return return
except IndexError: except IndexError: pass
pass
pkt.accept() pkt.accept()
queue_list = self.manager.add(ProtoTypes.TCP, self.port, regex_filter)
threads = []
for ele in queue_list:
threads.append(Thread(target=ele.run))
threads[-1].daemon = True
threads[-1].start()
for ele in threads: ele.join()
for ele in queue_list: ele.unbind()
self.queue_list = self.manager.add(ProtoTypes.TCP, self.port, regex_filter) def start(self):
for ele in self.queue_list: self.process = Process(target=self._starter)
self.threads.append(KThread(target=ele.run)) self.process.start()
self.threads[-1].daemon = True
self.threads[-1].start()
def stop(self): def stop(self):
self.manager.delete_by_port(self.port) self.manager.delete_by_port(self.port)
for ele in self.threads: if self.process:
ele.kill() self.process.kill()
if ele.is_alive(): self.process = None
print("Not killed succesffully") #TODO
self.threads = []
for ele in self.queue_list:
ele.unbind()
self.queue_list = []
def restart(self): def restart(self):
self.stop() self.stop()
self.start() self.start()

View File

@@ -4,5 +4,5 @@ uvicorn[standard]
passlib[bcrypt] passlib[bcrypt]
python-jose[cryptography] python-jose[cryptography]
NetfilterQueue NetfilterQueue
kthread
scapy scapy
python-pcre

View File

@@ -1,4 +1,5 @@
import threading import traceback
from typing import Dict
from proxy import Filter, Proxy from proxy import Filter, Proxy
import os, sqlite3, socket, asyncio import os, sqlite3, socket, asyncio
from base64 import b64decode from base64 import b64decode
@@ -10,7 +11,6 @@ class SQLite():
self.conn = None self.conn = None
self.cur = None self.cur = None
self.db_name = db_name self.db_name = db_name
self.lock = threading.Lock()
def connect(self) -> None: def connect(self) -> None:
try: try:
@@ -27,8 +27,7 @@ class SQLite():
self.conn.row_factory = dict_factory self.conn.row_factory = dict_factory
def disconnect(self) -> None: def disconnect(self) -> None:
with self.lock: self.conn.close()
self.conn.close()
def create_schema(self, tables = {}) -> None: def create_schema(self, tables = {}) -> None:
cur = self.conn.cursor() cur = self.conn.cursor()
@@ -39,9 +38,8 @@ class SQLite():
def query(self, query, *values): def query(self, query, *values):
cur = self.conn.cursor() cur = self.conn.cursor()
try: try:
with self.lock: cur.execute(query, values)
cur.execute(query, values) return cur.fetchall()
return cur.fetchall()
finally: finally:
cur.close() cur.close()
try: self.conn.commit() try: self.conn.commit()
@@ -100,10 +98,7 @@ class ServiceManager:
def __init__(self, port, db): def __init__(self, port, db):
self.port = port self.port = port
self.db = db self.db = db
self.proxy = Proxy( self.proxy = Proxy(port)
callback_blocked_update=self._stats_updater,
public_port=port
)
self.status = STATUS.STOP self.status = STATUS.STOP
self.filters = {} self.filters = {}
self._update_filters_from_db() self._update_filters_from_db()
@@ -139,7 +134,7 @@ class ServiceManager:
blocked_packets=filter_info["n_packets"], blocked_packets=filter_info["n_packets"],
code=f code=f
) )
self.proxy.filters = list(self.filters.values()) self.proxy.set_filters(self.filters.values())
def __update_status_db(self, status): def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE port = ?;", status, self.port) self.db.query("UPDATE services SET status = ? WHERE port = ?;", status, self.port)
@@ -161,8 +156,13 @@ class ServiceManager:
def _stats_updater(self,filter:Filter): def _stats_updater(self,filter:Filter):
print(filter, filter.blocked, filter.code)
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code) self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
def update_stats(self):
for ele in self.proxy.filters:
self._stats_updater(ele)
def _set_status(self,status): def _set_status(self,status):
self.status = status self.status = status
self.__update_status_db(status) self.__update_status_db(status)
@@ -174,9 +174,12 @@ class ServiceManager:
class ProxyManager: class ProxyManager:
def __init__(self, db:SQLite): def __init__(self, db:SQLite):
self.db = db self.db = db
self.proxy_table = {} self.proxy_table: Dict[ServiceManager] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
def init_updater(self):
asyncio.create_task(self._stats_updater())
async def close(self): async def close(self):
for key in list(self.proxy_table.keys()): for key in list(self.proxy_table.keys()):
await self.remove(key) await self.remove(key)
@@ -197,6 +200,17 @@ class ProxyManager:
self.proxy_table[srv_port] = ServiceManager(srv_port,self.db) self.proxy_table[srv_port] = ServiceManager(srv_port,self.db)
await self.proxy_table[srv_port].next(req_status) await self.proxy_table[srv_port].next(req_status)
async def _stats_updater(self):
while True:
print("ALIVE!")
try:
for key in list(self.proxy_table.keys()):
self.proxy_table[key].update_stats()
except Exception:
traceback.print_exc()
await asyncio.sleep(1)
def get(self,port): def get(self,port):
if port in self.proxy_table: if port in self.proxy_table:
return self.proxy_table[port] return self.proxy_table[port]