Code refactoring and adding port-hijacking backup commit
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from typing import Dict, List, Set
|
||||
from modules.nfregex.nftables import FiregexFilter, FiregexTables
|
||||
from modules.nfregex.nftables import FiregexTables
|
||||
from utils import ip_parse, run_func
|
||||
from modules.nfregex.models import Service, Regex
|
||||
import re, os, asyncio
|
||||
@@ -54,7 +54,7 @@ class RegexFilter:
|
||||
class FiregexInterceptor:
|
||||
|
||||
def __init__(self):
|
||||
self.filter:FiregexFilter
|
||||
self.srv:Service
|
||||
self.filter_map_lock:asyncio.Lock
|
||||
self.filter_map: Dict[str, RegexFilter]
|
||||
self.regex_filters: Set[RegexFilter]
|
||||
@@ -63,16 +63,14 @@ class FiregexInterceptor:
|
||||
self.update_task: asyncio.Task
|
||||
|
||||
@classmethod
|
||||
async def start(cls, filter: FiregexFilter):
|
||||
async def start(cls, srv: Service):
|
||||
self = cls()
|
||||
self.filter = filter
|
||||
self.srv = srv
|
||||
self.filter_map_lock = asyncio.Lock()
|
||||
self.update_config_lock = asyncio.Lock()
|
||||
input_range, output_range = await self._start_binary()
|
||||
self.update_task = asyncio.create_task(self.update_blocked())
|
||||
if not filter in nft.get():
|
||||
nft.add_input(queue_range=input_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
|
||||
nft.add_output(queue_range=output_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
|
||||
nft.add(self.srv, input_range, output_range)
|
||||
return self
|
||||
|
||||
async def _start_binary(self):
|
||||
|
||||
@@ -9,38 +9,40 @@ class STATUS:
|
||||
STOP = "stop"
|
||||
ACTIVE = "active"
|
||||
|
||||
nft = FiregexTables()
|
||||
|
||||
class FirewallManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.proxy_table: Dict[str, ServiceManager] = {}
|
||||
self.service_table: Dict[str, ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
for key in list(self.proxy_table.keys()):
|
||||
for key in list(self.service_table.keys()):
|
||||
await self.remove(key)
|
||||
|
||||
async def remove(self,srv_id):
|
||||
async with self.lock:
|
||||
if srv_id in self.proxy_table:
|
||||
await self.proxy_table[srv_id].next(STATUS.STOP)
|
||||
del self.proxy_table[srv_id]
|
||||
if srv_id in self.service_table:
|
||||
await self.service_table[srv_id].next(STATUS.STOP)
|
||||
del self.service_table[srv_id]
|
||||
|
||||
async def init(self):
|
||||
FiregexTables().init()
|
||||
nft.init()
|
||||
await self.reload()
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
for srv in self.db.query('SELECT * FROM services;'):
|
||||
srv = Service.from_dict(srv)
|
||||
if srv.id in self.proxy_table:
|
||||
if srv.id in self.service_table:
|
||||
continue
|
||||
self.proxy_table[srv.id] = ServiceManager(srv, self.db)
|
||||
await self.proxy_table[srv.id].next(srv.status)
|
||||
self.service_table[srv.id] = ServiceManager(srv, self.db)
|
||||
await self.service_table[srv.id].next(srv.status)
|
||||
|
||||
def get(self,srv_id):
|
||||
if srv_id in self.proxy_table:
|
||||
return self.proxy_table[srv_id]
|
||||
if srv_id in self.service_table:
|
||||
return self.service_table[srv_id]
|
||||
else:
|
||||
raise ServiceNotFoundException()
|
||||
|
||||
@@ -95,13 +97,13 @@ class ServiceManager:
|
||||
|
||||
async def start(self):
|
||||
if not self.interceptor:
|
||||
FiregexTables().delete(self.srv)
|
||||
self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv))
|
||||
nft.delete(self.srv)
|
||||
self.interceptor = await FiregexInterceptor.start(self.srv)
|
||||
await self._update_filters_from_db()
|
||||
self._set_status(STATUS.ACTIVE)
|
||||
|
||||
async def stop(self):
|
||||
FiregexTables().delete(self.srv)
|
||||
nft.delete(self.srv)
|
||||
if self.interceptor:
|
||||
await self.interceptor.stop()
|
||||
self.interceptor = None
|
||||
|
||||
@@ -2,9 +2,9 @@ from typing import List
|
||||
from modules.nfregex.models import Service
|
||||
from utils import ip_parse, ip_family, NFTableManager
|
||||
|
||||
class FiregexFilter():
|
||||
def __init__(self, proto:str, port:int, ip_int:str, target:str=None, id=None):
|
||||
self.id = int(id) if id else None
|
||||
class FiregexFilter:
|
||||
def __init__(self, proto:str, port:int, ip_int:str, target:str, id:int):
|
||||
self.id = id
|
||||
self.target = target
|
||||
self.proto = proto
|
||||
self.port = int(port)
|
||||
@@ -13,6 +13,8 @@ class FiregexFilter():
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, FiregexFilter):
|
||||
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
|
||||
elif isinstance(o, Service):
|
||||
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
|
||||
return False
|
||||
|
||||
class FiregexTables(NFTableManager):
|
||||
@@ -47,10 +49,14 @@ class FiregexTables(NFTableManager):
|
||||
])
|
||||
|
||||
def add(self, srv:Service, queue_range_input, queue_range_output):
|
||||
|
||||
for ele in self.get():
|
||||
if ele.__eq__(srv): return
|
||||
|
||||
ip_int = ip_parse(srv.ip_int)
|
||||
ip_addr = str(ip_int).split("/")[0]
|
||||
ip_addr_cidr = int(str(ip_int).split("/")[1])
|
||||
|
||||
|
||||
init, end = queue_range_output
|
||||
if init > end: init, end = end, init
|
||||
self.cmd({ "insert":{ "rule": {
|
||||
@@ -97,6 +103,11 @@ class FiregexTables(NFTableManager):
|
||||
|
||||
def delete(self, srv:Service):
|
||||
for filter in self.get():
|
||||
if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int):
|
||||
self.cmd({"delete":{"rule": {"handle": filter.id, "table": self.table_name, "chain": filter.target, "family": "inet"}}})
|
||||
if filter.__eq__(srv):
|
||||
self.cmd({ "delete":{ "rule": {
|
||||
"family": "inet",
|
||||
"table": self.table_name,
|
||||
"chain": filter.target,
|
||||
"handle": filter.id
|
||||
}}})
|
||||
|
||||
Reference in New Issue
Block a user