import sqlite3 from fastapi import APIRouter, HTTPException from pydantic import BaseModel from utils.sqlite import SQLite from utils import ip_parse, ip_family, socketio_emit, PortType from utils.models import ResetRequest, StatusMessageModel from modules.firewall.nftables import FiregexTables from modules.firewall.firewall import FirewallManager class RuleModel(BaseModel): active: bool name: str proto: str ip_src: str ip_dst: str port_src_from: PortType port_dst_from: PortType port_src_to: PortType port_dst_to: PortType action: str mode:str class RuleFormAdd(BaseModel): rules: list[RuleModel] policy: str class RuleInfo(BaseModel): rules: list[RuleModel] policy: str enabled: bool class RenameForm(BaseModel): name:str class FirewallSettings(BaseModel): keep_rules: bool allow_loopback: bool allow_established: bool app = APIRouter() db = SQLite('db/firewall-rules.db', { 'rules': { 'rule_id': 'INT PRIMARY KEY CHECK (rule_id >= 0)', 'mode': 'VARCHAR(1) NOT NULL CHECK (mode IN ("O", "I"))', # O = out, I = in, B = both 'name': 'VARCHAR(100) NOT NULL', 'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1))', 'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp", "any"))', 'ip_src': 'VARCHAR(100) NOT NULL', 'port_src_from': 'INT CHECK(port_src_from > 0 and port_src_from < 65536)', 'port_src_to': 'INT CHECK(port_src_to > 0 and port_src_to < 65536 and port_src_from <= port_src_to)', 'ip_dst': 'VARCHAR(100) NOT NULL', 'port_dst_from': 'INT CHECK(port_dst_from > 0 and port_dst_from < 65536)', 'port_dst_to': 'INT CHECK(port_dst_to > 0 and port_dst_to < 65536 and port_dst_from <= port_dst_to)', 'action': 'VARCHAR(10) NOT NULL CHECK (action IN ("accept", "drop", "reject"))', }, 'QUERY':[ "CREATE UNIQUE INDEX IF NOT EXISTS unique_rules ON rules (proto, ip_src, ip_dst, port_src_from, port_src_to, port_dst_from, port_dst_to, mode);" ] }) firewall = FirewallManager(db) async def reset(params: ResetRequest): if not params.delete: db.backup() await firewall.close() FiregexTables().reset() if params.delete: db.delete() db.init() else: db.restore() await firewall.init() async def startup(): db.init() await firewall.init() async def shutdown(): keep_rules = firewall.keep_rules db.backup() if not keep_rules: await firewall.close() db.disconnect() db.restore() async def refresh_frontend(additional:list[str]=[]): await socketio_emit(["firewall"]+additional) async def apply_changes(): await firewall.reload() await refresh_frontend() return {'status': 'ok'} @app.get("/settings", response_model=FirewallSettings) async def get_settings(): """Get the firewall settings""" return { "keep_rules": firewall.keep_rules, "allow_loopback": firewall.allow_loopback, "allow_established": firewall.allow_established } @app.post("/settings/set", response_model=StatusMessageModel) async def set_settings(form: FirewallSettings): """Set the firewall settings""" firewall.keep_rules = form.keep_rules firewall.allow_loopback = form.allow_loopback firewall.allow_established = form.allow_established return {'status': 'ok'} @app.get('/rules', response_model=RuleInfo) async def get_rule_list(): """Get the list of existent firegex rules""" return { "policy": firewall.policy, "rules": db.query("SELECT active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id;"), "enabled": firewall.enabled } @app.get('/enable', response_model=StatusMessageModel) async def enable_firewall(): """Request enabling the firewall""" firewall.enabled = True return await apply_changes() @app.get('/disable', response_model=StatusMessageModel) async def disable_firewall(): """Request disabling the firewall""" firewall.enabled = False return await apply_changes() def parse_and_check_rule(rule:RuleModel): if rule.ip_src.lower().strip() == "any" or rule.ip_dst.lower().split() == "any": rule.ip_dst = rule.ip_src = "any" else: try: rule.ip_src = ip_parse(rule.ip_src) rule.ip_dst = ip_parse(rule.ip_dst) except ValueError: raise HTTPException(status_code=400, detail="Invalid address") if ip_family(rule.ip_dst) != ip_family(rule.ip_src): raise HTTPException(status_code=400, detail="Destination and source addresses must be of the same family") rule.port_dst_from, rule.port_dst_to = min(rule.port_dst_from, rule.port_dst_to), max(rule.port_dst_from, rule.port_dst_to) rule.port_src_from, rule.port_src_to = min(rule.port_src_from, rule.port_src_to), max(rule.port_src_from, rule.port_src_to) if rule.proto not in ["tcp", "udp", "any"]: raise HTTPException(status_code=400, detail="Invalid protocol") if rule.action not in ["accept", "drop", "reject"]: raise HTTPException(status_code=400, detail="Invalid action") return rule @app.post('/rules/set', response_model=StatusMessageModel) async def add_new_service(form: RuleFormAdd): """Add a new service""" if form.policy not in ["accept", "drop", "reject"]: raise HTTPException(status_code=400, detail="Invalid policy") rules = [parse_and_check_rule(ele) for ele in form.rules] try: db.queries(["DELETE FROM rules"]+ [(""" INSERT INTO rules ( rule_id, active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ? ,?, ?)""", rid, ele.active, ele.name, ele.proto, ele.ip_src, ele.ip_dst, ele.port_src_from, ele.port_dst_from, ele.port_src_to, ele.port_dst_to, ele.action, ele.mode ) for rid, ele in enumerate(rules)] ) firewall.policy = form.policy except sqlite3.IntegrityError: raise HTTPException(status_code=400, detail="Error saving the rules: maybe there are duplicated rules") return await apply_changes()