From 1e94c26fd665c915464c50282d5d53fc63ab8211 Mon Sep 17 00:00:00 2001 From: DomySh Date: Tue, 12 Jul 2022 20:18:54 +0200 Subject: [PATCH] Refactor code pt 1 (not tested) --- backend/.vscode/settings.json | 5 + backend/app.py | 21 +-- backend/modules/__init__.py | 2 + backend/modules/firegex.py | 170 +++++++++++++++++++++ backend/modules/firewall.py | 196 +++++++++++++++++++++++++ backend/modules/iptables.py | 82 +++++++++++ backend/modules/sqlite.py | 130 +++++++++++++++++ backend/proxy.py | 268 ---------------------------------- backend/utils.py | 257 +------------------------------- 9 files changed, 597 insertions(+), 534 deletions(-) create mode 100644 backend/.vscode/settings.json create mode 100644 backend/modules/__init__.py create mode 100644 backend/modules/firegex.py create mode 100644 backend/modules/firewall.py create mode 100644 backend/modules/iptables.py create mode 100644 backend/modules/sqlite.py delete mode 100755 backend/proxy.py diff --git a/backend/.vscode/settings.json b/backend/.vscode/settings.json new file mode 100644 index 0000000..a1aef15 --- /dev/null +++ b/backend/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python.linting.pylintEnabled": false, + "python.linting.mypyEnabled": true, + "python.linting.enabled": true +} \ No newline at end of file diff --git a/backend/app.py b/backend/app.py index 07d5d22..eaea55e 100644 --- a/backend/app.py +++ b/backend/app.py @@ -5,12 +5,14 @@ from typing import List, Union from fastapi import FastAPI, HTTPException, WebSocket, Depends from pydantic import BaseModel, BaseSettings from fastapi.responses import FileResponse, StreamingResponse -from utils import * from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt from passlib.context import CryptContext from fastapi_socketio import SocketManager from ipaddress import ip_interface +from modules import SQLite, FirewallManager +from modules.firewall import STATUS +from utils import refactor_name, gen_service_id ON_DOCKER = len(sys.argv) > 1 and sys.argv[1] == "DOCKER" DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG" @@ -18,8 +20,7 @@ DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG" # DB init if not os.path.exists("db"): os.mkdir("db") db = SQLite('db/firegex.db') -conf = KeyValueStorage(db) -firewall = ProxyManager(db) +firewall = FirewallManager(db) class Settings(BaseSettings): JWT_ALGORITHM: str = "HS256" @@ -35,8 +36,8 @@ crypto = CryptContext(schemes=["bcrypt"], deprecated="auto") app = FastAPI(debug=DEBUG, redoc_url=None) sio = SocketManager(app, "/sock", socketio_path="") -def APP_STATUS(): return "init" if conf.get("password") is None else "run" -def JWT_SECRET(): return conf.get("secret") +def APP_STATUS(): return "init" if db.get("password") is None else "run" +def JWT_SECRET(): return db.get("secret") async def refresh_frontend(): await sio.emit("update","Refresh") @@ -49,7 +50,7 @@ async def startup_event(): db.init() await firewall.init(refresh_frontend) await refresh_frontend() - if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32)) + if not JWT_SECRET(): db.put("secret", secrets.token_hex(32)) @app.on_event("shutdown") async def shutdown_event(): @@ -108,7 +109,7 @@ async def login_api(form: OAuth2PasswordRequestForm = Depends()): if form.password == "": return {"status":"Cannot insert an empty password!"} await asyncio.sleep(0.3) # No bruteforce :) - if crypto.verify(form.password, conf.get("password")): + if crypto.verify(form.password, db.get("password")): return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"} raise HTTPException(406,"Wrong password!") @@ -124,10 +125,10 @@ async def change_password(form: PasswordChangeForm, auth: bool = Depends(is_logg if form.password == "": return {"status":"Cannot insert an empty password!"} if form.expire: - conf.put("secret", secrets.token_hex(32)) + db.put("secret", secrets.token_hex(32)) hash_psw = crypto.hash(form.password) - conf.put("password",hash_psw) + db.put("password",hash_psw) await refresh_frontend() return {"status":"ok", "access_token": create_access_token({"logged_in": True})} @@ -139,7 +140,7 @@ async def set_password(form: PasswordForm): if form.password == "": return {"status":"Cannot insert an empty password!"} hash_psw = crypto.hash(form.password) - conf.put("password",hash_psw) + db.put("password",hash_psw) await refresh_frontend() return {"status":"ok", "access_token": create_access_token({"logged_in": True})} diff --git a/backend/modules/__init__.py b/backend/modules/__init__.py new file mode 100644 index 0000000..14d33e5 --- /dev/null +++ b/backend/modules/__init__.py @@ -0,0 +1,2 @@ +from .firewall import FirewallManager +from .sqlite import SQLite \ No newline at end of file diff --git a/backend/modules/firegex.py b/backend/modules/firegex.py new file mode 100644 index 0000000..2880008 --- /dev/null +++ b/backend/modules/firegex.py @@ -0,0 +1,170 @@ +from typing import List +from pypacker import interceptor +from pypacker.layer3 import ip, ip6 +from pypacker.layer4 import tcp, udp +from ipaddress import ip_interface +from modules.iptables import IPTables +import os, traceback + +from modules.sqlite import Service + +class FilterTypes: + INPUT = "FIREGEX-INPUT" + OUTPUT = "FIREGEX-OUTPUT" + +QUEUE_BASE_NUM = 1000 + +class FiregexFilter(): + def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None, func=None): + self.target = target + self.id = int(id) if id else None + self.queue = queue + self.proto = proto + self.port = int(port) + self.ip_int = str(ip_int) + self.func = func + + def __eq__(self, o: object) -> bool: + if isinstance(o, FiregexFilter): + return self.port == o.port and self.proto == o.proto and ip_interface(self.ip_int) == ip_interface(o.ip_int) + return False + + def ipv6(self): + return ip_interface(self.ip_int).version == 6 + + def ipv4(self): + return ip_interface(self.ip_int).version == 4 + + def input_func(self): + def none(pkt): return True + def wrap(pkt): return self.func(pkt, True) + return wrap if self.func else none + + def output_func(self): + def none(pkt): return True + def wrap(pkt): return self.func(pkt, False) + return wrap if self.func else none + +class FiregexTables(IPTables): + + def __init__(self, ipv6=False): + super().__init__(ipv6, "mangle") + self.create_chain(FilterTypes.INPUT) + self.add_chain_to_input(FilterTypes.INPUT) + self.create_chain(FilterTypes.OUTPUT) + self.add_chain_to_output(FilterTypes.OUTPUT) + + def target_in_chain(self, chain, target): + for filter in self.list()[chain]: + if filter.target == target: + return True + return False + + def add_chain_to_input(self, chain): + if not self.target_in_chain("PREROUTING", str(chain)): + self.insert_rule("PREROUTING", str(chain)) + + def add_chain_to_output(self, chain): + if not self.target_in_chain("POSTROUTING", str(chain)): + self.insert_rule("POSTROUTING", str(chain)) + + def add_output(self, queue_range, proto = None, port = None, ip_int = None): + init, end = queue_range + if init > end: init, end = end, init + self.append_rule(FilterTypes.OUTPUT,"NFQUEUE" + * (["-p", str(proto)] if proto else []), + * (["-s", str(ip_int)] if ip_int else []), + * (["--sport", str(port)] if port else []), + * (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]), + "--queue-bypass" + ) + + def add_input(self, queue_range, proto = None, port = None, ip_int = None): + init, end = queue_range + if init > end: init, end = end, init + self.append_rule(FilterTypes.INPUT, "NFQUEUE", + * (["-p", str(proto)] if proto else []), + * (["-d", str(ip_int)] if ip_int else []), + * (["--dport", str(port)] if port else []), + * (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]), + "--queue-bypass" + ) + + def get(self) -> List[FiregexFilter]: + res = [] + for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]: + for filter in self.list()[filter_type]: + port = filter.sport() if filter_type == FilterTypes.OUTPUT else filter.dport() + queue = filter.nfqueue() + if queue and port: + res.append(FiregexFilter( + target=filter_type, + id=filter.id, + queue=queue, + proto=filter.prot, + port=port, + ip_int=filter.source if filter_type == FilterTypes.OUTPUT else filter.destination + )) + return res + + def add(self, filter:FiregexFilter): + if filter in self.get(): return None + return FiregexInterceptor( iptables=self, filter=filter, n_threads=int(os.getenv("N_THREADS_NFQUEUE","1"))) + + def delete_all(self): + for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]: + self.flush_chain(filter_type) + + def delete_by_srv(self, srv:Service): + for filter in self.get(): + if filter.port == srv.port and filter.proto == srv.proto and ip_interface(filter.ip_int) == ip_interface(srv.ip_int): + self.delete_rule(filter.target, filter.id) + +class FiregexInterceptor: + def __init__(self, iptables: FiregexTables, filter: FiregexFilter, n_threads:int = 1): + self.filter = filter + self.ipv6 = self.filter.ipv6() + self.itor_input, codes = self._start_queue(filter.input_func(), n_threads) + iptables.add_input(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) + self.itor_output, codes = self._start_queue(filter.output_func(), n_threads) + iptables.add_output(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) + + def _start_queue(self,func,n_threads): + def func_wrap(ll_data, ll_proto_id, data, ctx, *args): + pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data) + try: + data = None + if not pkt_parsed[tcp.TCP] is None: + data = pkt_parsed[tcp.TCP].body_bytes + if not pkt_parsed[tcp.TCP] is None: + data = pkt_parsed[udp.UDP].body_bytes + if data: + if func(data): + return data, interceptor.NF_ACCEPT + elif pkt_parsed[tcp.TCP]: + pkt_parsed[tcp.TCP].flags &= 0x00 + pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK + pkt_parsed[tcp.TCP].body_bytes = b"" + return pkt_parsed.bin(), interceptor.NF_ACCEPT + else: return b"", interceptor.NF_DROP + else: return data, interceptor.NF_ACCEPT + except Exception: + traceback.print_exc() + return data, interceptor.NF_ACCEPT + + ictor = interceptor.Interceptor() + starts = QUEUE_BASE_NUM + while True: + if starts >= 65536: + raise Exception("Netfilter queue is full!") + queue_ids = list(range(starts,starts+n_threads)) + try: + ictor.start(func_wrap, queue_ids=queue_ids) + break + except interceptor.UnableToBindException as e: + starts = e.queue_id + 1 + return ictor, (starts, starts+n_threads-1) + + def stop(self): + self.itor_input.stop() + self.itor_output.stop() \ No newline at end of file diff --git a/backend/modules/firewall.py b/backend/modules/firewall.py new file mode 100644 index 0000000..8a0e376 --- /dev/null +++ b/backend/modules/firewall.py @@ -0,0 +1,196 @@ +import traceback, asyncio, pcre +from typing import Dict +from modules.firegex import FiregexFilter, FiregexTables +from modules.sqlite import Regex, SQLite, Service + +class STATUS: + STOP = "stop" + ACTIVE = "active" + +class FirewallManager: + def __init__(self, db:SQLite): + self.db = db + self.proxy_table: Dict[str, ServiceManager] = {} + self.lock = asyncio.Lock() + self.updater_task = None + + def init_updater(self, callback = None): + if not self.updater_task: + self.updater_task = asyncio.create_task(self._stats_updater(callback)) + + def close_updater(self): + if self.updater_task: self.updater_task.cancel() + + async def close(self): + self.close_updater() + if self.updater_task: self.updater_task.cancel() + for key in list(self.proxy_table.keys()): + await self.remove(key) + + async def remove(self,srv_id): + async with self.lock: + if srv_id in self.proxy_table: + await self.proxy_table[srv_id].next(STATUS.STOP) + del self.proxy_table[srv_id] + + async def init(self, callback = None): + self.init_updater(callback) + await self.reload() + + async def reload(self): + async with self.lock: + for srv in self.db.query('SELECT * FROM services;'): + srv = Service.from_dict(srv) + if srv.id in self.proxy_table: + continue + + self.proxy_table[srv.id] = ServiceManager(srv, self.db) + await self.proxy_table[srv.id].next(srv["status"]) + + async def _stats_updater(self, callback): + try: + while True: + try: + for key in list(self.proxy_table.keys()): + self.proxy_table[key].update_stats() + except Exception: + traceback.print_exc() + if callback: + if asyncio.iscoroutinefunction(callback): await callback() + else: callback() + await asyncio.sleep(5) + except asyncio.CancelledError: + self.updater_task = None + return + + def get(self,srv_id): + if srv_id in self.proxy_table: + return self.proxy_table[srv_id] + else: + raise ServiceNotFoundException() + +class ServiceNotFoundException(Exception): pass + +class RegexFilter: + def __init__( + self, regex, + is_case_sensitive=True, + is_blacklist=True, + input_mode=False, + output_mode=False, + blocked_packets=0, + id=None + ): + self.regex = regex + self.is_case_sensitive = is_case_sensitive + self.is_blacklist = is_blacklist + if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True) + self.input_mode = input_mode + self.output_mode = output_mode + self.blocked = blocked_packets + self.id = id + self.compiled_regex = self.compile() + + @classmethod + def from_regex(cls, regex:Regex): + return cls( + id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive, + is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets, + input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"] + ) + + def compile(self): + if isinstance(self.regex, str): self.regex = self.regex.encode() + if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether") + return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex) + + def check(self, data): + return True if self.compiled_regex.search(data) else False + +class ServiceManager: + def __init__(self, srv: Service, db): + self.srv = srv + self.db = db + self.iptables = FiregexTables(self.srv.ipv6) + self.status = STATUS.STOP + self.filters: Dict[int, FiregexFilter] = {} + self._update_filters_from_db() + self.lock = asyncio.Lock() + self.interceptor = None + + # TODO I don't like so much this method + def _update_filters_from_db(self): + regexes = [ + Regex.from_dict(ele) for ele in + self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id) + ] + #Filter check + old_filters = set(self.filters.keys()) + new_filters = set([f.id for f in regexes]) + + #remove old filters + for f in old_filters: + if not f in new_filters: + del self.filters[f] + + #add new filters + for f in new_filters: + if not f in old_filters: + filter = [ele for ele in regexes if ele.id == f][0] + self.filters[f] = FiregexFilter.from_regex(filter) + + def __update_status_db(self, status): + self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv["service_id"]) + + async def next(self,to): + async with self.lock: + return self._next(to) + + def _next(self, to): + if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): + self.proxy.stop() + self._set_status(to) + # PAUSE -> ACTIVE + elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE): + self.proxy.restart() + + def _stats_updater(self,filter:RegexFilter): + self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id) + + def update_stats(self): + for ele in self.filters.values(): + self._stats_updater(ele) + + def _set_status(self,status): + self.status = status + self.__update_status_db(status) + + def start(self): + if not self.interceptor: + self.iptables.delete_by_srv(self.srv) + def regex_filter(pkt, by_client): + try: + for filter in self.filters.values(): + if (by_client and filter.input_mode) or (not by_client and filter.output_mode): + match = filter.check(pkt) + if (filter.is_blacklist and match) or (not filter.is_blacklist and not match): + filter.blocked+=1 + return False + except IndexError: pass + return True + self.interceptor = self.iptables.add(self.srv["proto"], self.srv["port"], self.srv["ip_int"], regex_filter) + self._set_status(STATUS.ACTIVE) + + def stop(self): + self.iptables.delete_by_srv(self.srv) + if self.interceptor: + self.interceptor.stop() + self.interceptor = None + + def restart(self): + self.stop() + self.start() + + async def update_filters(self): + async with self.lock: + self._update_filters_from_db() \ No newline at end of file diff --git a/backend/modules/iptables.py b/backend/modules/iptables.py new file mode 100644 index 0000000..edeea5b --- /dev/null +++ b/backend/modules/iptables.py @@ -0,0 +1,82 @@ +import os, re +from subprocess import PIPE, Popen +from typing import Dict, List, Tuple, Union + +class Rule(): + def __init__(self, id, target, prot, opt, source, destination, details): + self.id = id + self.target = target + self.prot = prot + self.opt = opt + self.source = source + self.destination = destination + self.details = details + + def dport(self) -> Union[int, None]: + port = re.findall(r"dpt:([0-9]+)", self.details) + return int(port[0]) if port else None + + def sport(self) -> Union[int, None]: + port = re.findall(r"spt:([0-9]+)", self.details) + return int(port[0]) if port else None + + def nfqueue(self) -> Union[Tuple[int,int], None]: + balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", self.details) + numbered = re.findall(r"NFQUEUE num ([0-9]+)", self.details) + queue_num = None + if balanced: queue_num = (int(balanced[0][0]), int(balanced[0][1])) + if numbered: queue_num = (int(numbered[0]), int(numbered[0])) + return queue_num + +class IPTables: + + def __init__(self, ipv6=False, table="filter"): + self.ipv6 = ipv6 + self.table = table + + def command(self, params) -> Tuple[bytes, bytes]: + params = ["-t", self.table] + params + if os.geteuid() != 0: + exit("You need to have root privileges to run this script.\nPlease try again, this time using 'sudo'. Exiting.") + return Popen(["ip6tables"]+params if self.ipv6 else ["iptables"]+params, stdout=PIPE, stderr=PIPE).communicate() + + def list(self) -> Dict[str, List[Rule]]: + stdout, strerr = self.command(["-L", "--line-number", "-n"]) + lines = stdout.decode().split("\n") + res: Dict[str, List[Rule]] = {} + chain_name = "" + for line in lines: + if line.startswith("Chain"): + chain_name = line.split()[1] + res[chain_name] = [] + elif line.split()[0].isnumeric(): + parsed = re.findall(r"([^ ]*)[ ]{,10}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]+([^ ]*)[ ]+(.*)", line) + if len(parsed) > 0: + parsed = parsed[0] + res[chain_name].append(Rule( + id=parsed[0].strip(), + target=parsed[1].strip(), + prot=parsed[2].strip(), + opt=parsed[3].strip(), + source=parsed[4].strip(), + destination=parsed[5].strip(), + details=" ".join(parsed[6:]).strip() if len(parsed[0]) >= 7 else "" + )) + return res + + def delete_rule(self, chain, id) -> None: + self.command(["-D", str(chain), str(id)]) + + def create_chain(self, name) -> None: + self.command(["-N", str(name)]) + + def flush_chain(self, name) -> None: + self.command(["-F", str(name)]) + + def insert_rule(self, chain, rule, *args, rulenum=1) -> None: + self.command(["-I", str(chain), str(rulenum), "-j", str(rule), *args]) + + def append_rule(self, chain, rule, *args) -> None: + self.command(["-A", str(chain), "-j", str(rule), *args]) + + diff --git a/backend/modules/sqlite.py b/backend/modules/sqlite.py new file mode 100644 index 0000000..80d690f --- /dev/null +++ b/backend/modules/sqlite.py @@ -0,0 +1,130 @@ +from typing import Union +import json, sqlite3, os +from hashlib import md5 + +class SQLite(): + def __init__(self, db_name: str) -> None: + self.conn: Union[None, sqlite3.Connection] = None + self.cur = None + self.db_name = db_name + self.schema = { + 'services': { + 'service_id': 'VARCHAR(100) PRIMARY KEY', + 'status': 'VARCHAR(100) NOT NULL', + 'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)', + 'name': 'VARCHAR(100) NOT NULL UNIQUE', + 'ipv6': 'BOOLEAN NOT NULL CHECK (ipv6 IN (0, 1)) DEFAULT 0', + 'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp"))', + 'ip_int': 'VARCHAR(100) NOT NULL', + }, + 'regexes': { + 'regex': 'TEXT NOT NULL', + 'mode': 'VARCHAR(1) NOT NULL', + 'service_id': 'VARCHAR(100) NOT NULL', + 'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))', + 'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', + 'regex_id': 'INTEGER PRIMARY KEY', + 'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))', + 'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1)) DEFAULT 1', + 'FOREIGN KEY (service_id)':'REFERENCES services (service_id)', + }, + 'QUERY':[ + "CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (ipv6, port, ip_int, proto);", + "CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);" + ] + } + self.DB_VER = md5(json.dumps(self.schema).encode()).hexdigest() + + def connect(self) -> None: + try: + self.conn = sqlite3.connect(self.db_name, check_same_thread = False) + except Exception: + with open(self.db_name, 'x'): pass + self.conn = sqlite3.connect(self.db_name, check_same_thread = False) + def dict_factory(cursor, row): + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d + self.conn.row_factory = dict_factory + + def disconnect(self) -> None: + if self.conn: self.conn.close() + + def create_schema(self, tables = {}) -> None: + if self.conn: + cur = self.conn.cursor() + cur.execute("CREATE TABLE IF NOT EXISTS main.keys_values(key VARCHAR(100) PRIMARY KEY, value VARCHAR(100) NOT NULL);") + for t in tables: + if t == "QUERY": continue + cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2])) + if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]] + cur.close() + + def query(self, query, *values): + cur = self.conn.cursor() + try: + cur.execute(query, values) + return cur.fetchall() + finally: + cur.close() + try: self.conn.commit() + except Exception: pass + + def delete(self): + self.disconnect() + os.remove(self.db_name) + + def init(self): + self.connect() + try: + if self.get('DB_VERSION') != self.DB_VER: raise Exception("DB_VERSION is not correct") + except Exception: + self.delete() + self.connect() + self.create_schema(self.schema) + self.put('DB_VERSION', self.DB_VER) + + def get(self, key): + q = self.query('SELECT value FROM keys_values WHERE key = ?', key) + if len(q) == 0: + return None + else: + return q[0]["value"] + + def put(self, key, value): + if self.get(key) is None: + self.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value)) + else: + self.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key) + + +class Service: + def __init__(self, id: str, status: str, port: int, name: str, ipv6: bool, proto: str, ip_int: str): + self.id = id + self.status = status + self.port = port + self.name = name + self.ipv6 = ipv6 + self.proto = proto + self.ip_int = ip_int + + @classmethod + def from_dict(cls, var: dict): + return cls(id=var["service_id"], status=var["status"], port=var["port"], name=var["name"], ipv6=var["ipv6"], proto=var["proto"], ip_int=var["ip_int"]) + + +class Regex: + def __init__(self, id: int, regex: str, mode: str, service_id: str, is_blacklist: bool, blocked_packets: int, is_case_sensitive: bool, active: bool): + self.regex = regex + self.mode = mode + self.service_id = service_id + self.is_blacklist = is_blacklist + self.blocked_packets = blocked_packets + self.id = id + self.is_case_sensitive = is_case_sensitive + self.active = active + + @classmethod + def from_dict(cls, var: dict): + return cls(id=var["regex_id"], regex=var["regex"], mode=var["mode"], service_id=var["service_id"], is_blacklist=var["is_blacklist"], blocked_packets=var["blocked_packets"], is_case_sensitive=var["is_case_sensitive"], active=var["active"]) \ No newline at end of file diff --git a/backend/proxy.py b/backend/proxy.py deleted file mode 100755 index 9a42d6f..0000000 --- a/backend/proxy.py +++ /dev/null @@ -1,268 +0,0 @@ -from typing import List -from pypacker import interceptor -from pypacker.layer3 import ip, ip6 -from pypacker.layer4 import tcp, udp -from subprocess import Popen, PIPE -import os, traceback, pcre, re -from ipaddress import ip_interface - -QUEUE_BASE_NUM = 1000 - -class FilterTypes: - INPUT = "FIREGEX-INPUT" - OUTPUT = "FIREGEX-OUTPUT" - -class ProtoTypes: - TCP = "tcp" - UDP = "udp" - -class IPTables: - - def __init__(self, ipv6=False, table="mangle"): - self.ipv6 = ipv6 - self.table = table - - def command(self, params): - params = ["-t", self.table] + params - if os.geteuid() != 0: - exit("You need to have root privileges to run this script.\nPlease try again, this time using 'sudo'. Exiting.") - return Popen(["ip6tables"]+params if self.ipv6 else ["iptables"]+params, stdout=PIPE, stderr=PIPE).communicate() - - def list_filters(self, param): - stdout, strerr = self.command(["-L", str(param), "--line-number", "-n"]) - output = [re.findall(r"([^ ]*)[ ]{,10}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]+([^ ]*)[ ]+(.*)", ele) for ele in stdout.decode().split("\n")] - return [{ - "id": ele[0][0].strip(), - "target": ele[0][1].strip(), - "prot": ele[0][2].strip(), - "opt": ele[0][3].strip(), - "source": ele[0][4].strip(), - "destination": ele[0][5].strip(), - "details": " ".join(ele[0][6:]).strip() if len(ele[0]) >= 7 else "", - } for ele in output if len(ele) > 0 and ele[0][0].isnumeric()] - - def delete_command(self, param, id): - self.command(["-D", str(param), str(id)]) - - def create_chain(self, name): - self.command(["-N", str(name)]) - - def flush_chain(self, name): - self.command(["-F", str(name)]) - - def add_chain_to_input(self, name): - if not self.find_if_filter_exists("PREROUTING", str(name)): - self.command(["-I", "PREROUTING", "-j", str(name)]) - - def add_chain_to_output(self, name): - if not self.find_if_filter_exists("POSTROUTING", str(name)): - self.command(["-I", "POSTROUTING", "-j", str(name)]) - - def find_if_filter_exists(self, type, target): - for filter in self.list_filters(type): - if filter["target"] == target: - return True - return False - - def add_s_to_c(self, queue_range, proto = None, port = None, ip_int = None): - init, end = queue_range - if init > end: init, end = end, init - self.command(["-A", FilterTypes.OUTPUT, - * (["-p", str(proto)] if proto else []), - * (["-s", str(ip_int)] if ip_int else []), - * (["--sport", str(port)] if port else []), - "-j", "NFQUEUE", - * (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]), - "--queue-bypass" - ]) - - def add_c_to_s(self, queue_range, proto = None, port = None, ip_int = None): - init, end = queue_range - if init > end: init, end = end, init - self.command(["-A", FilterTypes.INPUT, - * (["-p", str(proto)] if proto else []), - * (["-d", str(ip_int)] if ip_int else []), - * (["--dport", str(port)] if port else []), - "-j", "NFQUEUE", - * (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]), - "--queue-bypass" - ]) - -class FiregexFilter(): - def __init__(self, type, number, queue, proto, port, ipv6, ip_int): - self.type = type - self.id = int(number) - self.queue = queue - self.proto = proto - self.port = int(port) - self.iptable = IPTables(ipv6) - self.ip_int = str(ip_int) - - def __repr__(self) -> str: - return f"" - - def delete(self): - self.iptable.delete_command(self.type, self.id) - -class Interceptor: - def __init__(self, iptables, ip_int, c_to_s, s_to_c, proto, ipv6, port, n_threads): - self.proto = proto - self.ipv6 = ipv6 - self.itor_c_to_s, codes = self._start_queue(c_to_s, n_threads) - iptables.add_c_to_s(queue_range=codes, proto=proto, port=port, ip_int=ip_int) - self.itor_s_to_c, codes = self._start_queue(s_to_c, n_threads) - iptables.add_s_to_c(queue_range=codes, proto=proto, port=port, ip_int=ip_int) - - def _start_queue(self,func,n_threads): - def func_wrap(ll_data, ll_proto_id, data, ctx, *args): - pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data) - try: - level4 = None - if self.proto == ProtoTypes.TCP: level4 = pkt_parsed[tcp.TCP].body_bytes - elif self.proto == ProtoTypes.UDP: level4 = pkt_parsed[udp.UDP].body_bytes - if level4: - if func(level4): - return data, interceptor.NF_ACCEPT - elif self.proto == ProtoTypes.TCP: - pkt_parsed[tcp.TCP].flags &= 0x00 - pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK - pkt_parsed[tcp.TCP].body_bytes = b"" - return pkt_parsed.bin(), interceptor.NF_ACCEPT - else: return b"", interceptor.NF_DROP - else: return pkt_parsed.bin(), interceptor.NF_ACCEPT - except Exception: - traceback.print_exc() - return pkt_parsed.bin(), interceptor.NF_ACCEPT - - ictor = interceptor.Interceptor() - starts = QUEUE_BASE_NUM - while True: - if starts >= 65536: - raise Exception("Netfilter queue is full!") - queue_ids = list(range(starts,starts+n_threads)) - try: - ictor.start(func_wrap, queue_ids=queue_ids) - break - except interceptor.UnableToBindException as e: - starts = e.queue_id + 1 - return ictor, (starts, starts+n_threads-1) - - def stop(self): - self.itor_c_to_s.stop() - self.itor_s_to_c.stop() - -class FiregexFilterManager: - - def __init__(self, srv): - self.ipv6 = srv["ipv6"] - self.iptables = IPTables(self.ipv6) - self.iptables.create_chain(FilterTypes.INPUT) - self.iptables.create_chain(FilterTypes.OUTPUT) - self.iptables.add_chain_to_input(FilterTypes.INPUT) - self.iptables.add_chain_to_output(FilterTypes.OUTPUT) - - def get(self) -> List[FiregexFilter]: - res = [] - for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]: - for filter in self.iptables.list_filters(filter_type): - queue_num = None - balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", filter["details"]) - numbered = re.findall(r"NFQUEUE num ([0-9]+)", filter["details"]) - port = re.findall(r"[sd]pt:([0-9]+)", filter["details"]) - if balanced: queue_num = (int(balanced[0][0]), int(balanced[0][1])) - if numbered: queue_num = (int(numbered[0]), int(numbered[0])) - if queue_num and port: - res.append(FiregexFilter( - type=filter_type, - number=filter["id"], - queue=queue_num, - proto=filter["prot"], - port=int(port[0]), - ipv6=self.ipv6, - ip_int=filter["source"] if filter_type == FilterTypes.OUTPUT else filter["destination"] - )) - return res - - def add(self, proto, port, ip_int, func): - for ele in self.get(): - if int(port) == ele.port and proto == ele.proto and ip_interface(ip_int) == ip_interface(ele.ip_int): - return None - - def c_to_s(pkt): return func(pkt, True) - def s_to_c(pkt): return func(pkt, False) - - itor = Interceptor( iptables=self.iptables, ip_int=ip_int, - c_to_s=c_to_s, s_to_c=s_to_c, - proto=proto, ipv6=self.ipv6, port=port, - n_threads=int(os.getenv("N_THREADS_NFQUEUE","1"))) - return itor - - def delete_all(self): - for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]: - self.iptables.flush_chain(filter_type) - - def delete_by_srv(self, srv): - for filter in self.get(): - if filter.port == int(srv["port"]) and filter.proto == srv["proto"] and ip_interface(filter.ip_int) == ip_interface(srv["ip_int"]): - filter.delete() - -class Filter: - def __init__(self, regex, is_case_sensitive=True, is_blacklist=True, c_to_s=False, s_to_c=False, blocked_packets=0, code=None): - self.regex = regex - self.is_case_sensitive = is_case_sensitive - self.is_blacklist = is_blacklist - if c_to_s == s_to_c: c_to_s = s_to_c = True # (False, False) == (True, True) - self.c_to_s = c_to_s - self.s_to_c = s_to_c - self.blocked = blocked_packets - self.code = code - self.compiled_regex = self.compile() - - def compile(self): - if isinstance(self.regex, str): self.regex = self.regex.encode() - if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether") - return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex) - - def check(self, data): - return True if self.compiled_regex.search(data) else False - -class Proxy: - def __init__(self, srv, filters=None): - self.srv = srv - self.manager = FiregexFilterManager(self.srv) - self.filters: List[Filter] = filters if filters else [] - self.interceptor = None - - def set_filters(self, filters): - elements_to_pop = len(self.filters) - for ele in filters: - self.filters.append(ele) - for _ in range(elements_to_pop): - self.filters.pop(0) - - def start(self): - if not self.interceptor: - self.manager.delete_by_srv(self.srv) - def regex_filter(pkt, by_client): - try: - for filter in self.filters: - if (by_client and filter.c_to_s) or (not by_client and filter.s_to_c): - match = filter.check(pkt) - if (filter.is_blacklist and match) or (not filter.is_blacklist and not match): - filter.blocked+=1 - return False - except IndexError: pass - return True - - self.interceptor = self.manager.add(self.srv["proto"], self.srv["port"], self.srv["ip_int"], regex_filter) - - - def stop(self): - self.manager.delete_by_srv(self.srv) - if self.interceptor: - self.interceptor.stop() - self.interceptor = None - - def restart(self): - self.stop() - self.start() \ No newline at end of file diff --git a/backend/utils.py b/backend/utils.py index 20a4a51..43959f5 100755 --- a/backend/utils.py +++ b/backend/utils.py @@ -1,262 +1,7 @@ -from hashlib import md5 -import traceback -from typing import Dict -from proxy import Filter, Proxy -import os, sqlite3, socket, asyncio, re -import secrets, json -from base64 import b64decode +import os, socket, secrets LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1")) -class SQLite(): - def __init__(self, db_name) -> None: - self.conn = None - self.cur = None - self.db_name = db_name - self.schema = { - 'services': { - 'service_id': 'VARCHAR(100) PRIMARY KEY', - 'status': 'VARCHAR(100) NOT NULL', - 'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)', - 'name': 'VARCHAR(100) NOT NULL UNIQUE', - 'ipv6': 'BOOLEAN NOT NULL CHECK (ipv6 IN (0, 1)) DEFAULT 0', - 'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp"))', - 'ip_int': 'VARCHAR(100) NOT NULL', - }, - 'regexes': { - 'regex': 'TEXT NOT NULL', - 'mode': 'VARCHAR(1) NOT NULL', - 'service_id': 'VARCHAR(100) NOT NULL', - 'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))', - 'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', - 'regex_id': 'INTEGER PRIMARY KEY', - 'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))', - 'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1)) DEFAULT 1', - 'FOREIGN KEY (service_id)':'REFERENCES services (service_id)', - }, - 'keys_values': { - 'key': 'VARCHAR(100) PRIMARY KEY', - 'value': 'VARCHAR(100) NOT NULL', - }, - 'QUERY':[ - "CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (ipv6, port, ip_int, proto);", - "CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);" - ] - } - self.DB_VER = md5(json.dumps(self.schema).encode()).hexdigest() - - def connect(self) -> None: - try: - self.conn = sqlite3.connect(self.db_name, check_same_thread = False) - except Exception: - with open(self.db_name, 'x'): - pass - self.conn = sqlite3.connect(self.db_name, check_same_thread = False) - def dict_factory(cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d - self.conn.row_factory = dict_factory - - def disconnect(self) -> None: - if self.conn: self.conn.close() - - def create_schema(self, tables = {}) -> None: - cur = self.conn.cursor() - for t in tables: - if t == "QUERY": continue - cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2])) - if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]] - cur.close() - - def query(self, query, *values): - cur = self.conn.cursor() - try: - cur.execute(query, values) - return cur.fetchall() - finally: - cur.close() - try: self.conn.commit() - except Exception: pass - - def delete(self): - self.disconnect() - os.remove(self.db_name) - - def init(self): - self.connect() - try: - current_ver = self.query("SELECT value FROM keys_values WHERE key = 'DB_VERSION'")[0]['value'] - if current_ver != self.DB_VER: raise Exception("DB_VERSION is not correct") - except Exception: - self.delete() - self.connect() - self.create_schema(self.schema) - self.query("INSERT INTO keys_values (key, value) VALUES ('DB_VERSION', ?)", self.DB_VER) - -class KeyValueStorage: - def __init__(self, db): - self.db = db - - def get(self, key): - q = self.db.query('SELECT value FROM keys_values WHERE key = ?', key) - if len(q) == 0: - return None - else: - return q[0]["value"] - - def put(self, key, value): - if self.get(key) is None: - self.db.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value)) - else: - self.db.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key) - -class STATUS: - STOP = "stop" - ACTIVE = "active" - -class ServiceNotFoundException(Exception): pass - -class ServiceManager: - def __init__(self, srv, db): - self.srv = srv - self.db = db - self.proxy = Proxy(srv) - self.status = STATUS.STOP - self.filters = {} - self._update_filters_from_db() - self.lock = asyncio.Lock() - self.starter = None - - def _update_filters_from_db(self): - res = self.db.query(""" - SELECT - regex, mode, regex_id id, is_blacklist, - blocked_packets n_packets, is_case_sensitive - FROM regexes WHERE service_id = ? AND active=1; - """, self.srv["service_id"]) - - #Filter check - old_filters = set(self.filters.keys()) - new_filters = set([f["id"] for f in res]) - - #remove old filters - for f in old_filters: - if not f in new_filters: - del self.filters[f] - - for f in new_filters: - if not f in old_filters: - filter_info = [ele for ele in res if ele["id"] == f][0] - self.filters[f] = Filter( - is_case_sensitive=filter_info["is_case_sensitive"], - c_to_s=filter_info["mode"] in ["C","B"], - s_to_c=filter_info["mode"] in ["S","B"], - is_blacklist=filter_info["is_blacklist"], - regex=b64decode(filter_info["regex"]), - blocked_packets=filter_info["n_packets"], - code=f - ) - self.proxy.set_filters(self.filters.values()) - - def __update_status_db(self, status): - self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv["service_id"]) - - async def next(self,to): - async with self.lock: - return self._next(to) - - def _next(self, to): - if self.status != to: - # ACTIVE -> PAUSE - if (self.status, to) in [(STATUS.ACTIVE, STATUS.STOP)]: - self.proxy.stop() - self._set_status(to) - # PAUSE -> ACTIVE - elif (self.status, to) in [(STATUS.STOP, STATUS.ACTIVE)]: - self.proxy.restart() - self._set_status(to) - - - def _stats_updater(self,filter:Filter): - self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code) - - def update_stats(self): - for ele in self.proxy.filters: - self._stats_updater(ele) - - def _set_status(self,status): - self.status = status - self.__update_status_db(status) - - async def update_filters(self): - async with self.lock: - self._update_filters_from_db() - -class ProxyManager: - def __init__(self, db:SQLite): - self.db = db - self.proxy_table: Dict[ServiceManager] = {} - self.lock = asyncio.Lock() - self.updater_task = None - - def init_updater(self, callback = None): - if not self.updater_task: - self.updater_task = asyncio.create_task(self._stats_updater(callback)) - - def close_updater(self): - if self.updater_task: self.updater_task.cancel() - - async def close(self): - self.close_updater() - if self.updater_task: self.updater_task.cancel() - for key in list(self.proxy_table.keys()): - await self.remove(key) - - async def remove(self,srv_id): - async with self.lock: - if srv_id in self.proxy_table: - await self.proxy_table[srv_id].next(STATUS.STOP) - del self.proxy_table[srv_id] - - async def init(self, callback = None): - self.init_updater(callback) - await self.reload() - - async def reload(self): - async with self.lock: - for srv in self.db.query('SELECT * FROM services;'): - - srv_id = srv["service_id"] - if srv_id in self.proxy_table: - continue - - self.proxy_table[srv_id] = ServiceManager(srv, self.db) - await self.proxy_table[srv_id].next(srv["status"]) - - async def _stats_updater(self, callback): - try: - while True: - try: - for key in list(self.proxy_table.keys()): - self.proxy_table[key].update_stats() - except Exception: - traceback.print_exc() - if callback: - if asyncio.iscoroutinefunction(callback): await callback() - else: callback() - await asyncio.sleep(5) - except asyncio.CancelledError: - self.updater_task = None - return - - def get(self,srv_id): - if srv_id in self.proxy_table: - return self.proxy_table[srv_id] - else: - raise ServiceNotFoundException() - def refactor_name(name:str): name = name.strip() while " " in name: name = name.replace(" "," ")