Firewall refactor

This commit is contained in:
Domingo Dirutigliano
2023-09-28 20:45:58 +02:00
parent 99e4989cfe
commit 71edfc29c4
12 changed files with 212 additions and 166 deletions

View File

@@ -154,7 +154,7 @@ if __name__ == '__main__':
os.chdir(os.path.dirname(os.path.realpath(__file__)))
uvicorn.run(
"app:app",
host=None,
host="0.0.0.0",
port=FIREGEX_PORT,
reload=DEBUG,
access_log=True,

View File

@@ -2,6 +2,7 @@ import asyncio
from modules.firewall.nftables import FiregexTables
from modules.firewall.models import Rule
from utils.sqlite import SQLite
from modules.firewall.models import Action
nft = FiregexTables()
@@ -25,14 +26,15 @@ class FirewallManager:
map(Rule.from_dict, self.db.query('SELECT * FROM rules WHERE active = 1 ORDER BY rule_id;')),
policy=self.policy,
allow_loopback=self.allow_loopback,
allow_established=self.allow_established
allow_established=self.allow_established,
allow_icmp=self.allow_icmp
)
else:
nft.reset()
@property
def policy(self):
return self.db.get("POLICY", "accept")
return self.db.get("POLICY", Action.ACCEPT)
@policy.setter
def policy(self, value):
@@ -61,6 +63,14 @@ class FirewallManager:
@allow_loopback.setter
def allow_loopback(self, value):
self.db.set("allow_loopback", "1" if value else "0")
@property
def allow_icmp(self):
return self.db.get("allow_icmp", "1") == "1"
@allow_icmp.setter
def allow_icmp(self, value):
self.db.set("allow_icmp", "1" if value else "0")
@property
def allow_established(self):

View File

@@ -1,27 +1,47 @@
from enum import Enum
class Rule:
def __init__(self, proto: str, ip_src:str, ip_dst:str, port_src_from:str, port_dst_from:str, port_src_to:str, port_dst_to:str, action:str, mode:str):
def __init__(self, proto: str, src:str, dst:str, port_src_from:str, port_dst_from:str, port_src_to:str, port_dst_to:str, action:str, mode:str):
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
self.src = src
self.dst = dst
self.port_src_from = port_src_from
self.port_dst_from = port_dst_from
self.port_src_to = port_src_to
self.port_dst_to = port_dst_to
self.action = action
self.input_mode = mode in ["I"]
self.output_mode = mode in ["O"]
self.input_mode = mode == "in"
self.output_mode = mode == "out"
self.forward_mode = mode == "forward"
@classmethod
def from_dict(cls, var: dict):
return cls(
proto=var["proto"],
ip_src=var["ip_src"],
ip_dst=var["ip_dst"],
src=var["src"],
dst=var["dst"],
port_dst_from=var["port_dst_from"],
port_dst_to=var["port_dst_to"],
port_src_from=var["port_src_from"],
port_src_to=var["port_src_to"],
action=var["action"],
mode=var["mode"]
)
)
class Protocol(str, Enum):
TCP = "tcp",
UDP = "udp",
BOTH = "both",
ANY = "any"
class Mode(str, Enum):
IN = "in",
OUT = "out",
FORWARD = "forward"
class Action(str, Enum):
ACCEPT = "accept",
DROP = "drop",
REJECT = "reject"

View File

