Code refactoring and adding port-hijacking backup commit

This commit is contained in:
DomySh
2022-08-11 15:16:23 +00:00
parent f4fe3d3ab5
commit e6b4ddd4a0
6 changed files with 108 additions and 117 deletions

View File

@@ -1,5 +1,5 @@
from typing import Dict, List, Set from typing import Dict, List, Set
from modules.nfregex.nftables import FiregexFilter, FiregexTables from modules.nfregex.nftables import FiregexTables
from utils import ip_parse, run_func from utils import ip_parse, run_func
from modules.nfregex.models import Service, Regex from modules.nfregex.models import Service, Regex
import re, os, asyncio import re, os, asyncio
@@ -54,7 +54,7 @@ class RegexFilter:
class FiregexInterceptor: class FiregexInterceptor:
def __init__(self): def __init__(self):
self.filter:FiregexFilter self.srv:Service
self.filter_map_lock:asyncio.Lock self.filter_map_lock:asyncio.Lock
self.filter_map: Dict[str, RegexFilter] self.filter_map: Dict[str, RegexFilter]
self.regex_filters: Set[RegexFilter] self.regex_filters: Set[RegexFilter]
@@ -63,16 +63,14 @@ class FiregexInterceptor:
self.update_task: asyncio.Task self.update_task: asyncio.Task
@classmethod @classmethod
async def start(cls, filter: FiregexFilter): async def start(cls, srv: Service):
self = cls() self = cls()
self.filter = filter self.srv = srv
self.filter_map_lock = asyncio.Lock() self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock() self.update_config_lock = asyncio.Lock()
input_range, output_range = await self._start_binary() input_range, output_range = await self._start_binary()
self.update_task = asyncio.create_task(self.update_blocked()) self.update_task = asyncio.create_task(self.update_blocked())
if not filter in nft.get(): nft.add(self.srv, input_range, output_range)
nft.add_input(queue_range=input_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
nft.add_output(queue_range=output_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
return self return self
async def _start_binary(self): async def _start_binary(self):

View File

@@ -9,38 +9,40 @@ class STATUS:
STOP = "stop" STOP = "stop"
ACTIVE = "active" ACTIVE = "active"
nft = FiregexTables()
class FirewallManager: class FirewallManager:
def __init__(self, db:SQLite): def __init__(self, db:SQLite):
self.db = db self.db = db
self.proxy_table: Dict[str, ServiceManager] = {} self.service_table: Dict[str, ServiceManager] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
async def close(self): async def close(self):
for key in list(self.proxy_table.keys()): for key in list(self.service_table.keys()):
await self.remove(key) await self.remove(key)
async def remove(self,srv_id): async def remove(self,srv_id):
async with self.lock: async with self.lock:
if srv_id in self.proxy_table: if srv_id in self.service_table:
await self.proxy_table[srv_id].next(STATUS.STOP) await self.service_table[srv_id].next(STATUS.STOP)
del self.proxy_table[srv_id] del self.service_table[srv_id]
async def init(self): async def init(self):
FiregexTables().init() nft.init()
await self.reload() await self.reload()
async def reload(self): async def reload(self):
async with self.lock: async with self.lock:
for srv in self.db.query('SELECT * FROM services;'): for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv) srv = Service.from_dict(srv)
if srv.id in self.proxy_table: if srv.id in self.service_table:
continue continue
self.proxy_table[srv.id] = ServiceManager(srv, self.db) self.service_table[srv.id] = ServiceManager(srv, self.db)
await self.proxy_table[srv.id].next(srv.status) await self.service_table[srv.id].next(srv.status)
def get(self,srv_id): def get(self,srv_id):
if srv_id in self.proxy_table: if srv_id in self.service_table:
return self.proxy_table[srv_id] return self.service_table[srv_id]
else: else:
raise ServiceNotFoundException() raise ServiceNotFoundException()
@@ -95,13 +97,13 @@ class ServiceManager:
async def start(self): async def start(self):
if not self.interceptor: if not self.interceptor:
FiregexTables().delete(self.srv) nft.delete(self.srv)
self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv)) self.interceptor = await FiregexInterceptor.start(self.srv)
await self._update_filters_from_db() await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE) self._set_status(STATUS.ACTIVE)
async def stop(self): async def stop(self):
FiregexTables().delete(self.srv) nft.delete(self.srv)
if self.interceptor: if self.interceptor:
await self.interceptor.stop() await self.interceptor.stop()
self.interceptor = None self.interceptor = None

