Moduled Firegex, Merging pt1 (not finished and not working yet)

This commit is contained in:
DomySh
2022-07-21 20:25:39 +02:00
parent 3143f6e474
commit 63ba0e94e7
42 changed files with 2832 additions and 577 deletions

View File

@@ -6,8 +6,7 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt
from passlib.context import CryptContext
from fastapi_socketio import SocketManager
from modules import SQLite
from modules.firegex import FiregexTables
from utils.sqlite import SQLite
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, refresh_frontend, DEBUG
from utils.loader import frontend_deploy, load_routers
from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel

View File

@@ -1,2 +0,0 @@
from .firewall import FirewallManager
from .sqlite import SQLite

View File

@@ -1,276 +0,0 @@
from typing import Dict, List, Set
from utils import ip_parse, ip_family, run_func
from modules.sqlite import Service
import re, os, asyncio
import traceback, nftables
from modules.sqlite import Regex
QUEUE_BASE_NUM = 1000
class FiregexFilter():
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None):
self.nftables = nftables.Nftables()
self.id = int(id) if id else None
self.queue = queue
self.target = target
self.proto = proto
self.port = int(port)
self.ip_int = str(ip_int)
def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return False
class FiregexTables:
def __init__(self):
self.table_name = "firegex"
self.nft = nftables.Nftables()
def raw_cmd(self, *cmds):
return self.nft.json_cmd({"nftables": list(cmds)})
def cmd(self, *cmds):
code, out, err = self.raw_cmd(*cmds)
if code == 0: return out
else: raise Exception(err)
def init(self):
self.reset()
code, out, err = self.raw_cmd({"create":{"table":{"name":self.table_name,"family":"inet"}}})
if code == 0:
self.cmd(
{"create":{"chain":{
"family":"inet",
"table":self.table_name,
"name":"input",
"type":"filter",
"hook":"prerouting",
"prio":-150,
"policy":"accept"
}}},
{"create":{"chain":{
"family":"inet",
"table":self.table_name,
"name":"output",
"type":"filter",
"hook":"postrouting",
"prio":-150,
"policy":"accept"
}}}
)
def reset(self):
self.raw_cmd(
{"flush":{"table":{"name":"firegex","family":"inet"}}},
{"delete":{"table":{"name":"firegex","family":"inet"}}},
)
def list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"]
def add_output(self, queue_range, proto, port, ip_int):
init, end = queue_range
if init > end: init, end = end, init
ip_int = ip_parse(ip_int)
ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1])
self.cmd({ "insert":{ "rule": {
"family": "inet",
"table": self.table_name,
"chain": "output",
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, #ip_int
{'match': {'left': {'meta': {'key': 'l4proto'}}, 'op': '==', 'right': str(proto)}},
{'match': {"left": { "payload": {"protocol": str(proto), "field": "sport"}}, "op": "==", "right": int(port)}},
{"queue": {"num": str(init) if init == end else f"{init}-{end}", "flags": ["bypass"]}}
]
}}})
def add_input(self, queue_range, proto = None, port = None, ip_int = None):
init, end = queue_range
if init > end: init, end = end, init
ip_int = ip_parse(ip_int)
ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1])
self.cmd({"insert":{"rule":{
"family": "inet",
"table": self.table_name,
"chain": "input",
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}}, #ip_int
{'match': {"left": { "payload": {"protocol": str(proto), "field": "dport"}}, "op": "==", "right": int(port)}},
{"queue": {"num": str(init) if init == end else f"{init}-{end}", "flags": ["bypass"]}}
]
}}})
def get(self) -> List[FiregexFilter]:
res = []
for filter in [ele["rule"] for ele in self.list() if "rule" in ele and ele["rule"]["table"] == self.table_name]:
queue_str = str(filter["expr"][2]["queue"]["num"]).split("-")
queue = None
if len(queue_str) == 1: queue = int(queue_str[0]), int(queue_str[0])
else: queue = int(queue_str[0]), int(queue_str[1])
ip_int = None
if isinstance(filter["expr"][0]["match"]["right"],str):
ip_int = str(ip_parse(filter["expr"][0]["match"]["right"]))
else:
ip_int = f'{filter["expr"][0]["match"]["right"]["prefix"]["addr"]}/{filter["expr"][0]["match"]["right"]["prefix"]["len"]}'
res.append(FiregexFilter(
target=filter["chain"],
id=int(filter["handle"]),
queue=queue,
proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"],
port=filter["expr"][1]["match"]["right"],
ip_int=ip_int
))
return res
async def add(self, filter:FiregexFilter):
if filter in self.get(): return None
return await FiregexInterceptor.start( filter=filter, n_queues=int(os.getenv("N_THREADS_NFQUEUE","1")))
def delete_by_srv(self, srv:Service):
for filter in self.get():
if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int):
self.cmd({"delete":{"rule": {"handle": filter.id, "table": self.table_name, "chain": filter.target, "family": "inet"}}})
class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
is_blacklist=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
id=None,
update_func = None
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
self.id = id
self.update_func = update_func
self.compiled_regex = self.compile()
@classmethod
def from_regex(cls, regex:Regex, update_func = None):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"],
update_func = update_func
)
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if it's invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.input_mode:
yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex()
if self.output_mode:
yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex()
async def update(self):
if self.update_func:
await run_func(self.update_func, self)
class FiregexInterceptor:
def __init__(self):
self.filter:FiregexFilter
self.filter_map_lock:asyncio.Lock
self.filter_map: Dict[str, RegexFilter]
self.regex_filters: Set[RegexFilter]
self.update_config_lock:asyncio.Lock
self.process:asyncio.subprocess.Process
self.n_queues:int
self.update_task: asyncio.Task
@classmethod
async def start(cls, filter: FiregexFilter, n_queues:int = 1):
self = cls()
self.filter = filter
self.n_queues = n_queues
self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock()
input_range, output_range = await self._start_binary()
self.update_task = asyncio.create_task(self.update_blocked())
FiregexTables().add_input(queue_range=input_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
FiregexTables().add_output(queue_range=output_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
return self
async def _start_binary(self):
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./cppqueue")
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path, str(self.n_queues),
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE
)
line_fut = self.process.stdout.readuntil()
try:
line_fut = await asyncio.wait_for(line_fut, timeout=3)
except asyncio.TimeoutError:
self.process.kill()
raise Exception("Invalid binary output")
line = line_fut.decode()
if line.startswith("QUEUES "):
params = line.split()
return (int(params[2]), int(params[3])), (int(params[5]), int(params[6]))
else:
self.process.kill()
raise Exception("Invalid binary output")
async def update_blocked(self):
try:
while True:
line = (await self.process.stdout.readuntil()).decode()
if line.startswith("BLOCKED"):
regex_id = line.split()[1]
async with self.filter_map_lock:
if regex_id in self.filter_map:
self.filter_map[regex_id].blocked+=1
await self.filter_map[regex_id].update()
except asyncio.CancelledError: pass
except asyncio.IncompleteReadError: pass
except Exception:
traceback.print_exc()
async def stop(self):
self.update_task.cancel()
if self.process and self.process.returncode is None:
self.process.kill()
async def _update_config(self, filters_codes):
async with self.update_config_lock:
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain()
async def reload(self, filters:List[RegexFilter]):
async with self.filter_map_lock:
self.filter_map = self.compile_filters(filters)
filters_codes = self.get_filter_codes()
await self._update_config(filters_codes)
def get_filter_codes(self):
filters_codes = list(self.filter_map.keys())
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
return filters_codes
def compile_filters(self, filters:List[RegexFilter]):
res = {}
for filter_obj in filters:
try:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception: pass
return res

View File

View File

@@ -0,0 +1,151 @@
from typing import Dict, List, Set
from utils.firegextables import FiregexFilter, FiregexTables
from utils import ip_parse, ip_family, run_func
from modules.nfregex.models import Service, Regex
import re, os, asyncio
import traceback
QUEUE_BASE_NUM = 1000
class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
is_blacklist=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
id=None,
update_func = None
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
self.id = id
self.update_func = update_func
self.compiled_regex = self.compile()
@classmethod
def from_regex(cls, regex:Regex, update_func = None):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"],
update_func = update_func
)
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if it's invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.input_mode:
yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex()
if self.output_mode:
yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex()
async def update(self):
if self.update_func:
await run_func(self.update_func, self)
class FiregexInterceptor:
def __init__(self):
self.filter:FiregexFilter
self.filter_map_lock:asyncio.Lock
self.filter_map: Dict[str, RegexFilter]
self.regex_filters: Set[RegexFilter]
self.update_config_lock:asyncio.Lock
self.process:asyncio.subprocess.Process
self.n_queues:int
self.update_task: asyncio.Task
@classmethod
async def start(cls, filter: FiregexFilter, n_queues:int = int(os.getenv("NTHREADS","1"))):
self = cls()
self.filter = filter
self.n_queues = n_queues
self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock()
input_range, output_range = await self._start_binary()
self.update_task = asyncio.create_task(self.update_blocked())
if not filter in FiregexTables().get():
FiregexTables().add_input(queue_range=input_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
FiregexTables().add_output(queue_range=output_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
return self
async def _start_binary(self):
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppqueue")
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path, str(self.n_queues),
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE
)
line_fut = self.process.stdout.readuntil()
try:
line_fut = await asyncio.wait_for(line_fut, timeout=3)
except asyncio.TimeoutError:
self.process.kill()
raise Exception("Invalid binary output")
line = line_fut.decode()
if line.startswith("QUEUES "):
params = line.split()
return (int(params[2]), int(params[3])), (int(params[5]), int(params[6]))
else:
self.process.kill()
raise Exception("Invalid binary output")
async def update_blocked(self):
try:
while True:
line = (await self.process.stdout.readuntil()).decode()
if line.startswith("BLOCKED"):
regex_id = line.split()[1]
async with self.filter_map_lock:
if regex_id in self.filter_map:
self.filter_map[regex_id].blocked+=1
await self.filter_map[regex_id].update()
except asyncio.CancelledError: pass
except asyncio.IncompleteReadError: pass
except Exception:
traceback.print_exc()
async def stop(self):
self.update_task.cancel()
if self.process and self.process.returncode is None:
self.process.kill()
async def _update_config(self, filters_codes):
async with self.update_config_lock:
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain()
async def reload(self, filters:List[RegexFilter]):
async with self.filter_map_lock:
self.filter_map = self.compile_filters(filters)
filters_codes = self.get_filter_codes()
await self._update_config(filters_codes)
def get_filter_codes(self):
filters_codes = list(self.filter_map.keys())
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
return filters_codes
def compile_filters(self, filters:List[RegexFilter]):
res = {}
for filter_obj in filters:
try:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception: pass
return res
def delete_by_srv(srv:Service):
nft = FiregexTables()
for filter in nft.get():
if filter.port == srv.port and filter.proto == srv.proto and ip_parse(filter.ip_int) == ip_parse(srv.ip_int):
nft.cmd({"delete":{"rule": {"handle": filter.id, "table": nft.table_name, "chain": filter.target, "family": "inet"}}})

View File

@@ -1,7 +1,8 @@
import asyncio
from typing import Dict
from modules.firegex import FiregexFilter, FiregexTables, RegexFilter
from modules.sqlite import Regex, SQLite, Service
from modules.nfregex.firegex import FiregexFilter, FiregexInterceptor, FiregexTables, RegexFilter, delete_by_srv
from modules.nfregex.models import Regex, Service
from utils.sqlite import SQLite
class STATUS:
STOP = "stop"
@@ -93,13 +94,13 @@ class ServiceManager:
async def start(self):
if not self.interceptor:
FiregexTables().delete_by_srv(self.srv)
self.interceptor = await FiregexTables().add(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int))
delete_by_srv(self.srv)
self.interceptor = await FiregexInterceptor.start(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int))
await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE)
async def stop(self):
FiregexTables().delete_by_srv(self.srv)
delete_by_srv(self.srv)
if self.interceptor:
await self.interceptor.stop()
self.interceptor = None

