Filter with queue and iptables script developping

This commit is contained in:
DomySh
2022-07-07 19:16:58 +02:00
parent a1f9036eeb
commit 2d06fc46d8
5 changed files with 259 additions and 294 deletions

View File

@@ -192,12 +192,6 @@ async def get_service_by_id(service_id: str, auth: bool = Depends(is_loggined)):
class StatusMessageModel(BaseModel):
status:str
@app.get('/api/service/{service_id}/stop', response_model=StatusMessageModel)
async def service_stop(service_id: str, auth: bool = Depends(is_loggined)):
"""Request the stop of a specific service"""
await firewall.get(service_id).next(STATUS.STOP)
return {'status': 'ok'}
@app.get('/api/service/{service_id}/pause', response_model=StatusMessageModel)
async def service_pause(service_id: str, auth: bool = Depends(is_loggined)):
"""Request the pause of a specific service"""

View File

@@ -1,4 +1,214 @@
import re, os, asyncio
from typing import List, Set
from netfilterqueue import NetfilterQueue
from threading import Lock, Thread
from scapy.all import IP, TCP, UDP
from subprocess import Popen, PIPE
import os, pcre2, traceback, asyncio
QUEUE_BASE_NUM = 1000
def bind_queues(func, len_list=1):
if len_list <= 0: raise Exception("len must be >= 1")
queue_list = []
starts = QUEUE_BASE_NUM
end = starts
def func_wrap(pkt):
pkt_parsed = IP(pkt.get_payload())
try:
if pkt_parsed[UDP if UDP in pkt_parsed else TCP].payload: func(pkt, pkt_parsed)
else: pkt.accept()
except Exception:
traceback.print_exc()
pkt.accept()
while True:
if starts >= 65536:
raise Exception("Netfilter queue is full!")
try:
for _ in range(len_list):
queue_list.append(NetfilterQueue())
queue_list[-1].bind(end, func_wrap)
end+=1
end-=1
break
except OSError:
del queue_list[-1]
for ele in queue_list:
ele.unbind()
queue_list = []
starts = end = end+1
return queue_list, (starts, end)
class FilterTypes:
INPUT = "FIREGEX-INPUT"
OUTPUT = "FIREGEX-OUTPUT"
class ProtoTypes:
TCP = "tcp"
UDP = "udp"
class IPTables:
def command(params):
if os.geteuid() != 0:
exit("You need to have root privileges to run this script.\nPlease try again, this time using 'sudo'. Exiting.")
return Popen(["iptables"]+params, stdout=PIPE, stderr=PIPE).communicate()
def list_filters(param):
stdout, strerr = IPTables.command(["-L", str(param), "--line-number", "-n"])
output = [ele.split() for ele in stdout.decode().split("\n")]
return [{
"id": ele[0],
"target": ele[1],
"prot": ele[2],
"opt": ele[3],
"source": ele[4],
"destination": ele[5],
"details": " ".join(ele[6:]) if len(ele) >= 7 else "",
} for ele in output if len(ele) >= 6 and ele[0].isnumeric()]
def delete_command(param, id):
IPTables.command(["-R", str(param), str(id)])
def create_chain(name):
IPTables.command(["-N", str(name)])
def flush_chain(name):
IPTables.command(["-F", str(name)])
def add_chain_to_input(name):
IPTables.command(["-I", "INPUT", "-j", str(name)])
def add_chain_to_output(name):
IPTables.command(["-I", "OUTPUT", "-j", str(name)])
def add_s_to_c(proto, port, queue_range):
init, end = queue_range
if init > end: init, end = end, init
IPTables.command([
"-A", FilterTypes.OUTPUT, "-p", str(proto),
"--sport", str(port), "-j", "NFQUEUE",
"--queue-num" if init == end else "--queue-balance",
f"{init}" if init == end else f"{init}:{end}", "--queue-bypass"
])
def add_c_to_s(proto, port, queue_range):
init, end = queue_range
if init > end: init, end = end, init
IPTables.command([
"-A", FilterTypes.INPUT, "-p", str(proto),
"--dport", str(port), "-j", "NFQUEUE",
"--queue-num" if init == end else "--queue-balance",
f"{init}" if init == end else f"{init}:{end}", "--queue-bypass"
])
class FiregexFilter():
def __init__(self, type, number, queue, proto, port):
self.type = type
self.id = int(number)
self.queue = queue
self.proto = proto
self.port = int(port)
def __repr__(self) -> str:
return f"<FiregexFilter type={self.type} id={self.id} port={self.port} proto={self.proto} queue={self.queue}>"
def delete(self):
IPTables.delete_command(self.type, self.id)
class FiregexFilterManager:
def __init__(self):
IPTables.create_chain(FilterTypes.INPUT)
IPTables.create_chain(FilterTypes.OUTPUT)
input_found = False
output_found = False
for filter in IPTables.list_filters("INPUT"):
if filter["target"] == FilterTypes.INPUT:
input_found = True
break
for filter in IPTables.list_filters("OUTPUT"):
if filter["target"] == FilterTypes.OUTPUT:
output_found = True
break
if not input_found: IPTables.add_chain_to_input(FilterTypes.INPUT)
if not output_found: IPTables.add_chain_to_output(FilterTypes.OUTPUT)
def get(self) -> List[FiregexFilter]:
res = []
balanced_mode = pcre2.PCRE2(b"NFQUEUE balance ([0-9]+):([0-9]+)")
num_mode = pcre2.PCRE2(b"NFQUEUE num ([0-9]+)")
port_selected = pcre2.PCRE2(b"[sd]pt:([0-9]+)")
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
for filter in IPTables.list_filters(filter_type):
queue_num = None
balanced = balanced_mode.search(filter["details"].encode())
numbered = num_mode.search(filter["details"].encode())
port = port_selected.search(filter["details"].encode())
if balanced: queue_num = (int(balanced.group(1).decode()), int(balanced.group(2).decode()))
if numbered: queue_num = (int(numbered.group(1).decode()), int(numbered.group(1).decode()))
if queue_num and port:
res.append(FiregexFilter(
type=filter_type,
number=filter["id"],
queue=queue_num,
proto=filter["prot"],
port=port.group(1).decode()
))
return res
def add(self, proto, port, func_c_to_s, func_s_to_c, n_threads = 1):
for ele in self.get():
if int(port) == ele.port: return None
queues_c_to_s, codes = bind_queues(func_c_to_s, n_threads)
IPTables.add_c_to_s(proto, port, codes)
queues_s_to_c, codes = bind_queues(func_s_to_c, n_threads)
IPTables.add_s_to_c(proto, port, codes)
return queues_c_to_s + queues_s_to_c
def delete_all(self):
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
IPTables.flush_chain(filter_type)
def delete_by_port(self, port):
for filter in self.get():
if filter.port == int(port):
filter.delete()
def c_to_s(pkt, data):
print("SENDING", bytes(data[TCP].payload).decode())
if "bug" in bytes(data[TCP].payload).decode():
pkt.drop()
return
pkt.accept()
def s_to_c(pkt, data):
print("RECIVING", bytes(data[TCP].payload).decode())
pkt.accept()
"""
try:
manager.delete_all()
thr_list = []
q_list = manager.add("test_service",ProtoTypes.TCP, 8080, c_to_s, s_to_c)
print(manager.get())
for q in q_list:
thr_list.append(Thread(target=q.run))
thr_list[-1].start()
for t in thr_list:
t.join()
except KeyboardInterrupt:
for q in q_list:
q.unbind()
manager.delete_by_service("test_service")
#sudo iptables -I OUTPUT -p tcp --sport 8080 -j NFQUEUE --queue-num 10001 --queue-bypass -m comment --comment "&firegex&servid& Text"
#sudo iptables -I INPUT -p tcp --dport 8080 -j NFQUEUE --queue-num 10000 --queue-bypass -m comment --comment "&firegex&servid& Text"
"""
class Filter:
def __init__(self, regex, is_case_sensitive=True, is_blacklist=True, c_to_s=False, s_to_c=False, blocked_packets=0, code=None):
@@ -10,61 +220,60 @@ class Filter:
self.s_to_c = s_to_c
self.blocked = blocked_packets
self.code = code
self.compiled_regex = self.compile()
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if is invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.c_to_s:
yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex()
if self.s_to_c:
yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex()
return pcre2.PCRE2(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
def check(self, data):
return True if self.compiled_regex.search(data) else False
class Proxy:
def __init__(self, internal_port=0, public_port=0, callback_blocked_update=None, filters=None, public_host="0.0.0.0", internal_host="127.0.0.1"):
self.filter_map = {}
self.filter_map_lock = asyncio.Lock()
def __init__(self, public_port, callback_blocked_update=None, filters=None):
self.manager = FiregexFilterManager()
self.update_config_lock = asyncio.Lock()
self.status_change = asyncio.Lock()
self.public_host = public_host
self.public_port = public_port
self.internal_host = internal_host
self.internal_port = internal_port
self.filters = set(filters) if filters else set([])
self.process = None
self.port = public_port
self.filters: Set[Filter] = set(filters) if filters else set([])
self.use_filters = True
self.callback_blocked_update = callback_blocked_update
async def start(self, in_pause=False):
await self.status_change.acquire()
if not self.isactive():
try:
self.filter_map = self.compile_filters()
filters_codes = self.get_filter_codes() if not in_pause else []
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./proxy")
async def start(self):
self.manager.delete_by_port(self.port)
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port),
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE
)
await self.update_config(filters_codes)
finally:
self.status_change.release()
def c_to_s(pkt, data):
packet = bytes(data[TCP].payload)
try:
while True:
buff = await self.process.stdout.readuntil()
stdout_line = buff.decode()
if stdout_line.startswith("BLOCKED"):
regex_id = stdout_line.split()[1]
async with self.filter_map_lock:
if regex_id in self.filter_map:
self.filter_map[regex_id].blocked+=1
if self.callback_blocked_update: self.callback_blocked_update(self.filter_map[regex_id])
except Exception:
return await self.process.wait()
else:
self.status_change.release()
for filter in self.filters:
if filter.c_to_s:
match = filter.check(packet)
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
filter.blocked+=1
self.callback_blocked_update(filter)
pkt.drop()
return
except IndexError:
pass
pkt.accept()
def s_to_c(pkt, data):
packet = bytes(data[TCP].payload)
try:
for filter in self.filters:
if filter.s_to_c:
match = filter.check(packet)
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
filter.blocked+=1
self.callback_blocked_update(filter)
pkt.drop()
return
except IndexError:
pass
pkt.accept()
self.manager.add(ProtoTypes.TCP, self.port, c_to_s, s_to_c)
async def stop(self):
async with self.status_change:
@@ -105,12 +314,3 @@ class Proxy:
else:
await self.start(in_pause=True)
def compile_filters(self):
res = {}
for filter_obj in self.filters:
try:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception: pass
return res

