Refactor code pt 1 (not tested)

This commit is contained in:
DomySh
2022-07-12 20:18:54 +02:00
parent 632d0a1b12
commit 1e94c26fd6
9 changed files with 597 additions and 534 deletions

5
backend/.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,5 @@
{
"python.linting.pylintEnabled": false,
"python.linting.mypyEnabled": true,
"python.linting.enabled": true
}

View File

@@ -5,12 +5,14 @@ from typing import List, Union
from fastapi import FastAPI, HTTPException, WebSocket, Depends from fastapi import FastAPI, HTTPException, WebSocket, Depends
from pydantic import BaseModel, BaseSettings from pydantic import BaseModel, BaseSettings
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
from utils import *
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt from jose import JWTError, jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from fastapi_socketio import SocketManager from fastapi_socketio import SocketManager
from ipaddress import ip_interface from ipaddress import ip_interface
from modules import SQLite, FirewallManager
from modules.firewall import STATUS
from utils import refactor_name, gen_service_id
ON_DOCKER = len(sys.argv) > 1 and sys.argv[1] == "DOCKER" ON_DOCKER = len(sys.argv) > 1 and sys.argv[1] == "DOCKER"
DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG" DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG"
@@ -18,8 +20,7 @@ DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG"
# DB init # DB init
if not os.path.exists("db"): os.mkdir("db") if not os.path.exists("db"): os.mkdir("db")
db = SQLite('db/firegex.db') db = SQLite('db/firegex.db')
conf = KeyValueStorage(db) firewall = FirewallManager(db)
firewall = ProxyManager(db)
class Settings(BaseSettings): class Settings(BaseSettings):
JWT_ALGORITHM: str = "HS256" JWT_ALGORITHM: str = "HS256"
@@ -35,8 +36,8 @@ crypto = CryptContext(schemes=["bcrypt"], deprecated="auto")
app = FastAPI(debug=DEBUG, redoc_url=None) app = FastAPI(debug=DEBUG, redoc_url=None)
sio = SocketManager(app, "/sock", socketio_path="") sio = SocketManager(app, "/sock", socketio_path="")
def APP_STATUS(): return "init" if conf.get("password") is None else "run" def APP_STATUS(): return "init" if db.get("password") is None else "run"
def JWT_SECRET(): return conf.get("secret") def JWT_SECRET(): return db.get("secret")
async def refresh_frontend(): async def refresh_frontend():
await sio.emit("update","Refresh") await sio.emit("update","Refresh")
@@ -49,7 +50,7 @@ async def startup_event():
db.init() db.init()
await firewall.init(refresh_frontend) await firewall.init(refresh_frontend)
await refresh_frontend() await refresh_frontend()
if not JWT_SECRET(): conf.put("secret", secrets.token_hex(32)) if not JWT_SECRET(): db.put("secret", secrets.token_hex(32))
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
@@ -108,7 +109,7 @@ async def login_api(form: OAuth2PasswordRequestForm = Depends()):
if form.password == "": if form.password == "":
return {"status":"Cannot insert an empty password!"} return {"status":"Cannot insert an empty password!"}
await asyncio.sleep(0.3) # No bruteforce :) await asyncio.sleep(0.3) # No bruteforce :)
if crypto.verify(form.password, conf.get("password")): if crypto.verify(form.password, db.get("password")):
return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"} return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"}
raise HTTPException(406,"Wrong password!") raise HTTPException(406,"Wrong password!")
@@ -124,10 +125,10 @@ async def change_password(form: PasswordChangeForm, auth: bool = Depends(is_logg
if form.password == "": if form.password == "":
return {"status":"Cannot insert an empty password!"} return {"status":"Cannot insert an empty password!"}
if form.expire: if form.expire:
conf.put("secret", secrets.token_hex(32)) db.put("secret", secrets.token_hex(32))
hash_psw = crypto.hash(form.password) hash_psw = crypto.hash(form.password)
conf.put("password",hash_psw) db.put("password",hash_psw)
await refresh_frontend() await refresh_frontend()
return {"status":"ok", "access_token": create_access_token({"logged_in": True})} return {"status":"ok", "access_token": create_access_token({"logged_in": True})}
@@ -139,7 +140,7 @@ async def set_password(form: PasswordForm):
if form.password == "": if form.password == "":
return {"status":"Cannot insert an empty password!"} return {"status":"Cannot insert an empty password!"}
hash_psw = crypto.hash(form.password) hash_psw = crypto.hash(form.password)
conf.put("password",hash_psw) db.put("password",hash_psw)
await refresh_frontend() await refresh_frontend()
return {"status":"ok", "access_token": create_access_token({"logged_in": True})} return {"status":"ok", "access_token": create_access_token({"logged_in": True})}

View File

@@ -0,0 +1,2 @@
from .firewall import FirewallManager
from .sqlite import SQLite

170
backend/modules/firegex.py Normal file
View File

@@ -0,0 +1,170 @@
from typing import List
from pypacker import interceptor
from pypacker.layer3 import ip, ip6
from pypacker.layer4 import tcp, udp
from ipaddress import ip_interface
from modules.iptables import IPTables
import os, traceback
from modules.sqlite import Service
class FilterTypes:
INPUT = "FIREGEX-INPUT"
OUTPUT = "FIREGEX-OUTPUT"
QUEUE_BASE_NUM = 1000
class FiregexFilter():
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None, func=None):
self.target = target
self.id = int(id) if id else None
self.queue = queue
self.proto = proto
self.port = int(port)
self.ip_int = str(ip_int)
self.func = func
def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter):
return self.port == o.port and self.proto == o.proto and ip_interface(self.ip_int) == ip_interface(o.ip_int)
return False
def ipv6(self):
return ip_interface(self.ip_int).version == 6
def ipv4(self):
return ip_interface(self.ip_int).version == 4
def input_func(self):
def none(pkt): return True
def wrap(pkt): return self.func(pkt, True)
return wrap if self.func else none
def output_func(self):
def none(pkt): return True
def wrap(pkt): return self.func(pkt, False)
return wrap if self.func else none
class FiregexTables(IPTables):
def __init__(self, ipv6=False):
super().__init__(ipv6, "mangle")
self.create_chain(FilterTypes.INPUT)
self.add_chain_to_input(FilterTypes.INPUT)
self.create_chain(FilterTypes.OUTPUT)
self.add_chain_to_output(FilterTypes.OUTPUT)
def target_in_chain(self, chain, target):
for filter in self.list()[chain]:
if filter.target == target:
return True
return False
def add_chain_to_input(self, chain):
if not self.target_in_chain("PREROUTING", str(chain)):
self.insert_rule("PREROUTING", str(chain))
def add_chain_to_output(self, chain):
if not self.target_in_chain("POSTROUTING", str(chain)):
self.insert_rule("POSTROUTING", str(chain))
def add_output(self, queue_range, proto = None, port = None, ip_int = None):
init, end = queue_range
if init > end: init, end = end, init
self.append_rule(FilterTypes.OUTPUT,"NFQUEUE"
* (["-p", str(proto)] if proto else []),
* (["-s", str(ip_int)] if ip_int else []),
* (["--sport", str(port)] if port else []),
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
"--queue-bypass"
)
def add_input(self, queue_range, proto = None, port = None, ip_int = None):
init, end = queue_range
if init > end: init, end = end, init
self.append_rule(FilterTypes.INPUT, "NFQUEUE",
* (["-p", str(proto)] if proto else []),
* (["-d", str(ip_int)] if ip_int else []),
* (["--dport", str(port)] if port else []),
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
"--queue-bypass"
)
def get(self) -> List[FiregexFilter]:
res = []
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
for filter in self.list()[filter_type]:
port = filter.sport() if filter_type == FilterTypes.OUTPUT else filter.dport()
queue = filter.nfqueue()
if queue and port:
res.append(FiregexFilter(
target=filter_type,
id=filter.id,
queue=queue,
proto=filter.prot,
port=port,
ip_int=filter.source if filter_type == FilterTypes.OUTPUT else filter.destination
))
return res
def add(self, filter:FiregexFilter):
if filter in self.get(): return None
return FiregexInterceptor( iptables=self, filter=filter, n_threads=int(os.getenv("N_THREADS_NFQUEUE","1")))
def delete_all(self):
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
self.flush_chain(filter_type)
def delete_by_srv(self, srv:Service):
for filter in self.get():
if filter.port == srv.port and filter.proto == srv.proto and ip_interface(filter.ip_int) == ip_interface(srv.ip_int):
self.delete_rule(filter.target, filter.id)
class FiregexInterceptor:
def __init__(self, iptables: FiregexTables, filter: FiregexFilter, n_threads:int = 1):
self.filter = filter
self.ipv6 = self.filter.ipv6()
self.itor_input, codes = self._start_queue(filter.input_func(), n_threads)
iptables.add_input(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
self.itor_output, codes = self._start_queue(filter.output_func(), n_threads)
iptables.add_output(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
def _start_queue(self,func,n_threads):
def func_wrap(ll_data, ll_proto_id, data, ctx, *args):
pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data)
try:
data = None
if not pkt_parsed[tcp.TCP] is None:
data = pkt_parsed[tcp.TCP].body_bytes
if not pkt_parsed[tcp.TCP] is None:
data = pkt_parsed[udp.UDP].body_bytes
if data:
if func(data):
return data, interceptor.NF_ACCEPT
elif pkt_parsed[tcp.TCP]:
pkt_parsed[tcp.TCP].flags &= 0x00
pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK
pkt_parsed[tcp.TCP].body_bytes = b""
return pkt_parsed.bin(), interceptor.NF_ACCEPT
else: return b"", interceptor.NF_DROP
else: return data, interceptor.NF_ACCEPT
except Exception:
traceback.print_exc()
return data, interceptor.NF_ACCEPT
ictor = interceptor.Interceptor()
starts = QUEUE_BASE_NUM
while True:
if starts >= 65536:
raise Exception("Netfilter queue is full!")
queue_ids = list(range(starts,starts+n_threads))
try:
ictor.start(func_wrap, queue_ids=queue_ids)
break
except interceptor.UnableToBindException as e:
starts = e.queue_id + 1
return ictor, (starts, starts+n_threads-1)
def stop(self):
self.itor_input.stop()
self.itor_output.stop()

196
backend/modules/firewall.py Normal file
View File

@@ -0,0 +1,196 @@
import traceback, asyncio, pcre
from typing import Dict
from modules.firegex import FiregexFilter, FiregexTables
from modules.sqlite import Regex, SQLite, Service
class STATUS:
STOP = "stop"
ACTIVE = "active"
class FirewallManager:
def __init__(self, db:SQLite):
self.db = db
self.proxy_table: Dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
self.updater_task = None
def init_updater(self, callback = None):
if not self.updater_task:
self.updater_task = asyncio.create_task(self._stats_updater(callback))
def close_updater(self):
if self.updater_task: self.updater_task.cancel()
async def close(self):
self.close_updater()
if self.updater_task: self.updater_task.cancel()
for key in list(self.proxy_table.keys()):
await self.remove(key)
async def remove(self,srv_id):
async with self.lock:
if srv_id in self.proxy_table:
await self.proxy_table[srv_id].next(STATUS.STOP)
del self.proxy_table[srv_id]
async def init(self, callback = None):
self.init_updater(callback)
await self.reload()
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv)
if srv.id in self.proxy_table:
continue
self.proxy_table[srv.id] = ServiceManager(srv, self.db)
await self.proxy_table[srv.id].next(srv["status"])
async def _stats_updater(self, callback):
try:
while True:
try:
for key in list(self.proxy_table.keys()):
self.proxy_table[key].update_stats()
except Exception:
traceback.print_exc()
if callback:
if asyncio.iscoroutinefunction(callback): await callback()
else: callback()
await asyncio.sleep(5)
except asyncio.CancelledError:
self.updater_task = None
return
def get(self,srv_id):
if srv_id in self.proxy_table:
return self.proxy_table[srv_id]
else:
raise ServiceNotFoundException()
class ServiceNotFoundException(Exception): pass
class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
is_blacklist=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
id=None
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
self.id = id
self.compiled_regex = self.compile()
@classmethod
def from_regex(cls, regex:Regex):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"]
)
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
def check(self, data):
return True if self.compiled_regex.search(data) else False
class ServiceManager:
def __init__(self, srv: Service, db):
self.srv = srv
self.db = db
self.iptables = FiregexTables(self.srv.ipv6)
self.status = STATUS.STOP
self.filters: Dict[int, FiregexFilter] = {}
self._update_filters_from_db()
self.lock = asyncio.Lock()
self.interceptor = None
# TODO I don't like so much this method
def _update_filters_from_db(self):
regexes = [
Regex.from_dict(ele) for ele in
self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id)
]
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f.id for f in regexes])
#remove old filters
for f in old_filters:
if not f in new_filters:
del self.filters[f]
#add new filters
for f in new_filters:
if not f in old_filters:
filter = [ele for ele in regexes if ele.id == f][0]
self.filters[f] = FiregexFilter.from_regex(filter)
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv["service_id"])
async def next(self,to):
async with self.lock:
return self._next(to)
def _next(self, to):
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP):
self.proxy.stop()
self._set_status(to)
# PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
self.proxy.restart()
def _stats_updater(self,filter:RegexFilter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id)
def update_stats(self):
for ele in self.filters.values():
self._stats_updater(ele)
def _set_status(self,status):
self.status = status
self.__update_status_db(status)
def start(self):
if not self.interceptor:
self.iptables.delete_by_srv(self.srv)
def regex_filter(pkt, by_client):
try:
for filter in self.filters.values():
if (by_client and filter.input_mode) or (not by_client and filter.output_mode):
match = filter.check(pkt)
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
filter.blocked+=1
return False
except IndexError: pass
return True
self.interceptor = self.iptables.add(self.srv["proto"], self.srv["port"], self.srv["ip_int"], regex_filter)
self._set_status(STATUS.ACTIVE)
def stop(self):
self.iptables.delete_by_srv(self.srv)
if self.interceptor:
self.interceptor.stop()
self.interceptor = None
def restart(self):
self.stop()
self.start()
async def update_filters(self):
async with self.lock:
self._update_filters_from_db()

