diff --git a/backend/modules/firewall/firewall.py b/backend/modules/firewall/firewall.py index 2ea594c..e0a3c04 100644 --- a/backend/modules/firewall/firewall.py +++ b/backend/modules/firewall/firewall.py @@ -15,10 +15,10 @@ class FirewallManager: nft.reset() async def init(self): - FiregexTables().init() + nft.init() await self.reload() async def reload(self): async with self.lock: - nft.set(map(Rule.from_dict, self.db.query('SELECT * FROM rules WHERE active = 1 ORDER BY rule_id;'))) + nft.set(map(Rule.from_dict, self.db.query('SELECT * FROM rules WHERE active = 1 ORDER BY rule_id;')), policy=self.db.get('POLICY', 'accept')) diff --git a/backend/modules/firewall/models.py b/backend/modules/firewall/models.py index b7a2aa0..c124cdd 100644 --- a/backend/modules/firewall/models.py +++ b/backend/modules/firewall/models.py @@ -1,8 +1,5 @@ class Rule: - def __init__(self, rule_id: int, name: str, active: bool, proto: str, ip_src:str, ip_dst:str, port_src_from:str, port_dst_from:str, port_src_to:str, port_dst_to:str, action:str, mode:str): - self.rule_id = rule_id - self.active = active - self.name = name + def __init__(self, proto: str, ip_src:str, ip_dst:str, port_src_from:str, port_dst_from:str, port_src_to:str, port_dst_to:str, action:str, mode:str): self.proto = proto self.ip_src = ip_src self.ip_dst = ip_dst @@ -18,9 +15,6 @@ class Rule: @classmethod def from_dict(cls, var: dict): return cls( - rule_id=var["rule_id"], - active=var["active"], - name=var["name"], proto=var["proto"], ip_src=var["ip_src"], ip_dst=var["ip_dst"], diff --git a/backend/modules/firewall/nftables.py b/backend/modules/firewall/nftables.py index a5b3c37..b44bc0f 100644 --- a/backend/modules/firewall/nftables.py +++ b/backend/modules/firewall/nftables.py @@ -28,8 +28,8 @@ class FiregexTables(NFTableManager): rules_chain_in = "firewall_rules_in" rules_chain_out = "firewall_rules_out" - def __init__(self): - super().__init__([ + def init_comands(self, policy:str="accept", policy_out:str="accept"): + return [ {"add":{"chain":{ "family":"inet", "table":self.table_name, @@ -37,7 +37,7 @@ class FiregexTables(NFTableManager): "type":"filter", "hook":"prerouting", "prio":-300, - "policy":"accept" + "policy":policy }}}, {"add":{"chain":{ "family":"inet", @@ -46,24 +46,38 @@ class FiregexTables(NFTableManager): "type":"filter", "hook":"postrouting", "prio":-300, - "policy":"accept" + "policy":policy_out }}}, - ],[ + ] + + def __init__(self): + super().__init__(self.init_comands(),[ {"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}}, {"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}}, {"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}}, {"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}}, ]) - def delete_all(self): - self.cmd( - {"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}}, - {"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}}, - ) - - def set(self, srv:list[Rule]): - self.delete_all() - for ele in srv: self.add(ele) + def set(self, srvs:list[Rule], policy:str="accept"): + srvs = list(srvs) + self.reset() + if policy == "reject": + policy = "drop" + srvs.extend([ + Rule( + proto="any", + ip_src=iprule, + ip_dst=iprule, + port_src_from=1, + port_dst_from=1, + port_src_to=65535, + port_dst_to=65535, + action="reject", + mode="I" + ) for iprule in ["0.0.0.0/0", "::/0"] + ]) + self.cmd(*self.init_comands(policy)) + for ele in srvs[::-1]: self.add(ele) def add(self, srv:Rule): port_filters = [] diff --git a/backend/modules/porthijack/firewall.py b/backend/modules/porthijack/firewall.py index ebb4cfc..29e2f06 100644 --- a/backend/modules/porthijack/firewall.py +++ b/backend/modules/porthijack/firewall.py @@ -56,7 +56,7 @@ class FirewallManager: del self.service_table[srv_id] async def init(self): - FiregexTables().init() + nft.init() await self.reload() async def reload(self): diff --git a/backend/routers/firewall.py b/backend/routers/firewall.py index 7fc9721..13d11e9 100644 --- a/backend/routers/firewall.py +++ b/backend/routers/firewall.py @@ -19,7 +19,11 @@ class RuleModel(BaseModel): port_dst_to: PortType action: str mode:str - + +class RuleForm(BaseModel): + rules: list[RuleModel] + policy: str + class RuleAddResponse(BaseModel): status:str|list[dict] @@ -51,6 +55,8 @@ db = SQLite('db/firewall-rules.db', { ] }) +firewall = FirewallManager(db) + async def reset(params: ResetRequest): if not params.delete: db.backup() @@ -79,18 +85,18 @@ async def apply_changes(): await refresh_frontend() return {'status': 'ok'} -firewall = FirewallManager(db) - @app.get('/stats', response_model=GeneralStatModel) async def get_general_stats(): """Get firegex general status about rules""" return db.query("SELECT (SELECT COUNT(*) FROM rules) rules")[0] -@app.get('/rules', response_model=list[RuleModel]) +@app.get('/rules', response_model=RuleForm) async def get_rule_list(): """Get the list of existent firegex rules""" - return db.query("SELECT active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id;") - + return { + "policy": db.get("POLICY", "accept"), + "rules": db.query("SELECT active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id;") + } @app.get('/rule/{rule_id}/disable', response_model=StatusMessageModel) async def service_disable(rule_id: str): """Request disabling a specific rule""" @@ -141,10 +147,12 @@ def parse_and_check_rule(rule:RuleModel): @app.post('/rules/set', response_model=RuleAddResponse) -async def add_new_service(form: list[RuleModel]): +async def add_new_service(form: RuleForm): """Add a new service""" - form = [parse_and_check_rule(ele) for ele in form] - errors = [({"rule":i} | ele) for i, ele in enumerate(form) if isinstance(ele, dict)] + if form.policy not in ["accept", "drop", "reject"]: + return {"status": "Invalid policy"} + rules = [parse_and_check_rule(ele) for ele in form.rules] + errors = [({"rule":i} | ele) for i, ele in enumerate(rules) if isinstance(ele, dict)] if len(errors) > 0: return {'status': errors} try: @@ -164,8 +172,9 @@ async def add_new_service(form: list[RuleModel]): ele.port_src_from, ele.port_dst_from, ele.port_src_to, ele.port_dst_to, ele.action, ele.mode - ) for rid, ele in enumerate(form)] + ) for rid, ele in enumerate(rules)] ) + db.set("POLICY", form.policy) except sqlite3.IntegrityError: return {'status': 'Error saving the rules: maybe there are duplicated rules'} return await apply_changes() diff --git a/backend/utils/sqlite.py b/backend/utils/sqlite.py index 6bc688a..426265c 100644 --- a/backend/utils/sqlite.py +++ b/backend/utils/sqlite.py @@ -99,10 +99,10 @@ class SQLite(): self.create_schema(self.schema) self.put('DB_VERSION', self.DB_VER) - def get(self, key): + def get(self, key, default = None): q = self.query('SELECT value FROM keys_values WHERE key = ?', key) if len(q) == 0: - return None + return default else: return q[0]["value"] @@ -111,3 +111,6 @@ class SQLite(): self.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value)) else: self.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key) + + def set(self, key, value): + return self.put(key, value)