Moduled Firegex, Merging pt1 (not finished and not working yet)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from ipaddress import ip_interface
|
||||
import os, socket, secrets, psutil
|
||||
import os, socket, psutil
|
||||
import sys
|
||||
from fastapi_socketio import SocketManager
|
||||
|
||||
@@ -30,13 +30,6 @@ def refactor_name(name:str):
|
||||
while " " in name: name = name.replace(" "," ")
|
||||
return name
|
||||
|
||||
def gen_service_id(db):
|
||||
while True:
|
||||
res = secrets.token_hex(8)
|
||||
if len(db.query('SELECT 1 FROM services WHERE service_id = ?;', res)) == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
def list_files(mypath):
|
||||
from os import listdir
|
||||
from os.path import isfile, join
|
||||
|
||||
125
backend/utils/firegextables.py
Normal file
125
backend/utils/firegextables.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from typing import List
|
||||
import nftables
|
||||
from utils import ip_parse, ip_family
|
||||
|
||||
class FiregexFilter():
|
||||
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target:str=None, id=None):
|
||||
self.nftables = nftables.Nftables()
|
||||
self.id = int(id) if id else None
|
||||
self.queue = queue
|
||||
self.target = target
|
||||
self.proto = proto
|
||||
self.port = int(port)
|
||||
self.ip_int = str(ip_int)
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, FiregexFilter):
|
||||
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
|
||||
return False
|
||||
|
||||
class FiregexTables:
|
||||
|
||||
def __init__(self):
|
||||
self.table_name = "firegex"
|
||||
self.nft = nftables.Nftables()
|
||||
|
||||
def raw_cmd(self, *cmds):
|
||||
return self.nft.json_cmd({"nftables": list(cmds)})
|
||||
|
||||
def cmd(self, *cmds):
|
||||
code, out, err = self.raw_cmd(*cmds)
|
||||
|
||||
if code == 0: return out
|
||||
else: raise Exception(err)
|
||||
|
||||
def init(self):
|
||||
self.reset()
|
||||
code, out, err = self.raw_cmd({"create":{"table":{"name":self.table_name,"family":"inet"}}})
|
||||
if code == 0:
|
||||
self.cmd(
|
||||
{"create":{"chain":{
|
||||
"family":"inet",
|
||||
"table":self.table_name,
|
||||
"name":"input",
|
||||
"type":"filter",
|
||||
"hook":"prerouting",
|
||||
"prio":-150,
|
||||
"policy":"accept"
|
||||
}}},
|
||||
{"create":{"chain":{
|
||||
"family":"inet",
|
||||
"table":self.table_name,
|
||||
"name":"output",
|
||||
"type":"filter",
|
||||
"hook":"postrouting",
|
||||
"prio":-150,
|
||||
"policy":"accept"
|
||||
}}}
|
||||
)
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.raw_cmd(
|
||||
{"flush":{"table":{"name":"firegex","family":"inet"}}},
|
||||
{"delete":{"table":{"name":"firegex","family":"inet"}}},
|
||||
)
|
||||
|
||||
def list(self):
|
||||
return self.cmd({"list": {"ruleset": None}})["nftables"]
|
||||
|
||||
def add_output(self, queue_range, proto, port, ip_int):
|
||||
init, end = queue_range
|
||||
if init > end: init, end = end, init
|
||||
ip_int = ip_parse(ip_int)
|
||||
ip_addr = str(ip_int).split("/")[0]
|
||||
ip_addr_cidr = int(str(ip_int).split("/")[1])
|
||||
self.cmd({ "insert":{ "rule": {
|
||||
"family": "inet",
|
||||
"table": self.table_name,
|
||||
"chain": "output",
|
||||
"expr": [
|
||||
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'saddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}},
|
||||
{'match': {"left": { "payload": {"protocol": str(proto), "field": "sport"}}, "op": "==", "right": int(port)}},
|
||||
{"queue": {"num": str(init) if init == end else f"{init}-{end}", "flags": ["bypass"]}}
|
||||
]
|
||||
}}})
|
||||
|
||||
def add_input(self, queue_range, proto = None, port = None, ip_int = None):
|
||||
init, end = queue_range
|
||||
if init > end: init, end = end, init
|
||||
ip_int = ip_parse(ip_int)
|
||||
ip_addr = str(ip_int).split("/")[0]
|
||||
ip_addr_cidr = int(str(ip_int).split("/")[1])
|
||||
self.cmd({"insert":{"rule":{
|
||||
"family": "inet",
|
||||
"table": self.table_name,
|
||||
"chain": "input",
|
||||
"expr": [
|
||||
{'match': {'left': {'payload': {'protocol': ip_family(ip_int), 'field': 'daddr'}}, 'op': '==', 'right': {"prefix": {"addr": ip_addr, "len": ip_addr_cidr}}}},
|
||||
{'match': {"left": { "payload": {"protocol": str(proto), "field": "dport"}}, "op": "==", "right": int(port)}},
|
||||
{"queue": {"num": str(init) if init == end else f"{init}-{end}", "flags": ["bypass"]}}
|
||||
]
|
||||
}}})
|
||||
|
||||
def get(self) -> List[FiregexFilter]:
|
||||
res = []
|
||||
for filter in [ele["rule"] for ele in self.list() if "rule" in ele and ele["rule"]["table"] == self.table_name]:
|
||||
queue_str = str(filter["expr"][2]["queue"]["num"]).split("-")
|
||||
queue = None
|
||||
if len(queue_str) == 1: queue = int(queue_str[0]), int(queue_str[0])
|
||||
else: queue = int(queue_str[0]), int(queue_str[1])
|
||||
ip_int = None
|
||||
if isinstance(filter["expr"][0]["match"]["right"],str):
|
||||
ip_int = str(ip_parse(filter["expr"][0]["match"]["right"]))
|
||||
else:
|
||||
ip_int = f'{filter["expr"][0]["match"]["right"]["prefix"]["addr"]}/{filter["expr"][0]["match"]["right"]["prefix"]["len"]}'
|
||||
res.append(FiregexFilter(
|
||||
target=filter["chain"],
|
||||
id=int(filter["handle"]),
|
||||
queue=queue,
|
||||
proto=filter["expr"][1]["match"]["left"]["payload"]["protocol"],
|
||||
port=filter["expr"][1]["match"]["right"],
|
||||
ip_int=ip_int
|
||||
))
|
||||
return res
|
||||
|
||||
95
backend/utils/sqlite.py
Normal file
95
backend/utils/sqlite.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import Union
|
||||
import json, sqlite3, os
|
||||
from hashlib import md5
|
||||
import base64
|
||||
|
||||
class SQLite():
|
||||
def __init__(self, db_name: str, schema:dict = None) -> None:
|
||||
self.conn: Union[None, sqlite3.Connection] = None
|
||||
self.cur = None
|
||||
self.db_name = db_name
|
||||
self.__backup = None
|
||||
self.schema = {} if schema is None else schema
|
||||
self.DB_VER = md5(json.dumps(self.schema).encode()).hexdigest()
|
||||
|
||||
def connect(self) -> None:
|
||||
try:
|
||||
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
|
||||
except Exception:
|
||||
path_name = os.path.dirname(self.db_name)
|
||||
if not os.path.exists(path_name): os.makedirs(path_name)
|
||||
with open(self.db_name, 'x'): pass
|
||||
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
|
||||
def dict_factory(cursor, row):
|
||||
d = {}
|
||||
for idx, col in enumerate(cursor.description):
|
||||
d[col[0]] = row[idx]
|
||||
return d
|
||||
self.conn.row_factory = dict_factory
|
||||
|
||||
def backup(self):
|
||||
with open(self.db_name, "rb") as f:
|
||||
self.__backup = f.read()
|
||||
|
||||
def restore(self):
|
||||
were_active = True if self.conn else False
|
||||
self.disconnect()
|
||||
if self.__backup:
|
||||
with open(self.db_name, "wb") as f:
|
||||
f.write(self.__backup)
|
||||
self.__backup = None
|
||||
if were_active: self.connect()
|
||||
|
||||
def delete_backup(self):
|
||||
self.__backup = None
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.conn: self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
def create_schema(self, tables = {}) -> None:
|
||||
if self.conn:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("CREATE TABLE IF NOT EXISTS main.keys_values(key VARCHAR(100) PRIMARY KEY, value VARCHAR(100) NOT NULL);")
|
||||
for t in tables:
|
||||
if t == "QUERY": continue
|
||||
cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2]))
|
||||
if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]]
|
||||
cur.close()
|
||||
|
||||
def query(self, query, *values):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute(query, values)
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
cur.close()
|
||||
try: self.conn.commit()
|
||||
except Exception: pass
|
||||
|
||||
def delete(self):
|
||||
self.disconnect()
|
||||
os.remove(self.db_name)
|
||||
|
||||
def init(self):
|
||||
self.connect()
|
||||
try:
|
||||
if self.get('DB_VERSION') != self.DB_VER: raise Exception("DB_VERSION is not correct")
|
||||
except Exception:
|
||||
self.delete()
|
||||
self.connect()
|
||||
self.create_schema(self.schema)
|
||||
self.put('DB_VERSION', self.DB_VER)
|
||||
|
||||
def get(self, key):
|
||||
q = self.query('SELECT value FROM keys_values WHERE key = ?', key)
|
||||
if len(q) == 0:
|
||||
return None
|
||||
else:
|
||||
return q[0]["value"]
|
||||
|
||||
def put(self, key, value):
|
||||
if self.get(key) is None:
|
||||
self.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value))
|
||||
else:
|
||||
self.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key)
|
||||
Reference in New Issue
Block a user