View File

@@ -2,9 +2,9 @@ from typing import List
from modules.nfregex.models import Service from modules.nfregex.models import Service
from utils import ip_parse, ip_family, NFTableManager from utils import ip_parse, ip_family, NFTableManager
class FiregexFilter(): class FiregexFilter:
def __init__(self, proto:str, port:int, ip_int:str, target:str=None, id=None): def __init__(self, proto:str, port:int, ip_int:str, target:str, id:int):
self.id = int(id) if id else None self.id = id
self.target = target self.target = target
self.proto = proto self.proto = proto
self.port = int(port) self.port = int(port)
@@ -13,6 +13,8 @@ class FiregexFilter():
def __eq__(self, o: object) -> bool: def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter): if isinstance(o, FiregexFilter):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
elif isinstance(o, Service):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return False return False
class FiregexTables(NFTableManager): class FiregexTables(NFTableManager):
@@ -47,6 +49,10 @@ class FiregexTables(NFTableManager):
]) ])
def add(self, srv:Service, queue_range_input, queue_range_output): def add(self, srv:Service, queue_range_input, queue_range_output):
for ele in self.get():
if ele.__eq__(srv): return
ip_int = ip_parse(srv.ip_int) ip_int = ip_parse(srv.ip_int)
ip_addr = str(ip_int).split("/")[0] ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1]) ip_addr_cidr = int(str(ip_int).split("/")[1])
@@ -97,6 +103,11 @@ class FiregexTables(NFTableManager):
def delete(self, srv:Service): def delete(self, srv:Service):
for filter in self.get(): for filter in self.get():
if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int): if filter.__eq__(srv):
self.cmd({"delete":{"rule": {"handle": filter.id, "table": self.table_name, "chain": filter.target, "family": "inet"}}}) self.cmd({ "delete":{ "rule": {
"family": "inet",
"table": self.table_name,
"chain": filter.target,
"handle": filter.id
}}})

View File

@@ -1,28 +1,27 @@
from ast import Delete
import asyncio import asyncio
from typing import Dict from typing import Dict
from modules.porthijack.nftables import FiregexTables, FiregexFilter from modules.porthijack.nftables import FiregexTables, FiregexFilter
from modules.porthijack.models import Service from modules.porthijack.models import Service
from utils.sqlite import SQLite from utils.sqlite import SQLite
class STATUS: nft = FiregexTables()
STOP = "stop"
ACTIVE = "active"
class FirewallManager: class FirewallManager:
def __init__(self, db:SQLite): def __init__(self, db:SQLite):
self.db = db self.db = db
self.proxy_table: Dict[str, ServiceManager] = {} self.service_table: Dict[str, ServiceManager] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
async def close(self): async def close(self):
for key in list(self.proxy_table.keys()): for key in list(self.service_table.keys()):
await self.remove(key) await self.remove(key)
async def remove(self,srv_id): async def remove(self,srv_id):
async with self.lock: async with self.lock:
if srv_id in self.proxy_table: if srv_id in self.service_table:
await self.proxy_table[srv_id].next(STATUS.STOP) await self.service_table[srv_id].disable()
del self.proxy_table[srv_id] del self.service_table[srv_id]
async def init(self): async def init(self):
FiregexTables().init() FiregexTables().init()
@@ -32,14 +31,15 @@ class FirewallManager:
async with self.lock: async with self.lock:
for srv in self.db.query('SELECT * FROM services;'): for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv) srv = Service.from_dict(srv)
if srv.id in self.proxy_table: if srv.service_id in self.service_table:
continue continue
self.proxy_table[srv.id] = ServiceManager(srv, self.db) self.service_table[srv.service_id] = ServiceManager(srv, self.db)
await self.proxy_table[srv.id].next(srv.status) if srv.active:
await self.service_table[srv.service_id].enable()
def get(self,srv_id): def get(self,srv_id):
if srv_id in self.proxy_table: if srv_id in self.service_table:
return self.proxy_table[srv_id] return self.service_table[srv_id]
else: else:
raise ServiceNotFoundException() raise ServiceNotFoundException()
@@ -49,66 +49,33 @@ class ServiceManager:
def __init__(self, srv: Service, db): def __init__(self, srv: Service, db):
self.srv = srv self.srv = srv
self.db = db self.db = db
self.status = STATUS.STOP self.active = False
self.filters: Dict[int, FiregexFilter] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.interceptor = None
async def _update_filters_from_db(self): async def enable(self,to):
regexes = [ if (self.status != to):
Regex.from_dict(ele) for ele in
self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id)
]
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f.id for f in regexes])
#remove old filters
for f in old_filters:
if not f in new_filters:
del self.filters[f]
#add new filters
for f in new_filters:
if not f in old_filters:
filter = [ele for ele in regexes if ele.id == f][0]
self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater)
if self.interceptor: await self.interceptor.reload(self.filters.values())
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id)
async def next(self,to):
async with self.lock: async with self.lock:
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP):
await self.stop()
self._set_status(to)
# PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
await self.restart() await self.restart()
def _stats_updater(self,filter:RegexFilter): async def disable(self,to):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id) if (self.status != to):
async with self.lock:
await self.stop()
self._set_status(to)
def _set_status(self,status): def _set_status(self,active):
self.status = status self.active = active
self.__update_status_db(status) self.db.query("UPDATE services SET active = ? WHERE service_id = ?;", active, self.srv.service_id)
async def start(self): async def start(self):
if not self.interceptor: if not self.active:
FiregexTables().delete_by_srv(self.srv) nft.delete(self.srv)
self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int)) nft.add(self.srv)
await self._update_filters_from_db() self._set_status(True)
self._set_status(STATUS.ACTIVE)
async def stop(self): async def stop(self):
FiregexTables().delete_by_srv(self.srv) nft.delete(self.srv)
if self.interceptor:
await self.interceptor.stop()
self.interceptor = None
async def restart(self): async def restart(self):
await self.stop() await self.stop()
await self.start() await self.start()
async def update_filters(self):
async with self.lock:
await self._update_filters_from_db()

