removed signals in the proxy binary communication

This commit is contained in:
DomySh
2022-06-28 16:00:41 +02:00
parent 631326eeb7
commit b9a2f136b0
4 changed files with 110 additions and 75 deletions

View File

@@ -17,6 +17,11 @@ firewall = ProxyManager(db)
app = FastAPI(debug=DEBUG) app = FastAPI(debug=DEBUG)
@app.on_event("shutdown")
def shutdown_event():
firewall.close()
db.disconnect()
app.add_middleware(SessionMiddleware, secret_key=os.urandom(32)) app.add_middleware(SessionMiddleware, secret_key=os.urandom(32))
SESSION_TOKEN = secrets.token_hex(8) SESSION_TOKEN = secrets.token_hex(8)
APP_STATUS = "init" APP_STATUS = "init"
@@ -117,7 +122,7 @@ async def get_services(request: Request):
s.public_port public_port, s.public_port public_port,
s.internal_port internal_port, s.internal_port internal_port,
s.name name, s.name name,
COUNT(*) n_regex, COUNT(r.regex_id) n_regex,
COALESCE(SUM(r.blocked_packets),0) n_packets COALESCE(SUM(r.blocked_packets),0) n_packets
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id
GROUP BY s.service_id; GROUP BY s.service_id;
@@ -133,7 +138,7 @@ async def get_service(request: Request, service_id: str):
s.public_port public_port, s.public_port public_port,
s.internal_port internal_port, s.internal_port internal_port,
s.name name, s.name name,
COUNT(*) n_regex, COUNT(r.regex_id) n_regex,
COALESCE(SUM(r.blocked_packets),0) n_packets COALESCE(SUM(r.blocked_packets),0) n_packets
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id WHERE s.service_id = ? FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id WHERE s.service_id = ?
GROUP BY s.service_id; GROUP BY s.service_id;

View File

@@ -1,5 +1,3 @@
from signal import SIGUSR1
from secrets import token_urlsafe
import subprocess, re, os import subprocess, re, os
from threading import Lock from threading import Lock
@@ -31,6 +29,7 @@ class Proxy:
self.filter_map = {} self.filter_map = {}
self.filter_map_lock = Lock() self.filter_map_lock = Lock()
self.update_config_lock = Lock() self.update_config_lock = Lock()
self.status_change = Lock()
self.public_host = public_host self.public_host = public_host
self.public_port = public_port self.public_port = public_port
self.internal_host = internal_host self.internal_host = internal_host
@@ -40,16 +39,20 @@ class Proxy:
self.callback_blocked_update = callback_blocked_update self.callback_blocked_update = callback_blocked_update
def start(self, in_pause=False): def start(self, in_pause=False):
self.status_change.acquire()
if not self.isactive(): if not self.isactive():
self.filter_map = self.compile_filters() try:
filters_codes = list(self.filter_map.keys()) if not in_pause else [] self.filter_map = self.compile_filters()
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./proxy") filters_codes = list(self.filter_map.keys()) if not in_pause else []
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./proxy")
self.process = subprocess.Popen( self.process = subprocess.Popen(
[ proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port)], [ proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port)],
stdout=subprocess.PIPE, stdin=subprocess.PIPE, universal_newlines=True stdout=subprocess.PIPE, stdin=subprocess.PIPE, universal_newlines=True
) )
self.update_config(filters_codes, sendsignal=False) self.update_config(filters_codes)
finally:
self.status_change.release()
for stdout_line in iter(self.process.stdout.readline, ""): for stdout_line in iter(self.process.stdout.readline, ""):
if stdout_line.startswith("BLOCKED"): if stdout_line.startswith("BLOCKED"):
@@ -59,33 +62,33 @@ class Proxy:
if self.callback_blocked_update: self.callback_blocked_update(self.filter_map[regex_id]) if self.callback_blocked_update: self.callback_blocked_update(self.filter_map[regex_id])
self.process.stdout.close() self.process.stdout.close()
return self.process.wait() return self.process.wait()
else:
self.status_change.release()
def stop(self): def stop(self):
if self.isactive(): with self.status_change:
self.process.terminate() if self.isactive():
try: self.process.terminate()
self.process.wait(timeout=3) try:
except Exception: self.process.wait(timeout=3)
self.process.kill() except Exception:
return False self.process.kill()
finally: return False
self.process = None finally:
return True self.process = None
return True
def restart(self, in_pause=False): def restart(self, in_pause=False):
status = self.stop() status = self.stop()
self.start(in_pause=in_pause) self.start(in_pause=in_pause)
return status return status
def update_config(self, filters_codes, sendsignal=True): def update_config(self, filters_codes):
with self.update_config_lock: with self.update_config_lock:
if (self.isactive()): if (self.isactive()):
self.process.stdin.write(" ".join(filters_codes)) self.process.stdin.write(" ".join(filters_codes)+"\n")
self.process.stdin.write(" END ")
self.process.stdin.flush() self.process.stdin.flush()
if sendsignal:
self.process.send_signal(SIGUSR1)
def reload(self): def reload(self):
if self.isactive(): if self.isactive():