View File

@@ -0,0 +1,82 @@
import os, re
from subprocess import PIPE, Popen
from typing import Dict, List, Tuple, Union
class Rule():
def __init__(self, id, target, prot, opt, source, destination, details):
self.id = id
self.target = target
self.prot = prot
self.opt = opt
self.source = source
self.destination = destination
self.details = details
def dport(self) -> Union[int, None]:
port = re.findall(r"dpt:([0-9]+)", self.details)
return int(port[0]) if port else None
def sport(self) -> Union[int, None]:
port = re.findall(r"spt:([0-9]+)", self.details)
return int(port[0]) if port else None
def nfqueue(self) -> Union[Tuple[int,int], None]:
balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", self.details)
numbered = re.findall(r"NFQUEUE num ([0-9]+)", self.details)
queue_num = None
if balanced: queue_num = (int(balanced[0][0]), int(balanced[0][1]))
if numbered: queue_num = (int(numbered[0]), int(numbered[0]))
return queue_num
class IPTables:
def __init__(self, ipv6=False, table="filter"):
self.ipv6 = ipv6
self.table = table
def command(self, params) -> Tuple[bytes, bytes]:
params = ["-t", self.table] + params
if os.geteuid() != 0:
exit("You need to have root privileges to run this script.\nPlease try again, this time using 'sudo'. Exiting.")
return Popen(["ip6tables"]+params if self.ipv6 else ["iptables"]+params, stdout=PIPE, stderr=PIPE).communicate()
def list(self) -> Dict[str, List[Rule]]:
stdout, strerr = self.command(["-L", "--line-number", "-n"])
lines = stdout.decode().split("\n")
res: Dict[str, List[Rule]] = {}
chain_name = ""
for line in lines:
if line.startswith("Chain"):
chain_name = line.split()[1]
res[chain_name] = []
elif line.split()[0].isnumeric():
parsed = re.findall(r"([^ ]*)[ ]{,10}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]+([^ ]*)[ ]+(.*)", line)
if len(parsed) > 0:
parsed = parsed[0]
res[chain_name].append(Rule(
id=parsed[0].strip(),
target=parsed[1].strip(),
prot=parsed[2].strip(),
opt=parsed[3].strip(),
source=parsed[4].strip(),
destination=parsed[5].strip(),
details=" ".join(parsed[6:]).strip() if len(parsed[0]) >= 7 else ""
))
return res
def delete_rule(self, chain, id) -> None:
self.command(["-D", str(chain), str(id)])
def create_chain(self, name) -> None:
self.command(["-N", str(name)])
def flush_chain(self, name) -> None:
self.command(["-F", str(name)])
def insert_rule(self, chain, rule, *args, rulenum=1) -> None:
self.command(["-I", str(chain), str(rulenum), "-j", str(rule), *args])
def append_rule(self, chain, rule, *args) -> None:
self.command(["-A", str(chain), "-j", str(rule), *args])