View File

@@ -0,0 +1,30 @@
import base64
class Service:
def __init__(self, id: str, status: str, port: int, name: str, proto: str, ip_int: str):
self.id = id
self.status = status
self.port = port
self.name = name
self.proto = proto
self.ip_int = ip_int
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["service_id"], status=var["status"], port=var["port"], name=var["name"], proto=var["proto"], ip_int=var["ip_int"])
class Regex:
def __init__(self, id: int, regex: bytes, mode: str, service_id: str, is_blacklist: bool, blocked_packets: int, is_case_sensitive: bool, active: bool):
self.regex = regex
self.mode = mode
self.service_id = service_id
self.is_blacklist = is_blacklist
self.blocked_packets = blocked_packets
self.id = id
self.is_case_sensitive = is_case_sensitive
self.active = active
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["regex_id"], regex=base64.b64decode(var["regex"]), mode=var["mode"], service_id=var["service_id"], is_blacklist=var["is_blacklist"], blocked_packets=var["blocked_packets"], is_case_sensitive=var["is_case_sensitive"], active=var["active"])

BIN
backend/modules/proxy Executable file

Binary file not shown.

View File

View File

@@ -0,0 +1,116 @@
import re, os, asyncio
class Filter:
def __init__(self, regex, is_case_sensitive=True, is_blacklist=True, c_to_s=False, s_to_c=False, blocked_packets=0, code=None):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if c_to_s == s_to_c: c_to_s = s_to_c = True # (False, False) == (True, True)
self.c_to_s = c_to_s
self.s_to_c = s_to_c
self.blocked = blocked_packets
self.code = code
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if is invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.c_to_s:
yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex()
if self.s_to_c:
yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex()
class Proxy:
def __init__(self, internal_port=0, public_port=0, callback_blocked_update=None, filters=None, public_host="0.0.0.0", internal_host="127.0.0.1"):
self.filter_map = {}
self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock()
self.status_change = asyncio.Lock()
self.public_host = public_host
self.public_port = public_port
self.internal_host = internal_host
self.internal_port = internal_port
self.filters = set(filters) if filters else set([])
self.process = None
self.callback_blocked_update = callback_blocked_update
async def start(self, in_pause=False):
await self.status_change.acquire()
if not self.isactive():
try:
self.filter_map = self.compile_filters()
filters_codes = self.get_filter_codes() if not in_pause else []
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../proxy")
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port),
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE
)
await self.update_config(filters_codes)
finally:
self.status_change.release()
try:
while True:
buff = await self.process.stdout.readuntil()
stdout_line = buff.decode()
if stdout_line.startswith("BLOCKED"):
regex_id = stdout_line.split()[1]
async with self.filter_map_lock:
if regex_id in self.filter_map:
self.filter_map[regex_id].blocked+=1
if self.callback_blocked_update: self.callback_blocked_update(self.filter_map[regex_id])
except Exception:
return await self.process.wait()
else:
self.status_change.release()
async def stop(self):
async with self.status_change:
if self.isactive():
self.process.kill()
return False
return True
async def restart(self, in_pause=False):
status = await self.stop()
await self.start(in_pause=in_pause)
return status
async def update_config(self, filters_codes):
async with self.update_config_lock:
if (self.isactive()):
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain()
async def reload(self):
if self.isactive():
async with self.filter_map_lock:
self.filter_map = self.compile_filters()
filters_codes = self.get_filter_codes()
await self.update_config(filters_codes)
def get_filter_codes(self):
filters_codes = list(self.filter_map.keys())
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
return filters_codes
def isactive(self):
return self.process and self.process.returncode is None
async def pause(self):
if self.isactive():
await self.update_config([])
else:
await self.start(in_pause=True)
def compile_filters(self):
res = {}
for filter_obj in self.filters:
try:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception: pass
return res

