Refactoring code pt.1

This commit is contained in:
nik012003
2022-08-11 16:11:32 +02:00
committed by DomySh
parent 1931536516
commit f4fe3d3ab5
9 changed files with 354 additions and 73 deletions

View File

@@ -141,7 +141,3 @@ class FiregexInterceptor:
except Exception: pass except Exception: pass
return res return res
def delete_by_srv(srv:Service):
for filter in nft.get():
if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int):
nft.cmd({"delete":{"rule": {"handle": filter.id, "table": nft.table_name, "chain": filter.target, "family": "inet"}}})

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
from typing import Dict from typing import Dict
from modules.nfregex.firegex import FiregexInterceptor, RegexFilter, delete_by_srv from modules.nfregex.firegex import FiregexInterceptor, RegexFilter
from modules.nfregex.nftables import FiregexTables, FiregexFilter from modules.nfregex.nftables import FiregexTables, FiregexFilter
from modules.nfregex.models import Regex, Service from modules.nfregex.models import Regex, Service
from utils.sqlite import SQLite from utils.sqlite import SQLite
@@ -95,13 +95,13 @@ class ServiceManager:
async def start(self): async def start(self):
if not self.interceptor: if not self.interceptor:
delete_by_srv(self.srv) FiregexTables().delete(self.srv)
self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int)) self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv))
await self._update_filters_from_db() await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE) self._set_status(STATUS.ACTIVE)
async def stop(self): async def stop(self):
delete_by_srv(self.srv) FiregexTables().delete(self.srv)
if self.interceptor: if self.interceptor:
await self.interceptor.stop() await self.interceptor.stop()
self.interceptor = None self.interceptor = None

View File

@@ -1,10 +1,10 @@
from typing import List from typing import List
from modules.nfregex.models import Service
from utils import ip_parse, ip_family, NFTableManager from utils import ip_parse, ip_family, NFTableManager
class FiregexFilter(): class FiregexFilter():
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None): def __init__(self, proto:str, port:int, ip_int:str, target:str=None, id=None):
self.id = int(id) if id else None self.id = int(id) if id else None
self.queue = queue
self.target = target self.target = target
self.proto = proto self.proto = proto
self.port = int(port) self.port = int(port)
@@ -46,47 +46,41 @@ class FiregexTables(NFTableManager):
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.output_chain}}}, {"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.output_chain}}},
]) ])
def add_output(self, queue_range, proto, port, ip_int): def add(self, srv:Service, queue_range_input, queue_range_output):
init, end = queue_range ip_int = ip_parse(srv.ip_int)
if init > end: init, end = end, init
ip_int = ip_parse(ip_int)
ip_addr = str(ip_int).split("/")[0] ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1]) ip_addr_cidr = int(str(ip_int).split("/")[1])
init, end = queue_range_output
if init > end: init, end = end, init
self.cmd({ "insert":{ "rule": { self.cmd({ "insert":{ "rule": {
"family": "inet", "family": "inet",
"table": self.table_name, "table": self.table_name,
"chain": self.output_chain, "chain": self.output_chain,
"expr": [ "expr": [
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, {'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}},
{'match': {"left": { "payload": {"protocol": str(proto), "field": "sport"}}, "op": "==", "right": int(port)}}, {'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}} {"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
] ]
}}}) }}})
def add_input(self, queue_range, proto = None, port = None, ip_int = None): init, end = queue_range_input
init, end = queue_range
if init > end: init, end = end, init if init > end: init, end = end, init
ip_int = ip_parse(ip_int)
ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1])
self.cmd({"insert":{"rule":{ self.cmd({"insert":{"rule":{
"family": "inet", "family": "inet",
"table": self.table_name, "table": self.table_name,
"chain": self.input_chain, "chain": self.input_chain,
"expr": [ "expr": [
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, {'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}},
{'match': {"left": { "payload": {"protocol": str(proto), "field": "dport"}}, "op": "==", "right": int(port)}}, {'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}} {"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
] ]
}}}) }}})
def get(self) -> List[FiregexFilter]: def get(self) -> List[FiregexFilter]:
res = [] res = []
for filter in [ele["rule"] for ele in self.list() if "rule" in ele and ele["rule"]["table"] == self.table_name]: for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]):
queue_str = filter["expr"][2]["queue"]["num"]
queue = None
if isinstance(queue_str,dict): queue = int(queue_str["range"][0]), int(queue_str["range"][1])
else: queue = int(queue_str), int(queue_str)
ip_int = None ip_int = None
if isinstance(filter["expr"][0]["match"]["right"],str): if isinstance(filter["expr"][0]["match"]["right"],str):
ip_int = str(ip_parse(filter["expr"][0]["match"]["right"])) ip_int = str(ip_parse(filter["expr"][0]["match"]["right"]))
@@ -95,10 +89,14 @@ class FiregexTables(NFTableManager):
res.append(FiregexFilter( res.append(FiregexFilter(
target=filter["chain"], target=filter["chain"],
id=int(filter["handle"]), id=int(filter["handle"]),
queue=queue,
proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"], proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"],
port=filter["expr"][1]["match"]["right"], port=filter["expr"][1]["match"]["right"],
ip_int=ip_int ip_int=ip_int
)) ))
return res return res
def delete(self, srv:Service):
for filter in self.get():
if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int):
self.cmd({"delete":{"rule": {"handle": filter.id, "table": self.table_name, "chain": filter.target, "family": "inet"}}})

