Refactor code pt 1 (not tested)

This commit is contained in:
DomySh
2022-07-12 20:18:54 +02:00
parent 632d0a1b12
commit 1e94c26fd6
9 changed files with 597 additions and 534 deletions

View File

@@ -0,0 +1,2 @@
from .firewall import FirewallManager
from .sqlite import SQLite

170
backend/modules/firegex.py Normal file
View File

@@ -0,0 +1,170 @@
from typing import List
from pypacker import interceptor
from pypacker.layer3 import ip, ip6
from pypacker.layer4 import tcp, udp
from ipaddress import ip_interface
from modules.iptables import IPTables
import os, traceback
from modules.sqlite import Service
class FilterTypes:
INPUT = "FIREGEX-INPUT"
OUTPUT = "FIREGEX-OUTPUT"
QUEUE_BASE_NUM = 1000
class FiregexFilter():
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None, func=None):
self.target = target
self.id = int(id) if id else None
self.queue = queue
self.proto = proto
self.port = int(port)
self.ip_int = str(ip_int)
self.func = func
def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter):
return self.port == o.port and self.proto == o.proto and ip_interface(self.ip_int) == ip_interface(o.ip_int)
return False
def ipv6(self):
return ip_interface(self.ip_int).version == 6
def ipv4(self):
return ip_interface(self.ip_int).version == 4
def input_func(self):
def none(pkt): return True
def wrap(pkt): return self.func(pkt, True)
return wrap if self.func else none
def output_func(self):
def none(pkt): return True
def wrap(pkt): return self.func(pkt, False)
return wrap if self.func else none
class FiregexTables(IPTables):
def __init__(self, ipv6=False):
super().__init__(ipv6, "mangle")
self.create_chain(FilterTypes.INPUT)
self.add_chain_to_input(FilterTypes.INPUT)
self.create_chain(FilterTypes.OUTPUT)
self.add_chain_to_output(FilterTypes.OUTPUT)
def target_in_chain(self, chain, target):
for filter in self.list()[chain]:
if filter.target == target:
return True
return False
def add_chain_to_input(self, chain):
if not self.target_in_chain("PREROUTING", str(chain)):
self.insert_rule("PREROUTING", str(chain))
def add_chain_to_output(self, chain):
if not self.target_in_chain("POSTROUTING", str(chain)):
self.insert_rule("POSTROUTING", str(chain))
def add_output(self, queue_range, proto = None, port = None, ip_int = None):
init, end = queue_range
if init > end: init, end = end, init
self.append_rule(FilterTypes.OUTPUT,"NFQUEUE"
* (["-p", str(proto)] if proto else []),
* (["-s", str(ip_int)] if ip_int else []),
* (["--sport", str(port)] if port else []),
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
"--queue-bypass"
)
def add_input(self, queue_range, proto = None, port = None, ip_int = None):
init, end = queue_range
if init > end: init, end = end, init
self.append_rule(FilterTypes.INPUT, "NFQUEUE",
* (["-p", str(proto)] if proto else []),
* (["-d", str(ip_int)] if ip_int else []),
* (["--dport", str(port)] if port else []),
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
"--queue-bypass"
)
def get(self) -> List[FiregexFilter]:
res = []
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
for filter in self.list()[filter_type]:
port = filter.sport() if filter_type == FilterTypes.OUTPUT else filter.dport()
queue = filter.nfqueue()
if queue and port:
res.append(FiregexFilter(
target=filter_type,
id=filter.id,
queue=queue,
proto=filter.prot,
port=port,
ip_int=filter.source if filter_type == FilterTypes.OUTPUT else filter.destination
))
return res
def add(self, filter:FiregexFilter):
if filter in self.get(): return None
return FiregexInterceptor( iptables=self, filter=filter, n_threads=int(os.getenv("N_THREADS_NFQUEUE","1")))
def delete_all(self):
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
self.flush_chain(filter_type)
def delete_by_srv(self, srv:Service):
for filter in self.get():
if filter.port == srv.port and filter.proto == srv.proto and ip_interface(filter.ip_int) == ip_interface(srv.ip_int):
self.delete_rule(filter.target, filter.id)
class FiregexInterceptor:
def __init__(self, iptables: FiregexTables, filter: FiregexFilter, n_threads:int = 1):
self.filter = filter
self.ipv6 = self.filter.ipv6()
self.itor_input, codes = self._start_queue(filter.input_func(), n_threads)
iptables.add_input(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
self.itor_output, codes = self._start_queue(filter.output_func(), n_threads)
iptables.add_output(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
def _start_queue(self,func,n_threads):
def func_wrap(ll_data, ll_proto_id, data, ctx, *args):
pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data)
try:
data = None
if not pkt_parsed[tcp.TCP] is None:
data = pkt_parsed[tcp.TCP].body_bytes
if not pkt_parsed[tcp.TCP] is None:
data = pkt_parsed[udp.UDP].body_bytes
if data:
if func(data):
return data, interceptor.NF_ACCEPT
elif pkt_parsed[tcp.TCP]:
pkt_parsed[tcp.TCP].flags &= 0x00
pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK
pkt_parsed[tcp.TCP].body_bytes = b""
return pkt_parsed.bin(), interceptor.NF_ACCEPT
else: return b"", interceptor.NF_DROP
else: return data, interceptor.NF_ACCEPT
except Exception:
traceback.print_exc()
return data, interceptor.NF_ACCEPT
ictor = interceptor.Interceptor()
starts = QUEUE_BASE_NUM
while True:
if starts >= 65536:
raise Exception("Netfilter queue is full!")
queue_ids = list(range(starts,starts+n_threads))
try:
ictor.start(func_wrap, queue_ids=queue_ids)
break
except interceptor.UnableToBindException as e:
starts = e.queue_id + 1
return ictor, (starts, starts+n_threads-1)
def stop(self):
self.itor_input.stop()
self.itor_output.stop()

196
backend/modules/firewall.py Normal file
View File

@@ -0,0 +1,196 @@
import traceback, asyncio, pcre
from typing import Dict
from modules.firegex import FiregexFilter, FiregexTables
from modules.sqlite import Regex, SQLite, Service
class STATUS:
STOP = "stop"
ACTIVE = "active"
class FirewallManager:
def __init__(self, db:SQLite):
self.db = db
self.proxy_table: Dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
self.updater_task = None
def init_updater(self, callback = None):
if not self.updater_task:
self.updater_task = asyncio.create_task(self._stats_updater(callback))
def close_updater(self):
if self.updater_task: self.updater_task.cancel()
async def close(self):
self.close_updater()
if self.updater_task: self.updater_task.cancel()
for key in list(self.proxy_table.keys()):
await self.remove(key)
async def remove(self,srv_id):
async with self.lock:
if srv_id in self.proxy_table:
await self.proxy_table[srv_id].next(STATUS.STOP)
del self.proxy_table[srv_id]
async def init(self, callback = None):
self.init_updater(callback)
await self.reload()
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv)
if srv.id in self.proxy_table:
continue
self.proxy_table[srv.id] = ServiceManager(srv, self.db)
await self.proxy_table[srv.id].next(srv["status"])
async def _stats_updater(self, callback):
try:
while True:
try:
for key in list(self.proxy_table.keys()):
self.proxy_table[key].update_stats()
except Exception:
traceback.print_exc()
if callback:
if asyncio.iscoroutinefunction(callback): await callback()
else: callback()
await asyncio.sleep(5)
except asyncio.CancelledError:
self.updater_task = None
return
def get(self,srv_id):
if srv_id in self.proxy_table:
return self.proxy_table[srv_id]
else:
raise ServiceNotFoundException()
class ServiceNotFoundException(Exception): pass
class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
is_blacklist=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
id=None
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
self.id = id
self.compiled_regex = self.compile()
@classmethod
def from_regex(cls, regex:Regex):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"]
)
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
def check(self, data):
return True if self.compiled_regex.search(data) else False
class ServiceManager:
def __init__(self, srv: Service, db):
self.srv = srv
self.db = db
self.iptables = FiregexTables(self.srv.ipv6)
self.status = STATUS.STOP
self.filters: Dict[int, FiregexFilter] = {}
self._update_filters_from_db()
self.lock = asyncio.Lock()
self.interceptor = None
# TODO I don't like so much this method
def _update_filters_from_db(self):
regexes = [
Regex.from_dict(ele) for ele in
self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id)
]
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f.id for f in regexes])
#remove old filters
for f in old_filters:
if not f in new_filters:
del self.filters[f]
#add new filters
for f in new_filters:
if not f in old_filters:
filter = [ele for ele in regexes if ele.id == f][0]
self.filters[f] = FiregexFilter.from_regex(filter)
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv["service_id"])
async def next(self,to):
async with self.lock:
return self._next(to)
def _next(self, to):
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP):
self.proxy.stop()
self._set_status(to)
# PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
self.proxy.restart()
def _stats_updater(self,filter:RegexFilter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id)
def update_stats(self):
for ele in self.filters.values():
self._stats_updater(ele)
def _set_status(self,status):
self.status = status
self.__update_status_db(status)
def start(self):
if not self.interceptor:
self.iptables.delete_by_srv(self.srv)
def regex_filter(pkt, by_client):
try:
for filter in self.filters.values():
if (by_client and filter.input_mode) or (not by_client and filter.output_mode):
match = filter.check(pkt)
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
filter.blocked+=1
return False
except IndexError: pass
return True
self.interceptor = self.iptables.add(self.srv["proto"], self.srv["port"], self.srv["ip_int"], regex_filter)
self._set_status(STATUS.ACTIVE)
def stop(self):
self.iptables.delete_by_srv(self.srv)
if self.interceptor:
self.interceptor.stop()
self.interceptor = None
def restart(self):
self.stop()
self.start()
async def update_filters(self):
async with self.lock:
self._update_filters_from_db()