View File

@@ -0,0 +1,196 @@
import secrets
from modules.regexproxy.proxy import Filter, Proxy
import random, socket, asyncio
from base64 import b64decode
from utils.sqlite import SQLite
class STATUS:
WAIT = "wait"
STOP = "stop"
PAUSE = "pause"
ACTIVE = "active"
class ServiceNotFoundException(Exception): pass
class ServiceManager:
def __init__(self, id, db):
self.id = id
self.db = db
self.proxy = Proxy(
internal_host="127.0.0.1",
callback_blocked_update=self._stats_updater
)
self.status = STATUS.STOP
self.wanted_status = STATUS.STOP
self.filters = {}
self._update_port_from_db()
self._update_filters_from_db()
self.lock = asyncio.Lock()
self.starter = None
def _update_port_from_db(self):
res = self.db.query("""
SELECT
public_port,
internal_port
FROM services WHERE service_id = ?;
""", self.id)
if len(res) == 0: raise ServiceNotFoundException()
self.proxy.internal_port = res[0]["internal_port"]
self.proxy.public_port = res[0]["public_port"]
def _update_filters_from_db(self):
res = self.db.query("""
SELECT
regex, mode, regex_id `id`, is_blacklist,
blocked_packets n_packets, is_case_sensitive
FROM regexes WHERE service_id = ? AND active=1;
""", self.id)
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f["id"] for f in res])
#remove old filters
for f in old_filters:
if not f in new_filters:
del self.filters[f]
for f in new_filters:
if not f in old_filters:
filter_info = [ele for ele in res if ele["id"] == f][0]
self.filters[f] = Filter(
is_case_sensitive=filter_info["is_case_sensitive"],
c_to_s=filter_info["mode"] in ["C","B"],
s_to_c=filter_info["mode"] in ["S","B"],
is_blacklist=filter_info["is_blacklist"],
regex=b64decode(filter_info["regex"]),
blocked_packets=filter_info["n_packets"],
code=f
)
self.proxy.filters = list(self.filters.values())
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.id)
async def next(self,to):
async with self.lock:
return await self._next(to)
async def _next(self, to):
if self.status != to:
# ACTIVE -> PAUSE or PAUSE -> ACTIVE
if (self.status, to) in [(STATUS.ACTIVE, STATUS.PAUSE)]:
await self.proxy.pause()
self._set_status(to)
elif (self.status, to) in [(STATUS.PAUSE, STATUS.ACTIVE)]:
await self.proxy.reload()
self._set_status(to)
# ACTIVE -> STOP
elif (self.status,to) in [(STATUS.ACTIVE, STATUS.STOP), (STATUS.WAIT, STATUS.STOP), (STATUS.PAUSE, STATUS.STOP)]: #Stop proxy
if self.starter: self.starter.cancel()
await self.proxy.stop()
self._set_status(to)
# STOP -> ACTIVE or STOP -> PAUSE
elif (self.status, to) in [(STATUS.STOP, STATUS.ACTIVE), (STATUS.STOP, STATUS.PAUSE)]:
self.wanted_status = to
self._set_status(STATUS.WAIT)
self.__proxy_starter(to)
def _stats_updater(self,filter:Filter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
async def update_port(self):
async with self.lock:
self._update_port_from_db()
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
next_status = self.status if self.status != STATUS.WAIT else self.wanted_status
await self._next(STATUS.STOP)
await self._next(next_status)
def _set_status(self,status):
self.status = status
self.__update_status_db(status)
async def update_filters(self):
async with self.lock:
self._update_filters_from_db()
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
await self.proxy.reload()
def __proxy_starter(self,to):
async def func():
try:
while True:
if check_port_is_open(self.proxy.public_port):
self._set_status(to)
await self.proxy.start(in_pause=(to==STATUS.PAUSE))
self._set_status(STATUS.STOP)
return
else:
await asyncio.sleep(.5)
except asyncio.CancelledError:
self._set_status(STATUS.STOP)
await self.proxy.stop()
self.starter = asyncio.create_task(func())
class ProxyManager:
def __init__(self, db:SQLite):
self.db = db
self.proxy_table:dict = {}
self.lock = asyncio.Lock()
async def close(self):
for key in list(self.proxy_table.keys()):
await self.remove(key)
async def remove(self,id):
async with self.lock:
if id in self.proxy_table:
await self.proxy_table[id].next(STATUS.STOP)
del self.proxy_table[id]
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT service_id, status FROM services;'):
srv_id, req_status = srv["service_id"], srv["status"]
if srv_id in self.proxy_table:
continue
self.proxy_table[srv_id] = ServiceManager(srv_id,self.db)
await self.proxy_table[srv_id].next(req_status)
def get(self,id):
if id in self.proxy_table:
return self.proxy_table[id]
else:
raise ServiceNotFoundException()
def check_port_is_open(port):
try:
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('0.0.0.0',port))
sock.close()
return True
except Exception:
return False
def gen_service_id(db):
while True:
res = secrets.token_hex(8)
if len(db.query('SELECT 1 FROM services WHERE service_id = ?;', res)) == 0:
break
return res
def gen_internal_port(db):
while True:
res = random.randint(30000, 45000)
if len(db.query('SELECT 1 FROM services WHERE internal_port = ?;', res)) == 0:
break
return res

