diff --git a/backend/app.py b/backend/app.py index 7054f36..c0530d1 100644 --- a/backend/app.py +++ b/backend/app.py @@ -12,17 +12,50 @@ DEBUG = len(sys.argv) > 1 and sys.argv[1] == "DEBUG" # DB init if not os.path.exists("db"): os.mkdir("db") db = SQLite('db/firegex.db') -db.connect() conf = KeyValueStorage(db) firewall = ProxyManager(db) app = FastAPI(debug=DEBUG) @app.on_event("shutdown") -def shutdown_event(): - firewall.close() +async def shutdown_event(): + await firewall.close() db.disconnect() +@app.on_event("startup") +async def startup_event(): + global APP_STATUS + db.connect() + db.create_schema({ + 'services': { + 'status': 'VARCHAR(100) NOT NULL', + 'service_id': 'VARCHAR(100) PRIMARY KEY', + 'internal_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE', + 'public_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE', + 'name': 'VARCHAR(100) NOT NULL' + }, + 'regexes': { + 'regex': 'TEXT NOT NULL', + 'mode': 'VARCHAR(1) NOT NULL', + 'service_id': 'VARCHAR(100) NOT NULL', + 'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))', + 'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', + 'regex_id': 'INTEGER PRIMARY KEY', + 'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))', + 'FOREIGN KEY (service_id)':'REFERENCES services (service_id)', + }, + 'keys_values': { + 'key': 'VARCHAR(100) PRIMARY KEY', + 'value': 'VARCHAR(100) NOT NULL', + }, + }) + db.query("CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);") + + if not conf.get("password") is None: + APP_STATUS = "run" + + await firewall.reload() + app.add_middleware(SessionMiddleware, secret_key=os.urandom(32)) SESSION_TOKEN = secrets.token_hex(8) APP_STATUS = "init" @@ -151,19 +184,19 @@ async def get_service(request: Request, service_id: str): @app.get('/api/service/{service_id}/stop') async def get_service_stop(request: Request, service_id: str): login_check(request) - firewall.change_status(service_id,STATUS.STOP) + await firewall.get(service_id).next(STATUS.STOP) return {'status': 'ok'} @app.get('/api/service/{service_id}/pause') async def get_service_pause(request: Request, service_id: str): login_check(request) - firewall.change_status(service_id,STATUS.PAUSE) + await firewall.get(service_id).next(STATUS.PAUSE) return {'status': 'ok'} @app.get('/api/service/{service_id}/start') async def get_service_start(request: Request, service_id: str): login_check(request) - firewall.change_status(service_id,STATUS.ACTIVE) + await firewall.get(service_id).next(STATUS.ACTIVE) return {'status': 'ok'} @app.get('/api/service/{service_id}/delete') @@ -171,7 +204,7 @@ async def get_service_delete(request: Request, service_id: str): login_check(request) db.query('DELETE FROM services WHERE service_id = ?;', service_id) db.query('DELETE FROM regexes WHERE service_id = ?;', service_id) - firewall.fire_update(service_id) + await firewall.remove(service_id) return {'status': 'ok'} @@ -179,7 +212,7 @@ async def get_service_delete(request: Request, service_id: str): async def get_regen_port(request: Request, service_id: str): login_check(request) db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', gen_internal_port(db), service_id) - firewall.fire_update(service_id) + await firewall.get(service_id).update_port() return {'status': 'ok'} @@ -212,7 +245,7 @@ async def get_regex_delete(request: Request, regex_id: int): if len(res) != 0: db.query('DELETE FROM regexes WHERE regex_id = ?;', regex_id) - firewall.fire_update(res[0]["service_id"]) + await firewall.get(res[0]["service_id"]).update_filters() return {'status': 'ok'} @@ -236,7 +269,7 @@ async def post_regexes_add(request: Request, form: RegexAddForm): except sqlite3.IntegrityError: return {'status': 'An identical regex already exists'} - firewall.fire_update(form.service_id) + await firewall.get(form.service_id).update_filters() return {'status': 'ok'} class ServiceAddForm(BaseModel): @@ -250,7 +283,7 @@ async def post_services_add(request: Request, form: ServiceAddForm): try: db.query("INSERT INTO services (name, service_id, internal_port, public_port, status) VALUES (?, ?, ?, ?, ?)", form.name, serv_id, gen_internal_port(db), form.port, 'stop') - firewall.reload() + await firewall.reload() except sqlite3.IntegrityError: return {'status': 'Name or/and port of the service has been already assigned to another service'} @@ -300,40 +333,11 @@ async def catch_all(request: Request, full_path:str): if __name__ == '__main__': - db.create_schema({ - 'services': { - 'status': 'VARCHAR(100) NOT NULL', - 'service_id': 'VARCHAR(100) PRIMARY KEY', - 'internal_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE', - 'public_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE', - 'name': 'VARCHAR(100) NOT NULL' - }, - 'regexes': { - 'regex': 'TEXT NOT NULL', - 'mode': 'VARCHAR(1) NOT NULL', - 'service_id': 'VARCHAR(100) NOT NULL', - 'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))', - 'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', - 'regex_id': 'INTEGER PRIMARY KEY', - 'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))', - 'FOREIGN KEY (service_id)':'REFERENCES services (service_id)', - }, - 'keys_values': { - 'key': 'VARCHAR(100) PRIMARY KEY', - 'value': 'VARCHAR(100) NOT NULL', - }, - }) - db.query("CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);") - - if not conf.get("password") is None: - APP_STATUS = "run" - - firewall.reload() # os.environ {PORT = Backend Port (Main Port), F_PORT = Frontend Port} uvicorn.run( "app:app", host="0.0.0.0", port=int(os.getenv("PORT","4444")), reload=DEBUG, - access_log=DEBUG, + access_log=DEBUG ) diff --git a/backend/proxy/__init__.py b/backend/proxy/__init__.py index 36fcf88..c2e2738 100755 --- a/backend/proxy/__init__.py +++ b/backend/proxy/__init__.py @@ -1,5 +1,4 @@ -import subprocess, re, os -from threading import Lock +import subprocess, re, os, asyncio #c++ -o proxy proxy.cpp @@ -25,11 +24,11 @@ class Filter: yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex() class Proxy: - def __init__(self, internal_port, public_port, callback_blocked_update=None, filters=None, public_host="0.0.0.0", internal_host="127.0.0.1"): + def __init__(self, internal_port=0, public_port=0, callback_blocked_update=None, filters=None, public_host="0.0.0.0", internal_host="127.0.0.1"): self.filter_map = {} - self.filter_map_lock = Lock() - self.update_config_lock = Lock() - self.status_change = Lock() + self.filter_map_lock = asyncio.Lock() + self.update_config_lock = asyncio.Lock() + self.status_change = asyncio.Lock() self.public_host = public_host self.public_port = public_port self.internal_host = internal_host @@ -38,75 +37,72 @@ class Proxy: self.process = None self.callback_blocked_update = callback_blocked_update - def start(self, in_pause=False): - self.status_change.acquire() + async def start(self, in_pause=False): + await self.status_change.acquire() if not self.isactive(): try: self.filter_map = self.compile_filters() filters_codes = list(self.filter_map.keys()) if not in_pause else [] proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./proxy") - self.process = subprocess.Popen( - [ proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port)], - stdout=subprocess.PIPE, stdin=subprocess.PIPE, universal_newlines=True + self.process = await asyncio.create_subprocess_exec( + proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port), + stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE ) - self.update_config(filters_codes) + await self.update_config(filters_codes) finally: self.status_change.release() - - for stdout_line in iter(self.process.stdout.readline, ""): - if stdout_line.startswith("BLOCKED"): - regex_id = stdout_line.split()[1] - with self.filter_map_lock: - self.filter_map[regex_id].blocked+=1 - if self.callback_blocked_update: self.callback_blocked_update(self.filter_map[regex_id]) - self.process.stdout.close() - return self.process.wait() + try: + while True: + buff = await self.process.stdout.readuntil() + stdout_line = buff.decode() + if stdout_line.startswith("BLOCKED"): + regex_id = stdout_line.split()[1] + async with self.filter_map_lock: + self.filter_map[regex_id].blocked+=1 + if self.callback_blocked_update: await self.callback_blocked_update(self.filter_map[regex_id]) + except Exception: + return await self.process.wait() else: self.status_change.release() - def stop(self): - with self.status_change: + async def stop(self): + async with self.status_change: if self.isactive(): - self.process.terminate() - try: - self.process.wait(timeout=3) - except Exception: - self.process.kill() - return False - finally: - self.process = None + self.process.kill() + self.process = None + return False return True - def restart(self, in_pause=False): - status = self.stop() - self.start(in_pause=in_pause) + async def restart(self, in_pause=False): + status = await self.stop() + await self.start(in_pause=in_pause) return status - def update_config(self, filters_codes): - with self.update_config_lock: + async def update_config(self, filters_codes): + async with self.update_config_lock: if (self.isactive()): - self.process.stdin.write(" ".join(filters_codes)+"\n") - self.process.stdin.flush() + self.process.stdin.write((" ".join(filters_codes)+"\n").encode()) + await self.process.stdin.drain() - def reload(self): + async def reload(self): if self.isactive(): - with self.filter_map_lock: + async with self.filter_map_lock: self.filter_map = self.compile_filters() filters_codes = list(self.filter_map.keys()) - self.update_config(filters_codes) + await self.update_config(filters_codes) def isactive(self): - if self.process and not self.process.poll() is None: + if self.process and not self.process.returncode is None: self.process = None return True if self.process else False - def pause(self): + async def pause(self): if self.isactive(): - self.update_config([]) + await self.update_config([]) else: - self.start(in_pause=True) + await self.start(in_pause=True) def compile_filters(self): res = {} diff --git a/backend/utils.py b/backend/utils.py index 663233a..bc9c3bc 100755 --- a/backend/utils.py +++ b/backend/utils.py @@ -1,6 +1,6 @@ +import threading from proxy import Filter, Proxy -import random, string, os, threading, sqlite3, time, atexit, socket -from kthread import KThread +import random, string, os, sqlite3, socket, asyncio from base64 import b64decode LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1")) @@ -69,209 +69,157 @@ class STATUS: STOP = "stop" PAUSE = "pause" ACTIVE = "active" - -class ProxyManager: - def __init__(self, db:SQLite): + +class ServiceNotFoundException(Exception): + pass + +class ServiceManager: + def __init__(self, id, db): + self.id = id self.db = db - self.proxy_table = {} - self.lock = threading.Lock() - atexit.register(self.close) - - def __clean_proxy_table(self): - with self.lock: - for key in list(self.proxy_table.keys()): - if not self.proxy_table[key]["thread"].is_alive(): - del self.proxy_table[key] - - def close(self): - with self.lock: - for key in list(self.proxy_table.keys()): - if self.proxy_table[key]["thread"].is_alive(): - self.proxy_table[key]["thread"].kill() - del self.proxy_table[key] - - def reload(self): - self.__clean_proxy_table() - with self.lock: - for srv in self.db.query('SELECT service_id, status FROM services;'): - srv_id, n_status = srv["service_id"], srv["status"] - if srv_id in self.proxy_table: - continue - update_signal = threading.Event() - callback_signal = threading.Event() - req_status = [n_status] - thread = KThread(target=self.service_manager, args=(srv_id, req_status, update_signal, callback_signal)) - self.proxy_table[srv_id] = { - "thread":thread, - "event":update_signal, - "callback":callback_signal, - "next_status":req_status - } - thread.start() - callback_signal.wait() - callback_signal.clear() - - def get_service_data(self, id): + self.proxy = Proxy( + internal_host=LOCALHOST_IP, + callback_blocked_update=self._stats_updater + ) + self.status = STATUS.STOP + self.filters = {} + self._proxy_update() + self.lock = asyncio.Lock() + self.starter = None + + def _update_port_from_db(self): res = self.db.query(""" SELECT - service_id `id`, - status, public_port, internal_port FROM services WHERE service_id = ?; - """, id) - if len(res) == 0: return None - else: res = res[0] - res["filters"] = self.db.query(""" + """, self.id) + if len(res) == 0: raise ServiceNotFoundException() + self.proxy.internal_port = res[0]["internal_port"] + self.proxy.public_port = res[0]["public_port"] + + def _proxy_update(self): + self._update_port_from_db() + self._update_filters_from_db() + + def _update_filters_from_db(self): + res = self.db.query(""" SELECT regex, mode, regex_id `id`, is_blacklist, blocked_packets n_packets, is_case_sensitive FROM regexes WHERE service_id = ?; - """, id) - return res + """, self.id) - def change_status(self, id, to): - with self.lock: - if id in self.proxy_table: - if self.proxy_table[id]["thread"].is_alive(): - self.proxy_table[id]["next_status"][0] = to - self.proxy_table[id]["event"].set() - self.proxy_table[id]["callback"].wait() - self.proxy_table[id]["callback"].clear() - else: - del self.proxy_table[id] - - def fire_update(self, id): - with self.lock: - if id in self.proxy_table: - if self.proxy_table[id]["thread"].is_alive(): - self.proxy_table[id]["event"].set() - self.proxy_table[id]["callback"].wait() - self.proxy_table[id]["callback"].clear() - else: - del self.proxy_table[id] + #Filter check + old_filters = set(self.filters.keys()) + new_filters = set([f["id"] for f in res]) + + #remove old filters + for f in old_filters: + if not f in new_filters: + del self.filters[f] + + for f in new_filters: + if not f in old_filters: + filter_info = [ele for ele in res if ele["id"] == f][0] + self.filters[f] = Filter( + is_case_sensitive=filter_info["is_case_sensitive"], + c_to_s=filter_info["mode"] in ["C","B"], + s_to_c=filter_info["mode"] in ["S","B"], + is_blacklist=filter_info["is_blacklist"], + regex=b64decode(filter_info["regex"]), + blocked_packets=filter_info["n_packets"], + code=f + ) + self.proxy.filters = list(self.filters.values()) def __update_status_db(self, id, status): self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, id) - def __proxy_starter(self, id, proxy:Proxy, next_status): - def func(): - while True: - if check_port_is_open(proxy.public_port): - self.__update_status_db(id, next_status) - proxy.start(in_pause=(next_status==STATUS.PAUSE)) - self.__update_status_db(id, STATUS.STOP) - return - else: - time.sleep(.5) - - thread = KThread(target=func) - thread.start() - return thread - - def service_manager(self, id, next_status, signal:threading.Event, callback): - - proxy = None - thr_starter:KThread = None - filters = {} - - while True: - restart_required = False - reload_required = False - - data = self.get_service_data(id) - - #Close thread - if data is None: - if proxy and proxy.isactive(): - proxy.stop() - callback.set() - return - - if data["status"] == STATUS.STOP: - if thr_starter and thr_starter.is_alive(): thr_starter.kill() - - #Filter check - old_filters = set(filters.keys()) - new_filters = set([f["id"] for f in data["filters"]]) - - #remove old filters - for f in old_filters: - if not f in new_filters: - reload_required = True - del filters[f] - - for f in new_filters: - if not f in old_filters: - reload_required = True - filter_info = [ele for ele in data['filters'] if ele["id"] == f][0] - filters[f] = Filter( - is_case_sensitive=filter_info["is_case_sensitive"], - c_to_s=filter_info["mode"] in ["C","B"], - s_to_c=filter_info["mode"] in ["S","B"], - is_blacklist=filter_info["is_blacklist"], - regex=b64decode(filter_info["regex"]), - blocked_packets=filter_info["n_packets"], - code=f - ) - - - def stats_updater(filter:Filter): - self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code) - - if not proxy: - proxy = Proxy( - internal_port=data['internal_port'], - public_port=data['public_port'], - filters=list(filters.values()), - internal_host=LOCALHOST_IP, - callback_blocked_update=stats_updater - ) - - #Port checks - if proxy.internal_port != data['internal_port'] or proxy.public_port != data['public_port']: - proxy.internal_port = data['internal_port'] - proxy.public_port = data['public_port'] - restart_required = True - - #Update filters - if reload_required: - proxy.filters = list(filters.values()) - - #proxy status managment - if data["status"] != next_status[0]: + async def next(self,to): + async with self.lock: + if self.status != to: # ACTIVE -> PAUSE or PAUSE -> ACTIVE - if (data["status"], next_status[0]) in [(STATUS.ACTIVE, STATUS.PAUSE), (STATUS.PAUSE, STATUS.ACTIVE)]: - if restart_required: - proxy.restart(in_pause=next_status[0]) - else: - if next_status[0] == STATUS.ACTIVE: proxy.reload() - else: proxy.pause() - self.__update_status_db(id, next_status[0]) - reload_required = restart_required = False + if (self.status, to) in [(STATUS.ACTIVE, STATUS.PAUSE)]: + await self.proxy.pause() + self._set_status(to) + + elif (self.status, to) in [(STATUS.PAUSE, STATUS.ACTIVE)]: + await self.proxy.reload() + self._set_status(to) # ACTIVE -> STOP - elif (data["status"],next_status[0]) in [(STATUS.ACTIVE, STATUS.STOP), (STATUS.WAIT, STATUS.STOP), (STATUS.PAUSE, STATUS.STOP)]: #Stop proxy - if thr_starter and thr_starter.is_alive(): thr_starter.kill() - proxy.stop() - next_status[0] = STATUS.STOP - self.__update_status_db(id, STATUS.STOP) - reload_required = restart_required = False + elif (self.status,to) in [(STATUS.ACTIVE, STATUS.STOP), (STATUS.WAIT, STATUS.STOP), (STATUS.PAUSE, STATUS.STOP)]: #Stop proxy + if self.starter: self.starter.cancel() + await self.proxy.stop() + self._set_status(to) # STOP -> ACTIVE or STOP -> PAUSE - elif (data["status"], next_status[0]) in [(STATUS.STOP, STATUS.ACTIVE), (STATUS.STOP, STATUS.PAUSE)]: - self.__update_status_db(id, STATUS.WAIT) - thr_starter = self.__proxy_starter(id, proxy, next_status[0]) - reload_required = restart_required = False - - if data["status"] != STATUS.STOP: - if restart_required: proxy.restart(in_pause=(data["status"] == STATUS.PAUSE)) - elif reload_required and data["status"] != STATUS.PAUSE: proxy.reload() + elif (self.status, to) in [(STATUS.STOP, STATUS.ACTIVE), (STATUS.STOP, STATUS.PAUSE)]: + self._set_status(STATUS.WAIT) + self.__proxy_starter(to) - callback.set() - signal.wait() - signal.clear() - + + def _stats_updater(self,filter:Filter): + self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code) + + async def update_port(self): + async with self.lock: + self._update_port_from_db() + if self.status in [STATUS.PAUSE, STATUS.ACTIVE]: + await self.proxy.restart(in_pause=(self.status == STATUS.PAUSE)) + + def _set_status(self,status): + self.status = status + self.__update_status_db(self.id,status) + + + async def update_filters(self): + async with self.lock: + self._update_filters_from_db() + if self.status in [STATUS.PAUSE, STATUS.ACTIVE]: + await self.proxy.reload() + + def __proxy_starter(self,to): + async def func(): + while True: + if check_port_is_open(self.proxy.public_port): + self._set_status(to) + await self.proxy.start(in_pause=(to==STATUS.PAUSE)) + self._set_status(STATUS.STOP) + return + else: + await asyncio.sleep(.5) + self.starter = asyncio.create_task(func()) + +class ProxyManager: + def __init__(self, db:SQLite): + self.db = db + self.proxy_table = {} + self.lock = asyncio.Lock() + + async def close(self): + for key in list(self.proxy_table.keys()): + await self.remove(key) + + async def remove(self,id): + async with self.lock: + if id in self.proxy_table: + await self.proxy_table[id].proxy.stop() + del self.proxy_table[id] + + async def reload(self): + async with self.lock: + for srv in self.db.query('SELECT service_id, status FROM services;'): + srv_id, req_status = srv["service_id"], srv["status"] + if srv_id in self.proxy_table: + continue + + self.proxy_table[srv_id] = ServiceManager(srv_id,self.db) + await self.proxy_table[srv_id].next(req_status) + + def get(self,id): + return self.proxy_table[id] def check_port_is_open(port): try: @@ -293,5 +241,4 @@ def gen_internal_port(db): res = random.randint(30000, 45000) if len(db.query('SELECT 1 FROM services WHERE internal_port = ?;', res)) == 0: break - return res - + return res \ No newline at end of file