View File

@@ -7,7 +7,6 @@
#include <cstddef> #include <cstddef>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <csignal>
#include <regex> #include <regex>
#include <mutex> #include <mutex>
@@ -20,8 +19,6 @@
using namespace std; using namespace std;
boost::asio::io_service *ios_loop = nullptr;
bool unhexlify(string const &hex, string &newString) { bool unhexlify(string const &hex, string &newString) {
try{ try{
int len = hex.length(); int len = hex.length();
@@ -146,7 +143,7 @@ namespace tcp_proxy
typedef ip::tcp::socket socket_type; typedef ip::tcp::socket socket_type;
typedef boost::shared_ptr<bridge> ptr_type; typedef boost::shared_ptr<bridge> ptr_type;
bridge(boost::asio::io_service& ios) bridge(boost::asio::io_context& ios)
: downstream_socket_(ios), : downstream_socket_(ios),
upstream_socket_ (ios), upstream_socket_ (ios),
thread_safety(ios) thread_safety(ios)
@@ -320,7 +317,7 @@ namespace tcp_proxy
enum { max_data_length = 8192 }; //8KB enum { max_data_length = 8192 }; //8KB
unsigned char downstream_data_[max_data_length]; unsigned char downstream_data_[max_data_length];
unsigned char upstream_data_ [max_data_length]; unsigned char upstream_data_ [max_data_length];
boost::asio::io_service::strand thread_safety; boost::asio::io_context::strand thread_safety;
boost::mutex mutex_; boost::mutex mutex_;
public: public:
@@ -328,12 +325,12 @@ namespace tcp_proxy
{ {
public: public:
acceptor(boost::asio::io_service& io_service, acceptor(boost::asio::io_context& io_context,
const string& local_host, unsigned short local_port, const string& local_host, unsigned short local_port,
const string& upstream_host, unsigned short upstream_port) const string& upstream_host, unsigned short upstream_port)
: io_service_(io_service), : io_context_(io_context),
localhost_address(boost::asio::ip::address_v4::from_string(local_host)), localhost_address(boost::asio::ip::address_v4::from_string(local_host)),
acceptor_(io_service_,ip::tcp::endpoint(localhost_address,local_port)), acceptor_(io_context_,ip::tcp::endpoint(localhost_address,local_port)),
upstream_port_(upstream_port), upstream_port_(upstream_port),
upstream_host_(upstream_host) upstream_host_(upstream_host)
{} {}
@@ -342,7 +339,7 @@ namespace tcp_proxy
{ {
try try
{ {
session_ = boost::shared_ptr<bridge>(new bridge(io_service_)); session_ = boost::shared_ptr<bridge>(new bridge(io_context_));
acceptor_.async_accept(session_->downstream_socket(), acceptor_.async_accept(session_->downstream_socket(),
boost::asio::bind_executor(session_->thread_safety, boost::asio::bind_executor(session_->thread_safety,
@@ -378,7 +375,7 @@ namespace tcp_proxy
} }
} }
boost::asio::io_service& io_service_; boost::asio::io_context& io_context_;
ip::address_v4 localhost_address; ip::address_v4 localhost_address;
ip::tcp::acceptor acceptor_; ip::tcp::acceptor acceptor_;
ptr_type session_; ptr_type session_;
@@ -389,35 +386,64 @@ namespace tcp_proxy
}; };
} }
void update_config (boost::asio::streambuf &input_buffer){
void update_regex(){
#ifdef DEBUG
cerr << "Updating configuration" << endl;
#endif
std::unique_lock<std::mutex> lck(update_mutex);
regex_rules *regex_new_config = new regex_rules();
string data;
while(true){
cin >> data;
if (data == "END") break;
regex_new_config->add(data.c_str());
}
regex_config.reset(regex_new_config);
}
void signal_handler(int signal_num)
{
if (signal_num == SIGUSR1){
update_regex();
}else if(signal_num == SIGTERM){
if (ios_loop != nullptr) ios_loop->stop();
#ifdef DEBUG #ifdef DEBUG
cerr << "Close Requested" << endl; cerr << "Updating configuration" << endl;
#endif #endif
exit(0); std::istream config_stream(&input_buffer);
} std::unique_lock<std::mutex> lck(update_mutex);
regex_rules *regex_new_config = new regex_rules();
string data;
while(!config_stream.eof()){
config_stream >> data;
regex_new_config->add(data.c_str());
}
regex_config.reset(regex_new_config);
} }
class async_updater
{
public:
async_updater(boost::asio::io_context& io_context) : input_(io_context, ::dup(STDIN_FILENO)), thread_safety(io_context)
{
boost::asio::async_read_until(input_, input_buffer_, '\n',
boost::asio::bind_executor(thread_safety,
boost::bind(&async_updater::on_update, this,
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
void on_update(const boost::system::error_code& error, std::size_t length)
{
if (!error)
{
update_config(input_buffer_);
boost::asio::async_read_until(input_, input_buffer_, '\n',
boost::asio::bind_executor(thread_safety,
boost::bind(&async_updater::on_update, this,
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
{
close();
}
}
void close()
{
input_.close();
}
private:
boost::asio::posix::stream_descriptor input_;
boost::asio::io_context::strand thread_safety;
boost::asio::streambuf input_buffer_;
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
if (argc < 5) if (argc < 5)
@@ -431,14 +457,14 @@ int main(int argc, char* argv[])
const string local_host = argv[1]; const string local_host = argv[1];
const string forward_host = argv[3]; const string forward_host = argv[3];
update_regex(); boost::asio::io_context ios;
signal(SIGUSR1, signal_handler); boost::asio::streambuf buf;
boost::asio::posix::stream_descriptor cin_in(ios, ::dup(STDIN_FILENO));
boost::asio::read_until(cin_in, buf,'\n');
update_config(buf);
boost::asio::io_service ios; async_updater updater(ios);
ios_loop = &ios;
signal(SIGTERM, signal_handler);
#ifdef DEBUG #ifdef DEBUG
cerr << "Starting Proxy" << endl; cerr << "Starting Proxy" << endl;
@@ -457,7 +483,7 @@ int main(int argc, char* argv[])
#else #else
for (unsigned i = 0; i < thread::hardware_concurrency(); ++i) for (unsigned i = 0; i < thread::hardware_concurrency(); ++i)
#endif #endif
tg.create_thread(boost::bind(&boost::asio::io_service::run, &ios)); tg.create_thread(boost::bind(&boost::asio::io_context::run, &ios));
tg.join_all(); tg.join_all();
#else #else

View File

@@ -28,7 +28,8 @@ class SQLite():
self.conn.row_factory = dict_factory self.conn.row_factory = dict_factory
def disconnect(self) -> None: def disconnect(self) -> None:
self.conn.close() with self.lock:
self.conn.close()
def create_schema(self, tables = {}) -> None: def create_schema(self, tables = {}) -> None:
cur = self.conn.cursor() cur = self.conn.cursor()
@@ -75,7 +76,7 @@ class ProxyManager:
self.db = db self.db = db
self.proxy_table = {} self.proxy_table = {}
self.lock = threading.Lock() self.lock = threading.Lock()
atexit.register(self.clear) atexit.register(self.close)
def __clean_proxy_table(self): def __clean_proxy_table(self):
with self.lock: with self.lock:
@@ -83,7 +84,7 @@ class ProxyManager:
if not self.proxy_table[key]["thread"].is_alive(): if not self.proxy_table[key]["thread"].is_alive():
del self.proxy_table[key] del self.proxy_table[key]
def clear(self): def close(self):
with self.lock: with self.lock:
for key in list(self.proxy_table.keys()): for key in list(self.proxy_table.keys()):
if self.proxy_table[key]["thread"].is_alive(): if self.proxy_table[key]["thread"].is_alive():