Refactor code pt 1 (not tested)
This commit is contained in:
170
backend/modules/firegex.py
Normal file
170
backend/modules/firegex.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from typing import List
|
||||
from pypacker import interceptor
|
||||
from pypacker.layer3 import ip, ip6
|
||||
from pypacker.layer4 import tcp, udp
|
||||
from ipaddress import ip_interface
|
||||
from modules.iptables import IPTables
|
||||
import os, traceback
|
||||
|
||||
from modules.sqlite import Service
|
||||
|
||||
class FilterTypes:
|
||||
INPUT = "FIREGEX-INPUT"
|
||||
OUTPUT = "FIREGEX-OUTPUT"
|
||||
|
||||
QUEUE_BASE_NUM = 1000
|
||||
|
||||
class FiregexFilter():
|
||||
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None, func=None):
|
||||
self.target = target
|
||||
self.id = int(id) if id else None
|
||||
self.queue = queue
|
||||
self.proto = proto
|
||||
self.port = int(port)
|
||||
self.ip_int = str(ip_int)
|
||||
self.func = func
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, FiregexFilter):
|
||||
return self.port == o.port and self.proto == o.proto and ip_interface(self.ip_int) == ip_interface(o.ip_int)
|
||||
return False
|
||||
|
||||
def ipv6(self):
|
||||
return ip_interface(self.ip_int).version == 6
|
||||
|
||||
def ipv4(self):
|
||||
return ip_interface(self.ip_int).version == 4
|
||||
|
||||
def input_func(self):
|
||||
def none(pkt): return True
|
||||
def wrap(pkt): return self.func(pkt, True)
|
||||
return wrap if self.func else none
|
||||
|
||||
def output_func(self):
|
||||
def none(pkt): return True
|
||||
def wrap(pkt): return self.func(pkt, False)
|
||||
return wrap if self.func else none
|
||||
|
||||
class FiregexTables(IPTables):
|
||||
|
||||
def __init__(self, ipv6=False):
|
||||
super().__init__(ipv6, "mangle")
|
||||
self.create_chain(FilterTypes.INPUT)
|
||||
self.add_chain_to_input(FilterTypes.INPUT)
|
||||
self.create_chain(FilterTypes.OUTPUT)
|
||||
self.add_chain_to_output(FilterTypes.OUTPUT)
|
||||
|
||||
def target_in_chain(self, chain, target):
|
||||
for filter in self.list()[chain]:
|
||||
if filter.target == target:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_chain_to_input(self, chain):
|
||||
if not self.target_in_chain("PREROUTING", str(chain)):
|
||||
self.insert_rule("PREROUTING", str(chain))
|
||||
|
||||
def add_chain_to_output(self, chain):
|
||||
if not self.target_in_chain("POSTROUTING", str(chain)):
|
||||
self.insert_rule("POSTROUTING", str(chain))
|
||||
|
||||
def add_output(self, queue_range, proto = None, port = None, ip_int = None):
|
||||
init, end = queue_range
|
||||
if init > end: init, end = end, init
|
||||
self.append_rule(FilterTypes.OUTPUT,"NFQUEUE"
|
||||
* (["-p", str(proto)] if proto else []),
|
||||
* (["-s", str(ip_int)] if ip_int else []),
|
||||
* (["--sport", str(port)] if port else []),
|
||||
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
|
||||
"--queue-bypass"
|
||||
)
|
||||
|
||||
def add_input(self, queue_range, proto = None, port = None, ip_int = None):
|
||||
init, end = queue_range
|
||||
if init > end: init, end = end, init
|
||||
self.append_rule(FilterTypes.INPUT, "NFQUEUE",
|
||||
* (["-p", str(proto)] if proto else []),
|
||||
* (["-d", str(ip_int)] if ip_int else []),
|
||||
* (["--dport", str(port)] if port else []),
|
||||
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
|
||||
"--queue-bypass"
|
||||
)
|
||||
|
||||
def get(self) -> List[FiregexFilter]:
|
||||
res = []
|
||||
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
|
||||
for filter in self.list()[filter_type]:
|
||||
port = filter.sport() if filter_type == FilterTypes.OUTPUT else filter.dport()
|
||||
queue = filter.nfqueue()
|
||||
if queue and port:
|
||||
res.append(FiregexFilter(
|
||||
target=filter_type,
|
||||
id=filter.id,
|
||||
queue=queue,
|
||||
proto=filter.prot,
|
||||
port=port,
|
||||
ip_int=filter.source if filter_type == FilterTypes.OUTPUT else filter.destination
|
||||
))
|
||||
return res
|
||||
|
||||
def add(self, filter:FiregexFilter):
|
||||
if filter in self.get(): return None
|
||||
return FiregexInterceptor( iptables=self, filter=filter, n_threads=int(os.getenv("N_THREADS_NFQUEUE","1")))
|
||||
|
||||
def delete_all(self):
|
||||
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
|
||||
self.flush_chain(filter_type)
|
||||
|
||||
def delete_by_srv(self, srv:Service):
|
||||
for filter in self.get():
|
||||
if filter.port == srv.port and filter.proto == srv.proto and ip_interface(filter.ip_int) == ip_interface(srv.ip_int):
|
||||
self.delete_rule(filter.target, filter.id)
|
||||
|
||||
class FiregexInterceptor:
|
||||
def __init__(self, iptables: FiregexTables, filter: FiregexFilter, n_threads:int = 1):
|
||||
self.filter = filter
|
||||
self.ipv6 = self.filter.ipv6()
|
||||
self.itor_input, codes = self._start_queue(filter.input_func(), n_threads)
|
||||
iptables.add_input(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
|
||||
self.itor_output, codes = self._start_queue(filter.output_func(), n_threads)
|
||||
iptables.add_output(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
|
||||
|
||||
def _start_queue(self,func,n_threads):
|
||||
def func_wrap(ll_data, ll_proto_id, data, ctx, *args):
|
||||
pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data)
|
||||
try:
|
||||
data = None
|
||||
if not pkt_parsed[tcp.TCP] is None:
|
||||
data = pkt_parsed[tcp.TCP].body_bytes
|
||||
if not pkt_parsed[tcp.TCP] is None:
|
||||
data = pkt_parsed[udp.UDP].body_bytes
|
||||
if data:
|
||||
if func(data):
|
||||
return data, interceptor.NF_ACCEPT
|
||||
elif pkt_parsed[tcp.TCP]:
|
||||
pkt_parsed[tcp.TCP].flags &= 0x00
|
||||
pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK
|
||||
pkt_parsed[tcp.TCP].body_bytes = b""
|
||||
return pkt_parsed.bin(), interceptor.NF_ACCEPT
|
||||
else: return b"", interceptor.NF_DROP
|
||||
else: return data, interceptor.NF_ACCEPT
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return data, interceptor.NF_ACCEPT
|
||||
|
||||
ictor = interceptor.Interceptor()
|
||||
starts = QUEUE_BASE_NUM
|
||||
while True:
|
||||
if starts >= 65536:
|
||||
raise Exception("Netfilter queue is full!")
|
||||
queue_ids = list(range(starts,starts+n_threads))
|
||||
try:
|
||||
ictor.start(func_wrap, queue_ids=queue_ids)
|
||||
break
|
||||
except interceptor.UnableToBindException as e:
|
||||
starts = e.queue_id + 1
|
||||
return ictor, (starts, starts+n_threads-1)
|
||||
|
||||
def stop(self):
|
||||
self.itor_input.stop()
|
||||
self.itor_output.stop()
|
||||
Reference in New Issue
Block a user