Files
firegex-traffic-viewer/backend/routers/firewall.py
Domingo Dirutigliano b2e0ecc005 mode in the unique index
2023-09-22 20:53:39 +02:00

172 lines
6.3 KiB
Python

import sqlite3
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from utils.sqlite import SQLite
from utils import ip_parse, ip_family, refactor_name, refresh_frontend, PortType
from utils.models import ResetRequest, StatusMessageModel
from modules.firewall.nftables import FiregexTables
from modules.firewall.firewall import FirewallManager
class RuleModel(BaseModel):
active: bool
name: str
proto: str
ip_src: str
ip_dst: str
port_src_from: PortType
port_dst_from: PortType
port_src_to: PortType
port_dst_to: PortType
action: str
mode:str
class RuleAddResponse(BaseModel):
status:str|list[dict]
class RenameForm(BaseModel):
name:str
class GeneralStatModel(BaseModel):
rules: int
app = APIRouter()
db = SQLite('db/firewall-rules.db', {
'rules': {
'rule_id': 'INT PRIMARY KEY CHECK (rule_id >= 0)',
'mode': 'VARCHAR(1) NOT NULL CHECK (mode IN ("O", "I"))', # O = out, I = in, B = both
'name': 'VARCHAR(100) NOT NULL',
'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1))',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp", "any"))',
'ip_src': 'VARCHAR(100) NOT NULL',
'port_src_from': 'INT CHECK(port_src_from > 0 and port_src_from < 65536)',
'port_src_to': 'INT CHECK(port_src_to > 0 and port_src_to < 65536 and port_src_from <= port_src_to)',
'ip_dst': 'VARCHAR(100) NOT NULL',
'port_dst_from': 'INT CHECK(port_dst_from > 0 and port_dst_from < 65536)',
'port_dst_to': 'INT CHECK(port_dst_to > 0 and port_dst_to < 65536 and port_dst_from <= port_dst_to)',
'action': 'VARCHAR(10) NOT NULL CHECK (action IN ("accept", "drop", "reject"))',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_rules ON rules (proto, ip_src, ip_dst, port_src_from, port_src_to, port_dst_from, port_dst_to, action, mode);"
]
})
async def reset(params: ResetRequest):
if not params.delete:
db.backup()
await firewall.close()
FiregexTables().reset()
if params.delete:
db.delete()
db.init()
else:
db.restore()
await firewall.init()
async def startup():
db.init()
await firewall.init()
async def shutdown():
db.backup()
await firewall.close()
db.disconnect()
db.restore()
async def apply_changes():
await firewall.reload()
await refresh_frontend()
return {'status': 'ok'}
firewall = FirewallManager(db)
@app.get('/stats', response_model=GeneralStatModel)
async def get_general_stats():
"""Get firegex general status about rules"""
return db.query("SELECT (SELECT COUNT(*) FROM rules) rules")[0]
@app.get('/rules', response_model=list[RuleModel])
async def get_rule_list():
"""Get the list of existent firegex rules"""
return db.query("SELECT active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id;")
@app.get('/rule/{rule_id}/disable', response_model=StatusMessageModel)
async def service_disable(rule_id: str):
"""Request disabling a specific rule"""
if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0:
return {'status': 'Rule not found'}
db.query('UPDATE rules SET active = 0 WHERE rule_id = ?;', rule_id)
return await apply_changes()
@app.get('/rule/{rule_id}/enable', response_model=StatusMessageModel)
async def service_start(rule_id: str):
"""Request the enabling a specific rule"""
if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0:
return {'status': 'Rule not found'}
db.query('UPDATE rules SET active = 1 WHERE rule_id = ?;', rule_id)
return await apply_changes()
@app.post('/service/{rule_id}/rename', response_model=StatusMessageModel)
async def service_rename(rule_id: str, form: RenameForm):
"""Request to change the name of a specific service"""
if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0:
return {'status': 'Rule not found'}
form.name = refactor_name(form.name)
if not form.name: return {'status': 'The name cannot be empty!'}
try:
db.query('UPDATE rules SET name=? WHERE rule_id = ?;', form.name, rule_id)
except sqlite3.IntegrityError:
return {'status': 'This name is already used'}
await refresh_frontend()
return {'status': 'ok'}
def parse_and_check_rule(rule:RuleModel):
try:
rule.ip_src = ip_parse(rule.ip_src)
rule.ip_dst = ip_parse(rule.ip_dst)
except ValueError:
return {"status":"Invalid address"}
rule.port_dst_from, rule.port_dst_to = min(rule.port_dst_from, rule.port_dst_to), max(rule.port_dst_from, rule.port_dst_to)
rule.port_src_from, rule.port_src_to = min(rule.port_src_from, rule.port_src_to), max(rule.port_src_from, rule.port_src_to)
if ip_family(rule.ip_dst) != ip_family(rule.ip_src):
return {"status":"Destination and source addresses must be of the same family"}
if rule.proto not in ["tcp", "udp", "any"]:
return {"status":"Invalid protocol"}
if rule.action not in ["accept", "drop", "reject"]:
return {"status":"Invalid action"}
return rule
@app.post('/rules/set', response_model=RuleAddResponse)
async def add_new_service(form: list[RuleModel]):
"""Add a new service"""
form = [parse_and_check_rule(ele) for ele in form]
errors = [({"rule":i} | ele) for i, ele in enumerate(form) if isinstance(ele, dict)]
if len(errors) > 0:
return {'status': errors}
try:
db.queries(["DELETE FROM rules"]+
[("""
INSERT INTO rules (
rule_id, active, name,
proto,
ip_src, ip_dst,
port_src_from, port_dst_from,
port_src_to, port_dst_to,
action, mode
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ? ,?, ?)""",
rid, ele.active, ele.name,
ele.proto,
ele.ip_src, ele.ip_dst,
ele.port_src_from, ele.port_dst_from,
ele.port_src_to, ele.port_dst_to,
ele.action, ele.mode
) for rid, ele in enumerate(form)]
)
except sqlite3.IntegrityError:
return {'status': 'Error saving the rules: maybe there are duplicated rules'}
return await apply_changes()