asyncio in proxy manager

This commit is contained in:
nik012003
2022-06-28 19:38:28 +02:00
parent 8fb4218689
commit 4971281f5a
3 changed files with 219 additions and 272 deletions

View File

@@ -1,6 +1,6 @@
import threading
from proxy import Filter, Proxy
import random, string, os, threading, sqlite3, time, atexit, socket
from kthread import KThread
import random, string, os, sqlite3, socket, asyncio
from base64 import b64decode
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
@@ -69,209 +69,157 @@ class STATUS:
STOP = "stop"
PAUSE = "pause"
ACTIVE = "active"
class ProxyManager:
def __init__(self, db:SQLite):
class ServiceNotFoundException(Exception):
pass
class ServiceManager:
def __init__(self, id, db):
self.id = id
self.db = db
self.proxy_table = {}
self.lock = threading.Lock()
atexit.register(self.close)
def __clean_proxy_table(self):
with self.lock:
for key in list(self.proxy_table.keys()):
if not self.proxy_table[key]["thread"].is_alive():
del self.proxy_table[key]
def close(self):
with self.lock:
for key in list(self.proxy_table.keys()):
if self.proxy_table[key]["thread"].is_alive():
self.proxy_table[key]["thread"].kill()
del self.proxy_table[key]
def reload(self):
self.__clean_proxy_table()
with self.lock:
for srv in self.db.query('SELECT service_id, status FROM services;'):
srv_id, n_status = srv["service_id"], srv["status"]
if srv_id in self.proxy_table:
continue
update_signal = threading.Event()
callback_signal = threading.Event()
req_status = [n_status]
thread = KThread(target=self.service_manager, args=(srv_id, req_status, update_signal, callback_signal))
self.proxy_table[srv_id] = {
"thread":thread,
"event":update_signal,
"callback":callback_signal,
"next_status":req_status
}
thread.start()
callback_signal.wait()
callback_signal.clear()
def get_service_data(self, id):
self.proxy = Proxy(
internal_host=LOCALHOST_IP,
callback_blocked_update=self._stats_updater
)
self.status = STATUS.STOP
self.filters = {}
self._proxy_update()
self.lock = asyncio.Lock()
self.starter = None
def _update_port_from_db(self):
res = self.db.query("""
SELECT
service_id `id`,
status,
public_port,
internal_port
FROM services WHERE service_id = ?;
""", id)
if len(res) == 0: return None
else: res = res[0]
res["filters"] = self.db.query("""
""", self.id)
if len(res) == 0: raise ServiceNotFoundException()
self.proxy.internal_port = res[0]["internal_port"]
self.proxy.public_port = res[0]["public_port"]
def _proxy_update(self):
self._update_port_from_db()
self._update_filters_from_db()
def _update_filters_from_db(self):
res = self.db.query("""
SELECT
regex, mode, regex_id `id`, is_blacklist,
blocked_packets n_packets, is_case_sensitive
FROM regexes WHERE service_id = ?;
""", id)
return res
""", self.id)
def change_status(self, id, to):
with self.lock:
if id in self.proxy_table:
if self.proxy_table[id]["thread"].is_alive():
self.proxy_table[id]["next_status"][0] = to
self.proxy_table[id]["event"].set()
self.proxy_table[id]["callback"].wait()
self.proxy_table[id]["callback"].clear()
else:
del self.proxy_table[id]
def fire_update(self, id):
with self.lock:
if id in self.proxy_table:
if self.proxy_table[id]["thread"].is_alive():
self.proxy_table[id]["event"].set()
self.proxy_table[id]["callback"].wait()
self.proxy_table[id]["callback"].clear()
else:
del self.proxy_table[id]
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f["id"] for f in res])
#remove old filters
for f in old_filters:
if not f in new_filters:
del self.filters[f]
for f in new_filters:
if not f in old_filters:
filter_info = [ele for ele in res if ele["id"] == f][0]
self.filters[f] = Filter(
is_case_sensitive=filter_info["is_case_sensitive"],
c_to_s=filter_info["mode"] in ["C","B"],
s_to_c=filter_info["mode"] in ["S","B"],
is_blacklist=filter_info["is_blacklist"],
regex=b64decode(filter_info["regex"]),
blocked_packets=filter_info["n_packets"],
code=f
)
self.proxy.filters = list(self.filters.values())
def __update_status_db(self, id, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, id)
def __proxy_starter(self, id, proxy:Proxy, next_status):
def func():
while True:
if check_port_is_open(proxy.public_port):
self.__update_status_db(id, next_status)
proxy.start(in_pause=(next_status==STATUS.PAUSE))
self.__update_status_db(id, STATUS.STOP)
return
else:
time.sleep(.5)
thread = KThread(target=func)
thread.start()
return thread
def service_manager(self, id, next_status, signal:threading.Event, callback):
proxy = None
thr_starter:KThread = None
filters = {}
while True:
restart_required = False
reload_required = False
data = self.get_service_data(id)
#Close thread
if data is None:
if proxy and proxy.isactive():
proxy.stop()
callback.set()
return
if data["status"] == STATUS.STOP:
if thr_starter and thr_starter.is_alive(): thr_starter.kill()
#Filter check
old_filters = set(filters.keys())
new_filters = set([f["id"] for f in data["filters"]])
#remove old filters
for f in old_filters:
if not f in new_filters:
reload_required = True
del filters[f]
for f in new_filters:
if not f in old_filters:
reload_required = True
filter_info = [ele for ele in data['filters'] if ele["id"] == f][0]
filters[f] = Filter(
is_case_sensitive=filter_info["is_case_sensitive"],
c_to_s=filter_info["mode"] in ["C","B"],
s_to_c=filter_info["mode"] in ["S","B"],
is_blacklist=filter_info["is_blacklist"],
regex=b64decode(filter_info["regex"]),
blocked_packets=filter_info["n_packets"],
code=f
)
def stats_updater(filter:Filter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
if not proxy:
proxy = Proxy(
internal_port=data['internal_port'],
public_port=data['public_port'],
filters=list(filters.values()),
internal_host=LOCALHOST_IP,
callback_blocked_update=stats_updater
)
#Port checks
if proxy.internal_port != data['internal_port'] or proxy.public_port != data['public_port']:
proxy.internal_port = data['internal_port']
proxy.public_port = data['public_port']
restart_required = True
#Update filters
if reload_required:
proxy.filters = list(filters.values())
#proxy status managment
if data["status"] != next_status[0]:
async def next(self,to):
async with self.lock:
if self.status != to:
# ACTIVE -> PAUSE or PAUSE -> ACTIVE
if (data["status"], next_status[0]) in [(STATUS.ACTIVE, STATUS.PAUSE), (STATUS.PAUSE, STATUS.ACTIVE)]:
if restart_required:
proxy.restart(in_pause=next_status[0])
else:
if next_status[0] == STATUS.ACTIVE: proxy.reload()
else: proxy.pause()
self.__update_status_db(id, next_status[0])
reload_required = restart_required = False
if (self.status, to) in [(STATUS.ACTIVE, STATUS.PAUSE)]:
await self.proxy.pause()
self._set_status(to)
elif (self.status, to) in [(STATUS.PAUSE, STATUS.ACTIVE)]:
await self.proxy.reload()
self._set_status(to)
# ACTIVE -> STOP
elif (data["status"],next_status[0]) in [(STATUS.ACTIVE, STATUS.STOP), (STATUS.WAIT, STATUS.STOP), (STATUS.PAUSE, STATUS.STOP)]: #Stop proxy
if thr_starter and thr_starter.is_alive(): thr_starter.kill()
proxy.stop()
next_status[0] = STATUS.STOP
self.__update_status_db(id, STATUS.STOP)
reload_required = restart_required = False
elif (self.status,to) in [(STATUS.ACTIVE, STATUS.STOP), (STATUS.WAIT, STATUS.STOP), (STATUS.PAUSE, STATUS.STOP)]: #Stop proxy
if self.starter: self.starter.cancel()
await self.proxy.stop()
self._set_status(to)
# STOP -> ACTIVE or STOP -> PAUSE
elif (data["status"], next_status[0]) in [(STATUS.STOP, STATUS.ACTIVE), (STATUS.STOP, STATUS.PAUSE)]:
self.__update_status_db(id, STATUS.WAIT)
thr_starter = self.__proxy_starter(id, proxy, next_status[0])
reload_required = restart_required = False
if data["status"] != STATUS.STOP:
if restart_required: proxy.restart(in_pause=(data["status"] == STATUS.PAUSE))
elif reload_required and data["status"] != STATUS.PAUSE: proxy.reload()
elif (self.status, to) in [(STATUS.STOP, STATUS.ACTIVE), (STATUS.STOP, STATUS.PAUSE)]:
self._set_status(STATUS.WAIT)
self.__proxy_starter(to)
callback.set()
signal.wait()
signal.clear()
def _stats_updater(self,filter:Filter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
async def update_port(self):
async with self.lock:
self._update_port_from_db()
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
await self.proxy.restart(in_pause=(self.status == STATUS.PAUSE))
def _set_status(self,status):
self.status = status
self.__update_status_db(self.id,status)
async def update_filters(self):
async with self.lock:
self._update_filters_from_db()
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
await self.proxy.reload()
def __proxy_starter(self,to):
async def func():
while True:
if check_port_is_open(self.proxy.public_port):
self._set_status(to)
await self.proxy.start(in_pause=(to==STATUS.PAUSE))
self._set_status(STATUS.STOP)
return
else:
await asyncio.sleep(.5)
self.starter = asyncio.create_task(func())
class ProxyManager:
def __init__(self, db:SQLite):
self.db = db
self.proxy_table = {}
self.lock = asyncio.Lock()
async def close(self):
for key in list(self.proxy_table.keys()):
await self.remove(key)
async def remove(self,id):
async with self.lock:
if id in self.proxy_table:
await self.proxy_table[id].proxy.stop()
del self.proxy_table[id]
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT service_id, status FROM services;'):
srv_id, req_status = srv["service_id"], srv["status"]
if srv_id in self.proxy_table:
continue
self.proxy_table[srv_id] = ServiceManager(srv_id,self.db)
await self.proxy_table[srv_id].next(req_status)
def get(self,id):
return self.proxy_table[id]
def check_port_is_open(port):
try:
@@ -293,5 +241,4 @@ def gen_internal_port(db):
res = random.randint(30000, 45000)
if len(db.query('SELECT 1 FROM services WHERE internal_port = ?;', res)) == 0:
break
return res
return res