130
backend/modules/sqlite.py Normal file
View File

@@ -0,0 +1,130 @@
from typing import Union
import json, sqlite3, os
from hashlib import md5
class SQLite():
def __init__(self, db_name: str) -> None:
self.conn: Union[None, sqlite3.Connection] = None
self.cur = None
self.db_name = db_name
self.schema = {
'services': {
'service_id': 'VARCHAR(100) PRIMARY KEY',
'status': 'VARCHAR(100) NOT NULL',
'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)',
'name': 'VARCHAR(100) NOT NULL UNIQUE',
'ipv6': 'BOOLEAN NOT NULL CHECK (ipv6 IN (0, 1)) DEFAULT 0',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp"))',
'ip_int': 'VARCHAR(100) NOT NULL',
},
'regexes': {
'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL',
'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'regex_id': 'INTEGER PRIMARY KEY',
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1)) DEFAULT 1',
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (ipv6, port, ip_int, proto);",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);"
]
}
self.DB_VER = md5(json.dumps(self.schema).encode()).hexdigest()
def connect(self) -> None:
try:
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
except Exception:
with open(self.db_name, 'x'): pass
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
self.conn.row_factory = dict_factory
def disconnect(self) -> None:
if self.conn: self.conn.close()
def create_schema(self, tables = {}) -> None:
if self.conn:
cur = self.conn.cursor()
cur.execute("CREATE TABLE IF NOT EXISTS main.keys_values(key VARCHAR(100) PRIMARY KEY, value VARCHAR(100) NOT NULL);")
for t in tables:
if t == "QUERY": continue
cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2]))
if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]]
cur.close()
def query(self, query, *values):
cur = self.conn.cursor()
try:
cur.execute(query, values)
return cur.fetchall()
finally:
cur.close()
try: self.conn.commit()
except Exception: pass
def delete(self):
self.disconnect()
os.remove(self.db_name)
def init(self):
self.connect()
try:
if self.get('DB_VERSION') != self.DB_VER: raise Exception("DB_VERSION is not correct")
except Exception:
self.delete()
self.connect()
self.create_schema(self.schema)
self.put('DB_VERSION', self.DB_VER)
def get(self, key):
q = self.query('SELECT value FROM keys_values WHERE key = ?', key)
if len(q) == 0:
return None
else:
return q[0]["value"]
def put(self, key, value):
if self.get(key) is None:
self.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value))
else:
self.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key)
class Service:
def __init__(self, id: str, status: str, port: int, name: str, ipv6: bool, proto: str, ip_int: str):
self.id = id
self.status = status
self.port = port
self.name = name
self.ipv6 = ipv6
self.proto = proto
self.ip_int = ip_int
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["service_id"], status=var["status"], port=var["port"], name=var["name"], ipv6=var["ipv6"], proto=var["proto"], ip_int=var["ip_int"])
class Regex:
def __init__(self, id: int, regex: str, mode: str, service_id: str, is_blacklist: bool, blocked_packets: int, is_case_sensitive: bool, active: bool):
self.regex = regex
self.mode = mode
self.service_id = service_id
self.is_blacklist = is_blacklist
self.blocked_packets = blocked_packets
self.id = id
self.is_case_sensitive = is_case_sensitive
self.active = active
@classmethod
def from_dict(cls, var: dict):
return cls(id=var["regex_id"], regex=var["regex"], mode=var["mode"], service_id=var["service_id"], is_blacklist=var["is_blacklist"], blocked_packets=var["blocked_packets"], is_case_sensitive=var["is_case_sensitive"], active=var["active"])