View File

@@ -1,18 +1,21 @@
from typing import List from typing import List
from modules.porthijack.models import Service
from utils import ip_parse, ip_family, NFTableManager from utils import ip_parse, ip_family, NFTableManager
class FiregexFilter(): class FiregexHijackRule():
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None): def __init__(self, proto:str, public_port:int,proxy_port:int, ip_int:str, target:str, id:int):
self.id = int(id) if id else None self.id = id
self.queue = queue
self.target = target self.target = target
self.proto = proto self.proto = proto
self.port = int(port) self.public_port = public_port
self.proxy_port = proxy_port
self.ip_int = str(ip_int) self.ip_int = str(ip_int)
def __eq__(self, o: object) -> bool: def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter): if isinstance(o, FiregexHijackRule):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
elif isinstance(o, Service):
return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return False return False
class FiregexTables(NFTableManager): class FiregexTables(NFTableManager):
@@ -46,8 +49,12 @@ class FiregexTables(NFTableManager):
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.postrouting_porthijack}}} {"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.postrouting_porthijack}}}
]) ])
def add(self, ip_int, proto, public_port, proxy_port): def add(self, srv:Service):
ip_int = ip_parse(ip_int)
for ele in self.get():
if ele.__eq__(srv): return
ip_int = ip_parse(srv.ip_int)
ip_addr = str(ip_int).split("/")[0] ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1]) ip_addr_cidr = int(str(ip_int).split("/")[1])
self.cmd({ "insert":{ "rule": { self.cmd({ "insert":{ "rule": {
@@ -56,8 +63,8 @@ class FiregexTables(NFTableManager):
"chain": self.prerouting_porthijack, "chain": self.prerouting_porthijack,
"expr": [ "expr": [
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, {'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}},
{'match': {'left': { "payload": {"protocol": str(proto), "field": "dport"}}, "op": "==", "right": int(public_port)}}, {'match': {'left': { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.public_port)}},
{'mangle': {'key': {'payload': {'protocol': str(proto), 'field': 'dport'}}, 'value': int(proxy_port)}} {'mangle': {'key': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'value': int(srv.proxy_port)}}
] ]
}}}) }}})
self.cmd({ "insert":{ "rule": { self.cmd({ "insert":{ "rule": {
@@ -66,30 +73,36 @@ class FiregexTables(NFTableManager):
"chain": self.postrouting_porthijack, "chain": self.postrouting_porthijack,
"expr": [ "expr": [
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, {'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}},
{'match': {'left': { "payload": {"protocol": str(proto), "field": "sport"}}, "op": "==", "right": int(proxy_port)}}, {'match': {'left': { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.proxy_port)}},
{'mangle': {'key': {'payload': {'protocol': str(proto), 'field': 'sport'}}, 'value': int(public_port)}} {'mangle': {'key': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'value': int(srv.public_port)}}
] ]
}}}) }}})
def get(self) -> List[FiregexFilter]:
def get(self) -> List[FiregexHijackRule]:
res = [] res = []
for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]): for filter in self.list_rules(tables=[self.table_name], chains=[self.prerouting_porthijack,self.postrouting_porthijack]):
queue_str = filter["expr"][2]["queue"]["num"]
queue = None
if isinstance(queue_str,dict): queue = int(queue_str["range"][0]), int(queue_str["range"][1])
else: queue = int(queue_str), int(queue_str)
ip_int = None ip_int = None
if isinstance(filter["expr"][0]["match"]["right"],str): if isinstance(filter["expr"][0]["match"]["right"],str):
ip_int = str(ip_parse(filter["expr"][0]["match"]["right"])) ip_int = str(ip_parse(filter["expr"][0]["match"]["right"]))
else: else:
ip_int = f'{filter["expr"][0]["match"]["right"]["prefix"]["addr"]}/{filter["expr"][0]["match"]["right"]["prefix"]["len"]}' ip_int = f'{filter["expr"][0]["match"]["right"]["prefix"]["addr"]}/{filter["expr"][0]["match"]["right"]["prefix"]["len"]}'
res.append(FiregexFilter( res.append(FiregexHijackRule(
target=filter["chain"], target=filter["chain"],
id=int(filter["handle"]), id=int(filter["handle"]),
queue=queue,
proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"], proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"],
port=filter["expr"][1]["match"]["right"], public_port=filter["expr"][1]["match"]["right"] if filter["target"] == self.prerouting_porthijack else filter["expr"][2]["mangle"]["value"],
proxy_port=filter["expr"][1]["match"]["right"] if filter["target"] == self.postrouting_porthijack else filter["expr"][2]["mangle"]["value"],
ip_int=ip_int ip_int=ip_int
)) ))
return res return res
def delete(self, srv:Service):
for filter in self.get():
if filter.__eq__(srv):
self.cmd({ "delete":{ "rule": {
"family": "inet",
"table": self.table_name,
"chain": filter.target,
"handle": filter.id
}}})

