From f4fe3d3ab597bcb45876dbf1cdfe4c1cd606c595 Mon Sep 17 00:00:00 2001 From: nik012003 Date: Thu, 11 Aug 2022 16:11:32 +0200 Subject: [PATCH] Refactoring code pt.1 --- backend/modules/nfregex/firegex.py | 4 - backend/modules/nfregex/firewall.py | 8 +- backend/modules/nfregex/nftables.py | 38 ++--- backend/modules/porthijack/__init__.py | 0 backend/modules/porthijack/firewall.py | 114 +++++++++++++ backend/modules/porthijack/models.py | 13 ++ .../porthijack/nftables.py} | 81 ++++----- backend/routers/porthijack.py | 160 ++++++++++++++++++ backend/utils/__init__.py | 9 +- 9 files changed, 354 insertions(+), 73 deletions(-) create mode 100644 backend/modules/porthijack/__init__.py create mode 100644 backend/modules/porthijack/firewall.py create mode 100644 backend/modules/porthijack/models.py rename backend/{test.py => modules/porthijack/nftables.py} (61%) diff --git a/backend/modules/nfregex/firegex.py b/backend/modules/nfregex/firegex.py index 6bfa39f..c7ab825 100644 --- a/backend/modules/nfregex/firegex.py +++ b/backend/modules/nfregex/firegex.py @@ -141,7 +141,3 @@ class FiregexInterceptor: except Exception: pass return res -def delete_by_srv(srv:Service): - for filter in nft.get(): - if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int): - nft.cmd({"delete":{"rule": {"handle": filter.id, "table": nft.table_name, "chain": filter.target, "family": "inet"}}}) \ No newline at end of file diff --git a/backend/modules/nfregex/firewall.py b/backend/modules/nfregex/firewall.py index 07f103b..56aa63e 100644 --- a/backend/modules/nfregex/firewall.py +++ b/backend/modules/nfregex/firewall.py @@ -1,6 +1,6 @@ import asyncio from typing import Dict -from modules.nfregex.firegex import FiregexInterceptor, RegexFilter, delete_by_srv +from modules.nfregex.firegex import FiregexInterceptor, RegexFilter from modules.nfregex.nftables import FiregexTables, FiregexFilter from modules.nfregex.models import Regex, Service from utils.sqlite import SQLite @@ -95,13 +95,13 @@ class ServiceManager: async def start(self): if not self.interceptor: - delete_by_srv(self.srv) - self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int)) + FiregexTables().delete(self.srv) + self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv)) await self._update_filters_from_db() self._set_status(STATUS.ACTIVE) async def stop(self): - delete_by_srv(self.srv) + FiregexTables().delete(self.srv) if self.interceptor: await self.interceptor.stop() self.interceptor = None diff --git a/backend/modules/nfregex/nftables.py b/backend/modules/nfregex/nftables.py index c5b4e86..39ba765 100644 --- a/backend/modules/nfregex/nftables.py +++ b/backend/modules/nfregex/nftables.py @@ -1,10 +1,10 @@ from typing import List +from modules.nfregex.models import Service from utils import ip_parse, ip_family, NFTableManager class FiregexFilter(): - def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None): + def __init__(self, proto:str, port:int, ip_int:str, target:str=None, id=None): self.id = int(id) if id else None - self.queue = queue self.target = target self.proto = proto self.port = int(port) @@ -46,47 +46,41 @@ class FiregexTables(NFTableManager): {"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.output_chain}}}, ]) - def add_output(self, queue_range, proto, port, ip_int): - init, end = queue_range - if init > end: init, end = end, init - ip_int = ip_parse(ip_int) + def add(self, srv:Service, queue_range_input, queue_range_output): + ip_int = ip_parse(srv.ip_int) ip_addr = str(ip_int).split("/")[0] ip_addr_cidr = int(str(ip_int).split("/")[1]) + + init, end = queue_range_output + if init > end: init, end = end, init self.cmd({ "insert":{ "rule": { "family": "inet", "table": self.table_name, "chain": self.output_chain, "expr": [ {'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, - {'match': {"left": { "payload": {"protocol": str(proto), "field": "sport"}}, "op": "==", "right": int(port)}}, + {'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}}, {"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}} ] }}}) - - def add_input(self, queue_range, proto = None, port = None, ip_int = None): - init, end = queue_range + + init, end = queue_range_input if init > end: init, end = end, init - ip_int = ip_parse(ip_int) - ip_addr = str(ip_int).split("/")[0] - ip_addr_cidr = int(str(ip_int).split("/")[1]) self.cmd({"insert":{"rule":{ "family": "inet", "table": self.table_name, "chain": self.input_chain, "expr": [ {'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, - {'match': {"left": { "payload": {"protocol": str(proto), "field": "dport"}}, "op": "==", "right": int(port)}}, + {'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}}, {"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}} ] }}}) + def get(self) -> List[FiregexFilter]: res = [] - for filter in [ele["rule"] for ele in self.list() if "rule" in ele and ele["rule"]["table"] == self.table_name]: - queue_str = filter["expr"][2]["queue"]["num"] - queue = None - if isinstance(queue_str,dict): queue = int(queue_str["range"][0]), int(queue_str["range"][1]) - else: queue = int(queue_str), int(queue_str) + for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]): ip_int = None if isinstance(filter["expr"][0]["match"]["right"],str): ip_int = str(ip_parse(filter["expr"][0]["match"]["right"])) @@ -95,10 +89,14 @@ class FiregexTables(NFTableManager): res.append(FiregexFilter( target=filter["chain"], id=int(filter["handle"]), - queue=queue, proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"], port=filter["expr"][1]["match"]["right"], ip_int=ip_int )) return res + + def delete(self, srv:Service): + for filter in self.get(): + if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int): + self.cmd({"delete":{"rule": {"handle": filter.id, "table": self.table_name, "chain": filter.target, "family": "inet"}}}) \ No newline at end of file diff --git a/backend/modules/porthijack/__init__.py b/backend/modules/porthijack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/modules/porthijack/firewall.py b/backend/modules/porthijack/firewall.py new file mode 100644 index 0000000..13189eb --- /dev/null +++ b/backend/modules/porthijack/firewall.py @@ -0,0 +1,114 @@ +import asyncio +from typing import Dict +from modules.porthijack.nftables import FiregexTables, FiregexFilter +from modules.porthijack.models import Service +from utils.sqlite import SQLite + +class STATUS: + STOP = "stop" + ACTIVE = "active" + +class FirewallManager: + def __init__(self, db:SQLite): + self.db = db + self.proxy_table: Dict[str, ServiceManager] = {} + self.lock = asyncio.Lock() + + async def close(self): + for key in list(self.proxy_table.keys()): + await self.remove(key) + + async def remove(self,srv_id): + async with self.lock: + if srv_id in self.proxy_table: + await self.proxy_table[srv_id].next(STATUS.STOP) + del self.proxy_table[srv_id] + + async def init(self): + FiregexTables().init() + await self.reload() + + async def reload(self): + async with self.lock: + for srv in self.db.query('SELECT * FROM services;'): + srv = Service.from_dict(srv) + if srv.id in self.proxy_table: + continue + self.proxy_table[srv.id] = ServiceManager(srv, self.db) + await self.proxy_table[srv.id].next(srv.status) + + def get(self,srv_id): + if srv_id in self.proxy_table: + return self.proxy_table[srv_id] + else: + raise ServiceNotFoundException() + +class ServiceNotFoundException(Exception): pass + +class ServiceManager: + def __init__(self, srv: Service, db): + self.srv = srv + self.db = db + self.status = STATUS.STOP + self.filters: Dict[int, FiregexFilter] = {} + self.lock = asyncio.Lock() + self.interceptor = None + + async def _update_filters_from_db(self): + regexes = [ + Regex.from_dict(ele) for ele in + self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id) + ] + #Filter check + old_filters = set(self.filters.keys()) + new_filters = set([f.id for f in regexes]) + #remove old filters + for f in old_filters: + if not f in new_filters: + del self.filters[f] + #add new filters + for f in new_filters: + if not f in old_filters: + filter = [ele for ele in regexes if ele.id == f][0] + self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater) + if self.interceptor: await self.interceptor.reload(self.filters.values()) + + def __update_status_db(self, status): + self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id) + + async def next(self,to): + async with self.lock: + if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): + await self.stop() + self._set_status(to) + # PAUSE -> ACTIVE + elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE): + await self.restart() + + def _stats_updater(self,filter:RegexFilter): + self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id) + + def _set_status(self,status): + self.status = status + self.__update_status_db(status) + + async def start(self): + if not self.interceptor: + FiregexTables().delete_by_srv(self.srv) + self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int)) + await self._update_filters_from_db() + self._set_status(STATUS.ACTIVE) + + async def stop(self): + FiregexTables().delete_by_srv(self.srv) + if self.interceptor: + await self.interceptor.stop() + self.interceptor = None + + async def restart(self): + await self.stop() + await self.start() + + async def update_filters(self): + async with self.lock: + await self._update_filters_from_db() \ No newline at end of file diff --git a/backend/modules/porthijack/models.py b/backend/modules/porthijack/models.py new file mode 100644 index 0000000..4e2a1d3 --- /dev/null +++ b/backend/modules/porthijack/models.py @@ -0,0 +1,13 @@ +class Service: + def __init__(self, service_id: str, active: bool, public_port: int, proxy_port: int, name: str, proto: str, ip_int: str): + self.service_id = service_id + self.active = active + self.public_port = public_port + self.proxy_port = proxy_port + self.name = name + self.proto = proto + self.ip_int = ip_int + + @classmethod + def from_dict(cls, var: dict): + return cls(id=var["service_id"], active=var["active"], public_port=var["public_port"], proxy_port=var["proxy_port"], name=var["name"], proto=var["proto"], ip_int=var["ip_int"]) diff --git a/backend/test.py b/backend/modules/porthijack/nftables.py similarity index 61% rename from backend/test.py rename to backend/modules/porthijack/nftables.py index a0e9dc1..3eb57cb 100644 --- a/backend/test.py +++ b/backend/modules/porthijack/nftables.py @@ -1,49 +1,19 @@ +from typing import List +from utils import ip_parse, ip_family, NFTableManager -from ipaddress import ip_interface -import nftables, traceback - -def ip_parse(ip:str): - return str(ip_interface(ip).network) - -def ip_family(ip:str): - return "ip6" if ip_interface(ip).version == 6 else "ip" - -class Singleton(object): - __instance = None - def __new__(class_, *args, **kwargs): - if not isinstance(class_.__instance, class_): - class_.__instance = object.__new__(class_, *args, **kwargs) - return class_.__instance - -class NFTableManager(Singleton): - - table_name = "firegex" - - def __init__(self, init_cmd, reset_cmd): - self.__init_cmds = init_cmd - self.__reset_cmds = reset_cmd - self.nft = nftables.Nftables() - - def raw_cmd(self, *cmds): - return self.nft.json_cmd({"nftables": list(cmds)}) - - def cmd(self, *cmds): - code, out, err = self.raw_cmd(*cmds) - - if code == 0: return out - else: raise Exception(err) - - def init(self): - self.reset() - self.raw_cmd({"add":{"table":{"name":self.table_name,"family":"inet"}}}) - self.cmd(*self.__init_cmds) - - def reset(self): - self.raw_cmd(*self.__reset_cmds) - - def list(self): - return self.cmd({"list": {"ruleset": None}})["nftables"] +class FiregexFilter(): + def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None): + self.id = int(id) if id else None + self.queue = queue + self.target = target + self.proto = proto + self.port = int(port) + self.ip_int = str(ip_int) + def __eq__(self, o: object) -> bool: + if isinstance(o, FiregexFilter): + return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) + return False class FiregexTables(NFTableManager): prerouting_porthijack = "prerouting_porthijack" @@ -100,3 +70,26 @@ class FiregexTables(NFTableManager): {'mangle': {'key': {'payload': {'protocol': str(proto), 'field': 'sport'}}, 'value': int(public_port)}} ] }}}) + + def get(self) -> List[FiregexFilter]: + res = [] + for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]): + queue_str = filter["expr"][2]["queue"]["num"] + queue = None + if isinstance(queue_str,dict): queue = int(queue_str["range"][0]), int(queue_str["range"][1]) + else: queue = int(queue_str), int(queue_str) + ip_int = None + if isinstance(filter["expr"][0]["match"]["right"],str): + ip_int = str(ip_parse(filter["expr"][0]["match"]["right"])) + else: + ip_int = f'{filter["expr"][0]["match"]["right"]["prefix"]["addr"]}/{filter["expr"][0]["match"]["right"]["prefix"]["len"]}' + res.append(FiregexFilter( + target=filter["chain"], + id=int(filter["handle"]), + queue=queue, + proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"], + port=filter["expr"][1]["match"]["right"], + ip_int=ip_int + )) + return res + \ No newline at end of file diff --git a/backend/routers/porthijack.py b/backend/routers/porthijack.py index e69de29..3a812e4 100644 --- a/backend/routers/porthijack.py +++ b/backend/routers/porthijack.py @@ -0,0 +1,160 @@ +import secrets +import sqlite3 +from typing import List, Union +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from utils.sqlite import SQLite +from utils import ip_parse, refactor_name, refresh_frontend +from utils.models import ResetRequest, StatusMessageModel +from modules.porthijack.nftables import FiregexTables +from modules.porthijack.firewall import STATUS, FirewallManager + +class ServiceModel(BaseModel): + service_id: str + active: bool + public_port: int + proxy_port: int + name: str + proto: str + ip_int: str + +class RenameForm(BaseModel): + name:str + +class ServiceAddForm(BaseModel): + name: str + public_port: int + proxy_port: int + proto: str + ip_int: str + +class ServiceAddResponse(BaseModel): + status:str + service_id: Union[None,str] + +class GeneralStatModel(BaseModel): + services: int + +app = APIRouter() + +db = SQLite('db/port-hijacking.db', { + 'services': { + 'service_id': 'VARCHAR(100) PRIMARY KEY', + 'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1))', + 'public_port': 'INT NOT NULL CHECK(port > 0 and port < 65536)', + 'proxy_port': 'INT NOT NULL CHECK(port > 0 and port < 65536)', + 'name': 'VARCHAR(100) NOT NULL UNIQUE', + 'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp"))', + 'ip_int': 'VARCHAR(100) NOT NULL', + }, + 'QUERY':[ + "CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (public_port, ip_int, proto);", + ] +}) + +async def reset(params: ResetRequest): + if not params.delete: + db.backup() + await firewall.close() + FiregexTables().reset() + if params.delete: + db.delete() + db.init() + else: + db.restore() + await firewall.init() + + +async def startup(): + db.init() + await firewall.init() + +async def shutdown(): + db.backup() + await firewall.close() + db.disconnect() + db.restore() + +def gen_service_id(): + while True: + res = secrets.token_hex(8) + if len(db.query('SELECT 1 FROM services WHERE service_id = ?;', res)) == 0: + break + return res + +firewall = FirewallManager(db) + +@app.get('/stats', response_model=GeneralStatModel) +async def get_general_stats(): + """Get firegex general status about services""" + return db.query(""" + SELECT + (SELECT COUNT(*) FROM services) services + """)[0] + +@app.get('/services', response_model=List[ServiceModel]) +async def get_service_list(): + """Get the list of existent firegex services""" + return db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_int FROM services;") + +@app.get('/service/{service_id}', response_model=ServiceModel) +async def get_service_by_id(service_id: str, ): + """Get info about a specific service using his id""" + res = db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_int FROM services WHERE service_id = ?;", service_id) + if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!") + return res[0] + +@app.get('/service/{service_id}/stop', response_model=StatusMessageModel) +async def service_stop(service_id: str, ): + """Request the stop of a specific service""" + await firewall.get(service_id).next(STATUS.STOP) + await refresh_frontend() + return {'status': 'ok'} + +@app.get('/service/{service_id}/start', response_model=StatusMessageModel) +async def service_start(service_id: str, ): + """Request the start of a specific service""" + await firewall.get(service_id).next(STATUS.ACTIVE) + await refresh_frontend() + return {'status': 'ok'} + +@app.get('/service/{service_id}/delete', response_model=StatusMessageModel) +async def service_delete(service_id: str, ): + """Request the deletion of a specific service""" + db.query('DELETE FROM services WHERE service_id = ?;', service_id) + await firewall.remove(service_id) + await refresh_frontend() + return {'status': 'ok'} + +@app.post('/service/{service_id}/rename', response_model=StatusMessageModel) +async def service_rename(service_id: str, form: RenameForm, ): + """Request to change the name of a specific service""" + form.name = refactor_name(form.name) + if not form.name: return {'status': 'The name cannot be empty!'} + try: + db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id) + except sqlite3.IntegrityError: + return {'status': 'This name is already used'} + await refresh_frontend() + return {'status': 'ok'} + + +@app.post('/services/add', response_model=ServiceAddResponse) +async def add_new_service(form: ServiceAddForm, ): + """Add a new service""" + try: + form.ip_int = ip_parse(form.ip_int) + except ValueError: + return {"status":"Invalid address"} + if form.proto not in ["tcp", "udp"]: + return {"status":"Invalid protocol"} + srv_id = None + try: + srv_id = gen_service_id() + db.query("INSERT INTO services (service_id, active, public_port, proxy_port, name, proto, ip_int) VALUES (?, ?, ?, ?, ?, ?, ?)", + srv_id, False, form.public_port, form.proxy_port , form.name, form.proto, form.ip_int) + except sqlite3.IntegrityError: + return {'status': 'This type of service already exists'} + await firewall.reload() + await refresh_frontend() + return {'status': 'ok', 'service_id': srv_id} diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py index 5551442..0a4c30d 100755 --- a/backend/utils/__init__.py +++ b/backend/utils/__init__.py @@ -81,6 +81,13 @@ class NFTableManager(Singleton): def reset(self): self.raw_cmd(*self.__reset_cmds) - def list(self): + def list_rules(self, tables = None, chains = None): + for filter in [ele["rule"] for ele in self.raw_list() if "rule" in ele ]: + if tables and filter["table"] not in tables: continue + if chains and filter["chain"] not in chains: continue + yield filter + + def raw_list(self): return self.cmd({"list": {"ruleset": None}})["nftables"] + \ No newline at end of file