View File

@@ -1,268 +0,0 @@
from typing import List
from pypacker import interceptor
from pypacker.layer3 import ip, ip6
from pypacker.layer4 import tcp, udp
from subprocess import Popen, PIPE
import os, traceback, pcre, re
from ipaddress import ip_interface
QUEUE_BASE_NUM = 1000
class FilterTypes:
INPUT = "FIREGEX-INPUT"
OUTPUT = "FIREGEX-OUTPUT"
class ProtoTypes:
TCP = "tcp"
UDP = "udp"
class IPTables:
def __init__(self, ipv6=False, table="mangle"):
self.ipv6 = ipv6
self.table = table
def command(self, params):
params = ["-t", self.table] + params
if os.geteuid() != 0:
exit("You need to have root privileges to run this script.\nPlease try again, this time using 'sudo'. Exiting.")
return Popen(["ip6tables"]+params if self.ipv6 else ["iptables"]+params, stdout=PIPE, stderr=PIPE).communicate()
def list_filters(self, param):
stdout, strerr = self.command(["-L", str(param), "--line-number", "-n"])
output = [re.findall(r"([^ ]*)[ ]{,10}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]{,5}([^ ]*)[ ]+([^ ]*)[ ]+(.*)", ele) for ele in stdout.decode().split("\n")]
return [{
"id": ele[0][0].strip(),
"target": ele[0][1].strip(),
"prot": ele[0][2].strip(),
"opt": ele[0][3].strip(),
"source": ele[0][4].strip(),
"destination": ele[0][5].strip(),
"details": " ".join(ele[0][6:]).strip() if len(ele[0]) >= 7 else "",
} for ele in output if len(ele) > 0 and ele[0][0].isnumeric()]
def delete_command(self, param, id):
self.command(["-D", str(param), str(id)])
def create_chain(self, name):
self.command(["-N", str(name)])
def flush_chain(self, name):
self.command(["-F", str(name)])
def add_chain_to_input(self, name):
if not self.find_if_filter_exists("PREROUTING", str(name)):
self.command(["-I", "PREROUTING", "-j", str(name)])
def add_chain_to_output(self, name):
if not self.find_if_filter_exists("POSTROUTING", str(name)):
self.command(["-I", "POSTROUTING", "-j", str(name)])
def find_if_filter_exists(self, type, target):
for filter in self.list_filters(type):
if filter["target"] == target:
return True
return False
def add_s_to_c(self, queue_range, proto = None, port = None, ip_int = None):
init, end = queue_range
if init > end: init, end = end, init
self.command(["-A", FilterTypes.OUTPUT,
* (["-p", str(proto)] if proto else []),
* (["-s", str(ip_int)] if ip_int else []),
* (["--sport", str(port)] if port else []),
"-j", "NFQUEUE",
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
"--queue-bypass"
])
def add_c_to_s(self, queue_range, proto = None, port = None, ip_int = None):
init, end = queue_range
if init > end: init, end = end, init
self.command(["-A", FilterTypes.INPUT,
* (["-p", str(proto)] if proto else []),
* (["-d", str(ip_int)] if ip_int else []),
* (["--dport", str(port)] if port else []),
"-j", "NFQUEUE",
* (["--queue-num", f"{init}"] if init == end else ["--queue-balance", f"{init}:{end}"]),
"--queue-bypass"
])
class FiregexFilter():
def __init__(self, type, number, queue, proto, port, ipv6, ip_int):
self.type = type
self.id = int(number)
self.queue = queue
self.proto = proto
self.port = int(port)
self.iptable = IPTables(ipv6)
self.ip_int = str(ip_int)
def __repr__(self) -> str:
return f"<FiregexFilter type={self.type} id={self.id} port={self.port} proto={self.proto} queue={self.queue}>"
def delete(self):
self.iptable.delete_command(self.type, self.id)
class Interceptor:
def __init__(self, iptables, ip_int, c_to_s, s_to_c, proto, ipv6, port, n_threads):
self.proto = proto
self.ipv6 = ipv6
self.itor_c_to_s, codes = self._start_queue(c_to_s, n_threads)
iptables.add_c_to_s(queue_range=codes, proto=proto, port=port, ip_int=ip_int)
self.itor_s_to_c, codes = self._start_queue(s_to_c, n_threads)
iptables.add_s_to_c(queue_range=codes, proto=proto, port=port, ip_int=ip_int)
def _start_queue(self,func,n_threads):
def func_wrap(ll_data, ll_proto_id, data, ctx, *args):
pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data)
try:
level4 = None
if self.proto == ProtoTypes.TCP: level4 = pkt_parsed[tcp.TCP].body_bytes
elif self.proto == ProtoTypes.UDP: level4 = pkt_parsed[udp.UDP].body_bytes
if level4:
if func(level4):
return data, interceptor.NF_ACCEPT
elif self.proto == ProtoTypes.TCP:
pkt_parsed[tcp.TCP].flags &= 0x00
pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK
pkt_parsed[tcp.TCP].body_bytes = b""
return pkt_parsed.bin(), interceptor.NF_ACCEPT
else: return b"", interceptor.NF_DROP
else: return pkt_parsed.bin(), interceptor.NF_ACCEPT
except Exception:
traceback.print_exc()
return pkt_parsed.bin(), interceptor.NF_ACCEPT
ictor = interceptor.Interceptor()
starts = QUEUE_BASE_NUM
while True:
if starts >= 65536:
raise Exception("Netfilter queue is full!")
queue_ids = list(range(starts,starts+n_threads))
try:
ictor.start(func_wrap, queue_ids=queue_ids)
break
except interceptor.UnableToBindException as e:
starts = e.queue_id + 1
return ictor, (starts, starts+n_threads-1)
def stop(self):
self.itor_c_to_s.stop()
self.itor_s_to_c.stop()
class FiregexFilterManager:
def __init__(self, srv):
self.ipv6 = srv["ipv6"]
self.iptables = IPTables(self.ipv6)
self.iptables.create_chain(FilterTypes.INPUT)
self.iptables.create_chain(FilterTypes.OUTPUT)
self.iptables.add_chain_to_input(FilterTypes.INPUT)
self.iptables.add_chain_to_output(FilterTypes.OUTPUT)
def get(self) -> List[FiregexFilter]:
res = []
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
for filter in self.iptables.list_filters(filter_type):
queue_num = None
balanced = re.findall(r"NFQUEUE balance ([0-9]+):([0-9]+)", filter["details"])
numbered = re.findall(r"NFQUEUE num ([0-9]+)", filter["details"])
port = re.findall(r"[sd]pt:([0-9]+)", filter["details"])
if balanced: queue_num = (int(balanced[0][0]), int(balanced[0][1]))
if numbered: queue_num = (int(numbered[0]), int(numbered[0]))
if queue_num and port:
res.append(FiregexFilter(
type=filter_type,
number=filter["id"],
queue=queue_num,
proto=filter["prot"],
port=int(port[0]),
ipv6=self.ipv6,
ip_int=filter["source"] if filter_type == FilterTypes.OUTPUT else filter["destination"]
))
return res
def add(self, proto, port, ip_int, func):
for ele in self.get():
if int(port) == ele.port and proto == ele.proto and ip_interface(ip_int) == ip_interface(ele.ip_int):
return None
def c_to_s(pkt): return func(pkt, True)
def s_to_c(pkt): return func(pkt, False)
itor = Interceptor( iptables=self.iptables, ip_int=ip_int,
c_to_s=c_to_s, s_to_c=s_to_c,
proto=proto, ipv6=self.ipv6, port=port,
n_threads=int(os.getenv("N_THREADS_NFQUEUE","1")))
return itor
def delete_all(self):
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
self.iptables.flush_chain(filter_type)
def delete_by_srv(self, srv):
for filter in self.get():
if filter.port == int(srv["port"]) and filter.proto == srv["proto"] and ip_interface(filter.ip_int) == ip_interface(srv["ip_int"]):
filter.delete()
class Filter:
def __init__(self, regex, is_case_sensitive=True, is_blacklist=True, c_to_s=False, s_to_c=False, blocked_packets=0, code=None):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if c_to_s == s_to_c: c_to_s = s_to_c = True # (False, False) == (True, True)
self.c_to_s = c_to_s
self.s_to_c = s_to_c
self.blocked = blocked_packets
self.code = code
self.compiled_regex = self.compile()
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
def check(self, data):
return True if self.compiled_regex.search(data) else False
class Proxy:
def __init__(self, srv, filters=None):
self.srv = srv
self.manager = FiregexFilterManager(self.srv)
self.filters: List[Filter] = filters if filters else []
self.interceptor = None
def set_filters(self, filters):
elements_to_pop = len(self.filters)
for ele in filters:
self.filters.append(ele)
for _ in range(elements_to_pop):
self.filters.pop(0)
def start(self):
if not self.interceptor:
self.manager.delete_by_srv(self.srv)
def regex_filter(pkt, by_client):
try:
for filter in self.filters:
if (by_client and filter.c_to_s) or (not by_client and filter.s_to_c):
match = filter.check(pkt)
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
filter.blocked+=1
return False
except IndexError: pass
return True
self.interceptor = self.manager.add(self.srv["proto"], self.srv["port"], self.srv["ip_int"], regex_filter)
def stop(self):
self.manager.delete_by_srv(self.srv)
if self.interceptor:
self.interceptor.stop()
self.interceptor = None
def restart(self):
self.stop()
self.start()

