From e6b4ddd4a054bdcfa5caaa7221a9017b350c7695 Mon Sep 17 00:00:00 2001 From: DomySh Date: Thu, 11 Aug 2022 15:16:23 +0000 Subject: [PATCH] Code refactoring and adding port-hijacking backup commit --- backend/modules/nfregex/firegex.py | 12 ++-- backend/modules/nfregex/firewall.py | 30 ++++---- backend/modules/nfregex/nftables.py | 23 +++++-- backend/modules/porthijack/firewall.py | 95 +++++++++----------------- backend/modules/porthijack/nftables.py | 59 +++++++++------- backend/routers/porthijack.py | 6 +- 6 files changed, 108 insertions(+), 117 deletions(-) diff --git a/backend/modules/nfregex/firegex.py b/backend/modules/nfregex/firegex.py index c7ab825..0753e0f 100644 --- a/backend/modules/nfregex/firegex.py +++ b/backend/modules/nfregex/firegex.py @@ -1,5 +1,5 @@ from typing import Dict, List, Set -from modules.nfregex.nftables import FiregexFilter, FiregexTables +from modules.nfregex.nftables import FiregexTables from utils import ip_parse, run_func from modules.nfregex.models import Service, Regex import re, os, asyncio @@ -54,7 +54,7 @@ class RegexFilter: class FiregexInterceptor: def __init__(self): - self.filter:FiregexFilter + self.srv:Service self.filter_map_lock:asyncio.Lock self.filter_map: Dict[str, RegexFilter] self.regex_filters: Set[RegexFilter] @@ -63,16 +63,14 @@ class FiregexInterceptor: self.update_task: asyncio.Task @classmethod - async def start(cls, filter: FiregexFilter): + async def start(cls, srv: Service): self = cls() - self.filter = filter + self.srv = srv self.filter_map_lock = asyncio.Lock() self.update_config_lock = asyncio.Lock() input_range, output_range = await self._start_binary() self.update_task = asyncio.create_task(self.update_blocked()) - if not filter in nft.get(): - nft.add_input(queue_range=input_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) - nft.add_output(queue_range=output_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) + nft.add(self.srv, input_range, output_range) return self async def _start_binary(self): diff --git a/backend/modules/nfregex/firewall.py b/backend/modules/nfregex/firewall.py index 56aa63e..18544f6 100644 --- a/backend/modules/nfregex/firewall.py +++ b/backend/modules/nfregex/firewall.py @@ -9,38 +9,40 @@ class STATUS: STOP = "stop" ACTIVE = "active" +nft = FiregexTables() + class FirewallManager: def __init__(self, db:SQLite): self.db = db - self.proxy_table: Dict[str, ServiceManager] = {} + self.service_table: Dict[str, ServiceManager] = {} self.lock = asyncio.Lock() async def close(self): - for key in list(self.proxy_table.keys()): + for key in list(self.service_table.keys()): await self.remove(key) async def remove(self,srv_id): async with self.lock: - if srv_id in self.proxy_table: - await self.proxy_table[srv_id].next(STATUS.STOP) - del self.proxy_table[srv_id] + if srv_id in self.service_table: + await self.service_table[srv_id].next(STATUS.STOP) + del self.service_table[srv_id] async def init(self): - FiregexTables().init() + nft.init() await self.reload() async def reload(self): async with self.lock: for srv in self.db.query('SELECT * FROM services;'): srv = Service.from_dict(srv) - if srv.id in self.proxy_table: + if srv.id in self.service_table: continue - self.proxy_table[srv.id] = ServiceManager(srv, self.db) - await self.proxy_table[srv.id].next(srv.status) + self.service_table[srv.id] = ServiceManager(srv, self.db) + await self.service_table[srv.id].next(srv.status) def get(self,srv_id): - if srv_id in self.proxy_table: - return self.proxy_table[srv_id] + if srv_id in self.service_table: + return self.service_table[srv_id] else: raise ServiceNotFoundException() @@ -95,13 +97,13 @@ class ServiceManager: async def start(self): if not self.interceptor: - FiregexTables().delete(self.srv) - self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv)) + nft.delete(self.srv) + self.interceptor = await FiregexInterceptor.start(self.srv) await self._update_filters_from_db() self._set_status(STATUS.ACTIVE) async def stop(self): - FiregexTables().delete(self.srv) + nft.delete(self.srv) if self.interceptor: await self.interceptor.stop() self.interceptor = None diff --git a/backend/modules/nfregex/nftables.py b/backend/modules/nfregex/nftables.py index 39ba765..a0a31a0 100644 --- a/backend/modules/nfregex/nftables.py +++ b/backend/modules/nfregex/nftables.py @@ -2,9 +2,9 @@ from typing import List from modules.nfregex.models import Service from utils import ip_parse, ip_family, NFTableManager -class FiregexFilter(): - def __init__(self, proto:str, port:int, ip_int:str, target:str=None, id=None): - self.id = int(id) if id else None +class FiregexFilter: + def __init__(self, proto:str, port:int, ip_int:str, target:str, id:int): + self.id = id self.target = target self.proto = proto self.port = int(port) @@ -13,6 +13,8 @@ class FiregexFilter(): def __eq__(self, o: object) -> bool: if isinstance(o, FiregexFilter): return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) + elif isinstance(o, Service): + return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) return False class FiregexTables(NFTableManager): @@ -47,10 +49,14 @@ class FiregexTables(NFTableManager): ]) def add(self, srv:Service, queue_range_input, queue_range_output): + + for ele in self.get(): + if ele.__eq__(srv): return + ip_int = ip_parse(srv.ip_int) ip_addr = str(ip_int).split("/")[0] ip_addr_cidr = int(str(ip_int).split("/")[1]) - + init, end = queue_range_output if init > end: init, end = end, init self.cmd({ "insert":{ "rule": { @@ -97,6 +103,11 @@ class FiregexTables(NFTableManager): def delete(self, srv:Service): for filter in self.get(): - if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int): - self.cmd({"delete":{"rule": {"handle": filter.id, "table": self.table_name, "chain": filter.target, "family": "inet"}}}) + if filter.__eq__(srv): + self.cmd({ "delete":{ "rule": { + "family": "inet", + "table": self.table_name, + "chain": filter.target, + "handle": filter.id + }}}) \ No newline at end of file diff --git a/backend/modules/porthijack/firewall.py b/backend/modules/porthijack/firewall.py index 13189eb..9f253bc 100644 --- a/backend/modules/porthijack/firewall.py +++ b/backend/modules/porthijack/firewall.py @@ -1,28 +1,27 @@ +from ast import Delete import asyncio from typing import Dict from modules.porthijack.nftables import FiregexTables, FiregexFilter from modules.porthijack.models import Service from utils.sqlite import SQLite -class STATUS: - STOP = "stop" - ACTIVE = "active" +nft = FiregexTables() class FirewallManager: def __init__(self, db:SQLite): self.db = db - self.proxy_table: Dict[str, ServiceManager] = {} + self.service_table: Dict[str, ServiceManager] = {} self.lock = asyncio.Lock() async def close(self): - for key in list(self.proxy_table.keys()): + for key in list(self.service_table.keys()): await self.remove(key) async def remove(self,srv_id): async with self.lock: - if srv_id in self.proxy_table: - await self.proxy_table[srv_id].next(STATUS.STOP) - del self.proxy_table[srv_id] + if srv_id in self.service_table: + await self.service_table[srv_id].disable() + del self.service_table[srv_id] async def init(self): FiregexTables().init() @@ -32,14 +31,15 @@ class FirewallManager: async with self.lock: for srv in self.db.query('SELECT * FROM services;'): srv = Service.from_dict(srv) - if srv.id in self.proxy_table: + if srv.service_id in self.service_table: continue - self.proxy_table[srv.id] = ServiceManager(srv, self.db) - await self.proxy_table[srv.id].next(srv.status) + self.service_table[srv.service_id] = ServiceManager(srv, self.db) + if srv.active: + await self.service_table[srv.service_id].enable() def get(self,srv_id): - if srv_id in self.proxy_table: - return self.proxy_table[srv_id] + if srv_id in self.service_table: + return self.service_table[srv_id] else: raise ServiceNotFoundException() @@ -49,66 +49,33 @@ class ServiceManager: def __init__(self, srv: Service, db): self.srv = srv self.db = db - self.status = STATUS.STOP - self.filters: Dict[int, FiregexFilter] = {} + self.active = False self.lock = asyncio.Lock() - self.interceptor = None - - async def _update_filters_from_db(self): - regexes = [ - Regex.from_dict(ele) for ele in - self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id) - ] - #Filter check - old_filters = set(self.filters.keys()) - new_filters = set([f.id for f in regexes]) - #remove old filters - for f in old_filters: - if not f in new_filters: - del self.filters[f] - #add new filters - for f in new_filters: - if not f in old_filters: - filter = [ele for ele in regexes if ele.id == f][0] - self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater) - if self.interceptor: await self.interceptor.reload(self.filters.values()) - - def __update_status_db(self, status): - self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id) - async def next(self,to): - async with self.lock: - if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): + async def enable(self,to): + if (self.status != to): + async with self.lock: + await self.restart() + + async def disable(self,to): + if (self.status != to): + async with self.lock: await self.stop() self._set_status(to) - # PAUSE -> ACTIVE - elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE): - await self.restart() - def _stats_updater(self,filter:RegexFilter): - self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id) - - def _set_status(self,status): - self.status = status - self.__update_status_db(status) + def _set_status(self,active): + self.active = active + self.db.query("UPDATE services SET active = ? WHERE service_id = ?;", active, self.srv.service_id) async def start(self): - if not self.interceptor: - FiregexTables().delete_by_srv(self.srv) - self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int)) - await self._update_filters_from_db() - self._set_status(STATUS.ACTIVE) + if not self.active: + nft.delete(self.srv) + nft.add(self.srv) + self._set_status(True) async def stop(self): - FiregexTables().delete_by_srv(self.srv) - if self.interceptor: - await self.interceptor.stop() - self.interceptor = None + nft.delete(self.srv) async def restart(self): await self.stop() - await self.start() - - async def update_filters(self): - async with self.lock: - await self._update_filters_from_db() \ No newline at end of file + await self.start() \ No newline at end of file diff --git a/backend/modules/porthijack/nftables.py b/backend/modules/porthijack/nftables.py index 3eb57cb..13a2814 100644 --- a/backend/modules/porthijack/nftables.py +++ b/backend/modules/porthijack/nftables.py @@ -1,18 +1,21 @@ from typing import List +from modules.porthijack.models import Service from utils import ip_parse, ip_family, NFTableManager -class FiregexFilter(): - def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None): - self.id = int(id) if id else None - self.queue = queue +class FiregexHijackRule(): + def __init__(self, proto:str, public_port:int,proxy_port:int, ip_int:str, target:str, id:int): + self.id = id self.target = target self.proto = proto - self.port = int(port) + self.public_port = public_port + self.proxy_port = proxy_port self.ip_int = str(ip_int) def __eq__(self, o: object) -> bool: - if isinstance(o, FiregexFilter): - return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) + if isinstance(o, FiregexHijackRule): + return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) + elif isinstance(o, Service): + return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) return False class FiregexTables(NFTableManager): @@ -46,8 +49,12 @@ class FiregexTables(NFTableManager): {"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.postrouting_porthijack}}} ]) - def add(self, ip_int, proto, public_port, proxy_port): - ip_int = ip_parse(ip_int) + def add(self, srv:Service): + + for ele in self.get(): + if ele.__eq__(srv): return + + ip_int = ip_parse(srv.ip_int) ip_addr = str(ip_int).split("/")[0] ip_addr_cidr = int(str(ip_int).split("/")[1]) self.cmd({ "insert":{ "rule": { @@ -56,8 +63,8 @@ class FiregexTables(NFTableManager): "chain": self.prerouting_porthijack, "expr": [ {'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, - {'match': {'left': { "payload": {"protocol": str(proto), "field": "dport"}}, "op": "==", "right": int(public_port)}}, - {'mangle': {'key': {'payload': {'protocol': str(proto), 'field': 'dport'}}, 'value': int(proxy_port)}} + {'match': {'left': { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.public_port)}}, + {'mangle': {'key': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'value': int(srv.proxy_port)}} ] }}}) self.cmd({ "insert":{ "rule": { @@ -66,30 +73,36 @@ class FiregexTables(NFTableManager): "chain": self.postrouting_porthijack, "expr": [ {'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, - {'match': {'left': { "payload": {"protocol": str(proto), "field": "sport"}}, "op": "==", "right": int(proxy_port)}}, - {'mangle': {'key': {'payload': {'protocol': str(proto), 'field': 'sport'}}, 'value': int(public_port)}} + {'match': {'left': { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.proxy_port)}}, + {'mangle': {'key': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'value': int(srv.public_port)}} ] }}}) - def get(self) -> List[FiregexFilter]: + + def get(self) -> List[FiregexHijackRule]: res = [] - for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]): - queue_str = filter["expr"][2]["queue"]["num"] - queue = None - if isinstance(queue_str,dict): queue = int(queue_str["range"][0]), int(queue_str["range"][1]) - else: queue = int(queue_str), int(queue_str) + for filter in self.list_rules(tables=[self.table_name], chains=[self.prerouting_porthijack,self.postrouting_porthijack]): ip_int = None if isinstance(filter["expr"][0]["match"]["right"],str): ip_int = str(ip_parse(filter["expr"][0]["match"]["right"])) else: ip_int = f'{filter["expr"][0]["match"]["right"]["prefix"]["addr"]}/{filter["expr"][0]["match"]["right"]["prefix"]["len"]}' - res.append(FiregexFilter( + res.append(FiregexHijackRule( target=filter["chain"], id=int(filter["handle"]), - queue=queue, proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"], - port=filter["expr"][1]["match"]["right"], + public_port=filter["expr"][1]["match"]["right"] if filter["target"] == self.prerouting_porthijack else filter["expr"][2]["mangle"]["value"], + proxy_port=filter["expr"][1]["match"]["right"] if filter["target"] == self.postrouting_porthijack else filter["expr"][2]["mangle"]["value"], ip_int=ip_int )) return res - \ No newline at end of file + + def delete(self, srv:Service): + for filter in self.get(): + if filter.__eq__(srv): + self.cmd({ "delete":{ "rule": { + "family": "inet", + "table": self.table_name, + "chain": filter.target, + "handle": filter.id + }}}) \ No newline at end of file diff --git a/backend/routers/porthijack.py b/backend/routers/porthijack.py index 3a812e4..b5c7e19 100644 --- a/backend/routers/porthijack.py +++ b/backend/routers/porthijack.py @@ -7,7 +7,7 @@ from utils.sqlite import SQLite from utils import ip_parse, refactor_name, refresh_frontend from utils.models import ResetRequest, StatusMessageModel from modules.porthijack.nftables import FiregexTables -from modules.porthijack.firewall import STATUS, FirewallManager +from modules.porthijack.firewall import FirewallManager class ServiceModel(BaseModel): service_id: str @@ -107,14 +107,14 @@ async def get_service_by_id(service_id: str, ): @app.get('/service/{service_id}/stop', response_model=StatusMessageModel) async def service_stop(service_id: str, ): """Request the stop of a specific service""" - await firewall.get(service_id).next(STATUS.STOP) + await firewall.get(service_id).disable() await refresh_frontend() return {'status': 'ok'} @app.get('/service/{service_id}/start', response_model=StatusMessageModel) async def service_start(service_id: str, ): """Request the start of a specific service""" - await firewall.get(service_id).next(STATUS.ACTIVE) + await firewall.get(service_id).enable() await refresh_frontend() return {'status': 'ok'}