View File

View File

@@ -0,0 +1,114 @@
import asyncio
from typing import Dict
from modules.porthijack.nftables import FiregexTables, FiregexFilter
from modules.porthijack.models import Service
from utils.sqlite import SQLite
class STATUS:
STOP = "stop"
ACTIVE = "active"
class FirewallManager:
def __init__(self, db:SQLite):
self.db = db
self.proxy_table: Dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
async def close(self):
for key in list(self.proxy_table.keys()):
await self.remove(key)
async def remove(self,srv_id):
async with self.lock:
if srv_id in self.proxy_table:
await self.proxy_table[srv_id].next(STATUS.STOP)
del self.proxy_table[srv_id]
async def init(self):
FiregexTables().init()
await self.reload()
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv)
if srv.id in self.proxy_table:
continue
self.proxy_table[srv.id] = ServiceManager(srv, self.db)
await self.proxy_table[srv.id].next(srv.status)
def get(self,srv_id):
if srv_id in self.proxy_table:
return self.proxy_table[srv_id]
else:
raise ServiceNotFoundException()
class ServiceNotFoundException(Exception): pass
class ServiceManager:
def __init__(self, srv: Service, db):
self.srv = srv
self.db = db
self.status = STATUS.STOP
self.filters: Dict[int, FiregexFilter] = {}
self.lock = asyncio.Lock()
self.interceptor = None
async def _update_filters_from_db(self):
regexes = [
Regex.from_dict(ele) for ele in
self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id)
]
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f.id for f in regexes])
#remove old filters
for f in old_filters:
if not f in new_filters:
del self.filters[f]
#add new filters
for f in new_filters:
if not f in old_filters:
filter = [ele for ele in regexes if ele.id == f][0]
self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater)
if self.interceptor: await self.interceptor.reload(self.filters.values())
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id)
async def next(self,to):
async with self.lock:
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP):
await self.stop()
self._set_status(to)
# PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
await self.restart()
def _stats_updater(self,filter:RegexFilter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id)
def _set_status(self,status):
self.status = status
self.__update_status_db(status)
async def start(self):
if not self.interceptor:
FiregexTables().delete_by_srv(self.srv)
self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int))
await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE)
async def stop(self):
FiregexTables().delete_by_srv(self.srv)
if self.interceptor:
await self.interceptor.stop()
self.interceptor = None
async def restart(self):
await self.stop()
await self.start()
async def update_filters(self):
async with self.lock:
await self._update_filters_from_db()

View File

@@ -0,0 +1,13 @@
class Service:
def __init__(self, service_id: str, active: bool, public_port: int, proxy_port: int, name: str, proto: str, ip_int: str):
self.service_id = service_id
self.active = active
self.public_port = public_port
self.proxy_port = proxy_port
self.name = name
self.proto = proto
self.ip_int = ip_int
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["service_id"], active=var["active"], public_port=var["public_port"], proxy_port=var["proxy_port"], name=var["name"], proto=var["proto"], ip_int=var["ip_int"])