497
backend/nfqueue/proxy.cpp Normal file
View File

@@ -0,0 +1,497 @@
/*
Copyright (c) 2007 Arash Partow (http://www.partow.net)
URL: http://www.partow.net/programming/tcpproxy/index.html
Modified and adapted by Pwnzer0tt1
*/
#include <cstdlib>
#include <cstddef>
#include <iostream>
#include <string>
#include <mutex>
#include <boost/thread.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/enable_shared_from_this.hpp>
#include <boost/bind/bind.hpp>
#include <boost/asio.hpp>
#include <boost/thread/mutex.hpp>
#include <jpcre2.hpp>
typedef jpcre2::select<char> jp;
using namespace std;
bool unhexlify(string const &hex, string &newString) {
try{
int len = hex.length();
for(int i=0; i< len; i+=2)
{
std::string byte = hex.substr(i,2);
char chr = (char) (int)strtol(byte.c_str(), NULL, 16);
newString.push_back(chr);
}
return true;
}
catch (...){
return false;
}
}
typedef pair<string,jp::Regex> regex_rule_pair;
typedef vector<regex_rule_pair> regex_rule_vector;
struct regex_rules{
regex_rule_vector regex_s_c_w, regex_c_s_w, regex_s_c_b, regex_c_s_b;
regex_rule_vector* getByCode(char code){
switch(code){
case 'C': // Client to server Blacklist
return &regex_c_s_b; break;
case 'c': // Client to server Whitelist
return &regex_c_s_w; break;
case 'S': // Server to client Blacklist
return &regex_s_c_b; break;
case 's': // Server to client Whitelist
return &regex_s_c_w; break;
}
throw invalid_argument( "Expected 'C' 'c' 'S' or 's'" );
}
void add(const char* arg){
//Integrity checks
size_t arg_len = strlen(arg);
if (arg_len < 2 || arg_len%2 != 0) return;
if (arg[0] != '0' && arg[0] != '1') return;
if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's') return;
string hex(arg+2), expr;
if (!unhexlify(hex, expr)) return;
//Push regex
jp::Regex regex(expr,arg[0] == '1'?"gS":"giS");
if (regex){
#ifdef DEBUG
cerr << "Added regex " << expr << " " << arg << endl;
#endif
getByCode(arg[1])->push_back(make_pair(string(arg), regex));
} else {
cerr << "Regex " << arg << " was not compiled successfully" << endl;
}
}
};
shared_ptr<regex_rules> regex_config;
mutex update_mutex;
#ifdef MULTI_THREAD
mutex stdout_mutex;
#endif
bool filter_data(unsigned char* data, const size_t& bytes_transferred, regex_rule_vector const &blacklist, regex_rule_vector const &whitelist){
#ifdef DEBUG_PACKET
cerr << "---------------- Packet ----------------" << endl;
for(int i=0;i<bytes_transferred;i++) cerr << data[i];
cerr << endl;
for(int i=0;i<bytes_transferred;i++) fprintf(stderr, "%x", data[i]);
cerr << endl;
cerr << "---------------- End Packet ----------------" << endl;
#endif
string str_data((char *) data, bytes_transferred);
for (regex_rule_pair ele:blacklist){
try{
if(ele.second.match(str_data)){
#ifdef MULTI_THREAD
std::unique_lock<std::mutex> lck(stdout_mutex);
#endif
cout << "BLOCKED " << ele.first << endl;
return false;
}
} catch(...){
cerr << "Error while matching regex: " << ele.first << endl;
}
}
for (regex_rule_pair ele:whitelist){
try{
if(!ele.second.match(str_data)){
#ifdef MULTI_THREAD
std::unique_lock<std::mutex> lck(stdout_mutex);
#endif
cout << "BLOCKED " << ele.first << endl;
return false;
}
} catch(...){
cerr << "Error while matching regex: " << ele.first << endl;
}
}
#ifdef DEBUG
cerr << "Packet Accepted!" << endl;
#endif
return true;
}
namespace tcp_proxy
{
namespace ip = boost::asio::ip;
class bridge : public boost::enable_shared_from_this<bridge>
{
public:
typedef ip::tcp::socket socket_type;
typedef boost::shared_ptr<bridge> ptr_type;
bridge(boost::asio::io_context& ios)
: downstream_socket_(ios),
upstream_socket_ (ios),
thread_safety(ios)
{}
socket_type& downstream_socket()
{
// Client socket
return downstream_socket_;
}
socket_type& upstream_socket()
{
// Remote server socket
return upstream_socket_;
}
void start(const string& upstream_host, unsigned short upstream_port)
{
// Attempt connection to remote server (upstream side)
upstream_socket_.async_connect(
ip::tcp::endpoint(
boost::asio::ip::address::from_string(upstream_host),
upstream_port),
boost::asio::bind_executor(thread_safety,
boost::bind(
&bridge::handle_upstream_connect,
shared_from_this(),
boost::asio::placeholders::error)));
}
void handle_upstream_connect(const boost::system::error_code& error)
{
if (!error)
{
// Setup async read from remote server (upstream)
upstream_socket_.async_read_some(
boost::asio::buffer(upstream_data_,max_data_length),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_upstream_read,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
// Setup async read from client (downstream)
downstream_socket_.async_read_some(
boost::asio::buffer(downstream_data_,max_data_length),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_downstream_read,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
close();
}
private:
/*
Section A: Remote Server --> Proxy --> Client
Process data recieved from remote sever then send to client.
*/
// Read from remote server complete, now send data to client
void handle_upstream_read(const boost::system::error_code& error,
const size_t& bytes_transferred) // Da Server a Client
{
if (!error)
{
shared_ptr<regex_rules> regex_old_config = regex_config;
if (filter_data(upstream_data_, bytes_transferred, regex_old_config->regex_s_c_b, regex_old_config->regex_s_c_w)){
async_write(downstream_socket_,
boost::asio::buffer(upstream_data_,bytes_transferred),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_downstream_write,
shared_from_this(),
boost::asio::placeholders::error)));
}else{
close();
}
}
else
close();
}
// Write to client complete, Async read from remote server
void handle_downstream_write(const boost::system::error_code& error)
{
if (!error)
{
upstream_socket_.async_read_some(
boost::asio::buffer(upstream_data_,max_data_length),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_upstream_read,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
close();
}
// *** End Of Section A ***
/*
Section B: Client --> Proxy --> Remove Server
Process data recieved from client then write to remove server.
*/
// Read from client complete, now send data to remote server
void handle_downstream_read(const boost::system::error_code& error,
const size_t& bytes_transferred) // Da Client a Server
{
if (!error)
{
shared_ptr<regex_rules> regex_old_config = regex_config;
if (filter_data(downstream_data_, bytes_transferred, regex_old_config->regex_c_s_b, regex_old_config->regex_c_s_w)){
async_write(upstream_socket_,
boost::asio::buffer(downstream_data_,bytes_transferred),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_upstream_write,
shared_from_this(),
boost::asio::placeholders::error)));
}else{
close();
}
}
else
close();
}
// Write to remote server complete, Async read from client
void handle_upstream_write(const boost::system::error_code& error)
{
if (!error)
{
downstream_socket_.async_read_some(
boost::asio::buffer(downstream_data_,max_data_length),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_downstream_read,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
close();
}
// *** End Of Section B ***
void close()
{
boost::mutex::scoped_lock lock(mutex_);
if (downstream_socket_.is_open())
{
downstream_socket_.close();
}
if (upstream_socket_.is_open())
{
upstream_socket_.close();
}
}
socket_type downstream_socket_;
socket_type upstream_socket_;
enum { max_data_length = 8192 }; //8KB
unsigned char downstream_data_[max_data_length];
unsigned char upstream_data_ [max_data_length];
boost::asio::io_context::strand thread_safety;
boost::mutex mutex_;
public:
class acceptor
{
public:
acceptor(boost::asio::io_context& io_context,
const string& local_host, unsigned short local_port,
const string& upstream_host, unsigned short upstream_port)
: io_context_(io_context),
localhost_address(boost::asio::ip::address_v4::from_string(local_host)),
acceptor_(io_context_,ip::tcp::endpoint(localhost_address,local_port)),
upstream_port_(upstream_port),
upstream_host_(upstream_host)
{}
bool accept_connections()
{
try
{
session_ = boost::shared_ptr<bridge>(new bridge(io_context_));
acceptor_.async_accept(session_->downstream_socket(),
boost::asio::bind_executor(session_->thread_safety,
boost::bind(&acceptor::handle_accept,
this,
boost::asio::placeholders::error)));
}
catch(exception& e)
{
cerr << "acceptor exception: " << e.what() << endl;
return false;
}
return true;
}
private:
void handle_accept(const boost::system::error_code& error)
{
if (!error)
{
session_->start(upstream_host_,upstream_port_);
if (!accept_connections())
{
cerr << "Failure during call to accept." << endl;
}
}
else
{
cerr << "Error: " << error.message() << endl;
}
}
boost::asio::io_context& io_context_;
ip::address_v4 localhost_address;
ip::tcp::acceptor acceptor_;
ptr_type session_;
unsigned short upstream_port_;
string upstream_host_;
};
};
}
void update_config (boost::asio::streambuf &input_buffer){
#ifdef DEBUG
cerr << "Updating configuration" << endl;
#endif
std::istream config_stream(&input_buffer);
std::unique_lock<std::mutex> lck(update_mutex);
regex_rules *regex_new_config = new regex_rules();
string data;
while(true){
config_stream >> data;
if (config_stream.eof()) break;
regex_new_config->add(data.c_str());
}
regex_config.reset(regex_new_config);
}
class async_updater
{
public:
async_updater(boost::asio::io_context& io_context) : input_(io_context, ::dup(STDIN_FILENO)), thread_safety(io_context)
{
boost::asio::async_read_until(input_, input_buffer_, '\n',
boost::asio::bind_executor(thread_safety,
boost::bind(&async_updater::on_update, this,
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
void on_update(const boost::system::error_code& error, std::size_t length)
{
if (!error)
{
update_config(input_buffer_);
boost::asio::async_read_until(input_, input_buffer_, '\n',
boost::asio::bind_executor(thread_safety,
boost::bind(&async_updater::on_update, this,
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
{
close();
}
}
void close()
{
input_.close();
}
private:
boost::asio::posix::stream_descriptor input_;
boost::asio::io_context::strand thread_safety;
boost::asio::streambuf input_buffer_;
};
int main(int argc, char* argv[])
{
if (argc < 5)
{
cerr << "usage: tcpproxy_server <local host ip> <local port> <forward host ip> <forward port>" << endl;
return 1;
}
const unsigned short local_port = static_cast<unsigned short>(::atoi(argv[2]));
const unsigned short forward_port = static_cast<unsigned short>(::atoi(argv[4]));
const string local_host = argv[1];
const string forward_host = argv[3];
boost::asio::io_context ios;
boost::asio::streambuf buf;
boost::asio::posix::stream_descriptor cin_in(ios, ::dup(STDIN_FILENO));
boost::asio::read_until(cin_in, buf,'\n');
update_config(buf);
async_updater updater(ios);
#ifdef DEBUG
cerr << "Starting Proxy" << endl;
#endif
try
{
tcp_proxy::bridge::acceptor acceptor(ios,
local_host, local_port,
forward_host, forward_port);
acceptor.accept_connections();
#ifdef MULTI_THREAD
boost::thread_group tg;
#ifdef THREAD_NUM
for (unsigned i = 0; i < THREAD_NUM; ++i)
#else
for (unsigned i = 0; i < thread::hardware_concurrency(); ++i)
#endif
tg.create_thread(boost::bind(&boost::asio::io_context::run, &ios));
tg.join_all();
#else
ios.run();
#endif
}
catch(exception& e)
{
cerr << "Error: " << e.what() << endl;
return 1;
}
#ifdef DEBUG
cerr << "Proxy stopped!" << endl;
#endif
return 0;
}

BIN
backend/regextcpproxy.db Normal file

Binary file not shown.

View File

@@ -1,13 +1,14 @@
from base64 import b64decode
import re
import secrets
import sqlite3
from typing import List, Union
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from modules.firegex import FiregexTables
from modules.firewall import STATUS, FirewallManager
from modules.sqlite import SQLite
from utils import gen_service_id, ip_parse, refactor_name, refresh_frontend
from modules.nfregex.firegex import FiregexTables
from modules.nfregex.firewall import STATUS, FirewallManager
from utils.sqlite import SQLite
from utils import ip_parse, refactor_name, refresh_frontend
from utils.models import ResetRequest, StatusMessageModel
class GeneralStatModel(BaseModel):
@@ -58,29 +59,6 @@ class ServiceAddResponse(BaseModel):
app = APIRouter()
async def reset(params: ResetRequest):
if not params.delete:
db.backup()
await firewall.close()
FiregexTables().reset()
if params.delete:
db.delete()
db.init()
else:
db.restore()
await firewall.init()
async def startup():
db.init()
await firewall.init()
async def shutdown():
db.backup()
await firewall.close()
db.disconnect()
db.restore()
db = SQLite('db/nft-regex.db', {
'services': {
'service_id': 'VARCHAR(100) PRIMARY KEY',
@@ -107,6 +85,36 @@ db = SQLite('db/nft-regex.db', {
]
})
async def reset(params: ResetRequest):
if not params.delete:
db.backup()
await firewall.close()
FiregexTables().reset()
if params.delete:
db.delete()
db.init()
else:
db.restore()
await firewall.init()
async def startup():
db.init()
await firewall.init()
async def shutdown():
db.backup()
await firewall.close()
db.disconnect()
db.restore()
def gen_service_id():
while True:
res = secrets.token_hex(8)
if len(db.query('SELECT 1 FROM services WHERE service_id = ?;', res)) == 0:
break
return res
firewall = FirewallManager(db)
@app.get('/stats', response_model=GeneralStatModel)
@@ -271,7 +279,7 @@ async def add_new_service(form: ServiceAddForm, ):
return {"status":"Invalid protocol"}
srv_id = None
try:
srv_id = gen_service_id(db)
srv_id = gen_service_id()
db.query("INSERT INTO services (service_id ,name, port, status, proto, ip_int) VALUES (?, ?, ?, ?, ?, ?)",
srv_id, refactor_name(form.name), form.port, STATUS.STOP, form.proto, form.ip_int)
except sqlite3.IntegrityError:

View File

@@ -0,0 +1,297 @@
from base64 import b64decode
import sqlite3, re
from typing import List, Union
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from modules.regexproxy.utils import STATUS, ProxyManager, gen_internal_port, gen_service_id
from utils.sqlite import SQLite
from utils.models import ResetRequest, StatusMessageModel
from utils.firegextables import FiregexTables
app = APIRouter()
db = SQLite("regextcpproxy.db",{
'services': {
'status': 'VARCHAR(100) NOT NULL',
'service_id': 'VARCHAR(100) PRIMARY KEY',
'internal_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536)',
'public_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE',
'name': 'VARCHAR(100) NOT NULL'
},
'regexes': {
'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL',
'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'regex_id': 'INTEGER PRIMARY KEY',
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
'active' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1)) DEFAULT 1',
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);"
]
})
firewall = ProxyManager(db)
async def reset(params: ResetRequest):
if not params.delete:
db.backup()
await firewall.close()
FiregexTables().reset()
if params.delete:
db.delete()
db.init()
else:
db.restore()
await firewall.reload()
async def startup():
db.init()
await firewall.reload()
async def shutdown():
db.backup()
await firewall.close()
db.disconnect()
db.restore()
class GeneralStatModel(BaseModel):
closed:int
regexes: int
services: int
@app.get('/stats', response_model=GeneralStatModel)
async def get_general_stats():
"""Get firegex general status about services"""
return db.query("""
SELECT
(SELECT COALESCE(SUM(blocked_packets),0) FROM regexes) closed,
(SELECT COUNT(*) FROM regexes) regexes,
(SELECT COUNT(*) FROM services) services
""")[0]
class ServiceModel(BaseModel):
id:str
status: str
public_port: int
internal_port: int
name: str
n_regex: int
n_packets: int
@app.get('/services', response_model=List[ServiceModel])
async def get_service_list():
"""Get the list of existent firegex services"""
return db.query("""
SELECT
s.service_id `id`,
s.status status,
s.public_port public_port,
s.internal_port internal_port,
s.name name,
COUNT(r.regex_id) n_regex,
COALESCE(SUM(r.blocked_packets),0) n_packets
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id
GROUP BY s.service_id;
""")
@app.get('/service/{service_id}', response_model=ServiceModel)
async def get_service_by_id(service_id: str, ):
"""Get info about a specific service using his id"""
res = db.query("""
SELECT
s.service_id `id`,
s.status status,
s.public_port public_port,
s.internal_port internal_port,
s.name name,
COUNT(r.regex_id) n_regex,
COALESCE(SUM(r.blocked_packets),0) n_packets
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id WHERE s.service_id = ?
GROUP BY s.service_id;
""", service_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
return res[0]
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
async def service_stop(service_id: str, ):
"""Request the stop of a specific service"""
await firewall.get(service_id).next(STATUS.STOP)
return {'status': 'ok'}
@app.get('/service/{service_id}/pause', response_model=StatusMessageModel)
async def service_pause(service_id: str, ):
"""Request the pause of a specific service"""
await firewall.get(service_id).next(STATUS.PAUSE)
return {'status': 'ok'}
@app.get('/service/{service_id}/start', response_model=StatusMessageModel)
async def service_start(service_id: str, ):
"""Request the start of a specific service"""
await firewall.get(service_id).next(STATUS.ACTIVE)
return {'status': 'ok'}
@app.get('/service/{service_id}/delete', response_model=StatusMessageModel)
async def service_delete(service_id: str, ):
"""Request the deletion of a specific service"""
db.query('DELETE FROM services WHERE service_id = ?;', service_id)
db.query('DELETE FROM regexes WHERE service_id = ?;', service_id)
await firewall.remove(service_id)
return {'status': 'ok'}
@app.get('/service/{service_id}/regen-port', response_model=StatusMessageModel)
async def regen_service_port(service_id: str, ):
"""Request the regeneration of a the internal proxy port of a specific service"""
db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', gen_internal_port(db), service_id)
await firewall.get(service_id).update_port()
return {'status': 'ok'}
class ChangePortForm(BaseModel):
port: Union[int, None]
internalPort: Union[int, None]
@app.post('/service/{service_id}/change-ports', response_model=StatusMessageModel)
async def change_service_ports(service_id: str, change_port:ChangePortForm ):
"""Choose and change the ports of the service"""
if change_port.port is None and change_port.internalPort is None:
return {'status': 'Invalid Request!'}
try:
sql_inj = ""
query:List[Union[str,int]] = []
if not change_port.port is None:
sql_inj+=" public_port = ? "
query.append(change_port.port)
if not change_port.port is None and not change_port.internalPort is None:
sql_inj += ","
if not change_port.internalPort is None:
sql_inj+=" internal_port = ? "
query.append(change_port.internalPort)
query.append(service_id)
db.query(f'UPDATE services SET {sql_inj} WHERE service_id = ?;', *query)
except sqlite3.IntegrityError:
return {'status': 'Name or/and port of the service has been already assigned to another service'}
await firewall.get(service_id).update_port()
return {'status': 'ok'}
class RegexModel(BaseModel):
regex:str
mode:str
id:int
service_id:str
is_blacklist: bool
n_packets:int
is_case_sensitive:bool
active:bool
@app.get('/service/{service_id}/regexes', response_model=List[RegexModel])
async def get_service_regexe_list(service_id: str, ):
"""Get the list of the regexes of a service"""
return db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
blocked_packets n_packets, is_case_sensitive, active
FROM regexes WHERE service_id = ?;
""", service_id)
@app.get('/regex/{regex_id}', response_model=RegexModel)
async def get_regex_by_id(regex_id: int, ):
"""Get regex info using his id"""
res = db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
blocked_packets n_packets, is_case_sensitive, active
FROM regexes WHERE `id` = ?;
""", regex_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This regex does not exists!")
return res[0]
@app.get('/regex/{regex_id}/delete', response_model=StatusMessageModel)
async def regex_delete(regex_id: int, ):
"""Delete a regex using his id"""
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
db.query('DELETE FROM regexes WHERE regex_id = ?;', regex_id)
await firewall.get(res[0]["service_id"]).update_filters()
return {'status': 'ok'}
@app.get('/regex/{regex_id}/enable', response_model=StatusMessageModel)
async def regex_enable(regex_id: int, ):
"""Request the enabling of a regex"""
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
db.query('UPDATE regexes SET active=1 WHERE regex_id = ?;', regex_id)
await firewall.get(res[0]["service_id"]).update_filters()
return {'status': 'ok'}
@app.get('/regex/{regex_id}/disable', response_model=StatusMessageModel)
async def regex_disable(regex_id: int, ):
"""Request the deactivation of a regex"""
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
db.query('UPDATE regexes SET active=0 WHERE regex_id = ?;', regex_id)
await firewall.get(res[0]["service_id"]).update_filters()
return {'status': 'ok'}
class RegexAddForm(BaseModel):
service_id: str
regex: str
mode: str
active: Union[bool,None]
is_blacklist: bool
is_case_sensitive: bool
@app.post('/regexes/add', response_model=StatusMessageModel)
async def add_new_regex(form: RegexAddForm, ):
"""Add a new regex"""
try:
re.compile(b64decode(form.regex))
except Exception:
return {"status":"Invalid regex"}
try:
db.query("INSERT INTO regexes (service_id, regex, is_blacklist, mode, is_case_sensitive, active ) VALUES (?, ?, ?, ?, ?, ?);",
form.service_id, form.regex, form.is_blacklist, form.mode, form.is_case_sensitive, True if form.active is None else form.active )
except sqlite3.IntegrityError:
return {'status': 'An identical regex already exists'}
await firewall.get(form.service_id).update_filters()
return {'status': 'ok'}
class ServiceAddForm(BaseModel):
name: str
port: int
internalPort: Union[int, None]
class ServiceAddStatus(BaseModel):
status:str
id: Union[str,None]
class RenameForm(BaseModel):
name:str
@app.post('/service/{service_id}/rename', response_model=StatusMessageModel)
async def service_rename(service_id: str, form: RenameForm, ):
"""Request to change the name of a specific service"""
if not form.name: return {'status': 'The name cannot be empty!'}
db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id)
return {'status': 'ok'}
@app.post('/services/add', response_model=ServiceAddStatus)
async def add_new_service(form: ServiceAddForm, ):
"""Add a new service"""
serv_id = gen_service_id(db)
try:
internal_port = form.internalPort if form.internalPort else gen_internal_port(db)
db.query("INSERT INTO services (name, service_id, internal_port, public_port, status) VALUES (?, ?, ?, ?, ?)",
form.name, serv_id, internal_port, form.port, 'stop')
except sqlite3.IntegrityError:
return {'status': 'Name or/and ports of the service has been already assigned to another service'}
await firewall.reload()
return {'status': 'ok', "id": serv_id }

View File

@@ -1,6 +1,6 @@
import asyncio
from ipaddress import ip_interface
import os, socket, secrets, psutil
import os, socket, psutil
import sys
from fastapi_socketio import SocketManager
@@ -30,13 +30,6 @@ def refactor_name(name:str):
while " " in name: name = name.replace(" "," ")
return name
def gen_service_id(db):
while True:
res = secrets.token_hex(8)
if len(db.query('SELECT 1 FROM services WHERE service_id = ?;', res)) == 0:
break
return res
def list_files(mypath):
from os import listdir
from os.path import isfile, join

View File

@@ -0,0 +1,125 @@
from typing import List
import nftables
from utils import ip_parse, ip_family
class FiregexFilter():
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None):
self.nftables = nftables.Nftables()
self.id = int(id) if id else None
self.queue = queue
self.target = target
self.proto = proto
self.port = int(port)
self.ip_int = str(ip_int)
def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return False
class FiregexTables:
def __init__(self):
self.table_name = "firegex"
self.nft = nftables.Nftables()
def raw_cmd(self, *cmds):
return self.nft.json_cmd({"nftables": list(cmds)})
def cmd(self, *cmds):
code, out, err = self.raw_cmd(*cmds)
if code == 0: return out
else: raise Exception(err)
def init(self):
self.reset()
code, out, err = self.raw_cmd({"create":{"table":{"name":self.table_name,"family":"inet"}}})
if code == 0:
self.cmd(
{"create":{"chain":{
"family":"inet",
"table":self.table_name,
"name":"input",
"type":"filter",
"hook":"prerouting",
"prio":-150,
"policy":"accept"
}}},
{"create":{"chain":{
"family":"inet",
"table":self.table_name,
"name":"output",
"type":"filter",
"hook":"postrouting",
"prio":-150,
"policy":"accept"
}}}
)
def reset(self):
self.raw_cmd(
{"flush":{"table":{"name":"firegex","family":"inet"}}},
{"delete":{"table":{"name":"firegex","family":"inet"}}},
)
def list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"]
def add_output(self, queue_range, proto, port, ip_int):
init, end = queue_range
if init > end: init, end = end, init
ip_int = ip_parse(ip_int)
ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1])
self.cmd({ "insert":{ "rule": {
"family": "inet",
"table": self.table_name,
"chain": "output",
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}},
{'match': {"left": { "payload": {"protocol": str(proto), "field": "sport"}}, "op": "==", "right": int(port)}},
{"queue": {"num": str(init) if init == end else f"{init}-{end}", "flags": ["bypass"]}}
]
}}})
def add_input(self, queue_range, proto = None, port = None, ip_int = None):
init, end = queue_range
if init > end: init, end = end, init
ip_int = ip_parse(ip_int)
ip_addr = str(ip_int).split("/")[0]
ip_addr_cidr = int(str(ip_int).split("/")[1])
self.cmd({"insert":{"rule":{
"family": "inet",
"table": self.table_name,
"chain": "input",
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}},
{'match': {"left": { "payload": {"protocol": str(proto), "field": "dport"}}, "op": "==", "right": int(port)}},
{"queue": {"num": str(init) if init == end else f"{init}-{end}", "flags": ["bypass"]}}
]
}}})
def get(self) -> List[FiregexFilter]:
res = []
for filter in [ele["rule"] for ele in self.list() if "rule" in ele and ele["rule"]["table"] == self.table_name]:
queue_str = str(filter["expr"][2]["queue"]["num"]).split("-")
queue = None
if len(queue_str) == 1: queue = int(queue_str[0]), int(queue_str[0])
else: queue = int(queue_str[0]), int(queue_str[1])
ip_int = None
if isinstance(filter["expr"][0]["match"]["right"],str):
ip_int = str(ip_parse(filter["expr"][0]["match"]["right"]))
else:
ip_int = f'{filter["expr"][0]["match"]["right"]["prefix"]["addr"]}/{filter["expr"][0]["match"]["right"]["prefix"]["len"]}'
res.append(FiregexFilter(
target=filter["chain"],
id=int(filter["handle"]),
queue=queue,
proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"],
port=filter["expr"][1]["match"]["right"],
ip_int=ip_int
))
return res

View File

@@ -93,33 +93,3 @@ class SQLite():
self.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value))
else:
self.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key)
class Service:
def __init__(self, id: str, status: str, port: int, name: str, proto: str, ip_int: str):
self.id = id
self.status = status
self.port = port
self.name = name
self.proto = proto
self.ip_int = ip_int
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["service_id"], status=var["status"], port=var["port"], name=var["name"], proto=var["proto"], ip_int=var["ip_int"])
class Regex:
def __init__(self, id: int, regex: bytes, mode: str, service_id: str, is_blacklist: bool, blocked_packets: int, is_case_sensitive: bool, active: bool):
self.regex = regex
self.mode = mode
self.service_id = service_id
self.is_blacklist = is_blacklist
self.blocked_packets = blocked_packets
self.id = id
self.is_case_sensitive = is_case_sensitive
self.active = active
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["regex_id"], regex=base64.b64decode(var["regex"]), mode=var["mode"], service_id=var["service_id"], is_blacklist=var["is_blacklist"], blocked_packets=var["blocked_packets"], is_case_sensitive=var["is_case_sensitive"], active=var["active"])