adding firewall function to firegex!
This commit is contained in:
0
backend/modules/firewall/__init__.py
Normal file
0
backend/modules/firewall/__init__.py
Normal file
24
backend/modules/firewall/firewall.py
Normal file
24
backend/modules/firewall/firewall.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import asyncio
|
||||
from modules.firewall.nftables import FiregexTables
|
||||
from modules.firewall.models import Rule
|
||||
from utils.sqlite import SQLite
|
||||
|
||||
nft = FiregexTables()
|
||||
|
||||
class FirewallManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
async with self.lock:
|
||||
nft.reset()
|
||||
|
||||
async def init(self):
|
||||
FiregexTables().init()
|
||||
await self.reload()
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
nft.set(map(Rule.from_dict, self.db.query('SELECT * FROM rules WHERE active = 1 ORDER BY rule_id;')))
|
||||
|
||||
33
backend/modules/firewall/models.py
Normal file
33
backend/modules/firewall/models.py
Normal file
@@ -0,0 +1,33 @@
|
||||
class Rule:
|
||||
def __init__(self, rule_id: int, name: str, active: bool, proto: str, ip_src:str, ip_dst:str, port_src_from:str, port_dst_from:str, port_src_to:str, port_dst_to:str, action:str, mode:str):
|
||||
self.rule_id = rule_id
|
||||
self.active = active
|
||||
self.name = name
|
||||
self.proto = proto
|
||||
self.ip_src = ip_src
|
||||
self.ip_dst = ip_dst
|
||||
self.port_src_from = port_src_from
|
||||
self.port_dst_from = port_dst_from
|
||||
self.port_src_to = port_src_to
|
||||
self.port_dst_to = port_dst_to
|
||||
self.action = action
|
||||
self.input_mode = mode in ["I"]
|
||||
self.output_mode = mode in ["O"]
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, var: dict):
|
||||
return cls(
|
||||
rule_id=var["rule_id"],
|
||||
active=var["active"],
|
||||
name=var["name"],
|
||||
proto=var["proto"],
|
||||
ip_src=var["ip_src"],
|
||||
ip_dst=var["ip_dst"],
|
||||
port_dst_from=var["port_dst_from"],
|
||||
port_dst_to=var["port_dst_to"],
|
||||
port_src_from=var["port_src_from"],
|
||||
port_src_to=var["port_src_to"],
|
||||
action=var["action"],
|
||||
mode=var["mode"]
|
||||
)
|
||||
88
backend/modules/firewall/nftables.py
Normal file
88
backend/modules/firewall/nftables.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from modules.firewall.models import Rule
|
||||
from utils import nftables_int_to_json, ip_parse, ip_family, NFTableManager, nftables_json_to_int
|
||||
|
||||
|
||||
class FiregexHijackRule():
|
||||
def __init__(self, proto:str, ip_src:str, ip_dst:str, port_src_from:int, port_dst_from:int, port_src_to:int, port_dst_to:int, action:str, target:str, id:int):
|
||||
self.id = id
|
||||
self.target = target
|
||||
self.proto = proto
|
||||
self.ip_src = ip_src
|
||||
self.ip_dst = ip_dst
|
||||
self.port_src_from = min(port_src_from, port_src_to)
|
||||
self.port_dst_from = min(port_dst_from, port_dst_to)
|
||||
self.port_src_to = max(port_src_from, port_src_to)
|
||||
self.port_dst_to = max(port_dst_from, port_dst_to)
|
||||
self.action = action
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, FiregexHijackRule) or isinstance(o, Rule):
|
||||
return self.action == o.action and self.proto == o.proto and\
|
||||
ip_parse(self.ip_src) == ip_parse(o.ip_src) and ip_parse(self.ip_dst) == ip_parse(o.ip_dst) and\
|
||||
int(self.port_src_from) == int(o.port_src_from) and int(self.port_dst_from) == int(o.port_dst_from) and\
|
||||
int(self.port_src_to) == int(o.port_src_to) and int(self.port_dst_to) == int(o.port_dst_to)
|
||||
return False
|
||||
|
||||
|
||||
class FiregexTables(NFTableManager):
|
||||
rules_chain_in = "firewall_rules_in"
|
||||
rules_chain_out = "firewall_rules_out"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__([
|
||||
{"add":{"chain":{
|
||||
"family":"inet",
|
||||
"table":self.table_name,
|
||||
"name":self.rules_chain_in,
|
||||
"type":"filter",
|
||||
"hook":"prerouting",
|
||||
"prio":-300,
|
||||
"policy":"accept"
|
||||
}}},
|
||||
{"add":{"chain":{
|
||||
"family":"inet",
|
||||
"table":self.table_name,
|
||||
"name":self.rules_chain_out,
|
||||
"type":"filter",
|
||||
"hook":"postrouting",
|
||||
"prio":-300,
|
||||
"policy":"accept"
|
||||
}}},
|
||||
],[
|
||||
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}},
|
||||
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}},
|
||||
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}},
|
||||
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}},
|
||||
])
|
||||
|
||||
def delete_all(self):
|
||||
self.cmd(
|
||||
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}},
|
||||
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}},
|
||||
)
|
||||
|
||||
def set(self, srv:list[Rule]):
|
||||
self.delete_all()
|
||||
for ele in srv: self.add(ele)
|
||||
|
||||
def add(self, srv:Rule):
|
||||
port_filters = []
|
||||
if srv.proto != "any":
|
||||
if srv.port_src_from != 1 or srv.port_src_to != 65535: #Any Port
|
||||
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '>=', 'right': int(srv.port_src_from)}})
|
||||
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '<=', 'right': int(srv.port_src_to)}})
|
||||
if srv.port_dst_from != 1 or srv.port_dst_to != 65535: #Any Port
|
||||
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'op': '>=', 'right': int(srv.port_dst_from)}})
|
||||
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'op': '<=', 'right': int(srv.port_dst_to)}})
|
||||
if len(port_filters) == 0:
|
||||
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '!=', 'right': 0}}) #filter the protocol if no port is specified
|
||||
|
||||
self.cmd({ "insert":{ "rule": {
|
||||
"family": "inet",
|
||||
"table": self.table_name,
|
||||
"chain": self.rules_chain_out if srv.output_mode else self.rules_chain_in,
|
||||
"expr": [
|
||||
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_src), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_src)}},
|
||||
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_dst), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_dst)}},
|
||||
] + port_filters + [{'accept': None} if srv.action == "accept" else {'reject': {}} if srv.action == "reject" else {'drop': None}]
|
||||
}}})
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Dict, List, Set
|
||||
from modules.nfregex.nftables import FiregexTables
|
||||
from utils import ip_parse, run_func
|
||||
from modules.nfregex.models import Service, Regex
|
||||
@@ -56,8 +55,8 @@ class FiregexInterceptor:
|
||||
def __init__(self):
|
||||
self.srv:Service
|
||||
self.filter_map_lock:asyncio.Lock
|
||||
self.filter_map: Dict[str, RegexFilter]
|
||||
self.regex_filters: Set[RegexFilter]
|
||||
self.filter_map: dict[str, RegexFilter]
|
||||
self.regex_filters: set[RegexFilter]
|
||||
self.update_config_lock:asyncio.Lock
|
||||
self.process:asyncio.subprocess.Process
|
||||
self.update_task: asyncio.Task
|
||||
@@ -118,7 +117,7 @@ class FiregexInterceptor:
|
||||
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
|
||||
await self.process.stdin.drain()
|
||||
|
||||
async def reload(self, filters:List[RegexFilter]):
|
||||
async def reload(self, filters:list[RegexFilter]):
|
||||
async with self.filter_map_lock:
|
||||
self.filter_map = self.compile_filters(filters)
|
||||
filters_codes = self.get_filter_codes()
|
||||
@@ -129,7 +128,7 @@ class FiregexInterceptor:
|
||||
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
|
||||
return filters_codes
|
||||
|
||||
def compile_filters(self, filters:List[RegexFilter]):
|
||||
def compile_filters(self, filters:list[RegexFilter]):
|
||||
res = {}
|
||||
for filter_obj in filters:
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from modules.nfregex.firegex import FiregexInterceptor, RegexFilter
|
||||
from modules.nfregex.nftables import FiregexTables, FiregexFilter
|
||||
from modules.nfregex.models import Regex, Service
|
||||
@@ -11,49 +10,13 @@ class STATUS:
|
||||
|
||||
nft = FiregexTables()
|
||||
|
||||
class FirewallManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.service_table: Dict[str, ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
for key in list(self.service_table.keys()):
|
||||
await self.remove(key)
|
||||
|
||||
async def remove(self,srv_id):
|
||||
async with self.lock:
|
||||
if srv_id in self.service_table:
|
||||
await self.service_table[srv_id].next(STATUS.STOP)
|
||||
del self.service_table[srv_id]
|
||||
|
||||
async def init(self):
|
||||
nft.init()
|
||||
await self.reload()
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
for srv in self.db.query('SELECT * FROM services;'):
|
||||
srv = Service.from_dict(srv)
|
||||
if srv.id in self.service_table:
|
||||
continue
|
||||
self.service_table[srv.id] = ServiceManager(srv, self.db)
|
||||
await self.service_table[srv.id].next(srv.status)
|
||||
|
||||
def get(self,srv_id):
|
||||
if srv_id in self.service_table:
|
||||
return self.service_table[srv_id]
|
||||
else:
|
||||
raise ServiceNotFoundException()
|
||||
|
||||
class ServiceNotFoundException(Exception): pass
|
||||
|
||||
class ServiceManager:
|
||||
def __init__(self, srv: Service, db):
|
||||
self.srv = srv
|
||||
self.db = db
|
||||
self.status = STATUS.STOP
|
||||
self.filters: Dict[int, FiregexFilter] = {}
|
||||
self.filters: dict[int, FiregexFilter] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
self.interceptor = None
|
||||
|
||||
@@ -114,4 +77,41 @@ class ServiceManager:
|
||||
|
||||
async def update_filters(self):
|
||||
async with self.lock:
|
||||
await self._update_filters_from_db()
|
||||
await self._update_filters_from_db()
|
||||
|
||||
class FirewallManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.service_table: dict[str, ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
for key in list(self.service_table.keys()):
|
||||
await self.remove(key)
|
||||
|
||||
async def remove(self,srv_id):
|
||||
async with self.lock:
|
||||
if srv_id in self.service_table:
|
||||
await self.service_table[srv_id].next(STATUS.STOP)
|
||||
del self.service_table[srv_id]
|
||||
|
||||
async def init(self):
|
||||
nft.init()
|
||||
await self.reload()
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
for srv in self.db.query('SELECT * FROM services;'):
|
||||
srv = Service.from_dict(srv)
|
||||
if srv.id in self.service_table:
|
||||
continue
|
||||
self.service_table[srv.id] = ServiceManager(srv, self.db)
|
||||
await self.service_table[srv.id].next(srv.status)
|
||||
|
||||
def get(self,srv_id) -> ServiceManager:
|
||||
if srv_id in self.service_table:
|
||||
return self.service_table[srv_id]
|
||||
else:
|
||||
raise ServiceNotFoundException()
|
||||
|
||||
class ServiceNotFoundException(Exception): pass
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
from modules.nfregex.models import Service
|
||||
from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json
|
||||
|
||||
@@ -11,9 +10,7 @@ class FiregexFilter:
|
||||
self.ip_int = str(ip_int)
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, FiregexFilter):
|
||||
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
|
||||
elif isinstance(o, Service):
|
||||
if isinstance(o, FiregexFilter) or isinstance(o, Service):
|
||||
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
|
||||
return False
|
||||
|
||||
@@ -80,7 +77,7 @@ class FiregexTables(NFTableManager):
|
||||
}}})
|
||||
|
||||
|
||||
def get(self) -> List[FiregexFilter]:
|
||||
def get(self) -> list[FiregexFilter]:
|
||||
res = []
|
||||
for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]):
|
||||
ip_int = None
|
||||
|
||||
@@ -1,47 +1,10 @@
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from modules.porthijack.nftables import FiregexTables
|
||||
from modules.porthijack.models import Service
|
||||
from utils.sqlite import SQLite
|
||||
|
||||
nft = FiregexTables()
|
||||
|
||||
class FirewallManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.service_table: Dict[str, ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
for key in list(self.service_table.keys()):
|
||||
await self.remove(key)
|
||||
|
||||
async def remove(self,srv_id):
|
||||
async with self.lock:
|
||||
if srv_id in self.service_table:
|
||||
await self.service_table[srv_id].disable()
|
||||
del self.service_table[srv_id]
|
||||
|
||||
async def init(self):
|
||||
FiregexTables().init()
|
||||
await self.reload()
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
for srv in self.db.query('SELECT * FROM services;'):
|
||||
srv = Service.from_dict(srv)
|
||||
if srv.service_id in self.service_table:
|
||||
continue
|
||||
self.service_table[srv.service_id] = ServiceManager(srv, self.db)
|
||||
if srv.active:
|
||||
await self.service_table[srv.service_id].enable()
|
||||
|
||||
def get(self,srv_id):
|
||||
if srv_id in self.service_table:
|
||||
return self.service_table[srv_id]
|
||||
else:
|
||||
raise ServiceNotFoundException()
|
||||
|
||||
class ServiceNotFoundException(Exception): pass
|
||||
|
||||
class ServiceManager:
|
||||
@@ -74,4 +37,41 @@ class ServiceManager:
|
||||
|
||||
async def restart(self):
|
||||
await self.disable()
|
||||
await self.enable()
|
||||
await self.enable()
|
||||
|
||||
class FirewallManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.service_table: dict[str, ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
for key in list(self.service_table.keys()):
|
||||
await self.remove(key)
|
||||
|
||||
async def remove(self,srv_id):
|
||||
async with self.lock:
|
||||
if srv_id in self.service_table:
|
||||
await self.service_table[srv_id].disable()
|
||||
del self.service_table[srv_id]
|
||||
|
||||
async def init(self):
|
||||
FiregexTables().init()
|
||||
await self.reload()
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
for srv in self.db.query('SELECT * FROM services;'):
|
||||
srv = Service.from_dict(srv)
|
||||
if srv.service_id in self.service_table:
|
||||
continue
|
||||
self.service_table[srv.service_id] = ServiceManager(srv, self.db)
|
||||
if srv.active:
|
||||
await self.service_table[srv.service_id].enable()
|
||||
|
||||
def get(self,srv_id) -> ServiceManager:
|
||||
if srv_id in self.service_table:
|
||||
return self.service_table[srv_id]
|
||||
else:
|
||||
raise ServiceNotFoundException()
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
from modules.porthijack.models import Service
|
||||
from utils import addr_parse, ip_parse, ip_family, NFTableManager, nftables_json_to_int
|
||||
|
||||
@@ -13,9 +12,7 @@ class FiregexHijackRule():
|
||||
self.ip_dst = str(ip_dst)
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, FiregexHijackRule):
|
||||
return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_src) == ip_parse(o.ip_src)
|
||||
elif isinstance(o, Service):
|
||||
if isinstance(o, FiregexHijackRule) or isinstance(o, Service):
|
||||
return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_src) == ip_parse(o.ip_src)
|
||||
return False
|
||||
|
||||
@@ -79,10 +76,9 @@ class FiregexTables(NFTableManager):
|
||||
}}})
|
||||
|
||||
|
||||
def get(self) -> List[FiregexHijackRule]:
|
||||
def get(self) -> list[FiregexHijackRule]:
|
||||
res = []
|
||||
for filter in self.list_rules(tables=[self.table_name], chains=[self.prerouting_porthijack,self.postrouting_porthijack]):
|
||||
filter["expr"][0]["match"]["right"]
|
||||
res.append(FiregexHijackRule(
|
||||
target=filter["chain"],
|
||||
id=int(filter["handle"]),
|
||||
|
||||
@@ -145,7 +145,7 @@ class ServiceManager:
|
||||
class ProxyManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.proxy_table:dict = {}
|
||||
self.proxy_table: dict[str, ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
@@ -168,7 +168,7 @@ class ProxyManager:
|
||||
self.proxy_table[srv_id] = ServiceManager(srv_id,self.db)
|
||||
await self.proxy_table[srv_id].next(req_status)
|
||||
|
||||
def get(self,id):
|
||||
def get(self,id) -> ServiceManager:
|
||||
if id in self.proxy_table:
|
||||
return self.proxy_table[id]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user