@@ -1,34 +1,13 @@
from modules.firewall.models import Rule
from utils import nftables_int_to_json, ip_parse, ip_family, NFTableManager
class FiregexHijackRule():
def __init__(self, proto:str, ip_src:str, ip_dst:str, port_src_from:int, port_dst_from:int, port_src_to:int, port_dst_to:int, action:str, target:str, id:int):
self.id = id
self.target = target
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
self.port_src_from = min(port_src_from, port_src_to)
self.port_dst_from = min(port_dst_from, port_dst_to)
self.port_src_to = max(port_src_from, port_src_to)
self.port_dst_to = max(port_dst_from, port_dst_to)
self.action = action
def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexHijackRule) or isinstance(o, Rule):
return self.action == o.action and self.proto == o.proto and\
ip_parse(self.ip_src) == ip_parse(o.ip_src) and ip_parse(self.ip_dst) == ip_parse(o.ip_dst) and\
int(self.port_src_from) == int(o.port_src_from) and int(self.port_dst_from) == int(o.port_dst_from) and\
int(self.port_src_to) == int(o.port_src_to) and int(self.port_dst_to) == int(o.port_dst_to)
return False
from modules.firewall.models import Rule, Protocol, Mode, Action
from utils import nftables_int_to_json, ip_family, NFTableManager, is_ip_parse
import copy
class FiregexTables(NFTableManager):
rules_chain_in = "firewall_rules_in"
rules_chain_out = "firewall_rules_out"
rules_chain_fwd = "firewall_rules_fwd"
def init_comands(self, policy:str="accept", policy_out:str="accept", allow_loopback=False, allow_established=False):
def init_comands(self, policy:str=Action.ACCEPT, allow_loopback=False, allow_established=False, allow_icmp=False):
return [
{"add":{"chain":{
"family":"inet",
@@ -36,7 +15,16 @@ class FiregexTables(NFTableManager):
"name":self.rules_chain_in,
"type":"filter",
"hook":"prerouting",
"prio":-150,
"prio":0,
"policy":policy
}}},
{"add":{"chain":{
"family":"inet",
"table":self.table_name,
"name":self.rules_chain_fwd,
"type":"filter",
"hook":"forward",
"prio":0,
"policy":policy
}}},
{"add":{"chain":{
@@ -45,24 +33,41 @@ class FiregexTables(NFTableManager):
"name":self.rules_chain_out,
"type":"filter",
"hook":"postrouting",
"prio":-150,
"policy":policy_out
"prio":0,
"policy":Action.ACCEPT
}}},
] + ([
{ "add":{ "rule": {
"family": "inet", "table": self.table_name, "chain": self.rules_chain_out,
"expr": [{ "match": { "op": "==", "left": { "meta": { "key": "iif"}}, "right": "lo"}},{"accept": None}]
"expr": [{ "match": { "op": "==", "left": { "meta": { "key": "iif" }}, "right": "lo"}},{"accept": None}]
}}},
{ "add":{ "rule": {
"family": "inet", "table": self.table_name, "chain": self.rules_chain_in,
"expr": [{ "match": { "op": "==", "left": { "meta": { "key": "iif"}}, "right": "lo"}},{"accept": None}]
"expr": [{ "match": { "op": "==", "left": { "meta": { "key": "iif" }}, "right": "lo"}},{"accept": None}]
}}}
] if allow_loopback else []) + ([
{ "add":{ "rule": {
"family": "inet", "table": self.table_name, "chain": self.rules_chain_in,
"expr": [{ "match": {"op": "in", "left": { "ct": { "key": "state" }},"right": ["established"]} }, { "accept": None }]
}}}
] if allow_established else [])
] if allow_established else []) + ([
{ "add":{ "rule": {
"family": "inet", "table": self.table_name, "chain": self.rules_chain_in,
"expr": [{ "match": { "op": "==", "left": { "meta": { "key": "l4proto" } }, "right": "icmp"} }, { "accept": None }]
}}},
{ "add":{ "rule": {
"family": "inet", "table": self.table_name, "chain": self.rules_chain_fwd,
"expr": [{ "match": { "op": "==", "left": { "meta": { "key": "l4proto" } }, "right": "icmp"} }, { "accept": None }]
}}},
{ "add":{ "rule": {
"family": "inet", "table": self.table_name, "chain": self.rules_chain_in,
"expr": [{ "match": { "op": "==", "left": { "meta": { "key": "l4proto" } }, "right": "ipv6-icmp"} }, { "accept": None }]
}}},
{ "add":{ "rule": {
"family": "inet", "table": self.table_name, "chain": self.rules_chain_fwd,
"expr": [{ "match": { "op": "==", "left": { "meta": { "key": "l4proto" } }, "right": "ipv6-icmp"} }, { "accept": None }]
}}}
] if allow_icmp else [])
def __init__(self):
super().__init__(self.init_comands(),[
@@ -70,39 +75,57 @@ class FiregexTables(NFTableManager):
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}},
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}},
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}},
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_fwd}}},
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_fwd}}},
])
def set(self, srvs:list[Rule], policy:str="accept", allow_loopback=False, allow_established=False):
def set(self, srvs:list[Rule], policy:str="accept", allow_loopback=False, allow_established=False, allow_icmp=False):
srvs = list(srvs)
self.reset()
if policy == "reject":
policy = "drop"
if policy == Action.REJECT:
policy = Action.DROP
srvs.append(Rule(
proto="any",
ip_src="any",
ip_dst="any",
proto=Protocol.ANY,
src="",
dst="",
port_src_from=1,
port_dst_from=1,
port_src_to=65535,
port_dst_to=65535,
action="reject",
mode="I"
action=Action.REJECT,
mode=Mode.IN
))
rules = self.init_comands(policy, allow_loopback=allow_loopback, allow_established=allow_established) + self.get_rules(*srvs)
rules = self.init_comands(policy, allow_loopback=allow_loopback, allow_established=allow_established, allow_icmp=allow_icmp) + self.get_rules(*srvs)
self.cmd(*rules)
def get_rules(self,*srvs:Rule):
rules = []
for srv in srvs:
final_srvs:list[Rule] = []
for ele in srvs:
if ele.proto == Protocol.BOTH:
udp_rule = copy.deepcopy(ele)
udp_rule.proto = Protocol.UDP.value
ele.proto = Protocol.TCP.value
final_srvs.append(udp_rule)
final_srvs.append(ele)
for srv in final_srvs:
ip_filters = []
if srv.ip_src.lower() != "any" and srv.ip_dst.lower() != "any":
ip_filters = [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_src), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_src)}},
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_dst), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_dst)}},
]
if srv.src != "":
if is_ip_parse(srv.src):
ip_filters.append({'match': {'left': {'payload': {'protocol': ip_family(srv.src), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.src)}})
else:
ip_filters.append({"match": { "op": "==", "left": { "meta": { "key": "iifname" } }, "right": srv.src} })
if srv.dst != "":
if is_ip_parse(srv.dst):
ip_filters.append({'match': {'left': {'payload': {'protocol': ip_family(srv.dst), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.dst)}})
else:
ip_filters.append({"match": { "op": "==", "left": { "meta": { "key": "oifname" } }, "right": srv.dst} })
port_filters = []
if srv.proto != "any":
if not srv.proto in [Protocol.ANY, Protocol.BOTH]:
if srv.port_src_from != 1 or srv.port_src_to != 65535: #Any Port
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '>=', 'right': int(srv.port_src_from)}})
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '<=', 'right': int(srv.port_src_to)}})
@@ -110,13 +133,13 @@ class FiregexTables(NFTableManager):
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'op': '>=', 'right': int(srv.port_dst_from)}})
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'op': '<=', 'right': int(srv.port_dst_to)}})
if len(port_filters) == 0:
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '!=', 'right': 0}}) #filter the protocol if no port is specified
port_filters.append({'match': {'left': {'meta': {'key': 'l4proto'}}, 'op': '==', 'right': srv.proto}}) #filter the protocol if no port is specified
end_rules = [{'accept': None} if srv.action == "accept" else {'reject': {}} if (srv.action == "reject" and not srv.output_mode) else {'drop': None}]
rules.append({ "add":{ "rule": {
"family": "inet",
"table": self.table_name,
"chain": self.rules_chain_out if srv.output_mode else self.rules_chain_in,
"chain": self.rules_chain_out if srv.output_mode else self.rules_chain_in if srv.input_mode else self.rules_chain_fwd,
"expr": ip_filters + port_filters + end_rules
#If srv.output_mode is True, then the rule is in the output chain, so the reject action is not allowed
}}})