View File

@@ -94,8 +94,6 @@ class KeyValueStorage:
self.db.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key)
class STATUS:
WAIT = "wait"
STOP = "stop"
PAUSE = "pause"
ACTIVE = "active"
@@ -106,11 +104,9 @@ class ServiceManager:
self.id = id
self.db = db
self.proxy = Proxy(
internal_host=LOCALHOST_IP,
callback_blocked_update=self._stats_updater
)
self.status = STATUS.STOP
self.wanted_status = STATUS.STOP
self.status = STATUS.PAUSE
self.filters = {}
self._update_port_from_db()
self._update_filters_from_db()
@@ -125,7 +121,6 @@ class ServiceManager:
FROM services WHERE service_id = ?;
""", self.id)
if len(res) == 0: raise ServiceNotFoundException()
self.proxy.internal_port = res[0]["internal_port"]
self.proxy.public_port = res[0]["public_port"]
def _update_filters_from_db(self):
@@ -168,27 +163,15 @@ class ServiceManager:
async def _next(self, to):
if self.status != to:
# ACTIVE -> PAUSE or PAUSE -> ACTIVE
# ACTIVE -> PAUSE
if (self.status, to) in [(STATUS.ACTIVE, STATUS.PAUSE)]:
await self.proxy.pause()
self._set_status(to)
# PAUSE -> ACTIVE
elif (self.status, to) in [(STATUS.PAUSE, STATUS.ACTIVE)]:
await self.proxy.reload()
self._set_status(to)
# ACTIVE -> STOP
elif (self.status,to) in [(STATUS.ACTIVE, STATUS.STOP), (STATUS.WAIT, STATUS.STOP), (STATUS.PAUSE, STATUS.STOP)]: #Stop proxy
if self.starter: self.starter.cancel()
await self.proxy.stop()
self._set_status(to)
# STOP -> ACTIVE or STOP -> PAUSE
elif (self.status, to) in [(STATUS.STOP, STATUS.ACTIVE), (STATUS.STOP, STATUS.PAUSE)]:
self.wanted_status = to
self._set_status(STATUS.WAIT)
self.__proxy_starter(to)
def _stats_updater(self,filter:Filter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
@@ -196,10 +179,8 @@ class ServiceManager:
async def update_port(self):
async with self.lock:
self._update_port_from_db()
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
next_status = self.status if self.status != STATUS.WAIT else self.wanted_status
await self._next(STATUS.STOP)
await self._next(next_status)
if self.status in [STATUS.ACTIVE]:
await self.proxy.reload()
def _set_status(self,status):
self.status = status
@@ -211,22 +192,6 @@ class ServiceManager:
self._update_filters_from_db()
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
await self.proxy.reload()
def __proxy_starter(self,to):
async def func():
try:
while True:
if check_port_is_open(self.proxy.public_port):
self._set_status(to)
await self.proxy.start(in_pause=(to==STATUS.PAUSE))
self._set_status(STATUS.STOP)
return
else:
await asyncio.sleep(.5)
except asyncio.CancelledError:
self._set_status(STATUS.STOP)
await self.proxy.stop()
self.starter = asyncio.create_task(func())
class ProxyManager:
def __init__(self, db:SQLite):
@@ -241,7 +206,7 @@ class ProxyManager:
async def remove(self,id):
async with self.lock:
if id in self.proxy_table:
await self.proxy_table[id].next(STATUS.STOP)
await self.proxy_table[id].next(STATUS.PAUSE)
del self.proxy_table[id]
async def reload(self):