View File

@@ -0,0 +1,82 @@
import os, re
from subprocess import PIPE, Popen
from typing import Dict, List, Tuple, Union
class Rule():
def __init__(self, id, target, prot, opt, source, destination, details):
self.id = id
self.target = target
self.prot = prot
self.opt = opt
self.source = source
self.destination = destination
self.details = details
def dport(self) -> Union[int, None]:
port = re.findall(r"dpt:([0-9]+)", self.details)
return int(port[0]) if port else None
def sport(self) -> Union[int, None]:
port = re.findall(r"spt:([0-9]+)", self.details)
return int(port[0]) if port else None
def nfqueue(self) -> Union[Tuple[int,int], None]:
balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", self.details)
numbered = re.findall(r"NFQUEUE num ([0-9]+)", self.details)
queue_num = None
if balanced: queue_num = (int(balanced[0][0]), int(balanced[0][1]))
if numbered: queue_num = (int(numbered[0]), int(numbered[0]))
return queue_num
class IPTables:
def __init__(self, ipv6=False, table="filter"):
self.ipv6 = ipv6
self.table = table
def command(self, params) -> Tuple[bytes, bytes]:
params = ["-t", self.table] + params
if os.geteuid() != 0:
exit("You need to have root privileges to run this script.\nPlease try again, this time using 'sudo'. Exiting.")
return Popen(["ip6tables"]+params if self.ipv6 else ["iptables"]+params, stdout=PIPE, stderr=PIPE).communicate()
def list(self) -> Dict[str, List[Rule]]:
stdout, strerr = self.command(["-L", "--line-number", "-n"])
lines = stdout.decode().split("\n")
res: Dict[str, List[Rule]] = {}
chain_name = ""
for line in lines:
if line.startswith("Chain"):
chain_name = line.split()[1]
res[chain_name] = []
elif line.split()[0].isnumeric():
parsed = re.findall(r"([^ ]*)[ ]{,10}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]+([^ ]*)[ ]+(.*)", line)
if len(parsed) > 0:
parsed = parsed[0]
res[chain_name].append(Rule(
id=parsed[0].strip(),
target=parsed[1].strip(),
prot=parsed[2].strip(),
opt=parsed[3].strip(),
source=parsed[4].strip(),
destination=parsed[5].strip(),
details=" ".join(parsed[6:]).strip() if len(parsed[0]) >= 7 else ""
))
return res
def delete_rule(self, chain, id) -> None:
self.command(["-D", str(chain), str(id)])
def create_chain(self, name) -> None:
self.command(["-N", str(name)])
def flush_chain(self, name) -> None:
self.command(["-F", str(name)])
def insert_rule(self, chain, rule, *args, rulenum=1) -> None:
self.command(["-I", str(chain), str(rulenum), "-j", str(rule), *args])
def append_rule(self, chain, rule, *args) -> None:
self.command(["-A", str(chain), "-j", str(rule), *args])

