170 lines
6.7 KiB
Python
170 lines
6.7 KiB
Python
from typing import List
|
|
from pypacker import interceptor
|
|
from pypacker.layer3 import ip, ip6
|
|
from pypacker.layer4 import tcp, udp
|
|
from ipaddress import ip_interface
|
|
from modules.iptables import IPTables
|
|
import os, traceback
|
|
|
|
from modules.sqlite import Service
|
|
|
|
class FilterTypes:
|
|
INPUT = "FIREGEX-INPUT"
|
|
OUTPUT = "FIREGEX-OUTPUT"
|
|
|
|
QUEUE_BASE_NUM = 1000
|
|
|
|
class FiregexFilter():
|
|
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None, func=None):
|
|
self.target = target
|
|
self.id = int(id) if id else None
|
|
self.queue = queue
|
|
self.proto = proto
|
|
self.port = int(port)
|
|
self.ip_int = str(ip_int)
|
|
self.func = func
|
|
|
|
def __eq__(self, o: object) -> bool:
|
|
if isinstance(o, FiregexFilter):
|
|
return self.port == o.port and self.proto == o.proto and ip_interface(self.ip_int) == ip_interface(o.ip_int)
|
|
return False
|
|
|
|
def ipv6(self):
|
|
return ip_interface(self.ip_int).version == 6
|
|
|
|
def ipv4(self):
|
|
return ip_interface(self.ip_int).version == 4
|
|
|
|
def input_func(self):
|
|
def none(pkt): return True
|
|
def wrap(pkt): return self.func(pkt, True)
|
|
return wrap if self.func else none
|
|
|
|
def output_func(self):
|
|
def none(pkt): return True
|
|
def wrap(pkt): return self.func(pkt, False)
|
|
return wrap if self.func else none
|
|
|
|
class FiregexTables(IPTables):
|
|
|
|
def __init__(self, ipv6=False):
|
|
super().__init__(ipv6, "mangle")
|
|
self.create_chain(FilterTypes.INPUT)
|
|
self.add_chain_to_input(FilterTypes.INPUT)
|
|
self.create_chain(FilterTypes.OUTPUT)
|
|
self.add_chain_to_output(FilterTypes.OUTPUT)
|
|
|
|
def target_in_chain(self, chain, target):
|
|
for filter in self.list()[chain]:
|
|
if filter.target == target:
|
|
return True
|
|
return False
|
|
|
|
def add_chain_to_input(self, chain):
|
|
if not self.target_in_chain("PREROUTING", str(chain)):
|
|
self.insert_rule("PREROUTING", str(chain))
|
|
|
|
def add_chain_to_output(self, chain):
|
|
if not self.target_in_chain("POSTROUTING", str(chain)):
|
|
self.insert_rule("POSTROUTING", str(chain))
|
|
|
|
def add_output(self, queue_range, proto = None, port = None, ip_int = None):
|
|
init, end = queue_range
|
|
if init > end: init, end = end, init
|
|
self.append_rule(FilterTypes.OUTPUT,"NFQUEUE"
|
|
* (["-p", str(proto)] if proto else []),
|
|
* (["-s", str(ip_int)] if ip_int else []),
|
|
* (["--sport", str(port)] if port else []),
|
|
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
|
|
"--queue-bypass"
|
|
)
|
|
|
|
def add_input(self, queue_range, proto = None, port = None, ip_int = None):
|
|
init, end = queue_range
|
|
if init > end: init, end = end, init
|
|
self.append_rule(FilterTypes.INPUT, "NFQUEUE",
|
|
* (["-p", str(proto)] if proto else []),
|
|
* (["-d", str(ip_int)] if ip_int else []),
|
|
* (["--dport", str(port)] if port else []),
|
|
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
|
|
"--queue-bypass"
|
|
)
|
|
|
|
def get(self) -> List[FiregexFilter]:
|
|
res = []
|
|
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
|
|
for filter in self.list()[filter_type]:
|
|
port = filter.sport() if filter_type == FilterTypes.OUTPUT else filter.dport()
|
|
queue = filter.nfqueue()
|
|
if queue and port:
|
|
res.append(FiregexFilter(
|
|
target=filter_type,
|
|
id=filter.id,
|
|
queue=queue,
|
|
proto=filter.prot,
|
|
port=port,
|
|
ip_int=filter.source if filter_type == FilterTypes.OUTPUT else filter.destination
|
|
))
|
|
return res
|
|
|
|
def add(self, filter:FiregexFilter):
|
|
if filter in self.get(): return None
|
|
return FiregexInterceptor( iptables=self, filter=filter, n_threads=int(os.getenv("N_THREADS_NFQUEUE","1")))
|
|
|
|
def delete_all(self):
|
|
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
|
|
self.flush_chain(filter_type)
|
|
|
|
def delete_by_srv(self, srv:Service):
|
|
for filter in self.get():
|
|
if filter.port == srv.port and filter.proto == srv.proto and ip_interface(filter.ip_int) == ip_interface(srv.ip_int):
|
|
self.delete_rule(filter.target, filter.id)
|
|
|
|
class FiregexInterceptor:
|
|
def __init__(self, iptables: FiregexTables, filter: FiregexFilter, n_threads:int = 1):
|
|
self.filter = filter
|
|
self.ipv6 = self.filter.ipv6()
|
|
self.itor_input, codes = self._start_queue(filter.input_func(), n_threads)
|
|
iptables.add_input(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
|
|
self.itor_output, codes = self._start_queue(filter.output_func(), n_threads)
|
|
iptables.add_output(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
|
|
|
|
def _start_queue(self,func,n_threads):
|
|
def func_wrap(ll_data, ll_proto_id, data, ctx, *args):
|
|
pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data)
|
|
try:
|
|
data = None
|
|
if not pkt_parsed[tcp.TCP] is None:
|
|
data = pkt_parsed[tcp.TCP].body_bytes
|
|
if not pkt_parsed[tcp.TCP] is None:
|
|
data = pkt_parsed[udp.UDP].body_bytes
|
|
if data:
|
|
if func(data):
|
|
return data, interceptor.NF_ACCEPT
|
|
elif pkt_parsed[tcp.TCP]:
|
|
pkt_parsed[tcp.TCP].flags &= 0x00
|
|
pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK
|
|
pkt_parsed[tcp.TCP].body_bytes = b""
|
|
return pkt_parsed.bin(), interceptor.NF_ACCEPT
|
|
else: return b"", interceptor.NF_DROP
|
|
else: return data, interceptor.NF_ACCEPT
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return data, interceptor.NF_ACCEPT
|
|
|
|
ictor = interceptor.Interceptor()
|
|
starts = QUEUE_BASE_NUM
|
|
while True:
|
|
if starts >= 65536:
|
|
raise Exception("Netfilter queue is full!")
|
|
queue_ids = list(range(starts,starts+n_threads))
|
|
try:
|
|
ictor.start(func_wrap, queue_ids=queue_ids)
|
|
break
|
|
except interceptor.UnableToBindException as e:
|
|
starts = e.queue_id + 1
|
|
return ictor, (starts, starts+n_threads-1)
|
|
|
|
def stop(self):
|
|
self.itor_input.stop()
|
|
self.itor_output.stop() |