From Threads to Multiprocess
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Dict
|
||||
from proxy import Filter, Proxy
|
||||
import os, sqlite3, socket, asyncio
|
||||
from base64 import b64decode
|
||||
@@ -10,7 +11,6 @@ class SQLite():
|
||||
self.conn = None
|
||||
self.cur = None
|
||||
self.db_name = db_name
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def connect(self) -> None:
|
||||
try:
|
||||
@@ -27,8 +27,7 @@ class SQLite():
|
||||
self.conn.row_factory = dict_factory
|
||||
|
||||
def disconnect(self) -> None:
|
||||
with self.lock:
|
||||
self.conn.close()
|
||||
self.conn.close()
|
||||
|
||||
def create_schema(self, tables = {}) -> None:
|
||||
cur = self.conn.cursor()
|
||||
@@ -39,9 +38,8 @@ class SQLite():
|
||||
def query(self, query, *values):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
with self.lock:
|
||||
cur.execute(query, values)
|
||||
return cur.fetchall()
|
||||
cur.execute(query, values)
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
cur.close()
|
||||
try: self.conn.commit()
|
||||
@@ -100,10 +98,7 @@ class ServiceManager:
|
||||
def __init__(self, port, db):
|
||||
self.port = port
|
||||
self.db = db
|
||||
self.proxy = Proxy(
|
||||
callback_blocked_update=self._stats_updater,
|
||||
public_port=port
|
||||
)
|
||||
self.proxy = Proxy(port)
|
||||
self.status = STATUS.STOP
|
||||
self.filters = {}
|
||||
self._update_filters_from_db()
|
||||
@@ -139,7 +134,7 @@ class ServiceManager:
|
||||
blocked_packets=filter_info["n_packets"],
|
||||
code=f
|
||||
)
|
||||
self.proxy.filters = list(self.filters.values())
|
||||
self.proxy.set_filters(self.filters.values())
|
||||
|
||||
def __update_status_db(self, status):
|
||||
self.db.query("UPDATE services SET status = ? WHERE port = ?;", status, self.port)
|
||||
@@ -161,7 +156,12 @@ class ServiceManager:
|
||||
|
||||
|
||||
def _stats_updater(self,filter:Filter):
|
||||
print(filter, filter.blocked, filter.code)
|
||||
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
|
||||
|
||||
def update_stats(self):
|
||||
for ele in self.proxy.filters:
|
||||
self._stats_updater(ele)
|
||||
|
||||
def _set_status(self,status):
|
||||
self.status = status
|
||||
@@ -174,8 +174,11 @@ class ServiceManager:
|
||||
class ProxyManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.proxy_table = {}
|
||||
self.proxy_table: Dict[ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
def init_updater(self):
|
||||
asyncio.create_task(self._stats_updater())
|
||||
|
||||
async def close(self):
|
||||
for key in list(self.proxy_table.keys()):
|
||||
@@ -197,6 +200,17 @@ class ProxyManager:
|
||||
self.proxy_table[srv_port] = ServiceManager(srv_port,self.db)
|
||||
await self.proxy_table[srv_port].next(req_status)
|
||||
|
||||
async def _stats_updater(self):
|
||||
while True:
|
||||
print("ALIVE!")
|
||||
try:
|
||||
for key in list(self.proxy_table.keys()):
|
||||
self.proxy_table[key].update_stats()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
def get(self,port):
|
||||
if port in self.proxy_table:
|
||||
return self.proxy_table[port]
|
||||
|
||||
Reference in New Issue
Block a user