Refactor code pt 1 (not tested)
This commit is contained in:
130
backend/modules/sqlite.py
Normal file
130
backend/modules/sqlite.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from typing import Union
|
||||
import json, sqlite3, os
|
||||
from hashlib import md5
|
||||
|
||||
class SQLite():
|
||||
def __init__(self, db_name: str) -> None:
|
||||
self.conn: Union[None, sqlite3.Connection] = None
|
||||
self.cur = None
|
||||
self.db_name = db_name
|
||||
self.schema = {
|
||||
'services': {
|
||||
'service_id': 'VARCHAR(100) PRIMARY KEY',
|
||||
'status': 'VARCHAR(100) NOT NULL',
|
||||
'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)',
|
||||
'name': 'VARCHAR(100) NOT NULL UNIQUE',
|
||||
'ipv6': 'BOOLEAN NOT NULL CHECK (ipv6 IN (0, 1)) DEFAULT 0',
|
||||
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp"))',
|
||||
'ip_int': 'VARCHAR(100) NOT NULL',
|
||||
},
|
||||
'regexes': {
|
||||
'regex': 'TEXT NOT NULL',
|
||||
'mode': 'VARCHAR(1) NOT NULL',
|
||||
'service_id': 'VARCHAR(100) NOT NULL',
|
||||
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
|
||||
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
|
||||
'regex_id': 'INTEGER PRIMARY KEY',
|
||||
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
|
||||
'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1)) DEFAULT 1',
|
||||
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
|
||||
},
|
||||
'QUERY':[
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (ipv6, port, ip_int, proto);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);"
|
||||
]
|
||||
}
|
||||
self.DB_VER = md5(json.dumps(self.schema).encode()).hexdigest()
|
||||
|
||||
def connect(self) -> None:
|
||||
try:
|
||||
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
|
||||
except Exception:
|
||||
with open(self.db_name, 'x'): pass
|
||||
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
|
||||
def dict_factory(cursor, row):
|
||||
d = {}
|
||||
for idx, col in enumerate(cursor.description):
|
||||
d[col[0]] = row[idx]
|
||||
return d
|
||||
self.conn.row_factory = dict_factory
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.conn: self.conn.close()
|
||||
|
||||
def create_schema(self, tables = {}) -> None:
|
||||
if self.conn:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("CREATE TABLE IF NOT EXISTS main.keys_values(key VARCHAR(100) PRIMARY KEY, value VARCHAR(100) NOT NULL);")
|
||||
for t in tables:
|
||||
if t == "QUERY": continue
|
||||
cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2]))
|
||||
if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]]
|
||||
cur.close()
|
||||
|
||||
def query(self, query, *values):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute(query, values)
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
cur.close()
|
||||
try: self.conn.commit()
|
||||
except Exception: pass
|
||||
|
||||
def delete(self):
|
||||
self.disconnect()
|
||||
os.remove(self.db_name)
|
||||
|
||||
def init(self):
|
||||
self.connect()
|
||||
try:
|
||||
if self.get('DB_VERSION') != self.DB_VER: raise Exception("DB_VERSION is not correct")
|
||||
except Exception:
|
||||
self.delete()
|
||||
self.connect()
|
||||
self.create_schema(self.schema)
|
||||
self.put('DB_VERSION', self.DB_VER)
|
||||
|
||||
def get(self, key):
|
||||
q = self.query('SELECT value FROM keys_values WHERE key = ?', key)
|
||||
if len(q) == 0:
|
||||
return None
|
||||
else:
|
||||
return q[0]["value"]
|
||||
|
||||
def put(self, key, value):
|
||||
if self.get(key) is None:
|
||||
self.query('INSERT INTO keys_values (key, value) VALUES (?, ?);', key, str(value))
|
||||
else:
|
||||
self.query('UPDATE keys_values SET value=? WHERE key = ?;', str(value), key)
|
||||
|
||||
|
||||
class Service:
|
||||
def __init__(self, id: str, status: str, port: int, name: str, ipv6: bool, proto: str, ip_int: str):
|
||||
self.id = id
|
||||
self.status = status
|
||||
self.port = port
|
||||
self.name = name
|
||||
self.ipv6 = ipv6
|
||||
self.proto = proto
|
||||
self.ip_int = ip_int
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, var: dict):
|
||||
return cls(id=var["service_id"], status=var["status"], port=var["port"], name=var["name"], ipv6=var["ipv6"], proto=var["proto"], ip_int=var["ip_int"])
|
||||
|
||||
|
||||
class Regex:
|
||||
def __init__(self, id: int, regex: str, mode: str, service_id: str, is_blacklist: bool, blocked_packets: int, is_case_sensitive: bool, active: bool):
|
||||
self.regex = regex
|
||||
self.mode = mode
|
||||
self.service_id = service_id
|
||||
self.is_blacklist = is_blacklist
|
||||
self.blocked_packets = blocked_packets
|
||||
self.id = id
|
||||
self.is_case_sensitive = is_case_sensitive
|
||||
self.active = active
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, var: dict):
|
||||
return cls(id=var["regex_id"], regex=var["regex"], mode=var["mode"], service_id=var["service_id"], is_blacklist=var["is_blacklist"], blocked_packets=var["blocked_packets"], is_case_sensitive=var["is_case_sensitive"], active=var["active"])
|
||||
Reference in New Issue
Block a user