View File

@@ -1,49 +1,19 @@
from typing import List
from utils import ip_parse, ip_family, NFTableManager
from ipaddress import ip_interface class FiregexFilter():
import nftables, traceback def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None):
self.id = int(id) if id else None
def ip_parse(ip:str): self.queue = queue
return str(ip_interface(ip).network) self.target = target
self.proto = proto
def ip_family(ip:str): self.port = int(port)
return "ip6" if ip_interface(ip).version == 6 else "ip" self.ip_int = str(ip_int)
class Singleton(object):
__instance = None
def __new__(class_, *args, **kwargs):
if not isinstance(class_.__instance, class_):
class_.__instance = object.__new__(class_, *args, **kwargs)
return class_.__instance
class NFTableManager(Singleton):
table_name = "firegex"
def __init__(self, init_cmd, reset_cmd):
self.__init_cmds = init_cmd
self.__reset_cmds = reset_cmd
self.nft = nftables.Nftables()
def raw_cmd(self, *cmds):
return self.nft.json_cmd({"nftables": list(cmds)})
def cmd(self, *cmds):
code, out, err = self.raw_cmd(*cmds)
if code == 0: return out
else: raise Exception(err)
def init(self):
self.reset()
self.raw_cmd({"add":{"table":{"name":self.table_name,"family":"inet"}}})
self.cmd(*self.__init_cmds)
def reset(self):
self.raw_cmd(*self.__reset_cmds)
def list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"]
def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return False
class FiregexTables(NFTableManager): class FiregexTables(NFTableManager):
prerouting_porthijack = "prerouting_porthijack" prerouting_porthijack = "prerouting_porthijack"
@@ -100,3 +70,26 @@ class FiregexTables(NFTableManager):
{'mangle': {'key': {'payload': {'protocol': str(proto), 'field': 'sport'}}, 'value': int(public_port)}} {'mangle': {'key': {'payload': {'protocol': str(proto), 'field': 'sport'}}, 'value': int(public_port)}}
] ]
}}}) }}})
def get(self) -> List[FiregexFilter]:
res = []
for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]):
queue_str = filter["expr"][2]["queue"]["num"]
queue = None
if isinstance(queue_str,dict): queue = int(queue_str["range"][0]), int(queue_str["range"][1])
else: queue = int(queue_str), int(queue_str)
ip_int = None
if isinstance(filter["expr"][0]["match"]["right"],str):
ip_int = str(ip_parse(filter["expr"][0]["match"]["right"]))
else:
ip_int = f'{filter["expr"][0]["match"]["right"]["prefix"]["addr"]}/{filter["expr"][0]["match"]["right"]["prefix"]["len"]}'
res.append(FiregexFilter(
target=filter["chain"],
id=int(filter["handle"]),
queue=queue,
proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"],
port=filter["expr"][1]["match"]["right"],
ip_int=ip_int
))
return res

View File