View File

@@ -7,7 +7,7 @@ from utils.sqlite import SQLite
from utils import ip_parse, refactor_name, refresh_frontend from utils import ip_parse, refactor_name, refresh_frontend
from utils.models import ResetRequest, StatusMessageModel from utils.models import ResetRequest, StatusMessageModel
from modules.porthijack.nftables import FiregexTables from modules.porthijack.nftables import FiregexTables
from modules.porthijack.firewall import STATUS, FirewallManager from modules.porthijack.firewall import FirewallManager
class ServiceModel(BaseModel): class ServiceModel(BaseModel):
service_id: str service_id: str
@@ -107,14 +107,14 @@ async def get_service_by_id(service_id: str, ):
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel) @app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
async def service_stop(service_id: str, ): async def service_stop(service_id: str, ):
"""Request the stop of a specific service""" """Request the stop of a specific service"""
await firewall.get(service_id).next(STATUS.STOP) await firewall.get(service_id).disable()
await refresh_frontend() await refresh_frontend()
return {'status': 'ok'} return {'status': 'ok'}
@app.get('/service/{service_id}/start', response_model=StatusMessageModel) @app.get('/service/{service_id}/start', response_model=StatusMessageModel)
async def service_start(service_id: str, ): async def service_start(service_id: str, ):
"""Request the start of a specific service""" """Request the start of a specific service"""
await firewall.get(service_id).next(STATUS.ACTIVE) await firewall.get(service_id).enable()
await refresh_frontend() await refresh_frontend()
return {'status': 'ok'} return {'status': 'ok'}