View File

@@ -6,27 +6,29 @@ from utils import ip_parse, ip_family, socketio_emit, PortType
from utils.models import ResetRequest, StatusMessageModel
from modules.firewall.nftables import FiregexTables
from modules.firewall.firewall import FirewallManager
from modules.firewall.models import Protocol, Mode, Action
class RuleModel(BaseModel):
active: bool
name: str
proto: str
ip_src: str
ip_dst: str
proto: Protocol
src: str
dst: str
port_src_from: PortType
port_dst_from: PortType
port_src_to: PortType
port_dst_to: PortType
action: str
mode:str
action: Action
mode:Mode
class RuleFormAdd(BaseModel):
rules: list[RuleModel]
policy: str
policy: Action
class RuleInfo(BaseModel):
rules: list[RuleModel]
policy: str
policy: Action
enabled: bool
class RenameForm(BaseModel):
@@ -36,29 +38,30 @@ class FirewallSettings(BaseModel):
keep_rules: bool
allow_loopback: bool
allow_established: bool
app = APIRouter()
allow_icmp: bool
db = SQLite('db/firewall-rules.db', {
'rules': {
'rule_id': 'INT PRIMARY KEY CHECK (rule_id >= 0)',
'mode': 'VARCHAR(1) NOT NULL CHECK (mode IN ("O", "I"))', # O = out, I = in, B = both
'mode': 'VARCHAR(10) NOT NULL CHECK (mode IN ("in", "out", "forward"))',
'name': 'VARCHAR(100) NOT NULL',
'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1))',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp", "any"))',
'ip_src': 'VARCHAR(100) NOT NULL',
'proto': 'VARCHAR(10) NOT NULL CHECK (proto IN ("tcp", "udp", "both", "any"))',
'src': 'VARCHAR(100) NOT NULL',
'port_src_from': 'INT CHECK(port_src_from > 0 and port_src_from < 65536)',
'port_src_to': 'INT CHECK(port_src_to > 0 and port_src_to < 65536 and port_src_from <= port_src_to)',
'ip_dst': 'VARCHAR(100) NOT NULL',
'dst': 'VARCHAR(100) NOT NULL',
'port_dst_from': 'INT CHECK(port_dst_from > 0 and port_dst_from < 65536)',
'port_dst_to': 'INT CHECK(port_dst_to > 0 and port_dst_to < 65536 and port_dst_from <= port_dst_to)',
'action': 'VARCHAR(10) NOT NULL CHECK (action IN ("accept", "drop", "reject"))',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_rules ON rules (proto, ip_src, ip_dst, port_src_from, port_src_to, port_dst_from, port_dst_to, mode);"
"CREATE UNIQUE INDEX IF NOT EXISTS unique_rules ON rules (proto, src, dst, port_src_from, port_src_to, port_dst_from, port_dst_to, mode);"
]
})
app = APIRouter()
firewall = FirewallManager(db)
async def reset(params: ResetRequest):
@@ -101,7 +104,8 @@ async def get_settings():
return {
"keep_rules": firewall.keep_rules,
"allow_loopback": firewall.allow_loopback,
"allow_established": firewall.allow_established
"allow_established": firewall.allow_established,
"allow_icmp": firewall.allow_icmp
}
@app.post("/settings/set", response_model=StatusMessageModel)
@@ -110,14 +114,15 @@ async def set_settings(form: FirewallSettings):
firewall.keep_rules = form.keep_rules
firewall.allow_loopback = form.allow_loopback
firewall.allow_established = form.allow_established
return {'status': 'ok'}
firewall.allow_icmp = form.allow_icmp
return await apply_changes()
@app.get('/rules', response_model=RuleInfo)
async def get_rule_list():
"""Get the list of existent firegex rules"""
return {
"policy": firewall.policy,
"rules": db.query("SELECT active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id;"),
"rules": db.query("SELECT active, name, proto, src, dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id;"),
"enabled": firewall.enabled
}
@@ -135,31 +140,34 @@ async def disable_firewall():
def parse_and_check_rule(rule:RuleModel):
if rule.ip_src.lower().strip() == "any" or rule.ip_dst.lower().split() == "any":
rule.ip_dst = rule.ip_src = "any"
else:
try:
rule.ip_src = ip_parse(rule.ip_src)
rule.ip_dst = ip_parse(rule.ip_dst)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid address")
if ip_family(rule.ip_dst) != ip_family(rule.ip_src):
raise HTTPException(status_code=400, detail="Destination and source addresses must be of the same family")
is_src_ip = is_dst_ip = True
try:
rule.src = ip_parse(rule.src)
except ValueError:
is_src_ip = False
try:
rule.dst = ip_parse(rule.dst)
except ValueError:
is_dst_ip = False
if not is_src_ip and "/" in rule.src: # Slash is not allowed in ip interfaces names
raise HTTPException(status_code=400, detail="Invalid source address")
if not is_dst_ip and "/" in rule.dst:
raise HTTPException(status_code=400, detail="Invalid destination address")
if is_src_ip and is_dst_ip and ip_family(rule.dst) != ip_family(rule.src):
raise HTTPException(status_code=400, detail="Destination and source addresses must be of the same family")
rule.port_dst_from, rule.port_dst_to = min(rule.port_dst_from, rule.port_dst_to), max(rule.port_dst_from, rule.port_dst_to)
rule.port_src_from, rule.port_src_to = min(rule.port_src_from, rule.port_src_to), max(rule.port_src_from, rule.port_src_to)
if rule.proto not in ["tcp", "udp", "any"]:
raise HTTPException(status_code=400, detail="Invalid protocol")
if rule.action not in ["accept", "drop", "reject"]:
raise HTTPException(status_code=400, detail="Invalid action")
return rule
@app.post('/rules/set', response_model=StatusMessageModel)
async def add_new_service(form: RuleFormAdd):
"""Add a new service"""
if form.policy not in ["accept", "drop", "reject"]:
raise HTTPException(status_code=400, detail="Invalid policy")
rules = [parse_and_check_rule(ele) for ele in form.rules]
try:
db.queries(["DELETE FROM rules"]+
@@ -167,20 +175,20 @@ async def add_new_service(form: RuleFormAdd):
INSERT INTO rules (
rule_id, active, name,
proto,
ip_src, ip_dst,
src, dst,
port_src_from, port_dst_from,
port_src_to, port_dst_to,
action, mode
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ? ,?, ?)""",
rid, ele.active, ele.name,
ele.proto,
ele.ip_src, ele.ip_dst,
ele.src, ele.dst,
ele.port_src_from, ele.port_dst_from,
ele.port_src_to, ele.port_dst_to,
ele.action, ele.mode
) for rid, ele in enumerate(rules)]
)
firewall.policy = form.policy
firewall.policy = form.policy.value
except sqlite3.IntegrityError:
raise HTTPException(status_code=400, detail="Error saving the rules: maybe there are duplicated rules")
return await apply_changes()

View File

@@ -70,6 +70,13 @@ def list_files(mypath):
def ip_parse(ip:str):
return str(ip_interface(ip).network)
def is_ip_parse(ip:str):
try:
ip_parse(ip)
return True
except Exception:
return False
def addr_parse(ip:str):
return str(ip_address(ip))