From b9a2f136b0f49d0398193737b1b16b8e2b92b6a8 Mon Sep 17 00:00:00 2001 From: DomySh Date: Tue, 28 Jun 2022 16:00:41 +0200 Subject: [PATCH] removed signals in the proxy binary communication --- backend/app.py | 9 ++- backend/proxy/__init__.py | 55 +++++++++--------- backend/proxy/proxy.cpp | 114 +++++++++++++++++++++++--------------- backend/utils.py | 7 ++- 4 files changed, 110 insertions(+), 75 deletions(-) diff --git a/backend/app.py b/backend/app.py index 47f9cd7..daf0141 100644 --- a/backend/app.py +++ b/backend/app.py @@ -17,6 +17,11 @@ firewall = ProxyManager(db) app = FastAPI(debug=DEBUG) +@app.on_event("shutdown") +def shutdown_event(): + firewall.close() + db.disconnect() + app.add_middleware(SessionMiddleware, secret_key=os.urandom(32)) SESSION_TOKEN = secrets.token_hex(8) APP_STATUS = "init" @@ -117,7 +122,7 @@ async def get_services(request: Request): s.public_port public_port, s.internal_port internal_port, s.name name, - COUNT(*) n_regex, + COUNT(r.regex_id) n_regex, COALESCE(SUM(r.blocked_packets),0) n_packets FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id GROUP BY s.service_id; @@ -133,7 +138,7 @@ async def get_service(request: Request, service_id: str): s.public_port public_port, s.internal_port internal_port, s.name name, - COUNT(*) n_regex, + COUNT(r.regex_id) n_regex, COALESCE(SUM(r.blocked_packets),0) n_packets FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id WHERE s.service_id = ? GROUP BY s.service_id; diff --git a/backend/proxy/__init__.py b/backend/proxy/__init__.py index 1ef5eaa..36fcf88 100755 --- a/backend/proxy/__init__.py +++ b/backend/proxy/__init__.py @@ -1,5 +1,3 @@ -from signal import SIGUSR1 -from secrets import token_urlsafe import subprocess, re, os from threading import Lock @@ -31,6 +29,7 @@ class Proxy: self.filter_map = {} self.filter_map_lock = Lock() self.update_config_lock = Lock() + self.status_change = Lock() self.public_host = public_host self.public_port = public_port self.internal_host = internal_host @@ -40,16 +39,20 @@ class Proxy: self.callback_blocked_update = callback_blocked_update def start(self, in_pause=False): + self.status_change.acquire() if not self.isactive(): - self.filter_map = self.compile_filters() - filters_codes = list(self.filter_map.keys()) if not in_pause else [] - proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./proxy") + try: + self.filter_map = self.compile_filters() + filters_codes = list(self.filter_map.keys()) if not in_pause else [] + proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./proxy") - self.process = subprocess.Popen( - [ proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port)], - stdout=subprocess.PIPE, stdin=subprocess.PIPE, universal_newlines=True - ) - self.update_config(filters_codes, sendsignal=False) + self.process = subprocess.Popen( + [ proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port)], + stdout=subprocess.PIPE, stdin=subprocess.PIPE, universal_newlines=True + ) + self.update_config(filters_codes) + finally: + self.status_change.release() for stdout_line in iter(self.process.stdout.readline, ""): if stdout_line.startswith("BLOCKED"): @@ -59,33 +62,33 @@ class Proxy: if self.callback_blocked_update: self.callback_blocked_update(self.filter_map[regex_id]) self.process.stdout.close() return self.process.wait() + else: + self.status_change.release() + def stop(self): - if self.isactive(): - self.process.terminate() - try: - self.process.wait(timeout=3) - except Exception: - self.process.kill() - return False - finally: - self.process = None - return True + with self.status_change: + if self.isactive(): + self.process.terminate() + try: + self.process.wait(timeout=3) + except Exception: + self.process.kill() + return False + finally: + self.process = None + return True def restart(self, in_pause=False): status = self.stop() self.start(in_pause=in_pause) return status - def update_config(self, filters_codes, sendsignal=True): + def update_config(self, filters_codes): with self.update_config_lock: if (self.isactive()): - self.process.stdin.write(" ".join(filters_codes)) - self.process.stdin.write(" END ") + self.process.stdin.write(" ".join(filters_codes)+"\n") self.process.stdin.flush() - if sendsignal: - self.process.send_signal(SIGUSR1) - def reload(self): if self.isactive(): diff --git a/backend/proxy/proxy.cpp b/backend/proxy/proxy.cpp index c401aa1..9987bb3 100644 --- a/backend/proxy/proxy.cpp +++ b/backend/proxy/proxy.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include @@ -20,8 +19,6 @@ using namespace std; -boost::asio::io_service *ios_loop = nullptr; - bool unhexlify(string const &hex, string &newString) { try{ int len = hex.length(); @@ -146,7 +143,7 @@ namespace tcp_proxy typedef ip::tcp::socket socket_type; typedef boost::shared_ptr ptr_type; - bridge(boost::asio::io_service& ios) + bridge(boost::asio::io_context& ios) : downstream_socket_(ios), upstream_socket_ (ios), thread_safety(ios) @@ -320,7 +317,7 @@ namespace tcp_proxy enum { max_data_length = 8192 }; //8KB unsigned char downstream_data_[max_data_length]; unsigned char upstream_data_ [max_data_length]; - boost::asio::io_service::strand thread_safety; + boost::asio::io_context::strand thread_safety; boost::mutex mutex_; public: @@ -328,12 +325,12 @@ namespace tcp_proxy { public: - acceptor(boost::asio::io_service& io_service, + acceptor(boost::asio::io_context& io_context, const string& local_host, unsigned short local_port, const string& upstream_host, unsigned short upstream_port) - : io_service_(io_service), + : io_context_(io_context), localhost_address(boost::asio::ip::address_v4::from_string(local_host)), - acceptor_(io_service_,ip::tcp::endpoint(localhost_address,local_port)), + acceptor_(io_context_,ip::tcp::endpoint(localhost_address,local_port)), upstream_port_(upstream_port), upstream_host_(upstream_host) {} @@ -342,7 +339,7 @@ namespace tcp_proxy { try { - session_ = boost::shared_ptr(new bridge(io_service_)); + session_ = boost::shared_ptr(new bridge(io_context_)); acceptor_.async_accept(session_->downstream_socket(), boost::asio::bind_executor(session_->thread_safety, @@ -378,7 +375,7 @@ namespace tcp_proxy } } - boost::asio::io_service& io_service_; + boost::asio::io_context& io_context_; ip::address_v4 localhost_address; ip::tcp::acceptor acceptor_; ptr_type session_; @@ -389,35 +386,64 @@ namespace tcp_proxy }; } - -void update_regex(){ - #ifdef DEBUG - cerr << "Updating configuration" << endl; - #endif - std::unique_lock lck(update_mutex); - regex_rules *regex_new_config = new regex_rules(); - string data; - while(true){ - cin >> data; - if (data == "END") break; - regex_new_config->add(data.c_str()); - } - regex_config.reset(regex_new_config); -} - -void signal_handler(int signal_num) -{ - if (signal_num == SIGUSR1){ - update_regex(); - }else if(signal_num == SIGTERM){ - if (ios_loop != nullptr) ios_loop->stop(); +void update_config (boost::asio::streambuf &input_buffer){ #ifdef DEBUG - cerr << "Close Requested" << endl; + cerr << "Updating configuration" << endl; #endif - exit(0); - } + std::istream config_stream(&input_buffer); + std::unique_lock lck(update_mutex); + regex_rules *regex_new_config = new regex_rules(); + string data; + while(!config_stream.eof()){ + config_stream >> data; + regex_new_config->add(data.c_str()); + } + regex_config.reset(regex_new_config); } +class async_updater +{ +public: + async_updater(boost::asio::io_context& io_context) : input_(io_context, ::dup(STDIN_FILENO)), thread_safety(io_context) + { + + boost::asio::async_read_until(input_, input_buffer_, '\n', + boost::asio::bind_executor(thread_safety, + boost::bind(&async_updater::on_update, this, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred))); + } + + void on_update(const boost::system::error_code& error, std::size_t length) + { + if (!error) + { + update_config(input_buffer_); + boost::asio::async_read_until(input_, input_buffer_, '\n', + boost::asio::bind_executor(thread_safety, + boost::bind(&async_updater::on_update, this, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred))); + } + else + { + close(); + } + } + + void close() + { + input_.close(); + } + +private: + boost::asio::posix::stream_descriptor input_; + boost::asio::io_context::strand thread_safety; + boost::asio::streambuf input_buffer_; +}; + + + int main(int argc, char* argv[]) { if (argc < 5) @@ -431,14 +457,14 @@ int main(int argc, char* argv[]) const string local_host = argv[1]; const string forward_host = argv[3]; - update_regex(); + boost::asio::io_context ios; - signal(SIGUSR1, signal_handler); - - boost::asio::io_service ios; - ios_loop = &ios; - - signal(SIGTERM, signal_handler); + boost::asio::streambuf buf; + boost::asio::posix::stream_descriptor cin_in(ios, ::dup(STDIN_FILENO)); + boost::asio::read_until(cin_in, buf,'\n'); + update_config(buf); + + async_updater updater(ios); #ifdef DEBUG cerr << "Starting Proxy" << endl; @@ -457,7 +483,7 @@ int main(int argc, char* argv[]) #else for (unsigned i = 0; i < thread::hardware_concurrency(); ++i) #endif - tg.create_thread(boost::bind(&boost::asio::io_service::run, &ios)); + tg.create_thread(boost::bind(&boost::asio::io_context::run, &ios)); tg.join_all(); #else @@ -474,4 +500,4 @@ int main(int argc, char* argv[]) #endif return 0; -} +} \ No newline at end of file diff --git a/backend/utils.py b/backend/utils.py index 846f228..66ec185 100755 --- a/backend/utils.py +++ b/backend/utils.py @@ -28,7 +28,8 @@ class SQLite(): self.conn.row_factory = dict_factory def disconnect(self) -> None: - self.conn.close() + with self.lock: + self.conn.close() def create_schema(self, tables = {}) -> None: cur = self.conn.cursor() @@ -75,7 +76,7 @@ class ProxyManager: self.db = db self.proxy_table = {} self.lock = threading.Lock() - atexit.register(self.clear) + atexit.register(self.close) def __clean_proxy_table(self): with self.lock: @@ -83,7 +84,7 @@ class ProxyManager: if not self.proxy_table[key]["thread"].is_alive(): del self.proxy_table[key] - def clear(self): + def close(self): with self.lock: for key in list(self.proxy_table.keys()): if self.proxy_table[key]["thread"].is_alive():