View File

@@ -1,262 +1,7 @@
from hashlib import md5 import os, socket, secrets
import traceback
from typing import Dict
from proxy import Filter, Proxy
import os, sqlite3, socket, asyncio, re
import secrets, json
from base64 import b64decode
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1")) LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
class SQLite():
def __init__(self, db_name) -> None:
self.conn = None
self.cur = None
self.db_name = db_name
self.schema = {
'services': {
'service_id': 'VARCHAR(100) PRIMARY KEY',
'status': 'VARCHAR(100) NOT NULL',
'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)',
'name': 'VARCHAR(100) NOT NULL UNIQUE',
'ipv6': 'BOOLEAN NOT NULL CHECK (ipv6 IN (0, 1)) DEFAULT 0',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp"))',
'ip_int': 'VARCHAR(100) NOT NULL',
},
'regexes': {
'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL',
'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'regex_id': 'INTEGER PRIMARY KEY',
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1)) DEFAULT 1',
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
},
'keys_values': {
'key': 'VARCHAR(100) PRIMARY KEY',
'value': 'VARCHAR(100) NOT NULL',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (ipv6, port, ip_int, proto);",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);"
]
}
self.DB_VER = md5(json.dumps(self.schema).encode()).hexdigest()
def connect(self) -> None:
try:
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
except Exception:
with open(self.db_name, 'x'):
pass
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
self.conn.row_factory = dict_factory
def disconnect(self) -> None:
if self.conn: self.conn.close()
def create_schema(self, tables = {}) -> None:
cur = self.conn.cursor()
for t in tables:
if t == "QUERY": continue
cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2]))
if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]]
cur.close()
def query(self, query, *values):
cur = self.conn.cursor()
try:
cur.execute(query, values)
return cur.fetchall()
finally:
cur.close()
try: self.conn.commit()
except Exception: pass
def delete(self):
self.disconnect()
os.remove(self.db_name)
def init(self):
self.connect()
try:
current_ver = self.query("SELECT value FROM keys_values WHERE key = 'DB_VERSION'")[0]['value']
if current_ver != self.DB_VER: raise Exception("DB_VERSION is not correct")
except Exception:
self.delete()
self.connect()
self.create_schema(self.schema)
self.query("INSERT INTO keys_values (key, value) VALUES ('DB_VERSION', ?)", self.DB_VER)
class KeyValueStorage:
def __init__(self, db):
self.db = db
def get(self, key):
q = self.db.query('SELECT value FROM keys_values WHERE key = ?', key)
if len(q) == 0:
return None
else:
return q[0]["value"]
def put(self, key, value):
if self.get(key) is None:
self.db.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value))
else:
self.db.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key)
class STATUS:
STOP = "stop"
ACTIVE = "active"
class ServiceNotFoundException(Exception): pass
class ServiceManager:
def __init__(self, srv, db):
self.srv = srv
self.db = db
self.proxy = Proxy(srv)
self.status = STATUS.STOP
self.filters = {}
self._update_filters_from_db()
self.lock = asyncio.Lock()
self.starter = None
def _update_filters_from_db(self):
res = self.db.query("""
SELECT
regex, mode, regex_id id, is_blacklist,
blocked_packets n_packets, is_case_sensitive
FROM regexes WHERE service_id = ? AND active=1;
""", self.srv["service_id"])
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f["id"] for f in res])
#remove old filters
for f in old_filters:
if not f in new_filters:
del self.filters[f]
for f in new_filters:
if not f in old_filters:
filter_info = [ele for ele in res if ele["id"] == f][0]
self.filters[f] = Filter(
is_case_sensitive=filter_info["is_case_sensitive"],
c_to_s=filter_info["mode"] in ["C","B"],
s_to_c=filter_info["mode"] in ["S","B"],
is_blacklist=filter_info["is_blacklist"],
regex=b64decode(filter_info["regex"]),
blocked_packets=filter_info["n_packets"],
code=f
)
self.proxy.set_filters(self.filters.values())
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv["service_id"])
async def next(self,to):
async with self.lock:
return self._next(to)
def _next(self, to):
if self.status != to:
# ACTIVE -> PAUSE
if (self.status, to) in [(STATUS.ACTIVE, STATUS.STOP)]:
self.proxy.stop()
self._set_status(to)
# PAUSE -> ACTIVE
elif (self.status, to) in [(STATUS.STOP, STATUS.ACTIVE)]:
self.proxy.restart()
self._set_status(to)
def _stats_updater(self,filter:Filter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
def update_stats(self):
for ele in self.proxy.filters:
self._stats_updater(ele)
def _set_status(self,status):
self.status = status
self.__update_status_db(status)
async def update_filters(self):
async with self.lock:
self._update_filters_from_db()
class ProxyManager:
def __init__(self, db:SQLite):
self.db = db
self.proxy_table: Dict[ServiceManager] = {}
self.lock = asyncio.Lock()
self.updater_task = None
def init_updater(self, callback = None):
if not self.updater_task:
self.updater_task = asyncio.create_task(self._stats_updater(callback))
def close_updater(self):
if self.updater_task: self.updater_task.cancel()
async def close(self):
self.close_updater()
if self.updater_task: self.updater_task.cancel()
for key in list(self.proxy_table.keys()):
await self.remove(key)
async def remove(self,srv_id):
async with self.lock:
if srv_id in self.proxy_table:
await self.proxy_table[srv_id].next(STATUS.STOP)
del self.proxy_table[srv_id]
async def init(self, callback = None):
self.init_updater(callback)
await self.reload()
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT * FROM services;'):
srv_id = srv["service_id"]
if srv_id in self.proxy_table:
continue
self.proxy_table[srv_id] = ServiceManager(srv, self.db)
await self.proxy_table[srv_id].next(srv["status"])
async def _stats_updater(self, callback):
try:
while True:
try:
for key in list(self.proxy_table.keys()):
self.proxy_table[key].update_stats()
except Exception:
traceback.print_exc()
if callback:
if asyncio.iscoroutinefunction(callback): await callback()
else: callback()
await asyncio.sleep(5)
except asyncio.CancelledError:
self.updater_task = None
return
def get(self,srv_id):
if srv_id in self.proxy_table:
return self.proxy_table[srv_id]
else:
raise ServiceNotFoundException()
def refactor_name(name:str): def refactor_name(name:str):
name = name.strip() name = name.strip()
while " " in name: name = name.replace(" "," ") while " " in name: name = name.replace(" "," ")