From Threads to Multiprocess
This commit is contained in:
@@ -3,12 +3,6 @@ FROM python:slim-buster
|
||||
|
||||
RUN apt-get update && apt-get -y install build-essential libpcre2-dev python-dev git iptables libnetfilter-queue-dev
|
||||
|
||||
WORKDIR /tmp/
|
||||
RUN git clone https://github.com/gpfei/python-pcre2.git
|
||||
WORKDIR /tmp/python-pcre2/
|
||||
RUN python3 setup.py install
|
||||
WORKDIR /
|
||||
|
||||
RUN mkdir /execute
|
||||
WORKDIR /execute
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ def JWT_SECRET(): return conf.get("secret")
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
db.init()
|
||||
firewall.init_updater()
|
||||
if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32))
|
||||
await firewall.reload()
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import List, Set
|
||||
import multiprocessing
|
||||
from threading import Thread
|
||||
from typing import List
|
||||
from netfilterqueue import NetfilterQueue
|
||||
from multiprocessing import Manager, Process
|
||||
from scapy.all import IP, TCP, UDP
|
||||
from subprocess import Popen, PIPE
|
||||
import os, pcre2, traceback
|
||||
from kthread import KThread
|
||||
import os, traceback, pcre, re
|
||||
|
||||
QUEUE_BASE_NUM = 1000
|
||||
|
||||
@@ -146,24 +148,21 @@ class FiregexFilterManager:
|
||||
|
||||
def get(self) -> List[FiregexFilter]:
|
||||
res = []
|
||||
balanced_mode = pcre2.PCRE2(b"NFQUEUE balance ([0-9]+):([0-9]+)")
|
||||
num_mode = pcre2.PCRE2(b"NFQUEUE num ([0-9]+)")
|
||||
port_selected = pcre2.PCRE2(b"[sd]pt:([0-9]+)")
|
||||
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
|
||||
for filter in IPTables.list_filters(filter_type):
|
||||
queue_num = None
|
||||
balanced = balanced_mode.search(filter["details"].encode())
|
||||
numbered = num_mode.search(filter["details"].encode())
|
||||
port = port_selected.search(filter["details"].encode())
|
||||
if balanced: queue_num = (int(balanced.group(1).decode()), int(balanced.group(2).decode()))
|
||||
if numbered: queue_num = (int(numbered.group(1).decode()), int(numbered.group(1).decode()))
|
||||
balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", filter["details"])
|
||||
numbered = re.findall(r"NFQUEUE num ([0-9]+)", filter["details"])
|
||||
port = re.findall(r"[sd]pt:([0-9]+)", filter["details"])
|
||||
if balanced: queue_num = (int(balanced[0]), int(balanced[1]))
|
||||
if numbered: queue_num = (int(numbered[0]), int(numbered[0]))
|
||||
if queue_num and port:
|
||||
res.append(FiregexFilter(
|
||||
type=filter_type,
|
||||
number=filter["id"],
|
||||
queue=queue_num,
|
||||
proto=filter["prot"],
|
||||
port=port.group(1).decode()
|
||||
port=int(port[0])
|
||||
))
|
||||
return res
|
||||
|
||||
@@ -204,24 +203,31 @@ class Filter:
|
||||
def compile(self):
|
||||
if isinstance(self.regex, str): self.regex = self.regex.encode()
|
||||
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
|
||||
return pcre2.PCRE2(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
|
||||
return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
|
||||
|
||||
def check(self, data):
|
||||
return True if self.compiled_regex.search(data) else False
|
||||
|
||||
def inc_block(self):
|
||||
print("INC", self.blocked)
|
||||
self.blocked+=1
|
||||
|
||||
class Proxy:
|
||||
def __init__(self, public_port = 0, callback_blocked_update=None, filters=None):
|
||||
def __init__(self, port, filters=None):
|
||||
self.manager = FiregexFilterManager()
|
||||
self.port = public_port
|
||||
self.filters: Set[Filter] = set(filters) if filters else set([])
|
||||
self.use_filters = True
|
||||
self.callback_blocked_update = callback_blocked_update
|
||||
self.threads = []
|
||||
self.queue_list = []
|
||||
self.port = port
|
||||
self.filters = Manager().list(filters) if filters else Manager().list([])
|
||||
self.process = None
|
||||
|
||||
def start(self):
|
||||
def set_filters(self, filters):
|
||||
elements_to_pop = len(self.filters)
|
||||
for ele in filters:
|
||||
self.filters.append(ele)
|
||||
for _ in range(elements_to_pop):
|
||||
self.filters.pop(0)
|
||||
|
||||
def _starter(self):
|
||||
self.manager.delete_by_port(self.port)
|
||||
|
||||
def regex_filter(pkt, data, by_client):
|
||||
packet = bytes(data[TCP if TCP in data else UDP].payload)
|
||||
try:
|
||||
@@ -229,32 +235,31 @@ class Proxy:
|
||||
if (by_client and filter.c_to_s) or (not by_client and filter.s_to_c):
|
||||
match = filter.check(packet)
|
||||
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
|
||||
filter.blocked+=1
|
||||
self.callback_blocked_update(filter)
|
||||
filter.inc_block()
|
||||
pkt.drop()
|
||||
return
|
||||
except IndexError:
|
||||
pass
|
||||
except IndexError: pass
|
||||
pkt.accept()
|
||||
queue_list = self.manager.add(ProtoTypes.TCP, self.port, regex_filter)
|
||||
threads = []
|
||||
for ele in queue_list:
|
||||
threads.append(Thread(target=ele.run))
|
||||
threads[-1].daemon = True
|
||||
threads[-1].start()
|
||||
for ele in threads: ele.join()
|
||||
for ele in queue_list: ele.unbind()
|
||||
|
||||
self.queue_list = self.manager.add(ProtoTypes.TCP, self.port, regex_filter)
|
||||
for ele in self.queue_list:
|
||||
self.threads.append(KThread(target=ele.run))
|
||||
self.threads[-1].daemon = True
|
||||
self.threads[-1].start()
|
||||
def start(self):
|
||||
self.process = Process(target=self._starter)
|
||||
self.process.start()
|
||||
|
||||
def stop(self):
|
||||
self.manager.delete_by_port(self.port)
|
||||
for ele in self.threads:
|
||||
ele.kill()
|
||||
if ele.is_alive():
|
||||
print("Not killed succesffully") #TODO
|
||||
self.threads = []
|
||||
for ele in self.queue_list:
|
||||
ele.unbind()
|
||||
self.queue_list = []
|
||||
|
||||
if self.process:
|
||||
self.process.kill()
|
||||
self.process = None
|
||||
|
||||
def restart(self):
|
||||
self.stop()
|
||||
self.start()
|
||||
|
||||
|
||||
@@ -4,5 +4,5 @@ uvicorn[standard]
|
||||
passlib[bcrypt]
|
||||
python-jose[cryptography]
|
||||
NetfilterQueue
|
||||
kthread
|
||||
scapy
|
||||
python-pcre
|
||||
@@ -1,4 +1,5 @@
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Dict
|
||||
from proxy import Filter, Proxy
|
||||
import os, sqlite3, socket, asyncio
|
||||
from base64 import b64decode
|
||||
@@ -10,7 +11,6 @@ class SQLite():
|
||||
self.conn = None
|
||||
self.cur = None
|
||||
self.db_name = db_name
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def connect(self) -> None:
|
||||
try:
|
||||
@@ -27,8 +27,7 @@ class SQLite():
|
||||
self.conn.row_factory = dict_factory
|
||||
|
||||
def disconnect(self) -> None:
|
||||
with self.lock:
|
||||
self.conn.close()
|
||||
self.conn.close()
|
||||
|
||||
def create_schema(self, tables = {}) -> None:
|
||||
cur = self.conn.cursor()
|
||||
@@ -39,9 +38,8 @@ class SQLite():
|
||||
def query(self, query, *values):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
with self.lock:
|
||||
cur.execute(query, values)
|
||||
return cur.fetchall()
|
||||
cur.execute(query, values)
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
cur.close()
|
||||
try: self.conn.commit()
|
||||
@@ -100,10 +98,7 @@ class ServiceManager:
|
||||
def __init__(self, port, db):
|
||||
self.port = port
|
||||
self.db = db
|
||||
self.proxy = Proxy(
|
||||
callback_blocked_update=self._stats_updater,
|
||||
public_port=port
|
||||
)
|
||||
self.proxy = Proxy(port)
|
||||
self.status = STATUS.STOP
|
||||
self.filters = {}
|
||||
self._update_filters_from_db()
|
||||
@@ -139,7 +134,7 @@ class ServiceManager:
|
||||
blocked_packets=filter_info["n_packets"],
|
||||
code=f
|
||||
)
|
||||
self.proxy.filters = list(self.filters.values())
|
||||
self.proxy.set_filters(self.filters.values())
|
||||
|
||||
def __update_status_db(self, status):
|
||||
self.db.query("UPDATE services SET status = ? WHERE port = ?;", status, self.port)
|
||||
@@ -161,8 +156,13 @@ class ServiceManager:
|
||||
|
||||
|
||||
def _stats_updater(self,filter:Filter):
|
||||
print(filter, filter.blocked, filter.code)
|
||||
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
|
||||
|
||||
def update_stats(self):
|
||||
for ele in self.proxy.filters:
|
||||
self._stats_updater(ele)
|
||||
|
||||
def _set_status(self,status):
|
||||
self.status = status
|
||||
self.__update_status_db(status)
|
||||
@@ -174,9 +174,12 @@ class ServiceManager:
|
||||
class ProxyManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.proxy_table = {}
|
||||
self.proxy_table: Dict[ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
def init_updater(self):
|
||||
asyncio.create_task(self._stats_updater())
|
||||
|
||||
async def close(self):
|
||||
for key in list(self.proxy_table.keys()):
|
||||
await self.remove(key)
|
||||
@@ -197,6 +200,17 @@ class ProxyManager:
|
||||
self.proxy_table[srv_port] = ServiceManager(srv_port,self.db)
|
||||
await self.proxy_table[srv_port].next(req_status)
|
||||
|
||||
async def _stats_updater(self):
|
||||
while True:
|
||||
print("ALIVE!")
|
||||
try:
|
||||
for key in list(self.proxy_table.keys()):
|
||||
self.proxy_table[key].update_stats()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
def get(self,port):
|
||||
if port in self.proxy_table:
|
||||
return self.proxy_table[port]
|
||||
|
||||
Reference in New Issue
Block a user