130
backend/modules/sqlite.py Normal file
View File

@@ -0,0 +1,130 @@
from typing import Union
import json, sqlite3, os
from hashlib import md5
class SQLite():
def __init__(self, db_name: str) -> None:
self.conn: Union[None, sqlite3.Connection] = None
self.cur = None
self.db_name = db_name
self.schema = {
'services': {
'service_id': 'VARCHAR(100) PRIMARY KEY',
'status': 'VARCHAR(100) NOT NULL',
'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)',
'name': 'VARCHAR(100) NOT NULL UNIQUE',
'ipv6': 'BOOLEAN NOT NULL CHECK (ipv6 IN (0, 1)) DEFAULT 0',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp"))',
'ip_int': 'VARCHAR(100) NOT NULL',
},
'regexes': {
'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL',
'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'regex_id': 'INTEGER PRIMARY KEY',
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1)) DEFAULT 1',
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (ipv6, port, ip_int, proto);",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);"
]
}
self.DB_VER = md5(json.dumps(self.schema).encode()).hexdigest()
def connect(self) -> None:
try:
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
except Exception:
with open(self.db_name, 'x'): pass
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
self.conn.row_factory = dict_factory
def disconnect(self) -> None:
if self.conn: self.conn.close()
def create_schema(self, tables = {}) -> None:
if self.conn:
cur = self.conn.cursor()
cur.execute("CREATE TABLE IF NOT EXISTS main.keys_values(key VARCHAR(100) PRIMARY KEY, value VARCHAR(100) NOT NULL);")
for t in tables:
if t == "QUERY": continue
cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2]))
if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]]
cur.close()
def query(self, query, *values):
cur = self.conn.cursor()
try:
cur.execute(query, values)
return cur.fetchall()
finally:
cur.close()
try: self.conn.commit()
except Exception: pass
def delete(self):
self.disconnect()
os.remove(self.db_name)
def init(self):
self.connect()
try:
if self.get('DB_VERSION') != self.DB_VER: raise Exception("DB_VERSION is not correct")
except Exception:
self.delete()
self.connect()
self.create_schema(self.schema)
self.put('DB_VERSION', self.DB_VER)
def get(self, key):
q = self.query('SELECT value FROM keys_values WHERE key = ?', key)
if len(q) == 0:
return None
else:
return q[0]["value"]
def put(self, key, value):
if self.get(key) is None:
self.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value))
else:
self.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key)
class Service:
def __init__(self, id: str, status: str, port: int, name: str, ipv6: bool, proto: str, ip_int: str):
self.id = id
self.status = status
self.port = port
self.name = name
self.ipv6 = ipv6
self.proto = proto
self.ip_int = ip_int
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["service_id"], status=var["status"], port=var["port"], name=var["name"], ipv6=var["ipv6"], proto=var["proto"], ip_int=var["ip_int"])
class Regex:
def __init__(self, id: int, regex: str, mode: str, service_id: str, is_blacklist: bool, blocked_packets: int, is_case_sensitive: bool, active: bool):
self.regex = regex
self.mode = mode
self.service_id = service_id
self.is_blacklist = is_blacklist
self.blocked_packets = blocked_packets
self.id = id
self.is_case_sensitive = is_case_sensitive
self.active = active
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["regex_id"], regex=var["regex"], mode=var["mode"], service_id=var["service_id"], is_blacklist=var["is_blacklist"], blocked_packets=var["blocked_packets"], is_case_sensitive=var["is_case_sensitive"], active=var["active"])