diff --git a/.gitignore b/.gitignore index 40095c0..03e9b7f 100644 --- a/.gitignore +++ b/.gitignore @@ -29,7 +29,6 @@ /firegex-compose-tmp-file.yml /firegex.py /tests/benchmark.csv -/backend/modules/nfproxy/socks/ # misc **/.DS_Store **/.env.local diff --git a/Dockerfile b/Dockerfile index 1e55e4e..d26b7a4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,7 +27,7 @@ RUN pip3 install --no-cache-dir --break-system-packages -r /execute/requirements COPY ./backend/binsrc /execute/binsrc RUN g++ binsrc/nfqueue.cpp -o modules/cppqueue -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libhs libmnl) -#RUN g++ binsrc/nfproxy-tun.cpp -o modules/cppproxy -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libmnl) +#RUN g++ binsrc/nfproxy-tun.cpp -o modules/cpproxy -std=c++23 -O3 -lnetfilter_queue -pthread -lnfnetlink $(pkg-config --cflags --libs libtins libmnl) COPY ./backend/ /execute/ COPY --from=frontend /app/dist/ ./frontend/ diff --git a/backend/binsrc/classes/netfilter.cpp b/backend/binsrc/classes/netfilter.cpp index 5ac7c1b..cb58213 100644 --- a/backend/binsrc/classes/netfilter.cpp +++ b/backend/binsrc/classes/netfilter.cpp @@ -5,9 +5,10 @@ #include #include #include -#include #include #include +#include +#include using namespace std; diff --git a/backend/binsrc/nfproxy-tun.cpp b/backend/binsrc/nfproxy-tun.cpp index 9443f04..b44135a 100644 --- a/backend/binsrc/nfproxy-tun.cpp +++ b/backend/binsrc/nfproxy-tun.cpp @@ -1,20 +1,48 @@ +#include "proxytun/settings.cpp" #include "proxytun/proxytun.cpp" -#include "utils.hpp" -#include +#include "classes/netfilter.cpp" #include +#include using namespace std; +void config_updater (){ + while (true){ + //TODO read config getline(cin, line); + if (cin.eof()){ + cerr << "[fatal] [updater] cin.eof()" << endl; + exit(EXIT_FAILURE); + } + if (cin.bad()){ + cerr << "[fatal] [updater] cin.bad()" << endl; + exit(EXIT_FAILURE); + } + cerr << "[info] [updater] Updating configuration" << endl; + + try{ + //TODO add data config.reset(new PyCodeConfig("")); + cerr << "[info] [updater] Config update done" << endl; + osyncstream(cout) << "ACK OK" << endl; + }catch(const std::exception& e){ + cerr << "[error] [updater] Failed to build new configuration!" << endl; + osyncstream(cout) << "ACK FAIL " << e.what() << endl; + } + } +} + int main(int argc, char *argv[]){ int n_of_threads = 1; char * n_threads_str = getenv("NTHREADS"); if (n_threads_str != nullptr) n_of_threads = ::atoi(n_threads_str); if(n_of_threads <= 0) n_of_threads = 1; + + config.reset(new PyCodeConfig("")); - NFQueueSequence queues(n_of_threads); + NFQueueSequence queues(n_of_threads); queues.start(); osyncstream(cout) << "QUEUES " << queues.init() << " " << queues.end() << endl; cerr << "[info] [main] Queues: " << queues.init() << ":" << queues.end() << " threads assigned: " << n_of_threads << endl; + config_updater(); } diff --git a/backend/binsrc/proxytun/proxytun.cpp b/backend/binsrc/proxytun/proxytun.cpp index dbce409..50aad35 100644 --- a/backend/binsrc/proxytun/proxytun.cpp +++ b/backend/binsrc/proxytun/proxytun.cpp @@ -1,5 +1,5 @@ -#ifndef PROXY_TUNNEL_CPP -#define PROXY_TUNNEL_CPP +#ifndef PROXY_TUNNEL_CLASS_CPP +#define PROXY_TUNNEL_CLASS_CPP #include #include @@ -12,88 +12,157 @@ #include #include #include +#include +#include +#include #include #include "../classes/netfilter.cpp" -#include +#include "stream_ctx.cpp" +#include "settings.cpp" using Tins::TCPIP::Stream; using Tins::TCPIP::StreamFollower; using namespace std; -typedef Tins::TCPIP::StreamIdentifier stream_id; - -class SocketTunnelQueue: public NfQueueExecutor { +class PyProxyQueue: public NfQueueExecutor { public: - - StreamFollower follower; + stream_ctx sctx; void before_loop() override { - follower.new_stream_callback(bind(on_new_stream, placeholders::_1)); - follower.stream_termination_callback(bind(on_stream_close, placeholders::_1)); + sctx.follower.new_stream_callback(bind(on_new_stream, placeholders::_1, &sctx)); + sctx.follower.stream_termination_callback(bind(on_stream_close, placeholders::_1, &sctx)); } void * callback_data_fetch() override{ - return nullptr; + return &sctx; } - static bool filter_action(){ + static bool filter_action(packet_info& info){ + shared_ptr conf = config; + auto stream_search = info.sctx->streams_ctx.find(info.sid); + pyfilter_ctx stream_match; + if (stream_search == info.sctx->streams_ctx.end()){ + // TODO: New pyfilter_ctx + }else{ + stream_match = stream_search->second; + } + + bool has_matched = false; + //TODO exec filtering action + + if (has_matched){ + // Say to firegex what filter has matched + //osyncstream(cout) << "BLOCKED " << rules_vector[match_res.matched] << endl; + return false; + } return true; } + + //If the stream has already been matched, drop all data, and try to close the connection + static void keep_fin_packet(stream_ctx* sctx){ + sctx->match_info.matching_has_been_called = true; + sctx->match_info.already_closed = true; + } - static void on_data_recv(Stream& stream, string data, bool is_input) { - bool result = filter_action(); + static void on_data_recv(Stream& stream, stream_ctx* sctx, string data) { + sctx->match_info.matching_has_been_called = true; + sctx->match_info.already_closed = false; + bool result = filter_action(*sctx->match_info.pkt_info); if (!result){ - stream.ignore_client_data(); - stream.ignore_server_data(); + sctx->clean_stream_by_id(sctx->match_info.pkt_info->sid); + stream.client_data_callback(bind(keep_fin_packet, sctx)); + stream.server_data_callback(bind(keep_fin_packet, sctx)); } + sctx->match_info.result = result; } //Input data filtering - static void on_client_data(Stream& stream) { - on_data_recv(stream, string(stream.client_payload().begin(), stream.client_payload().end()), true); + static void on_client_data(Stream& stream, stream_ctx* sctx) { + sctx->match_info.pkt_info->is_input = true; + on_data_recv(stream, sctx, string(stream.client_payload().begin(), stream.client_payload().end())); } //Server data filtering - static void on_server_data(Stream& stream) { - on_data_recv(stream, string(stream.server_payload().begin(), stream.server_payload().end()), false); + static void on_server_data(Stream& stream, stream_ctx* sctx) { + sctx->match_info.pkt_info->is_input = false; + on_data_recv(stream, sctx, string(stream.server_payload().begin(), stream.server_payload().end())); } - // A stream was terminated. The second argument is the reason why it was terminated - static void on_stream_close(Stream& stream) { + static void on_stream_close(Stream& stream, stream_ctx* sctx) { stream_id stream_id = stream_id::make_identifier(stream); + sctx->clean_stream_by_id(stream_id); } - static void on_new_stream(Stream& stream) { + static void on_new_stream(Stream& stream, stream_ctx* sctx) { stream.auto_cleanup_payloads(true); if (stream.is_partial_stream()) { - return; + //TODO take a decision about this... + stream.enable_recovery_mode(10 * 1024); } - stream.client_data_callback(bind(on_client_data, placeholders::_1)); - stream.server_data_callback(bind(on_server_data, placeholders::_1)); - stream.stream_closed_callback(bind(on_stream_close, placeholders::_1)); + stream.client_data_callback(bind(on_client_data, placeholders::_1, sctx)); + stream.server_data_callback(bind(on_server_data, placeholders::_1, sctx)); + stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, sctx)); } - - + template - static void build_verdict(T packet, uint8_t *payload, uint16_t plen, nlmsghdr *nlh_verdict, nfqnl_msg_packet_hdr *ph){ - sctx->tcp_match_util.matching_has_been_called = false; + static void build_verdict(T packet, uint8_t *payload, uint16_t plen, nlmsghdr *nlh_verdict, nfqnl_msg_packet_hdr *ph, stream_ctx* sctx, bool is_ipv6){ + Tins::TCP* tcp = packet.template find_pdu(); + if (!tcp){ + throw invalid_argument("Only TCP and UDP are supported"); + } + Tins::PDU* application_layer = tcp->inner_pdu(); + u_int16_t payload_size = 0; + if (application_layer != nullptr){ + payload_size = application_layer->size(); + } + packet_info pktinfo{ + payload: string(payload+plen - payload_size, payload+plen), + sid: stream_id::make_identifier(packet), + is_ipv6: is_ipv6, + sctx: sctx, + packet_pdu: &packet, + tcp: tcp, + }; + sctx->match_info.matching_has_been_called = false; + sctx->match_info.pkt_info = &pktinfo; sctx->follower.process_packet(packet); - if (sctx->tcp_match_util.matching_has_been_called && !sctx->tcp_match_util.result){ - Tins::PDU* data_layer = tcp->release_inner_pdu(); - if (data_layer != nullptr){ - delete data_layer; + // Do an action only is an ordered packet has been received + if (sctx->match_info.matching_has_been_called){ + bool empty_payload = pktinfo.payload.empty(); + //In this 2 cases we have to remove all data about the stream + if (!sctx->match_info.result || sctx->match_info.already_closed){ + #ifdef DEBUG + cerr << "[DEBUG] [NetfilterQueue.build_verdict] Stream matched, removing all data about it" << endl; + #endif + sctx->clean_stream_by_id(pktinfo.sid); + //If the packet has data, we have to remove it + if (!empty_payload){ + Tins::PDU* data_layer = tcp->release_inner_pdu(); + if (data_layer != nullptr){ + delete data_layer; + } + } + //For the first matched data or only for data packets, we set FIN bit + //This only for client packets, because this will trigger server to close the connection + //Packets will be filtered anyway also if client don't send packets + if ((!sctx->match_info.result || !empty_payload) && is_input){ + tcp->set_flag(Tins::TCP::FIN,1); + tcp->set_flag(Tins::TCP::ACK,1); + tcp->set_flag(Tins::TCP::SYN,0); + } + //Send the edited packet to the kernel + nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); } - tcp->set_flag(Tins::TCP::FIN,1); - tcp->set_flag(Tins::TCP::ACK,1); - tcp->set_flag(Tins::TCP::SYN,0); - nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); } nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); + } static int queue_cb(const nlmsghdr *nlh, const mnl_socket* nl, void *data_ptr) { + stream_ctx* sctx = (stream_ctx*)data_ptr; + //Extract attributes from the nlmsghdr nlattr *attr[NFQA_MAX+1] = {}; @@ -116,23 +185,25 @@ class SocketTunnelQueue: public NfQueueExecutor { struct nlmsghdr *nlh_verdict; nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id)); - // Check IP protocol version if ( (payload[0] & 0xf0) == 0x40 ){ - build_verdict(Tins::IP(payload, plen), payload, plen, nlh_verdict, ph); + build_verdict(Tins::IP(payload, plen), payload, plen, nlh_verdict, ph, sctx, false); }else{ - build_verdict(Tins::IPv6(payload, plen), payload, plen, nlh_verdict, ph); + build_verdict(Tins::IPv6(payload, plen), payload, plen, nlh_verdict, ph, sctx, true); } if (mnl_socket_sendto(nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) { throw runtime_error( "mnl_socket_send" ); } - return MNL_CB_OK; } - SocketTunnelQueue(int queue) : NfQueueExecutor(queue, &queue_cb) {} + PyProxyQueue(int queue) : NfQueueExecutor(queue, &queue_cb) {} + + ~PyProxyQueue() { + sctx.clean(); + } }; -#endif // PROXY_TUNNEL_CPP \ No newline at end of file +#endif // PROXY_TUNNEL_CLASS_CPP \ No newline at end of file diff --git a/backend/binsrc/proxytun/settings.cpp b/backend/binsrc/proxytun/settings.cpp new file mode 100644 index 0000000..fc43c51 --- /dev/null +++ b/backend/binsrc/proxytun/settings.cpp @@ -0,0 +1,26 @@ +#ifndef PROXY_TUNNEL_SETTINGS_CPP +#define PROXY_TUNNEL_SETTINGS_CPP + +#include +#include +#include +#include "../utils.hpp" +#include +#include +#include + +using namespace std; + +class PyCodeConfig{ + public: + const string code; + public: + PyCodeConfig(string pycode): code(pycode){} + + ~PyCodeConfig(){} +}; + +shared_ptr config; + +#endif // PROXY_TUNNEL_SETTINGS_CPP + diff --git a/backend/binsrc/proxytun/stream_ctx.cpp b/backend/binsrc/proxytun/stream_ctx.cpp new file mode 100644 index 0000000..b2ade14 --- /dev/null +++ b/backend/binsrc/proxytun/stream_ctx.cpp @@ -0,0 +1,60 @@ + +#ifndef STREAM_CTX_CPP +#define STREAM_CTX_CPP + +#include +#include +#include + +using Tins::TCPIP::Stream; +using Tins::TCPIP::StreamFollower; +using namespace std; + +typedef Tins::TCPIP::StreamIdentifier stream_id; + +struct pyfilter_ctx { + void * pyglob; // TODO python glob??? + string pycode; +}; + +typedef map matching_map; + +struct packet_info; + +struct tcp_stream_tmp { + bool matching_has_been_called = false; + bool already_closed = false; + bool result; + packet_info *pkt_info; +}; + +struct stream_ctx { + matching_map streams_ctx; + StreamFollower follower; + tcp_stream_tmp match_info; + void clean_stream_by_id(stream_id sid){ + auto stream_search = streams_ctx.find(sid); + if (stream_search != streams_ctx.end()){ + auto stream_match = stream_search->second; + //DEALLOC PY GLOB TODO + } + } + void clean(){ + for (auto ele: streams_ctx){ + //TODO dealloc ele.second.pyglob + } + } +}; + +struct packet_info { + string payload; + stream_id sid; + bool is_input; + bool is_ipv6; + stream_ctx* sctx; + Tins::PDU* packet_pdu; + Tins::TCP* tcp; +}; + + +#endif // STREAM_CTX_CPP \ No newline at end of file diff --git a/backend/binsrc/regex/regexfilter.cpp b/backend/binsrc/regex/regexfilter.cpp index bd86817..2690fa3 100644 --- a/backend/binsrc/regex/regexfilter.cpp +++ b/backend/binsrc/regex/regexfilter.cpp @@ -124,17 +124,22 @@ class RegexQueue: public NfQueueExecutor { } if (match_res.has_matched){ auto rules_vector = info.is_input ? conf->input_ruleset.regexes : conf->output_ruleset.regexes; - stringstream msg; - msg << "BLOCKED " << rules_vector[match_res.matched] << "\n"; - osyncstream(cout) << msg.str() << flush; + osyncstream(cout) << "BLOCKED " << rules_vector[match_res.matched] << endl; return false; } return true; } + + //If the stream has already been matched, drop all data, and try to close the connection + static void keep_fin_packet(stream_ctx* sctx){ + sctx->match_info.matching_has_been_called = true; + sctx->match_info.already_closed = true; + } static void on_data_recv(Stream& stream, stream_ctx* sctx, string data) { - sctx->tcp_match_util.matching_has_been_called = true; - bool result = filter_action(*sctx->tcp_match_util.pkt_info); + sctx->match_info.matching_has_been_called = true; + sctx->match_info.already_closed = false; + bool result = filter_action(*sctx->match_info.pkt_info); #ifdef DEBUG cerr << "[DEBUG] [NetfilterQueue.on_data_recv] result: " << result << endl; #endif @@ -142,11 +147,11 @@ class RegexQueue: public NfQueueExecutor { #ifdef DEBUG cerr << "[DEBUG] [NetfilterQueue.on_data_recv] Stream matched, removing all data about it" << endl; #endif - sctx->clean_stream_by_id(sctx->tcp_match_util.pkt_info->sid); - stream.ignore_client_data(); - stream.ignore_server_data(); + sctx->clean_stream_by_id(sctx->match_info.pkt_info->sid); + stream.client_data_callback(bind(keep_fin_packet, sctx)); + stream.server_data_callback(bind(keep_fin_packet, sctx)); } - sctx->tcp_match_util.result = result; + sctx->match_info.result = result; } //Input data filtering @@ -159,7 +164,6 @@ class RegexQueue: public NfQueueExecutor { on_data_recv(stream, sctx, string(stream.server_payload().begin(), stream.server_payload().end())); } - // A stream was terminated. The second argument is the reason why it was terminated static void on_stream_close(Stream& stream, stream_ctx* sctx) { stream_id stream_id = stream_id::make_identifier(stream); @@ -176,18 +180,17 @@ class RegexQueue: public NfQueueExecutor { stream.auto_cleanup_payloads(true); if (stream.is_partial_stream()) { #ifdef DEBUG - cerr << "[DEBUG] [NetfilterQueue.on_new_stream] Partial stream detected, skipping" << endl; + cerr << "[DEBUG] [NetfilterQueue.on_new_stream] Partial stream detected" << endl; #endif - return; + stream.enable_recovery_mode(10 * 1024); } stream.client_data_callback(bind(on_client_data, placeholders::_1, sctx)); stream.server_data_callback(bind(on_server_data, placeholders::_1, sctx)); stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, sctx)); } - - + template - static void build_verdict(T packet, uint8_t *payload, uint16_t plen, nlmsghdr *nlh_verdict, nfqnl_msg_packet_hdr *ph, stream_ctx* sctx, bool is_input){ + static void build_verdict(T packet, uint8_t *payload, uint16_t plen, nlmsghdr *nlh_verdict, nfqnl_msg_packet_hdr *ph, stream_ctx* sctx, bool is_input, bool is_ipv6){ Tins::TCP* tcp = packet.template find_pdu(); if (tcp){ @@ -197,15 +200,17 @@ class RegexQueue: public NfQueueExecutor { payload_size = application_layer->size(); } packet_info pktinfo{ - packet: string(payload, payload+plen), payload: string(payload+plen - payload_size, payload+plen), sid: stream_id::make_identifier(packet), is_input: is_input, is_tcp: true, + is_ipv6: is_ipv6, sctx: sctx, + packet_pdu: &packet, + layer4_pdu: tcp, }; - sctx->tcp_match_util.matching_has_been_called = false; - sctx->tcp_match_util.pkt_info = &pktinfo; + sctx->match_info.matching_has_been_called = false; + sctx->match_info.pkt_info = &pktinfo; #ifdef DEBUG cerr << "[DEBUG] [NetfilterQueue.build_verdict] TCP Packet received " << packet.src_addr() << ":" << tcp->sport() << " -> " << packet.dst_addr() << ":" << tcp->dport() << " thr: " << this_thread::get_id() << ", sending to libtins StreamFollower" << endl; #endif @@ -217,15 +222,33 @@ class RegexQueue: public NfQueueExecutor { cerr << "[DEBUG] [NetfilterQueue.build_verdict] StreamFollower has NOT called matching functions" << endl; } #endif - if (sctx->tcp_match_util.matching_has_been_called && !sctx->tcp_match_util.result){ - Tins::PDU* data_layer = tcp->release_inner_pdu(); - if (data_layer != nullptr){ - delete data_layer; + // Do an action only is an ordered packet has been received + if (sctx->match_info.matching_has_been_called){ + bool empty_payload = pktinfo.payload.empty(); + //In this 2 cases we have to remove all data about the stream + if (!sctx->match_info.result || sctx->match_info.already_closed){ + #ifdef DEBUG + cerr << "[DEBUG] [NetfilterQueue.build_verdict] Stream matched, removing all data about it" << endl; + #endif + sctx->clean_stream_by_id(pktinfo.sid); + //If the packet has data, we have to remove it + if (!empty_payload){ + Tins::PDU* data_layer = tcp->release_inner_pdu(); + if (data_layer != nullptr){ + delete data_layer; + } + } + //For the first matched data or only for data packets, we set FIN bit + //This only for client packets, because this will trigger server to close the connection + //Packets will be filtered anyway also if client don't send packets + if ((!sctx->match_info.result || !empty_payload) && is_input){ + tcp->set_flag(Tins::TCP::FIN,1); + tcp->set_flag(Tins::TCP::ACK,1); + tcp->set_flag(Tins::TCP::SYN,0); + } + //Send the edited packet to the kernel + nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); } - tcp->set_flag(Tins::TCP::FIN,1); - tcp->set_flag(Tins::TCP::ACK,1); - tcp->set_flag(Tins::TCP::SYN,0); - nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); } nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); }else{ @@ -242,12 +265,14 @@ class RegexQueue: public NfQueueExecutor { nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); } packet_info pktinfo{ - packet: string(payload, payload+plen), payload: string(payload+plen - payload_size, payload+plen), sid: stream_id::make_identifier(packet), is_input: is_input, is_tcp: false, + is_ipv6: is_ipv6, sctx: sctx, + packet_pdu: &packet, + layer4_pdu: udp, }; if (filter_action(pktinfo)){ nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); @@ -298,9 +323,9 @@ class RegexQueue: public NfQueueExecutor { // Check IP protocol version if ( (payload[0] & 0xf0) == 0x40 ){ - build_verdict(Tins::IP(payload, plen), payload, plen, nlh_verdict, ph, sctx, is_input); + build_verdict(Tins::IP(payload, plen), payload, plen, nlh_verdict, ph, sctx, is_input, false); }else{ - build_verdict(Tins::IPv6(payload, plen), payload, plen, nlh_verdict, ph, sctx, is_input); + build_verdict(Tins::IPv6(payload, plen), payload, plen, nlh_verdict, ph, sctx, is_input, true); } if (mnl_socket_sendto(nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) { diff --git a/backend/binsrc/regex/stream_ctx.cpp b/backend/binsrc/regex/stream_ctx.cpp index 36df1fb..8b12e45 100644 --- a/backend/binsrc/regex/stream_ctx.cpp +++ b/backend/binsrc/regex/stream_ctx.cpp @@ -51,6 +51,7 @@ struct packet_info; struct tcp_stream_tmp { bool matching_has_been_called = false; + bool already_closed = false; bool result; packet_info *pkt_info; }; @@ -62,7 +63,7 @@ struct stream_ctx { hs_scratch_t* out_scratch = nullptr; u_int16_t latest_config_ver = 0; StreamFollower follower; - tcp_stream_tmp tcp_match_util; + tcp_stream_tmp match_info; void clean_scratches(){ if (out_scratch != nullptr){ @@ -131,12 +132,14 @@ struct stream_ctx { }; struct packet_info { - string packet; string payload; stream_id sid; bool is_input; bool is_tcp; + bool is_ipv6; stream_ctx* sctx; + Tins::PDU* packet_pdu; + Tins::PDU* layer4_pdu; }; diff --git a/backend/modules/nfproxy/firegex.py b/backend/modules/nfproxy/firegex.py index 20651a5..0ccf94f 100644 --- a/backend/modules/nfproxy/firegex.py +++ b/backend/modules/nfproxy/firegex.py @@ -3,8 +3,9 @@ from utils import run_func from modules.nfproxy.models import Service, PyFilter import os import asyncio -import socket -import shutil +from utils import DEBUG +import traceback +from fastapi import HTTPException nft = FiregexTables() @@ -13,22 +14,16 @@ class FiregexInterceptor: def __init__(self): self.srv:Service self._stats_updater_cb:callable + self.filter_map_lock:asyncio.Lock + self.filter_map: dict[str, PyFilter] + self.pyfilters: set[PyFilter] + self.update_config_lock:asyncio.Lock self.process:asyncio.subprocess.Process - self.base_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "socks", self.srv.id - ) - self.n_threads = int(os.getenv("NTHREADS","1")) - - self.connection_socket = os.path.join(self.base_dir, "connection.sock") - self.vedict_sockets = [os.path.join(self.base_dir, f"vedict{i}.sock") for i in range(self.n_threads)] - self.socks = [] - - def add_sock(self, path): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) - sock.bind(path) - self.socks.append(sock) - return sock + self.update_task: asyncio.Task + self.ack_arrived = False + self.ack_status = None + self.ack_fail_what = "" + self.ack_lock = asyncio.Lock() async def _call_stats_updater_callback(self, filter: PyFilter): if self._stats_updater_cb: @@ -37,25 +32,24 @@ class FiregexInterceptor: @classmethod async def start(cls, srv: Service, stats_updater_cb:callable): self = cls() - self.srv = srv self._stats_updater_cb = stats_updater_cb - os.makedirs(self.base_dir, exist_ok=True) - self.add_sock(self.connection_socket) - for path in self.vedict_sockets: - self.add_sock(path) + self.srv = srv + self.filter_map_lock = asyncio.Lock() + self.update_config_lock = asyncio.Lock() queue_range = await self._start_binary() - # TODO starts python workers + self.update_task = asyncio.create_task(self.update_stats()) nft.add(self.srv, queue_range) + if not self.ack_lock.locked(): + await self.ack_lock.acquire() return self async def _start_binary(self): - proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppproxy") + proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cpproxy") self.process = await asyncio.create_subprocess_exec( proxy_binary_path, stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, + env={"NTHREADS": os.getenv("NTHREADS","1")}, ) - self.process.stdin.write(self.base_dir.encode().hex().encode()+b" 3\n") - await self.process.stdin.drain() line_fut = self.process.stdout.readuntil() try: line_fut = await asyncio.wait_for(line_fut, timeout=3) @@ -70,14 +64,58 @@ class FiregexInterceptor: self.process.kill() raise Exception("Invalid binary output") + async def update_stats(self): + try: + while True: + line = (await self.process.stdout.readuntil()).decode() + if DEBUG: + print(line) + if line.startswith("BLOCKED "): + filter_id = line.split()[1] + async with self.filter_map_lock: + if filter_id in self.filter_map: + self.filter_map[filter_id].blocked_packets+=1 + await self.filter_map[filter_id].update() + if line.startswith("EDITED "): + filter_id = line.split()[1] + async with self.filter_map_lock: + if filter_id in self.filter_map: + self.filter_map[filter_id].edited_packets+=1 + await self.filter_map[filter_id].update() + if line.startswith("ACK "): + self.ack_arrived = True + self.ack_status = line.split()[1].upper() == "OK" + if not self.ack_status: + self.ack_fail_what = " ".join(line.split()[2:]) + self.ack_lock.release() + except asyncio.CancelledError: + pass + except asyncio.IncompleteReadError: + pass + except Exception: + traceback.print_exc() + async def stop(self): + self.update_task.cancel() if self.process and self.process.returncode is None: self.process.kill() - for sock in self.socks: - sock.close() - shutil.rmtree(self.base_dir) + + async def _update_config(self, filters_codes): + async with self.update_config_lock: + # TODO write compiled code correctly + # self.process.stdin.write((" ".join(filters_codes)+"\n").encode()) + await self.process.stdin.drain() + try: + async with asyncio.timeout(3): + await self.ack_lock.acquire() + except TimeoutError: + pass + if not self.ack_arrived or not self.ack_status: + raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}") async def reload(self, filters:list[PyFilter]): - # filters are the functions to use in the workers (other functions are disabled or not flagged as filters) - # TODO update filters in python workers (prob for new filters added) (reading from file????) - pass \ No newline at end of file + async with self.filter_map_lock: + self.filter_map = self.compile_filters(filters) + # TODO COMPILE CODE + #await self._update_config(filters_codes) TODO pass the compiled code + diff --git a/backend/routers/nfproxy.py b/backend/routers/nfproxy.py index 8580404..c80aa72 100644 --- a/backend/routers/nfproxy.py +++ b/backend/routers/nfproxy.py @@ -8,7 +8,6 @@ from utils.sqlite import SQLite from utils import ip_parse, refactor_name, socketio_emit, PortType from utils.models import ResetRequest, StatusMessageModel -# TODO copied file, review class ServiceModel(BaseModel): service_id: str status: str diff --git a/proxy-client/firegex/__main__.py b/proxy-client/firegex/__main__.py index 56d0693..adcf48a 100644 --- a/proxy-client/firegex/__main__.py +++ b/proxy-client/firegex/__main__.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # TODO implement cli start function -from firegexproxy.cli import run +from firegex.cli import run if __name__ == "__main__": run()