@@ -0,0 +1,160 @@
import secrets
import sqlite3
from typing import List, Union
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from utils.sqlite import SQLite
from utils import ip_parse, refactor_name, refresh_frontend
from utils.models import ResetRequest, StatusMessageModel
from modules.porthijack.nftables import FiregexTables
from modules.porthijack.firewall import STATUS, FirewallManager
class ServiceModel(BaseModel):
service_id: str
active: bool
public_port: int
proxy_port: int
name: str
proto: str
ip_int: str
class RenameForm(BaseModel):
name:str
class ServiceAddForm(BaseModel):
name: str
public_port: int
proxy_port: int
proto: str
ip_int: str
class ServiceAddResponse(BaseModel):
status:str
service_id: Union[None,str]
class GeneralStatModel(BaseModel):
services: int
app = APIRouter()
db = SQLite('db/port-hijacking.db', {
'services': {
'service_id': 'VARCHAR(100) PRIMARY KEY',
'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1))',
'public_port': 'INT NOT NULL CHECK(port > 0 and port < 65536)',
'proxy_port': 'INT NOT NULL CHECK(port > 0 and port < 65536)',
'name': 'VARCHAR(100) NOT NULL UNIQUE',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp"))',
'ip_int': 'VARCHAR(100) NOT NULL',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (public_port, ip_int, proto);",
]
})
async def reset(params: ResetRequest):
if not params.delete:
db.backup()
await firewall.close()
FiregexTables().reset()
if params.delete:
db.delete()
db.init()
else:
db.restore()
await firewall.init()
async def startup():
db.init()
await firewall.init()
async def shutdown():
db.backup()
await firewall.close()
db.disconnect()
db.restore()
def gen_service_id():
while True:
res = secrets.token_hex(8)
if len(db.query('SELECT 1 FROM services WHERE service_id = ?;', res)) == 0:
break
return res
firewall = FirewallManager(db)
@app.get('/stats', response_model=GeneralStatModel)
async def get_general_stats():
"""Get firegex general status about services"""
return db.query("""
SELECT
(SELECT COUNT(*) FROM services) services
""")[0]
@app.get('/services', response_model=List[ServiceModel])
async def get_service_list():
"""Get the list of existent firegex services"""
return db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_int FROM services;")
@app.get('/service/{service_id}', response_model=ServiceModel)
async def get_service_by_id(service_id: str, ):
"""Get info about a specific service using his id"""
res = db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_int FROM services WHERE service_id = ?;", service_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
return res[0]
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
async def service_stop(service_id: str, ):
"""Request the stop of a specific service"""
await firewall.get(service_id).next(STATUS.STOP)
await refresh_frontend()
return {'status': 'ok'}
@app.get('/service/{service_id}/start', response_model=StatusMessageModel)
async def service_start(service_id: str, ):
"""Request the start of a specific service"""
await firewall.get(service_id).next(STATUS.ACTIVE)
await refresh_frontend()
return {'status': 'ok'}
@app.get('/service/{service_id}/delete', response_model=StatusMessageModel)
async def service_delete(service_id: str, ):
"""Request the deletion of a specific service"""
db.query('DELETE FROM services WHERE service_id = ?;', service_id)
await firewall.remove(service_id)
await refresh_frontend()
return {'status': 'ok'}
@app.post('/service/{service_id}/rename', response_model=StatusMessageModel)
async def service_rename(service_id: str, form: RenameForm, ):
"""Request to change the name of a specific service"""
form.name = refactor_name(form.name)
if not form.name: return {'status': 'The name cannot be empty!'}
try:
db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id)
except sqlite3.IntegrityError:
return {'status': 'This name is already used'}
await refresh_frontend()
return {'status': 'ok'}
@app.post('/services/add', response_model=ServiceAddResponse)
async def add_new_service(form: ServiceAddForm, ):
"""Add a new service"""
try:
form.ip_int = ip_parse(form.ip_int)
except ValueError:
return {"status":"Invalid address"}
if form.proto not in ["tcp", "udp"]:
return {"status":"Invalid protocol"}
srv_id = None
try:
srv_id = gen_service_id()
db.query("INSERT INTO services (service_id, active, public_port, proxy_port, name, proto, ip_int) VALUES (?, ?, ?, ?, ?, ?, ?)",
srv_id, False, form.public_port, form.proxy_port , form.name, form.proto, form.ip_int)
except sqlite3.IntegrityError:
return {'status': 'This type of service already exists'}
await firewall.reload()
await refresh_frontend()
return {'status': 'ok', 'service_id': srv_id}

View File

@@ -81,6 +81,13 @@ class NFTableManager(Singleton):
def reset(self): def reset(self):
self.raw_cmd(*self.__reset_cmds) self.raw_cmd(*self.__reset_cmds)
def list(self): def list_rules(self, tables = None, chains = None):
for filter in [ele["rule"] for ele in self.raw_list() if "rule" in ele ]:
if tables and filter["table"] not in tables: continue
if chains and filter["chain"] not in chains: continue
yield filter
def raw_list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"] return self.cmd({"list": {"ruleset": None}})["nftables"]