diff --git a/.gitignore b/.gitignore index d74d9af..88bc9d8 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ /frontend/coverage /fgex-lib/firegex.egg-info /fgex-lib/dist +/fgex-lib/build /fgex-lib/fgex-pip/fgex.egg-info /fgex-lib/fgex-pip/dist /backend/db/ diff --git a/Dockerfile b/Dockerfile index 4599907..feb8659 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ RUN bun run build FROM --platform=$TARGETARCH registry.fedoraproject.org/fedora:latest RUN dnf -y update && dnf install -y python3.13-devel @development-tools gcc-c++ \ libnetfilter_queue-devel libnfnetlink-devel libmnl-devel libcap-ng-utils nftables \ - vectorscan-devel libtins-devel python3-nftables libpcap-devel boost-devel uv + vectorscan-devel libtins-devel python3-nftables libpcap-devel boost-devel uv redis RUN mkdir -p /execute/modules WORKDIR /execute diff --git a/backend/app.py b/backend/app.py index b6646f2..f12224c 100644 --- a/backend/app.py +++ b/backend/app.py @@ -9,12 +9,13 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import jwt from passlib.context import CryptContext from utils.sqlite import SQLite -from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager +from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager, NORELOAD from utils.loader import frontend_deploy, load_routers from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel from contextlib import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware import socketio +from socketio.exceptions import ConnectionRefusedError # DB init db = SQLite('db/firegex.db') @@ -52,7 +53,6 @@ if DEBUG: allow_headers=["*"], ) - utils.socketio = socketio.AsyncServer( async_mode="asgi", cors_allowed_origins=[], @@ -69,9 +69,6 @@ def set_psw(psw: str): hash_psw = crypto.hash(psw) db.put("password",hash_psw) -@utils.socketio.on("update") -async def updater(): pass - def create_access_token(data: dict): to_encode = data.copy() encoded_jwt = jwt.encode(to_encode, JWT_SECRET(), algorithm=JWT_ALGORITHM) @@ -90,6 +87,28 @@ async def check_login(token: str = Depends(oauth2_scheme)): return False return logged_in +@utils.socketio.on("connect") +async def sio_connect(sid, environ, auth): + if not auth or not await check_login(auth.get("token")): + raise ConnectionRefusedError("Unauthorized") + utils.sid_list.add(sid) + +@utils.socketio.on("disconnect") +async def sio_disconnect(sid): + try: + utils.sid_list.remove(sid) + except KeyError: + pass + +async def disconnect_all(): + while True: + if len(utils.sid_list) == 0: + break + await utils.socketio.disconnect(utils.sid_list.pop()) + +@utils.socketio.on("update") +async def updater(): pass + async def is_loggined(auth: bool = Depends(check_login)): if not auth: raise HTTPException( @@ -122,6 +141,7 @@ async def login_api(form: OAuth2PasswordRequestForm = Depends()): return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"} raise HTTPException(406,"Wrong password!") + @app.post('/api/set-password', response_model=ChangePasswordModel) async def set_password(form: PasswordForm): """Set the password of firegex""" @@ -143,6 +163,7 @@ async def change_password(form: PasswordChangeForm): return {"status":"Cannot insert an empty password!"} if form.expire: db.put("secret", secrets.token_hex(32)) + await disconnect_all() set_psw(form.password) await refresh_frontend() @@ -200,7 +221,7 @@ if __name__ == '__main__': "app:app", host="::" if DEBUG else None, port=FIREGEX_PORT, - reload=DEBUG, + reload=DEBUG and not NORELOAD, access_log=True, workers=1, # Firewall module can't be replicated in multiple workers # Later the firewall module will be moved to a separate process diff --git a/backend/binsrc/classes/nfqueue.cpp b/backend/binsrc/classes/nfqueue.cpp index 582e683..7bfe9c4 100644 --- a/backend/binsrc/classes/nfqueue.cpp +++ b/backend/binsrc/classes/nfqueue.cpp @@ -17,6 +17,7 @@ enum class FilterAction{ DROP, ACCEPT, MANGLE, NOACTION }; enum class L4Proto { TCP, UDP, RAW }; typedef Tins::TCPIP::StreamIdentifier stream_id; +//TODO DUBBIO: I PACCHETTI INVIATI A PYTHON SONO GIA' FIXATI? template class PktRequest { @@ -25,6 +26,9 @@ class PktRequest { mnl_socket* nl = nullptr; uint16_t res_id; uint32_t packet_id; + size_t _original_size; + size_t _data_original_size; + bool need_tcp_fixing = false; public: bool is_ipv6; Tins::IP* ipv4 = nullptr; @@ -39,17 +43,27 @@ class PktRequest { size_t data_size; stream_id sid; + int64_t* tcp_in_offset = nullptr; + int64_t* tcp_out_offset = nullptr; + T* ctx; private: - inline void fetch_data_size(Tins::PDU* pdu){ + static size_t inner_data_size(Tins::PDU* pdu){ + if (pdu == nullptr){ + return 0; + } auto inner = pdu->inner_pdu(); if (inner == nullptr){ - data_size = 0; - }else{ - data_size = inner->size(); + return 0; } + return inner->size(); + } + + inline void fetch_data_size(Tins::PDU* pdu){ + data_size = inner_data_size(pdu); + _data_original_size = data_size; } L4Proto fill_l4_info(){ @@ -86,23 +100,92 @@ class PktRequest { } } + bool need_tcp_fix(){ + return (tcp_in_offset != nullptr && *tcp_in_offset != 0) || (tcp_out_offset != nullptr && *tcp_out_offset != 0); + } + + Tins::PDU::serialization_type reserialize_raw_data(const uint8_t* data, const size_t& data_size){ + if (is_ipv6){ + Tins::IPv6 ipv6_new = Tins::IPv6(data, data_size); + if (tcp){ + Tins::TCP* tcp_new = ipv6_new.find_pdu(); + } + return ipv6_new.serialize(); + }else{ + Tins::IP ipv4_new = Tins::IP(data, data_size); + if (tcp){ + Tins::TCP* tcp_new = ipv4_new.find_pdu(); + } + return ipv4_new.serialize(); + } + } + + void _fix_ack_seq_tcp(Tins::TCP* this_tcp){ + need_tcp_fixing = need_tcp_fix(); + #ifdef DEBUG + if (need_tcp_fixing){ + cerr << "[DEBUG] Fixing ack_seq with offsets " << *tcp_in_offset << " " << *tcp_out_offset << endl; + } + #endif + if(this_tcp == nullptr){ + return; + } + if (is_input){ + if (tcp_in_offset != nullptr){ + this_tcp->seq(this_tcp->seq() + *tcp_in_offset); + } + if (tcp_out_offset != nullptr){ + this_tcp->ack_seq(this_tcp->ack_seq() - *tcp_out_offset); + } + }else{ + if (tcp_in_offset != nullptr){ + this_tcp->ack_seq(this_tcp->ack_seq() - *tcp_in_offset); + } + if (tcp_out_offset != nullptr){ + this_tcp->seq(this_tcp->seq() + *tcp_out_offset); + } + } + #ifdef DEBUG + if (need_tcp_fixing){ + size_t new_size = inner_data_size(this_tcp); + cerr << "[DEBUG] FIXED PKT " << (is_input?"-> IN ":"<- OUT") << " [SEQ: " << this_tcp->seq() << "] \t[ACK: " << this_tcp->ack_seq() << "] \t[SIZE: " << new_size << "]" << endl; + } + #endif + } + + public: PktRequest(const char* payload, size_t plen, T* ctx, mnl_socket* nl, nfgenmsg *nfg, nfqnl_msg_packet_hdr *ph, bool is_input): ctx(ctx), nl(nl), res_id(nfg->res_id), packet_id(ph->packet_id), is_input(is_input), packet(string(payload, plen)), - is_ipv6((payload[0] & 0xf0) == 0x60){ - if (is_ipv6){ - ipv6 = new Tins::IPv6((uint8_t*)packet.c_str(), plen); - sid = stream_id::make_identifier(*ipv6); - }else{ - ipv4 = new Tins::IP((uint8_t*)packet.c_str(), plen); - sid = stream_id::make_identifier(*ipv4); - } - l4_proto = fill_l4_info(); - data = packet.data()+(plen-data_size); + action(FilterAction::NOACTION), + is_ipv6((payload[0] & 0xf0) == 0x60) + { + if (is_ipv6){ + ipv6 = new Tins::IPv6((uint8_t*)packet.c_str(), plen); + sid = stream_id::make_identifier(*ipv6); + _original_size = ipv6->size(); + }else{ + ipv4 = new Tins::IP((uint8_t*)packet.data(), plen); + sid = stream_id::make_identifier(*ipv4); + _original_size = ipv4->size(); } + l4_proto = fill_l4_info(); + data = packet.data()+(plen-data_size); + #ifdef DEBUG + if (tcp){ + cerr << "[DEBUG] NEW_PACKET " << (is_input?"-> IN ":"<- OUT") << " [SEQ: " << tcp->seq() << "] \t[ACK: " << tcp->ack_seq() << "] \t[SIZE: " << data_size << "]" << endl; + } + #endif + } + + void fix_tcp_ack(){ + if (tcp){ + _fix_ack_seq_tcp(tcp); + } + } void drop(){ if (action == FilterAction::NOACTION){ @@ -113,6 +196,14 @@ class PktRequest { } } + size_t data_original_size(){ + return _data_original_size; + } + + size_t original_size(){ + return _original_size; + } + void accept(){ if (action == FilterAction::NOACTION){ action = FilterAction::ACCEPT; @@ -131,7 +222,26 @@ class PktRequest { } } - void mangle_custom_pkt(const uint8_t* pkt, size_t pkt_size){ + void reject(){ + if (tcp){ + //If the packet has data, we have to remove it + delete tcp->release_inner_pdu(); + //For the first matched data or only for data packets, we set FIN bit + //This only for client packets, because this will trigger server to close the connection + //Packets will be filtered anyway also if client don't send packets + if (_data_original_size != 0 && is_input){ + tcp->set_flag(Tins::TCP::FIN,1); + tcp->set_flag(Tins::TCP::ACK,1); + tcp->set_flag(Tins::TCP::SYN,0); + } + //Send the edited packet to the kernel + mangle(); + }else{ + drop(); + } + } + + void mangle_custom_pkt(uint8_t* pkt, const size_t& pkt_size){ if (action == FilterAction::NOACTION){ action = FilterAction::MANGLE; perfrom_action(pkt, pkt_size); @@ -149,26 +259,58 @@ class PktRequest { delete ipv6; } + inline Tins::PDU::serialization_type serialize(){ + if (is_ipv6){ + return ipv6->serialize(); + }else{ + return ipv4->serialize(); + } + } + private: - void perfrom_action(const uint8_t* custom_data = nullptr, size_t custom_data_size = 0){ + void perfrom_action(uint8_t* custom_data = nullptr, size_t custom_data_size = 0){ char buf[MNL_SOCKET_BUFFER_SIZE]; struct nlmsghdr *nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(res_id)); switch (action) { case FilterAction::ACCEPT: + if (need_tcp_fixing){ + Tins::PDU::serialization_type data = serialize(); + nfq_nlmsg_verdict_put_pkt(nlh_verdict, data.data(), data.size()); + } nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_ACCEPT ); break; case FilterAction::DROP: nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_DROP ); break; case FilterAction::MANGLE:{ - if (custom_data != nullptr){ - nfq_nlmsg_verdict_put_pkt(nlh_verdict, custom_data, custom_data_size); - }else if (is_ipv6){ - nfq_nlmsg_verdict_put_pkt(nlh_verdict, ipv6->serialize().data(), ipv6->size()); + //If not custom data, use the data in the packets + Tins::PDU::serialization_type data; + if (custom_data == nullptr){ + data = serialize(); }else{ - nfq_nlmsg_verdict_put_pkt(nlh_verdict, ipv4->serialize().data(), ipv4->size()); + try{ + data = reserialize_raw_data(custom_data, custom_data_size); + }catch(...){ + nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_DROP ); + action = FilterAction::DROP; + break; + } } + #ifdef DEBUG + size_t new_size = _data_original_size+((int64_t)custom_data_size) - ((int64_t)_original_size); + cerr << "[DEBUG] MANGLEDPKT " << (is_input?"-> IN ":"<- OUT") << " [SIZE: " << new_size << "]" << endl; + #endif + if (tcp && custom_data_size != _original_size){ + int64_t delta = ((int64_t)custom_data_size) - ((int64_t)_original_size); + + if (is_input && tcp_in_offset != nullptr){ + *tcp_in_offset += delta; + }else if (!is_input && tcp_out_offset != nullptr){ + *tcp_out_offset += delta; + } + } + nfq_nlmsg_verdict_put_pkt(nlh_verdict, data.data(), data.size()); nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_ACCEPT ); break; } diff --git a/backend/binsrc/nfproxy.cpp b/backend/binsrc/nfproxy.cpp index 96c12d1..1d44efc 100644 --- a/backend/binsrc/nfproxy.cpp +++ b/backend/binsrc/nfproxy.cpp @@ -4,11 +4,11 @@ #include "pyproxy/settings.cpp" #include "pyproxy/pyproxy.cpp" #include "classes/netfilter.cpp" -#include #include #include #include #include +#include "utils.cpp" using namespace std; using namespace Firegex::PyProxy; @@ -33,13 +33,13 @@ def invalid_curl_agent(http): The code is now edited adding an intestation and a end statement: ```python -global __firegex_pyfilter_enabled, __firegex_proto + __firegex_pyfilter_enabled = ["invalid_curl_agent", "func3"] # This list is dynamically generated by firegex backend __firegex_proto = "http" import firegex.nfproxy.internals - -firegex.nfproxy.internals.compile() # This function can save other global variables, to use by the packet handler and is used generally to check and optimize the code +firegex.nfproxy.internals.compile(globals(), locals()) # This function can save other global variables, to use by the packet handler and is used generally to check and optimize the code ```` +(First lines are the same to keep line of code consistent on exceptions messages) This code will be executed only once, and is needed to build the global and local context to use The globals and locals generated here are copied for each connection, and are used to handle the packets @@ -82,60 +82,53 @@ firegex lib will give you all the needed possibilities to do this is many ways Final note: is not raccomanded to use variables that starts with __firegex_ in your code, because they may break the nfproxy */ -ssize_t read_check(int __fd, void *__buf, size_t __nbytes){ - ssize_t bytes = read(__fd, __buf, __nbytes); - if (bytes == 0){ - cerr << "[fatal] [updater] read() returned EOF" << endl; - throw invalid_argument("read() returned EOF"); - } - if (bytes < 0){ - cerr << "[fatal] [updater] read() returned an error" << bytes << endl; - throw invalid_argument("read() returned an error"); - } - return bytes; -} + void config_updater (){ while (true){ + PyThreadState* state = PyEval_SaveThread(); // Release GIL while doing IO operation uint32_t code_size; - read_check(STDIN_FILENO, &code_size, 4); - //Python will send number always in little endian - code_size = le32toh(code_size); - string code; - code.resize(code_size); - read_check(STDIN_FILENO, code.data(), code_size); + memcpy(&code_size, control_socket.recv(4).c_str(), 4); + code_size = be32toh(code_size); + string code = control_socket.recv(code_size); + #ifdef DEBUG + cerr << "[DEBUG] [updater] Received code: " << code << endl; + #endif cerr << "[info] [updater] Updating configuration" << endl; + PyEval_AcquireThread(state); //Restore GIL before executing python code try{ config.reset(new PyCodeConfig(code)); cerr << "[info] [updater] Config update done" << endl; - osyncstream(cout) << "ACK OK" << endl; + control_socket << "ACK OK" << endl; }catch(const std::exception& e){ cerr << "[error] [updater] Failed to build new configuration!" << endl; - osyncstream(cout) << "ACK FAIL " << e.what() << endl; + control_socket << "ACK FAIL " << e.what() << endl; } } } -int main(int argc, char *argv[]){ +int main(int argc, char *argv[]) { + // Connect to the python backend using the unix socket + init_control_socket(); + + // Initialize the python interpreter Py_Initialize(); atexit(Py_Finalize); init_handle_packet_code(); //Compile the static code used to handle packets - if (freopen(nullptr, "rb", stdin) == nullptr){ // We need to read from stdin binary data - cerr << "[fatal] [main] Failed to reopen stdin in binary mode" << endl; - return 1; - } int n_of_threads = 1; char * n_threads_str = getenv("NTHREADS"); if (n_threads_str != nullptr) n_of_threads = ::atoi(n_threads_str); if(n_of_threads <= 0) n_of_threads = 1; config.reset(new PyCodeConfig()); + MultiThreadQueue queue(n_of_threads); - osyncstream(cout) << "QUEUE " << queue.queue_num() << endl; + control_socket << "QUEUE " << queue.queue_num() << endl; + cerr << "[info] [main] Queue: " << queue.queue_num() << " threads assigned: " << n_of_threads << endl; thread qthr([&](){ diff --git a/backend/binsrc/pyproxy/pyproxy.cpp b/backend/binsrc/pyproxy/pyproxy.cpp index 1f2c51c..41f4540 100644 --- a/backend/binsrc/pyproxy/pyproxy.cpp +++ b/backend/binsrc/pyproxy/pyproxy.cpp @@ -33,7 +33,8 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue { public: stream_ctx sctx; StreamFollower follower; - PyGILState_STATE gstate; + PyThreadState * gtstate = nullptr; + PyInterpreterConfig py_thread_config = { .use_main_obmalloc = 0, .allow_fork = 0, @@ -44,24 +45,23 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue { .gil = PyInterpreterConfig_OWN_GIL, }; PyThreadState *tstate = NULL; - PyStatus pystatus; - - struct { - bool matching_has_been_called = false; - bool already_closed = false; - bool rejected = true; - NfQueue::PktRequest* pkt; - } match_ctx; + NfQueue::PktRequest* pkt; + tcp_ack_seq_ctx* current_tcp_ack = nullptr; void before_loop() override { - // Create thred structure for python - gstate = PyGILState_Ensure(); + PyStatus pystatus; // Create a new interpreter for the thread + gtstate = PyThreadState_New(PyInterpreterState_Main()); + PyEval_AcquireThread(gtstate); pystatus = Py_NewInterpreterFromConfig(&tstate, &py_thread_config); - if (PyStatus_Exception(pystatus)) { - Py_ExitStatusException(pystatus); + if(tstate == nullptr){ cerr << "[fatal] [main] Failed to create new interpreter" << endl; - exit(EXIT_FAILURE); + throw invalid_argument("Failed to create new interpreter (null tstate)"); + } + if (PyStatus_Exception(pystatus)) { + cerr << "[fatal] [main] Failed to create new interpreter" << endl; + Py_ExitStatusException(pystatus); + throw invalid_argument("Failed to create new interpreter (pystatus exc)"); } // Setting callbacks for the stream follower follower.new_stream_callback(bind(on_new_stream, placeholders::_1, this)); @@ -69,21 +69,24 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue { } inline void print_blocked_reason(const string& func_name){ - osyncstream(cout) << "BLOCKED " << func_name << endl; + control_socket << "BLOCKED " << func_name << endl; } inline void print_mangle_reason(const string& func_name){ - osyncstream(cout) << "MANGLED " << func_name << endl; + control_socket << "MANGLED " << func_name << endl; } inline void print_exception_reason(){ - osyncstream(cout) << "EXCEPTION" << endl; + control_socket << "EXCEPTION" << endl; } //If the stream has already been matched, drop all data, and try to close the connection - static void keep_fin_packet(PyProxyQueue* proxy_info){ - proxy_info->match_ctx.matching_has_been_called = true; - proxy_info->match_ctx.already_closed = true; + static void keep_fin_packet(PyProxyQueue* pyq){ + pyq->pkt->reject();// This is needed because the callback has to take the updated pkt pointer! + } + + static void keep_dropped(PyProxyQueue* pyq){ + pyq->pkt->drop();// This is needed because the callback has to take the updated pkt pointer! } void filter_action(NfQueue::PktRequest* pkt, Stream& stream){ @@ -92,36 +95,45 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue { if (stream_search == sctx.streams_ctx.end()){ shared_ptr conf = config; //If config is not set, ignore the stream - if (conf->glob == nullptr || conf->local == nullptr){ + PyObject* compiled_code = conf->compiled_code(); + if (compiled_code == nullptr){ stream.client_data_callback(nullptr); stream.server_data_callback(nullptr); return pkt->accept(); } - stream_match = new pyfilter_ctx(conf->glob, conf->local); + stream_match = new pyfilter_ctx(compiled_code); + Py_DECREF(compiled_code); sctx.streams_ctx.insert_or_assign(pkt->sid, stream_match); }else{ stream_match = stream_search->second; - } + } + auto result = stream_match->handle_packet(pkt); switch(result.action){ case PyFilterResponse::ACCEPT: - pkt->accept(); + return pkt->accept(); case PyFilterResponse::DROP: print_blocked_reason(*result.filter_match_by); sctx.clean_stream_by_id(pkt->sid); - stream.client_data_callback(nullptr); - stream.server_data_callback(nullptr); - break; + stream.client_data_callback(bind(keep_dropped, this)); + stream.server_data_callback(bind(keep_dropped, this)); + return pkt->drop(); case PyFilterResponse::REJECT: + print_blocked_reason(*result.filter_match_by); sctx.clean_stream_by_id(pkt->sid); stream.client_data_callback(bind(keep_fin_packet, this)); stream.server_data_callback(bind(keep_fin_packet, this)); - pkt->ctx->match_ctx.rejected = true; //Handler will take care of the rest - break; + return pkt->reject(); case PyFilterResponse::MANGLE: - print_mangle_reason(*result.filter_match_by); - pkt->mangle_custom_pkt((uint8_t*)result.mangled_packet->c_str(), result.mangled_packet->size()); - break; + pkt->mangle_custom_pkt((uint8_t*)result.mangled_packet->data(), result.mangled_packet->size()); + if (pkt->get_action() == NfQueue::FilterAction::DROP){ + cerr << "[error] [filter_action] Failed to mangle: the packet sent is not serializzable... the packet was dropped" << endl; + print_blocked_reason(*result.filter_match_by); + print_exception_reason(); + }else{ + print_mangle_reason(*result.filter_match_by); + } + return; case PyFilterResponse::EXCEPTION: case PyFilterResponse::INVALID: print_exception_reason(); @@ -129,16 +141,15 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue { //Free the packet data stream.client_data_callback(nullptr); stream.server_data_callback(nullptr); - pkt->accept(); - break; + return pkt->accept(); } } static void on_data_recv(Stream& stream, PyProxyQueue* proxy_info, string data) { - proxy_info->match_ctx.matching_has_been_called = true; - proxy_info->match_ctx.already_closed = false; - proxy_info->filter_action(proxy_info->match_ctx.pkt, stream); + proxy_info->pkt->data = data.data(); + proxy_info->pkt->data_size = data.size(); + proxy_info->filter_action(proxy_info->pkt, stream); } //Input data filtering @@ -152,77 +163,77 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue { } // A stream was terminated. The second argument is the reason why it was terminated - static void on_stream_close(Stream& stream, PyProxyQueue* proxy_info) { + static void on_stream_close(Stream& stream, PyProxyQueue* pyq) { stream_id stream_id = stream_id::make_identifier(stream); - proxy_info->sctx.clean_stream_by_id(stream_id); + pyq->sctx.clean_stream_by_id(stream_id); + pyq->sctx.clean_tcp_ack_by_id(stream_id); } - static void on_new_stream(Stream& stream, PyProxyQueue* proxy_info) { + static void on_new_stream(Stream& stream, PyProxyQueue* pyq) { stream.auto_cleanup_payloads(true); if (stream.is_partial_stream()) { stream.enable_recovery_mode(10 * 1024); } - stream.client_data_callback(bind(on_client_data, placeholders::_1, proxy_info)); - stream.server_data_callback(bind(on_server_data, placeholders::_1, proxy_info)); - stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, proxy_info)); + + if (pyq->current_tcp_ack != nullptr){ + pyq->current_tcp_ack->reset(); + }else{ + pyq->current_tcp_ack = new tcp_ack_seq_ctx(); + pyq->sctx.tcp_ack_ctx.insert_or_assign(pyq->pkt->sid, pyq->current_tcp_ack); + pyq->pkt->tcp_in_offset = &pyq->current_tcp_ack->in_tcp_offset; + pyq->pkt->tcp_out_offset = &pyq->current_tcp_ack->out_tcp_offset; + } + + //Should not happen, but with this we can be sure about this + auto tcp_ack_search = pyq->sctx.tcp_ack_ctx.find(pyq->pkt->sid); + if (tcp_ack_search != pyq->sctx.tcp_ack_ctx.end()){ + tcp_ack_search->second->reset(); + } + + stream.client_data_callback(bind(on_client_data, placeholders::_1, pyq)); + stream.server_data_callback(bind(on_server_data, placeholders::_1, pyq)); + stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, pyq)); } + void handle_next_packet(NfQueue::PktRequest* _pkt) override{ + pkt = _pkt; // Setting packet context - void handle_next_packet(NfQueue::PktRequest* pkt) override{ if (pkt->l4_proto != NfQueue::L4Proto::TCP){ throw invalid_argument("Only TCP and UDP are supported"); } - Tins::PDU* application_layer = pkt->tcp->inner_pdu(); - u_int16_t payload_size = 0; - if (application_layer != nullptr){ - payload_size = application_layer->size(); + + auto tcp_ack_search = sctx.tcp_ack_ctx.find(pkt->sid); + if (tcp_ack_search != sctx.tcp_ack_ctx.end()){ + current_tcp_ack = tcp_ack_search->second; + pkt->tcp_in_offset = ¤t_tcp_ack->in_tcp_offset; + pkt->tcp_out_offset = ¤t_tcp_ack->out_tcp_offset; + }else{ + current_tcp_ack = nullptr; + //If necessary will be created by libtis new_stream callback } - match_ctx.matching_has_been_called = false; - match_ctx.pkt = pkt; + if (pkt->is_ipv6){ + pkt->fix_tcp_ack(); follower.process_packet(*pkt->ipv6); }else{ + pkt->fix_tcp_ack(); follower.process_packet(*pkt->ipv4); } - // Do an action only is an ordered packet has been received - if (match_ctx.matching_has_been_called){ - bool empty_payload = payload_size == 0; - //In this 2 cases we have to remove all data about the stream - if (!match_ctx.rejected || match_ctx.already_closed){ - sctx.clean_stream_by_id(pkt->sid); - //If the packet has data, we have to remove it - if (!empty_payload){ - Tins::PDU* data_layer = pkt->tcp->release_inner_pdu(); - if (data_layer != nullptr){ - delete data_layer; - } - } - //For the first matched data or only for data packets, we set FIN bit - //This only for client packets, because this will trigger server to close the connection - //Packets will be filtered anyway also if client don't send packets - if ((!match_ctx.rejected || !empty_payload) && pkt->is_input){ - pkt->tcp->set_flag(Tins::TCP::FIN,1); - pkt->tcp->set_flag(Tins::TCP::ACK,1); - pkt->tcp->set_flag(Tins::TCP::SYN,0); - } - //Send the edited packet to the kernel - return pkt->mangle(); - }else{ - //Fallback to the default action - if (pkt->get_action() == NfQueue::FilterAction::NOACTION){ - return pkt->accept(); - } - } - }else{ + + //Fallback to the default action + if (pkt->get_action() == NfQueue::FilterAction::NOACTION){ return pkt->accept(); } } ~PyProxyQueue() { // Closing first the interpreter + Py_EndInterpreter(tstate); - // Releasing the GIL and the thread data structure - PyGILState_Release(gstate); + PyEval_ReleaseThread(tstate); + PyThreadState_Clear(tstate); + PyThreadState_Delete(tstate); + sctx.clean(); } diff --git a/backend/binsrc/pyproxy/settings.cpp b/backend/binsrc/pyproxy/settings.cpp index 80f9a08..91b8cc2 100644 --- a/backend/binsrc/pyproxy/settings.cpp +++ b/backend/binsrc/pyproxy/settings.cpp @@ -2,58 +2,73 @@ #define PROXY_TUNNEL_SETTINGS_CPP #include - +#include #include #include #include +#include "../utils.cpp" using namespace std; namespace Firegex { namespace PyProxy { +class PyCodeConfig; + +shared_ptr config; +PyObject* py_handle_packet_code = nullptr; +UnixClientConnection control_socket; class PyCodeConfig{ public: - PyObject* glob = nullptr; - PyObject* local = nullptr; - - private: - void _clean(){ - Py_XDECREF(glob); - Py_XDECREF(local); - } - public: + string encoded_code; PyCodeConfig(const string& pycode){ - PyObject* compiled_code = Py_CompileStringExFlags(pycode.c_str(), "", Py_file_input, NULL, 2); if (compiled_code == nullptr){ std::cerr << "[fatal] [main] Failed to compile the code" << endl; - _clean(); throw invalid_argument("Failed to compile the code"); } - glob = PyDict_New(); - local = PyDict_New(); - PyObject* result = PyEval_EvalCode(compiled_code, glob, local); - Py_XDECREF(compiled_code); + PyObject* glob = PyDict_New(); + PyObject* result = PyEval_EvalCode(compiled_code, glob, glob); + Py_DECREF(glob); if (!result){ PyErr_Print(); - _clean(); + Py_DECREF(compiled_code); std::cerr << "[fatal] [main] Failed to execute the code" << endl; throw invalid_argument("Failed to execute the code, maybe an invalid filter code has been provided"); } Py_DECREF(result); + PyObject* code_dump = PyMarshal_WriteObjectToString(compiled_code, 4); + Py_DECREF(compiled_code); + if (code_dump == nullptr){ + PyErr_Print(); + std::cerr << "[fatal] [main] Failed to dump the code" << endl; + throw invalid_argument("Failed to dump the code"); + } + if (!PyBytes_Check(code_dump)){ + std::cerr << "[fatal] [main] Failed to dump the code" << endl; + throw invalid_argument("Failed to dump the code"); + } + encoded_code = string(PyBytes_AsString(code_dump), PyBytes_Size(code_dump)); + Py_DECREF(code_dump); } - PyCodeConfig(){} - ~PyCodeConfig(){ - _clean(); + PyObject* compiled_code(){ + if (encoded_code.empty()) return nullptr; + return PyMarshal_ReadObjectFromString(encoded_code.c_str(), encoded_code.size()); } + + PyCodeConfig(){} }; -shared_ptr config; -PyObject* py_handle_packet_code = nullptr; +void init_control_socket(){ + char * socket_path = getenv("FIREGEX_NFPROXY_SOCK"); + if (socket_path == nullptr) throw invalid_argument("FIREGEX_NFPROXY_SOCK not set"); + if (strlen(socket_path) >= 108) throw invalid_argument("FIREGEX_NFPROXY_SOCK too long"); + control_socket = UnixClientConnection(socket_path); +} + void init_handle_packet_code(){ py_handle_packet_code = Py_CompileStringExFlags( diff --git a/backend/binsrc/pyproxy/stream_ctx.cpp b/backend/binsrc/pyproxy/stream_ctx.cpp index 633ca50..761e20d 100644 --- a/backend/binsrc/pyproxy/stream_ctx.cpp +++ b/backend/binsrc/pyproxy/stream_ctx.cpp @@ -27,10 +27,21 @@ enum PyFilterResponse { INVALID = 5 }; +const PyFilterResponse VALID_PYTHON_RESPONSE[4] = { + PyFilterResponse::ACCEPT, + PyFilterResponse::DROP, + PyFilterResponse::REJECT, + PyFilterResponse::MANGLE +}; + struct py_filter_response { PyFilterResponse action; string* filter_match_by = nullptr; string* mangled_packet = nullptr; + + py_filter_response(PyFilterResponse action, string* filter_match_by = nullptr, string* mangled_packet = nullptr): + action(action), filter_match_by(filter_match_by), mangled_packet(mangled_packet){} + ~py_filter_response(){ delete mangled_packet; delete filter_match_by; @@ -39,34 +50,35 @@ struct py_filter_response { typedef Tins::TCPIP::StreamIdentifier stream_id; +struct tcp_ack_seq_ctx{ + //Can be negative, so we use int64_t (for a uint64_t value) + int64_t in_tcp_offset = 0; + int64_t out_tcp_offset = 0; + tcp_ack_seq_ctx(){} + void reset(){ + in_tcp_offset = 0; + out_tcp_offset = 0; + } +}; + struct pyfilter_ctx { PyObject * glob = nullptr; - PyObject * local = nullptr; - pyfilter_ctx(PyObject * original_glob, PyObject * original_local){ - PyObject *copy = PyImport_ImportModule("copy"); - if (copy == nullptr){ + pyfilter_ctx(PyObject * compiled_code){ + glob = PyDict_New(); + PyObject* result = PyEval_EvalCode(compiled_code, glob, glob); + if (!result){ PyErr_Print(); - throw invalid_argument("Failed to import copy module"); + Py_XDECREF(glob); + std::cerr << "[fatal] [main] Failed to compile the code" << endl; + throw invalid_argument("Failed to execute the code, maybe an invalid filter code has been provided"); } - PyObject *deepcopy = PyObject_GetAttrString(copy, "deepcopy"); - glob = PyObject_CallFunctionObjArgs(deepcopy, original_glob, NULL); - if (glob == nullptr){ - PyErr_Print(); - throw invalid_argument("Failed to deepcopy the global dict"); - } - local = PyObject_CallFunctionObjArgs(deepcopy, original_local, NULL); - if (local == nullptr){ - PyErr_Print(); - throw invalid_argument("Failed to deepcopy the local dict"); - } - Py_DECREF(copy); + Py_XDECREF(result); } ~pyfilter_ctx(){ - Py_XDECREF(glob); - Py_XDECREF(local); + Py_DECREF(glob); } inline void set_item_to_glob(const char* key, PyObject* value){ @@ -84,15 +96,12 @@ struct pyfilter_ctx { } } - inline void set_item_to_local(const char* key, PyObject* value){ - set_item_to_dict(local, key, value); - } - inline void set_item_to_dict(PyObject* dict, const char* key, PyObject* value){ if (PyDict_SetItemString(dict, key, value) != 0){ PyErr_Print(); throw invalid_argument("Failed to set item to dict"); } + Py_DECREF(value); } py_filter_response handle_packet( @@ -101,6 +110,7 @@ struct pyfilter_ctx { PyObject * packet_info = PyDict_New(); set_item_to_dict(packet_info, "data", PyBytes_FromStringAndSize(pkt->data, pkt->data_size)); + set_item_to_dict(packet_info, "l4_size", PyLong_FromLong(pkt->data_original_size())); set_item_to_dict(packet_info, "raw_packet", PyBytes_FromStringAndSize(pkt->packet.c_str(), pkt->packet.size())); set_item_to_dict(packet_info, "is_input", PyBool_FromLong(pkt->is_input)); set_item_to_dict(packet_info, "is_ipv6", PyBool_FromLong(pkt->is_ipv6)); @@ -108,92 +118,156 @@ struct pyfilter_ctx { // Set packet info to the global context set_item_to_glob("__firegex_packet_info", packet_info); - PyObject * result = PyEval_EvalCode(py_handle_packet_code, glob, local); + PyObject * result = PyEval_EvalCode(py_handle_packet_code, glob, glob); del_item_from_glob("__firegex_packet_info"); - Py_DECREF(packet_info); + Py_DECREF(packet_info); if (!result){ PyErr_Print(); - return py_filter_response{PyFilterResponse::EXCEPTION, nullptr}; + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] Exception raised" << endl; + #endif + return py_filter_response(PyFilterResponse::EXCEPTION); } + + Py_DECREF(result); result = get_item_from_glob("__firegex_pyfilter_result"); if (result == nullptr){ - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] No result found" << endl; + #endif + return py_filter_response(PyFilterResponse::INVALID); } if (!PyDict_Check(result)){ PyErr_Print(); + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] Result is not a dict" << endl; + #endif del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; + return py_filter_response(PyFilterResponse::INVALID); } PyObject* action = PyDict_GetItemString(result, "action"); if (action == nullptr){ + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] No result action found" << endl; + #endif del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; + return py_filter_response(PyFilterResponse::INVALID); } if (!PyLong_Check(action)){ + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] Action is not a long" << endl; + #endif del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; + return py_filter_response(PyFilterResponse::INVALID); } PyFilterResponse action_enum = (PyFilterResponse)PyLong_AsLong(action); - if (action_enum == PyFilterResponse::ACCEPT || action_enum == PyFilterResponse::EXCEPTION || action_enum == PyFilterResponse::INVALID){ - del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{action_enum, nullptr, nullptr}; - }else{ - PyObject *func_name_py = PyDict_GetItemString(result, "matched_by"); - if (func_name_py == nullptr){ - del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; - } - if (!PyUnicode_Check(func_name_py)){ - del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; - } - string* func_name = new string(PyUnicode_AsUTF8(func_name_py)); - if (action_enum == PyFilterResponse::DROP || action_enum == PyFilterResponse::REJECT){ - del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{action_enum, func_name, nullptr}; - } - if (action_enum != PyFilterResponse::MANGLE){ - PyObject* mangled_packet = PyDict_GetItemString(result, "mangled_packet"); - if (mangled_packet == nullptr){ - del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; - } - if (!PyBytes_Check(mangled_packet)){ - del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; - } - string* pkt_str = new string(PyBytes_AsString(mangled_packet), PyBytes_Size(mangled_packet)); - del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::MANGLE, func_name, pkt_str}; + //Check action_enum + bool valid = false; + for (auto valid_action: VALID_PYTHON_RESPONSE){ + if (action_enum == valid_action){ + valid = true; + break; } } + if (!valid){ + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] Invalid action" << endl; + #endif + del_item_from_glob("__firegex_pyfilter_result"); + return py_filter_response(PyFilterResponse::INVALID); + } + + if (action_enum == PyFilterResponse::ACCEPT){ + del_item_from_glob("__firegex_pyfilter_result"); + return py_filter_response(action_enum); + } + PyObject *func_name_py = PyDict_GetItemString(result, "matched_by"); + if (func_name_py == nullptr){ + del_item_from_glob("__firegex_pyfilter_result"); + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] No result matched_by found" << endl; + #endif + return py_filter_response(PyFilterResponse::INVALID); + } + if (!PyUnicode_Check(func_name_py)){ + del_item_from_glob("__firegex_pyfilter_result"); + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] matched_by is not a string" << endl; + #endif + return py_filter_response(PyFilterResponse::INVALID); + } + string* func_name = new string(PyUnicode_AsUTF8(func_name_py)); + if (action_enum == PyFilterResponse::DROP || action_enum == PyFilterResponse::REJECT){ + del_item_from_glob("__firegex_pyfilter_result"); + return py_filter_response(action_enum, func_name); + } + if (action_enum == PyFilterResponse::MANGLE){ + PyObject* mangled_packet = PyDict_GetItemString(result, "mangled_packet"); + if (mangled_packet == nullptr){ + del_item_from_glob("__firegex_pyfilter_result"); + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] No result mangled_packet found" << endl; + #endif + return py_filter_response(PyFilterResponse::INVALID); + } + if (!PyBytes_Check(mangled_packet)){ + #ifdef DEBUG + cerr << "[DEBUG] [handle_packet] mangled_packet is not a bytes" << endl; + #endif + del_item_from_glob("__firegex_pyfilter_result"); + return py_filter_response(PyFilterResponse::INVALID); + } + string* pkt_str = new string(PyBytes_AsString(mangled_packet), PyBytes_Size(mangled_packet)); + del_item_from_glob("__firegex_pyfilter_result"); + return py_filter_response(PyFilterResponse::MANGLE, func_name, pkt_str); + } + + //Should never reach this point, but just in case of new action not managed... del_item_from_glob("__firegex_pyfilter_result"); - return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; + return py_filter_response(PyFilterResponse::INVALID); } }; typedef map matching_map; +typedef map tcp_ack_map; struct stream_ctx { matching_map streams_ctx; + tcp_ack_map tcp_ack_ctx; void clean_stream_by_id(stream_id sid){ auto stream_search = streams_ctx.find(sid); if (stream_search != streams_ctx.end()){ auto stream_match = stream_search->second; delete stream_match; + streams_ctx.erase(stream_search->first); } } + + void clean_tcp_ack_by_id(stream_id sid){ + auto tcp_ack_search = tcp_ack_ctx.find(sid); + if (tcp_ack_search != tcp_ack_ctx.end()){ + auto tcp_ack = tcp_ack_search->second; + delete tcp_ack; + tcp_ack_ctx.erase(tcp_ack_search->first); + } + } + void clean(){ for (auto ele: streams_ctx){ delete ele.second; } + for (auto ele: tcp_ack_ctx){ + delete ele.second; + } + tcp_ack_ctx.clear(); + streams_ctx.clear(); } }; diff --git a/backend/binsrc/regex/regexfilter.cpp b/backend/binsrc/regex/regexfilter.cpp index 0ea15d2..c84b12b 100644 --- a/backend/binsrc/regex/regexfilter.cpp +++ b/backend/binsrc/regex/regexfilter.cpp @@ -37,13 +37,7 @@ public: stream_ctx sctx; u_int16_t latest_config_ver = 0; StreamFollower follower; - struct { - bool matching_has_been_called = false; - bool already_closed = false; - bool result; - NfQueue::PktRequest* pkt; - } match_ctx; - + NfQueue::PktRequest* pkt; bool filter_action(NfQueue::PktRequest* pkt){ shared_ptr conf = regex_config; @@ -119,49 +113,23 @@ public: return true; } - void handle_next_packet(NfQueue::PktRequest* pkt) override{ - bool empty_payload = pkt->data_size == 0; + void handle_next_packet(NfQueue::PktRequest* _pkt) override{ + pkt = _pkt; // Setting packet context if (pkt->tcp){ - match_ctx.matching_has_been_called = false; - match_ctx.pkt = pkt; - if (pkt->ipv4){ follower.process_packet(*pkt->ipv4); }else{ follower.process_packet(*pkt->ipv6); } - - // Do an action only is an ordered packet has been received - if (match_ctx.matching_has_been_called){ - - //In this 2 cases we have to remove all data about the stream - if (!match_ctx.result || match_ctx.already_closed){ - sctx.clean_stream_by_id(pkt->sid); - //If the packet has data, we have to remove it - if (!empty_payload){ - Tins::PDU* data_layer = pkt->tcp->release_inner_pdu(); - if (data_layer != nullptr){ - delete data_layer; - } - } - //For the first matched data or only for data packets, we set FIN bit - //This only for client packets, because this will trigger server to close the connection - //Packets will be filtered anyway also if client don't send packets - if ((!match_ctx.result || !empty_payload) && pkt->is_input){ - pkt->tcp->set_flag(Tins::TCP::FIN,1); - pkt->tcp->set_flag(Tins::TCP::ACK,1); - pkt->tcp->set_flag(Tins::TCP::SYN,0); - } - //Send the edited packet to the kernel - return pkt->mangle(); - } + //Fallback to the default action + if (pkt->get_action() == NfQueue::FilterAction::NOACTION){ + return pkt->accept(); } - return pkt->accept(); }else{ if (!pkt->udp){ throw invalid_argument("Only TCP and UDP are supported"); } - if(empty_payload){ + if(pkt->data_size == 0){ return pkt->accept(); }else if (filter_action(pkt)){ return pkt->accept(); @@ -170,22 +138,21 @@ public: } } } + //If the stream has already been matched, drop all data, and try to close the connection static void keep_fin_packet(RegexNfQueue* nfq){ - nfq->match_ctx.matching_has_been_called = true; - nfq->match_ctx.already_closed = true; + nfq->pkt->reject();// This is needed because the callback has to take the updated pkt pointer! } static void on_data_recv(Stream& stream, RegexNfQueue* nfq, string data) { - nfq->match_ctx.matching_has_been_called = true; - nfq->match_ctx.already_closed = false; - bool result = nfq->filter_action(nfq->match_ctx.pkt); - if (!result){ - nfq->sctx.clean_stream_by_id(nfq->match_ctx.pkt->sid); + nfq->pkt->data = data.data(); + nfq->pkt->data_size = data.size(); + if (!nfq->filter_action(nfq->pkt)){ + nfq->sctx.clean_stream_by_id(nfq->pkt->sid); stream.client_data_callback(bind(keep_fin_packet, nfq)); stream.server_data_callback(bind(keep_fin_packet, nfq)); + nfq->pkt->reject(); } - nfq->match_ctx.result = result; } //Input data filtering diff --git a/backend/binsrc/regex/stream_ctx.cpp b/backend/binsrc/regex/stream_ctx.cpp index dc1c3fe..3ee6e3d 100644 --- a/backend/binsrc/regex/stream_ctx.cpp +++ b/backend/binsrc/regex/stream_ctx.cpp @@ -17,7 +17,6 @@ namespace Regex { typedef Tins::TCPIP::StreamIdentifier stream_id; typedef map matching_map; -#ifdef DEBUG ostream& operator<<(ostream& os, const Tins::TCPIP::StreamIdentifier::address_type &sid){ bool first_print = false; for (auto ele: sid){ @@ -33,7 +32,6 @@ ostream& operator<<(ostream& os, const stream_id &sid){ os << sid.max_address << ":" << sid.max_address_port << " -> " << sid.min_address << ":" << sid.min_address_port; return os; } -#endif struct stream_ctx { matching_map in_hs_streams; diff --git a/backend/binsrc/utils.cpp b/backend/binsrc/utils.cpp index a4d889a..59ca77b 100644 --- a/backend/binsrc/utils.cpp +++ b/backend/binsrc/utils.cpp @@ -1,10 +1,17 @@ +#ifndef UTILS_CPP +#define UTILS_CPP + #include #include #include #include - -#ifndef UTILS_CPP -#define UTILS_CPP +#include +#include +#include +#include +#include +#include +#include bool unhexlify(std::string const &hex, std::string &newString) { try{ @@ -22,6 +29,113 @@ bool unhexlify(std::string const &hex, std::string &newString) { } } +class UnixClientConnection { +public: + int sockfd = -1; + struct sockaddr_un addr; +private: + // Internal buffer to accumulate the output until flush + std::ostringstream streamBuffer; +public: + + UnixClientConnection(){}; + + UnixClientConnection(const char* path) { + sockfd = socket(AF_UNIX, SOCK_STREAM, 0); + if (sockfd == -1) { + throw std::runtime_error(std::string("socket error: ") + std::strerror(errno)); + } + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1); + if (connect(sockfd, reinterpret_cast(&addr), sizeof(addr)) != 0) { + throw std::runtime_error(std::string("connect error: ") + std::strerror(errno)); + } + } + + // Delete copy constructor and assignment operator to avoid resource duplication + UnixClientConnection(const UnixClientConnection&) = delete; + UnixClientConnection& operator=(const UnixClientConnection&) = delete; + + // Move constructor + UnixClientConnection(UnixClientConnection&& other) noexcept + : sockfd(other.sockfd), addr(other.addr) { + other.sockfd = -1; + } + + // Move assignment operator + UnixClientConnection& operator=(UnixClientConnection&& other) noexcept { + if (this != &other) { + if (sockfd != -1) { + close(sockfd); + } + sockfd = other.sockfd; + addr = other.addr; + other.sockfd = -1; + } + return *this; + } + + void send(const std::string& data) { + if (::write(sockfd, data.c_str(), data.size()) == -1) { + throw std::runtime_error(std::string("write error: ") + std::strerror(errno)); + } + } + + std::string recv(size_t size) { + std::string buffer(size, '\0'); + ssize_t bytesRead = ::read(sockfd, &buffer[0], size); + if (bytesRead <= 0) { + throw std::runtime_error(std::string("read error: ") + std::strerror(errno)); + } + buffer.resize(bytesRead); // resize to actual bytes read + return buffer; + } + + // Template overload for generic types + template + UnixClientConnection& operator<<(const T& data) { + streamBuffer << data; + return *this; + } + + // Overload for manipulators (e.g., std::endl) + UnixClientConnection& operator<<(std::ostream& (*manip)(std::ostream&)) { + // Check if the manipulator is std::endl (or equivalent flush) + if (manip == static_cast(std::endl)){ + streamBuffer << '\n'; // Add a newline + std::string packet = streamBuffer.str(); + streamBuffer.str(""); // Clear the buffer + // Send the accumulated data as one packet + send(packet); + } + if (static_cast(std::flush)) { + std::string packet = streamBuffer.str(); + streamBuffer.str(""); // Clear the buffer + // Send the accumulated data as one packet + send(packet); + } else { + // For other manipulators, simply pass them to the buffer + streamBuffer << manip; + } + return *this; + } + + // Overload operator<< to allow printing connection info + friend std::ostream& operator<<(std::ostream& os, const UnixClientConnection& conn) { + os << "UnixClientConnection(sockfd=" << conn.sockfd + << ", path=" << conn.addr.sun_path << ")"; + return os; + } + + ~UnixClientConnection() { + if (sockfd != -1) { + close(sockfd); + } + } +}; + + #ifdef USE_PIPES_FOR_BLOKING_QUEUE template diff --git a/backend/modules/firewall/nftables.py b/backend/modules/firewall/nftables.py index c27ab1c..6822fca 100644 --- a/backend/modules/firewall/nftables.py +++ b/backend/modules/firewall/nftables.py @@ -1,4 +1,4 @@ -from modules.firewall.models import * +from modules.firewall.models import FirewallSettings, Action, Rule, Protocol, Mode, Table from utils import nftables_int_to_json, ip_family, NFTableManager, is_ip_parse import copy @@ -9,7 +9,8 @@ class FiregexTables(NFTableManager): filter_table = "filter" mangle_table = "mangle" - def init_comands(self, policy:str=Action.ACCEPT, opt: FirewallSettings|None = None): + def init_comands(self, policy:str=Action.ACCEPT, opt: + FirewallSettings|None = None): rules = [ {"add":{"table":{"name":self.filter_table,"family":"ip"}}}, {"add":{"table":{"name":self.filter_table,"family":"ip6"}}}, @@ -41,7 +42,8 @@ class FiregexTables(NFTableManager): {"add":{"chain":{"family":"ip","table":self.mangle_table,"name":self.rules_chain_out}}}, {"add":{"chain":{"family":"ip6","table":self.mangle_table,"name":self.rules_chain_out}}}, ] - if opt is None: return rules + if opt is None: + return rules if opt.allow_loopback: rules.extend([ @@ -194,13 +196,18 @@ class FiregexTables(NFTableManager): def chain_to_firegex(self, chain:str, table:str): if table == self.filter_table: match chain: - case "INPUT": return self.rules_chain_in - case "OUTPUT": return self.rules_chain_out - case "FORWARD": return self.rules_chain_fwd + case "INPUT": + return self.rules_chain_in + case "OUTPUT": + return self.rules_chain_out + case "FORWARD": + return self.rules_chain_fwd elif table == self.mangle_table: match chain: - case "PREROUTING": return self.rules_chain_in - case "POSTROUTING": return self.rules_chain_out + case "PREROUTING": + return self.rules_chain_in + case "POSTROUTING": + return self.rules_chain_out return None def insert_firegex_chains(self): @@ -214,7 +221,8 @@ class FiregexTables(NFTableManager): if r.get("family") == family and r.get("table") == table and r.get("chain") == chain and r.get("expr") == rule_to_add: found = True break - if found: continue + if found: + continue yield { "add":{ "rule": { "family": family, "table": table, @@ -274,7 +282,7 @@ class FiregexTables(NFTableManager): ip_filters.append({"match": { "op": "==", "left": { "meta": { "key": "oifname" } }, "right": srv.dst} }) port_filters = [] - if not srv.proto in [Protocol.ANY, Protocol.BOTH]: + if srv.proto not in [Protocol.ANY, Protocol.BOTH]: if srv.port_src_from != 1 or srv.port_src_to != 65535: #Any Port port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '>=', 'right': int(srv.port_src_from)}}) port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '<=', 'right': int(srv.port_src_to)}}) diff --git a/backend/modules/nfproxy/firegex.py b/backend/modules/nfproxy/firegex.py index 095eb17..13eea5a 100644 --- a/backend/modules/nfproxy/firegex.py +++ b/backend/modules/nfproxy/firegex.py @@ -1,11 +1,10 @@ from modules.nfproxy.nftables import FiregexTables -from utils import run_func from modules.nfproxy.models import Service, PyFilter import os import asyncio -from utils import DEBUG import traceback from fastapi import HTTPException +import time nft = FiregexTables() @@ -13,29 +12,37 @@ class FiregexInterceptor: def __init__(self): self.srv:Service - self._stats_updater_cb:callable self.filter_map_lock:asyncio.Lock self.filter_map: dict[str, PyFilter] - self.pyfilters: set[PyFilter] self.update_config_lock:asyncio.Lock self.process:asyncio.subprocess.Process self.update_task: asyncio.Task + self.server_task: asyncio.Task + self.sock_path: str + self.unix_sock: asyncio.Server self.ack_arrived = False self.ack_status = None - self.ack_fail_what = "Unknown" + self.ack_fail_what = "Queue response timed-out" self.ack_lock = asyncio.Lock() - - async def _call_stats_updater_callback(self, filter: PyFilter): - if self._stats_updater_cb: - await run_func(self._stats_updater_cb(filter)) + self.sock_reader:asyncio.StreamReader = None + self.sock_writer:asyncio.StreamWriter = None + self.sock_conn_lock:asyncio.Lock + self.last_time_exception = 0 @classmethod - async def start(cls, srv: Service, stats_updater_cb:callable): + async def start(cls, srv: Service): self = cls() - self._stats_updater_cb = stats_updater_cb self.srv = srv self.filter_map_lock = asyncio.Lock() self.update_config_lock = asyncio.Lock() + self.sock_conn_lock = asyncio.Lock() + if not self.sock_conn_lock.locked(): + await self.sock_conn_lock.acquire() + self.sock_path = f"/tmp/firegex_nfproxy_{srv.id}.sock" + if os.path.exists(self.sock_path): + os.remove(self.sock_path) + self.unix_sock = await asyncio.start_unix_server(self._server_listener,path=self.sock_path) + self.server_task = asyncio.create_task(self.unix_sock.serve_forever()) queue_range = await self._start_binary() self.update_task = asyncio.create_task(self.update_stats()) nft.add(self.srv, queue_range) @@ -46,19 +53,20 @@ class FiregexInterceptor: async def _start_binary(self): proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cpproxy") self.process = await asyncio.create_subprocess_exec( - proxy_binary_path, - stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, + proxy_binary_path, stdin=asyncio.subprocess.DEVNULL, env={ "NTHREADS": os.getenv("NTHREADS","1"), "FIREGEX_NFQUEUE_FAIL_OPEN": "1" if self.srv.fail_open else "0", + "FIREGEX_NFPROXY_SOCK": self.sock_path }, ) - line_fut = self.process.stdout.readuntil() try: - line_fut = await asyncio.wait_for(line_fut, timeout=3) + async with asyncio.timeout(3): + await self.sock_conn_lock.acquire() + line_fut = await self.sock_reader.readuntil() except asyncio.TimeoutError: self.process.kill() - raise Exception("Invalid binary output") + raise Exception("Binary don't returned queue number until timeout") line = line_fut.decode() if line.startswith("QUEUE "): params = line.split() @@ -67,25 +75,45 @@ class FiregexInterceptor: self.process.kill() raise Exception("Invalid binary output") + async def _server_listener(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter): + if self.sock_reader or self.sock_writer: + writer.write_eof() # Technically never reached + writer.close() + reader.feed_eof() + return + self.sock_reader = reader + self.sock_writer = writer + self.sock_conn_lock.release() + async def update_stats(self): try: while True: - line = (await self.process.stdout.readuntil()).decode() - if DEBUG: - print(line) + try: + line = (await self.sock_reader.readuntil()).decode() + except Exception as e: + self.ack_arrived = False + self.ack_status = False + self.ack_fail_what = "Can't read from nfq client" + self.ack_lock.release() + await self.stop() + raise HTTPException(status_code=500, detail="Can't read from nfq client") from e if line.startswith("BLOCKED "): - filter_id = line.split()[1] + filter_name = line.split()[1] + print("BLOCKED", filter_name) async with self.filter_map_lock: - if filter_id in self.filter_map: - self.filter_map[filter_id].blocked_packets+=1 - await self.filter_map[filter_id].update() + print("LOCKED MAP LOCK") + if filter_name in self.filter_map: + print("ADDING BLOCKED PACKET") + self.filter_map[filter_name].blocked_packets+=1 + await self.filter_map[filter_name].update() if line.startswith("MANGLED "): - filter_id = line.split()[1] + filter_name = line.split()[1] async with self.filter_map_lock: - if filter_id in self.filter_map: - self.filter_map[filter_id].edited_packets+=1 - await self.filter_map[filter_id].update() + if filter_name in self.filter_map: + self.filter_map[filter_name].edited_packets+=1 + await self.filter_map[filter_name].update() if line.startswith("EXCEPTION"): + self.last_time_exception = time.time() print("TODO EXCEPTION HANDLING") # TODO if line.startswith("ACK "): self.ack_arrived = True @@ -101,22 +129,29 @@ class FiregexInterceptor: traceback.print_exc() async def stop(self): + self.server_task.cancel() self.update_task.cancel() + self.unix_sock.close() + if os.path.exists(self.sock_path): + os.remove(self.sock_path) if self.process and self.process.returncode is None: self.process.kill() async def _update_config(self, code): async with self.update_config_lock: - self.process.stdin.write(len(code).to_bytes(4, byteorder='big')+code.encode()) - await self.process.stdin.drain() - try: - async with asyncio.timeout(3): - await self.ack_lock.acquire() - except TimeoutError: - pass - if not self.ack_arrived or not self.ack_status: - await self.stop() - raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}") + if self.sock_writer: + self.sock_writer.write(len(code).to_bytes(4, byteorder='big')+code.encode()) + await self.sock_writer.drain() + try: + async with asyncio.timeout(3): + await self.ack_lock.acquire() + except TimeoutError: + self.ack_fail_what = "Queue response timed-out" + if not self.ack_arrived or not self.ack_status: + await self.stop() + raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}") + else: + raise HTTPException(status_code=400, detail="Socket not ready") async def reload(self, filters:list[PyFilter]): async with self.filter_map_lock: @@ -125,12 +160,13 @@ class FiregexInterceptor: filter_file = f.read() else: filter_file = "" + self.filter_map = {ele.name: ele for ele in filters} await self._update_config( - "global __firegex_pyfilter_enabled\n" + + + filter_file + "\n\n" + "__firegex_pyfilter_enabled = [" + ", ".join([repr(f.name) for f in filters]) + "]\n" + "__firegex_proto = " + repr(self.srv.proto) + "\n" + - "import firegex.nfproxy.internals\n\n" + - filter_file + "\n\n" + - "firegex.nfproxy.internals.compile()" + "import firegex.nfproxy.internals\n" + + "firegex.nfproxy.internals.compile(globals())\n" ) diff --git a/backend/modules/nfproxy/firewall.py b/backend/modules/nfproxy/firewall.py index 59002d9..c424686 100644 --- a/backend/modules/nfproxy/firewall.py +++ b/backend/modules/nfproxy/firewall.py @@ -15,18 +15,18 @@ class ServiceManager: self.srv = srv self.db = db self.status = STATUS.STOP - self.filters: dict[int, FiregexFilter] = {} + self.filters: dict[str, FiregexFilter] = {} self.lock = asyncio.Lock() self.interceptor = None async def _update_filters_from_db(self): pyfilters = [ - PyFilter.from_dict(ele) for ele in + PyFilter.from_dict(ele, self.db) for ele in self.db.query("SELECT * FROM pyfilter WHERE service_id = ? AND active=1;", self.srv.id) ] #Filter check old_filters = set(self.filters.keys()) - new_filters = set([f.id for f in pyfilters]) + new_filters = set([f.name for f in pyfilters]) #remove old filters for f in old_filters: if f not in new_filters: @@ -34,7 +34,7 @@ class ServiceManager: #add new filters for f in new_filters: if f not in old_filters: - self.filters[f] = [ele for ele in pyfilters if ele.id == f][0] + self.filters[f] = [ele for ele in pyfilters if ele.name == f][0] if self.interceptor: await self.interceptor.reload(self.filters.values()) @@ -43,16 +43,11 @@ class ServiceManager: async def next(self,to): async with self.lock: - if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): + if to == STATUS.STOP: await self.stop() - self._set_status(to) - # PAUSE -> ACTIVE - elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE): + if to == STATUS.ACTIVE: await self.restart() - def _stats_updater(self,filter:PyFilter): - self.db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE filter_id = ?;", filter.blocked_packets, filter.edited_packets, filter.id) - def _set_status(self,status): self.status = status self.__update_status_db(status) @@ -60,7 +55,7 @@ class ServiceManager: async def start(self): if not self.interceptor: nft.delete(self.srv) - self.interceptor = await FiregexInterceptor.start(self.srv, self._stats_updater) + self.interceptor = await FiregexInterceptor.start(self.srv) await self._update_filters_from_db() self._set_status(STATUS.ACTIVE) @@ -69,6 +64,7 @@ class ServiceManager: if self.interceptor: await self.interceptor.stop() self.interceptor = None + self._set_status(STATUS.STOP) async def restart(self): await self.stop() diff --git a/backend/modules/nfproxy/models.py b/backend/modules/nfproxy/models.py index 4417db0..bb691cd 100644 --- a/backend/modules/nfproxy/models.py +++ b/backend/modules/nfproxy/models.py @@ -15,13 +15,19 @@ class Service: class PyFilter: - def __init__(self, filter_id:int, name: str, blocked_packets: int, edited_packets: int, active: bool, **other): - self.id = filter_id + def __init__(self, name: str, blocked_packets: int, edited_packets: int, active: bool, db, **other): self.name = name self.blocked_packets = blocked_packets self.edited_packets = edited_packets self.active = active + self.__db = db + + async def update(self): + self.__db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE name = ?;", self.blocked_packets, self.edited_packets, self.name) + + def __repr__(self): + return f"" @classmethod - def from_dict(cls, var: dict): - return cls(**var) + def from_dict(cls, var: dict, db): + return cls(**var, db=db) diff --git a/backend/modules/nfproxy/nftables.py b/backend/modules/nfproxy/nftables.py index eafa129..84c24c9 100644 --- a/backend/modules/nfproxy/nftables.py +++ b/backend/modules/nfproxy/nftables.py @@ -1,6 +1,14 @@ from modules.nfproxy.models import Service from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json +def convert_protocol_to_l4(proto:str): + if proto == "tcp": + return "tcp" + elif proto == "http": + return "tcp" + else: + raise Exception("Invalid protocol") + class FiregexFilter: def __init__(self, proto:str, port:int, ip_int:str, target:str, id:int): self.id = id @@ -11,7 +19,7 @@ class FiregexFilter: def __eq__(self, o: object) -> bool: if isinstance(o, FiregexFilter) or isinstance(o, Service): - return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) + return self.port == o.port and self.proto == convert_protocol_to_l4(o.proto) and ip_parse(self.ip_int) == ip_parse(o.ip_int) return False class FiregexTables(NFTableManager): @@ -61,7 +69,7 @@ class FiregexTables(NFTableManager): "chain": self.output_chain, "expr": [ {'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}}, - {'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}}, + {'match': {"left": { "payload": {"protocol": convert_protocol_to_l4(str(srv.proto)), "field": "sport"}}, "op": "==", "right": int(srv.port)}}, {"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1338}}, {"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}} ] @@ -72,7 +80,7 @@ class FiregexTables(NFTableManager): "chain": self.input_chain, "expr": [ {'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}}, - {'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}}, + {'match': {"left": { "payload": {"protocol": convert_protocol_to_l4(str(srv.proto)), "field": "dport"}}, "op": "==", "right": int(srv.port)}}, {"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1337}}, {"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}} ] diff --git a/backend/modules/nfregex/firegex.py b/backend/modules/nfregex/firegex.py index 5e6b2b0..701ca9d 100644 --- a/backend/modules/nfregex/firegex.py +++ b/backend/modules/nfregex/firegex.py @@ -79,7 +79,7 @@ class FiregexInterceptor: self.update_task: asyncio.Task self.ack_arrived = False self.ack_status = None - self.ack_fail_what = "Unknown" + self.ack_fail_what = "Queue response timed-out" self.ack_lock = asyncio.Lock() @classmethod @@ -158,7 +158,7 @@ class FiregexInterceptor: async with asyncio.timeout(3): await self.ack_lock.acquire() except TimeoutError: - pass + self.ack_fail_what = "Queue response timed-out" if not self.ack_arrived or not self.ack_status: await self.stop() raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}") diff --git a/backend/modules/nfregex/firewall.py b/backend/modules/nfregex/firewall.py index d0d5479..ec9231e 100644 --- a/backend/modules/nfregex/firewall.py +++ b/backend/modules/nfregex/firewall.py @@ -45,11 +45,9 @@ class ServiceManager: async def next(self,to): async with self.lock: - if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): + if to == STATUS.STOP: await self.stop() - self._set_status(to) - # PAUSE -> ACTIVE - elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE): + if to == STATUS.ACTIVE: await self.restart() def _stats_updater(self,filter:RegexFilter): @@ -71,6 +69,7 @@ class ServiceManager: if self.interceptor: await self.interceptor.stop() self.interceptor = None + self._set_status(STATUS.STOP) async def restart(self): await self.stop() diff --git a/backend/routers/nfproxy.py b/backend/routers/nfproxy.py index efcc664..77405d1 100644 --- a/backend/routers/nfproxy.py +++ b/backend/routers/nfproxy.py @@ -10,6 +10,10 @@ from utils.models import ResetRequest, StatusMessageModel import os from firegex.nfproxy.internals import get_filter_names from fastapi.responses import PlainTextResponse +from modules.nfproxy.nftables import convert_protocol_to_l4 +import asyncio +import traceback +from utils import DEBUG class ServiceModel(BaseModel): service_id: str @@ -28,12 +32,10 @@ class RenameForm(BaseModel): class SettingsForm(BaseModel): port: PortType|None = None - proto: str|None = None ip_int: str|None = None fail_open: bool|None = None class PyFilterModel(BaseModel): - filter_id: int name: str blocked_packets: int edited_packets: int @@ -52,6 +54,7 @@ class ServiceAddResponse(BaseModel): class SetPyFilterForm(BaseModel): code: str + sid: str|None = None app = APIRouter() @@ -62,12 +65,12 @@ db = SQLite('db/nft-pyfilters.db', { 'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)', 'name': 'VARCHAR(100) NOT NULL UNIQUE', 'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "http"))', + 'l4_proto': 'VARCHAR(3) NOT NULL CHECK (l4_proto IN ("tcp", "udp"))', 'ip_int': 'VARCHAR(100) NOT NULL', 'fail_open': 'BOOLEAN NOT NULL CHECK (fail_open IN (0, 1)) DEFAULT 1', }, 'pyfilter': { - 'filter_id': 'INTEGER PRIMARY KEY', - 'name': 'VARCHAR(100) NOT NULL', + 'name': 'VARCHAR(100) PRIMARY KEY', 'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', 'edited_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', 'service_id': 'VARCHAR(100) NOT NULL', @@ -75,7 +78,7 @@ db = SQLite('db/nft-pyfilters.db', { 'FOREIGN KEY (service_id)':'REFERENCES services (service_id)', }, 'QUERY':[ - "CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (port, ip_int, proto);", + "CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (port, ip_int, l4_proto);", "CREATE UNIQUE INDEX IF NOT EXISTS unique_pyfilter_service ON pyfilter (name, service_id);" ] }) @@ -132,7 +135,7 @@ async def get_service_list(): s.proto proto, s.ip_int ip_int, s.fail_open fail_open, - COUNT(f.filter_id) n_filters, + COUNT(f.name) n_filters, COALESCE(SUM(f.blocked_packets),0) blocked_packets, COALESCE(SUM(f.edited_packets),0) edited_packets FROM services s LEFT JOIN pyfilter f ON s.service_id = f.service_id @@ -151,7 +154,7 @@ async def get_service_by_id(service_id: str): s.proto proto, s.ip_int ip_int, s.fail_open fail_open, - COUNT(f.filter_id) n_filters, + COUNT(f.name) n_filters, COALESCE(SUM(f.blocked_packets),0) blocked_packets, COALESCE(SUM(f.edited_packets),0) edited_packets FROM services s LEFT JOIN pyfilter f ON s.service_id = f.service_id @@ -202,9 +205,6 @@ async def service_rename(service_id: str, form: RenameForm): @app.put('/services/{service_id}/settings', response_model=StatusMessageModel) async def service_settings(service_id: str, form: SettingsForm): """Request to change the settings of a specific service (will cause a restart)""" - - if form.proto is not None and form.proto not in ["tcp", "udp"]: - raise HTTPException(status_code=400, detail="Invalid protocol") if form.port is not None and (form.port < 1 or form.port > 65535): raise HTTPException(status_code=400, detail="Invalid port") @@ -245,38 +245,38 @@ async def get_service_pyfilter_list(service_id: str): raise HTTPException(status_code=400, detail="This service does not exists!") return db.query(""" SELECT - filter_id, name, blocked_packets, edited_packets, active + name, blocked_packets, edited_packets, active FROM pyfilter WHERE service_id = ?; """, service_id) -@app.get('/pyfilters/{filter_id}', response_model=PyFilterModel) -async def get_pyfilter_by_id(filter_id: int): +@app.get('/pyfilters/{filter_name}', response_model=PyFilterModel) +async def get_pyfilter_by_id(filter_name: str): """Get pyfilter info using his id""" res = db.query(""" SELECT - filter_id, name, blocked_packets, edited_packets, active - FROM pyfilter WHERE filter_id = ?; - """, filter_id) + name, blocked_packets, edited_packets, active + FROM pyfilter WHERE name = ?; + """, filter_name) if len(res) == 0: raise HTTPException(status_code=400, detail="This filter does not exists!") return res[0] -@app.post('/pyfilters/{filter_id}/enable', response_model=StatusMessageModel) -async def pyfilter_enable(filter_id: int): +@app.post('/pyfilters/{filter_name}/enable', response_model=StatusMessageModel) +async def pyfilter_enable(filter_name: str): """Request the enabling of a pyfilter""" - res = db.query('SELECT * FROM pyfilter WHERE filter_id = ?;', filter_id) + res = db.query('SELECT * FROM pyfilter WHERE name = ?;', filter_name) if len(res) != 0: - db.query('UPDATE pyfilter SET active=1 WHERE filter_id = ?;', filter_id) + db.query('UPDATE pyfilter SET active=1 WHERE name = ?;', filter_name) await firewall.get(res[0]["service_id"]).update_filters() await refresh_frontend() return {'status': 'ok'} -@app.post('/pyfilters/{filter_id}/disable', response_model=StatusMessageModel) -async def pyfilter_disable(filter_id: int): +@app.post('/pyfilters/{filter_name}/disable', response_model=StatusMessageModel) +async def pyfilter_disable(filter_name: str): """Request the deactivation of a pyfilter""" - res = db.query('SELECT * FROM pyfilter WHERE filter_id = ?;', filter_id) + res = db.query('SELECT * FROM pyfilter WHERE name = ?;', filter_name) if len(res) != 0: - db.query('UPDATE pyfilter SET active=0 WHERE filter_id = ?;', filter_id) + db.query('UPDATE pyfilter SET active=0 WHERE name = ?;', filter_name) await firewall.get(res[0]["service_id"]).update_filters() await refresh_frontend() return {'status': 'ok'} @@ -293,8 +293,8 @@ async def add_new_service(form: ServiceAddForm): srv_id = None try: srv_id = gen_service_id() - db.query("INSERT INTO services (service_id ,name, port, status, proto, ip_int, fail_open) VALUES (?, ?, ?, ?, ?, ?, ?)", - srv_id, refactor_name(form.name), form.port, STATUS.STOP, form.proto, form.ip_int, form.fail_open) + db.query("INSERT INTO services (service_id ,name, port, status, proto, ip_int, fail_open, l4_proto) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + srv_id, refactor_name(form.name), form.port, STATUS.STOP, form.proto, form.ip_int, form.fail_open, convert_protocol_to_l4(form.proto)) except sqlite3.IntegrityError: raise HTTPException(status_code=400, detail="This type of service already exists") await firewall.reload() @@ -308,29 +308,41 @@ async def set_pyfilters(service_id: str, form: SetPyFilterForm): if len(service) == 0: raise HTTPException(status_code=400, detail="This service does not exists!") service = service[0] + service_id = service["service_id"] srv_proto = service["proto"] + try: - found_filters = get_filter_names(form.code, srv_proto) - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) + async with asyncio.timeout(8): + try: + found_filters = get_filter_names(form.code, srv_proto) + except Exception as e: + if DEBUG: + traceback.print_exc() + raise HTTPException(status_code=400, detail="Compile error: "+str(e)) + + # Remove filters that are not in the new code + existing_filters = db.query("SELECT name FROM pyfilter WHERE service_id = ?;", service_id) + existing_filters = [ele["name"] for ele in existing_filters] + for filter in existing_filters: + if filter not in found_filters: + db.query("DELETE FROM pyfilter WHERE name = ?;", filter) + + # Add filters that are in the new code but not in the database + for filter in found_filters: + if not db.query("SELECT 1 FROM pyfilter WHERE service_id = ? AND name = ?;", service_id, filter): + db.query("INSERT INTO pyfilter (name, service_id) VALUES (?, ?);", filter, service["service_id"]) + + # Eventually edited filters will be reloaded + os.makedirs("db/nfproxy_filters", exist_ok=True) + with open(f"db/nfproxy_filters/{service_id}.py", "w") as f: + f.write(form.code) + await firewall.get(service_id).update_filters() + await refresh_frontend() + except asyncio.TimeoutError: + if DEBUG: + traceback.print_exc() + raise HTTPException(status_code=400, detail="The operation took too long") - # Remove filters that are not in the new code - existing_filters = db.query("SELECT filter_id FROM pyfilter WHERE service_id = ?;", service_id) - for filter in existing_filters: - if filter["name"] not in found_filters: - db.query("DELETE FROM pyfilter WHERE filter_id = ?;", filter["filter_id"]) - - # Add filters that are in the new code but not in the database - for filter in found_filters: - if not db.query("SELECT 1 FROM pyfilter WHERE service_id = ? AND name = ?;", service_id, filter): - db.query("INSERT INTO pyfilter (name, service_id) VALUES (?, ?);", filter, service["service_id"]) - - # Eventually edited filters will be reloaded - os.makedirs("db/nfproxy_filters", exist_ok=True) - with open(f"db/nfproxy_filters/{service_id}.py", "w") as f: - f.write(form.code) - await firewall.get(service_id).update_filters() - await refresh_frontend() return {'status': 'ok'} @app.get('/services/{service_id}/pyfilters/code', response_class=PlainTextResponse) @@ -343,7 +355,3 @@ async def get_pyfilters(service_id: str): return f.read() except FileNotFoundError: return "" - -#TODO check all the APIs and add -# 1. API to change the python filter file (DONE) -# 2. a socketio mechanism to lock the previous feature \ No newline at end of file diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py index 1d9c23a..c4fc13d 100644 --- a/backend/utils/__init__.py +++ b/backend/utils/__init__.py @@ -8,15 +8,22 @@ import nftables from socketio import AsyncServer from fastapi import Path from typing import Annotated +from functools import wraps +from pydantic import BaseModel, ValidationError +import traceback +from utils.models import StatusMessageModel +from typing import List LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1")) socketio:AsyncServer = None +sid_list:set = set() ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) ROUTERS_DIR = os.path.join(ROOT_DIR,"routers") ON_DOCKER = "DOCKER" in sys.argv DEBUG = "DEBUG" in sys.argv +NORELOAD = "NORELOAD" in sys.argv FIREGEX_PORT = int(os.getenv("PORT","4444")) JWT_ALGORITHM: str = "HS256" API_VERSION = "{{VERSION_PLACEHOLDER}}" if "{" not in "{{VERSION_PLACEHOLDER}}" else "0.0.0" @@ -153,4 +160,50 @@ class NFTableManager(Singleton): def raw_list(self): return self.cmd({"list": {"ruleset": None}})["nftables"] - +def _json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json"): + res = obj.model_dump(mode=mode, exclude_unset=not unset) + if convert_keys: + for from_k, to_k in convert_keys.items(): + if from_k in res: + res[to_k] = res.pop(from_k) + if exclude: + for ele in exclude: + if ele in res: + del res[ele] + return res + +def json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json") -> dict: + if isinstance(obj, list): + return [_json_like(ele, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode) for ele in obj] + return _json_like(obj, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode) + +def register_event(sio_server: AsyncServer, event_name: str, model: BaseModel, response_model: BaseModel|None = None): + def decorator(func): + @sio_server.on(event_name) # Automatically registers the event + @wraps(func) + async def wrapper(sid, data): + try: + # Parse and validate incoming data + parsed_data = model.model_validate(data) + except ValidationError: + return json_like(StatusMessageModel(status=f"Invalid {event_name} request")) + + # Call the original function with the parsed data + result = await func(sid, parsed_data) + # If a response model is provided, validate the output + if response_model: + try: + parsed_result = response_model.model_validate(result) + except ValidationError: + traceback.print_exc() + return json_like(StatusMessageModel(status=f"SERVER ERROR: Invalid {event_name} response")) + else: + parsed_result = result + # Emit the validated result + if parsed_result: + if isinstance(parsed_result, BaseModel): + return json_like(parsed_result) + return parsed_result + return wrapper + return decorator + diff --git a/backend/utils/loader.py b/backend/utils/loader.py index 435c8c2..5e5dd32 100644 --- a/backend/utils/loader.py +++ b/backend/utils/loader.py @@ -7,6 +7,7 @@ from starlette.responses import StreamingResponse from fastapi.responses import FileResponse from utils import DEBUG, ON_DOCKER, ROUTERS_DIR, list_files, run_func from utils.models import ResetRequest +import asyncio REACT_BUILD_DIR: str = "../frontend/build/" if not ON_DOCKER else "frontend/" REACT_HTML_PATH: str = os.path.join(REACT_BUILD_DIR,"index.html") @@ -87,12 +88,9 @@ def load_routers(app): if router.shutdown: shutdowns.append(router.shutdown) async def reset(reset_option:ResetRequest): - for func in resets: - await run_func(func, reset_option) + await asyncio.gather(*[run_func(func, reset_option) for func in resets]) async def startup(): - for func in startups: - await run_func(func) + await asyncio.gather(*[run_func(func) for func in startups]) async def shutdown(): - for func in shutdowns: - await run_func(func) + await asyncio.gather(*[run_func(func) for func in shutdowns]) return reset, startup, shutdown diff --git a/docs/FiregexInternals.png b/docs/FiregexInternals.png index 6a19f3c..7ecc3e0 100644 Binary files a/docs/FiregexInternals.png and b/docs/FiregexInternals.png differ diff --git a/docs/Firegex_Screenshot.png b/docs/Firegex_Screenshot.png index 935d573..2532aba 100644 Binary files a/docs/Firegex_Screenshot.png and b/docs/Firegex_Screenshot.png differ diff --git a/fgex-lib/firegex/nfproxy/__init__.py b/fgex-lib/firegex/nfproxy/__init__.py index 40d6559..8948e4c 100644 --- a/fgex-lib/firegex/nfproxy/__init__.py +++ b/fgex-lib/firegex/nfproxy/__init__.py @@ -1,11 +1,23 @@ import functools +from firegex.nfproxy.params import RawPacket +from enum import Enum -ACCEPT = 0 -DROP = 1 -REJECT = 2 -MANGLE = 3 -EXCEPTION = 4 -INVALID = 5 +class Action(Enum): + ACCEPT = 0 + DROP = 1 + REJECT = 2 + MANGLE = 3 + +class FullStreamAction(Enum): + FLUSH = 0 + ACCEPT = 1 + REJECT = 2 + DROP = 3 + +ACCEPT = Action.ACCEPT +DROP = Action.DROP +REJECT = Action.REJECT +MANGLE = Action.MANGLE def pyfilter(func): """ @@ -27,12 +39,14 @@ def get_pyfilters(): """Returns the list of functions marked with @pyfilter.""" return list(pyfilter.registry) +def clear_pyfilter_registry(): + """Clears the pyfilter registry.""" + if hasattr(pyfilter, "registry"): + pyfilter.registry.clear() - - - - - - - - +__all__ = [ + "ACCEPT", "DROP", "REJECT", "MANGLE", "EXCEPTION", "INVALID", + "Action", "FullStreamAction", + "pyfilter", + "RawPacket" +] \ No newline at end of file diff --git a/fgex-lib/firegex/nfproxy/internals.py b/fgex-lib/firegex/nfproxy/internals.py index fb9fa98..cfa9169 100644 --- a/fgex-lib/firegex/nfproxy/internals.py +++ b/fgex-lib/firegex/nfproxy/internals.py @@ -1,21 +1,7 @@ from inspect import signature from firegex.nfproxy.params import RawPacket, NotReadyToRun -from firegex.nfproxy import ACCEPT, DROP, REJECT, MANGLE, EXCEPTION, INVALID - -RESULTS = [ - ACCEPT, - DROP, - REJECT, - MANGLE, - EXCEPTION, - INVALID -] -FULL_STREAM_ACTIONS = [ - "flush" - "accept", - "reject", - "drop" -] +from firegex.nfproxy import Action, FullStreamAction +from dataclasses import dataclass, field type_annotations_associations = { "tcp": { @@ -26,136 +12,178 @@ type_annotations_associations = { } } -def _generate_filter_structure(filters: list[str], proto:str, glob:dict, local:dict): +@dataclass +class FilterHandler: + func: callable + name: str + params: dict[type, callable] + proto: str + +class internal_data: + filter_call_info: list[FilterHandler] = [] + stream: list[RawPacket] = [] + stream_size: int = 0 + stream_max_size: int = 1*8e20 + full_stream_action: str = "flush" + filter_glob: dict = {} + +@dataclass +class PacketHandlerResult: + glob: dict = field(repr=False) + action: Action = Action.ACCEPT + matched_by: str = None + mangled_packet: bytes = None + + def set_result(self) -> None: + self.glob["__firegex_pyfilter_result"] = { + "action": self.action.value, + "matched_by": self.matched_by, + "mangled_packet": self.mangled_packet + } + + def reset_result(self) -> None: + self.glob["__firegex_pyfilter_result"] = None + +def context_call(func, *args, **kargs): + internal_data.filter_glob["__firegex_tmp_args"] = args + internal_data.filter_glob["__firegex_tmp_kargs"] = kargs + internal_data.filter_glob["__firege_tmp_call"] = func + res = eval("__firege_tmp_call(*__firegex_tmp_args, **__firegex_tmp_kargs)", internal_data.filter_glob, internal_data.filter_glob) + del internal_data.filter_glob["__firegex_tmp_args"] + del internal_data.filter_glob["__firegex_tmp_kargs"] + del internal_data.filter_glob["__firege_tmp_call"] + return res + +def generate_filter_structure(filters: list[str], proto:str, glob:dict) -> list[FilterHandler]: if proto not in type_annotations_associations.keys(): raise Exception("Invalid protocol") - res = [] - valid_annotation_type = type_annotations_associations[proto] def add_func_to_list(func): if not callable(func): raise Exception(f"{func} is not a function") sig = signature(func) - params_function = [] + params_function = {} for k, v in sig.parameters.items(): if v.annotation in valid_annotation_type.keys(): - params_function.append((v.annotation, valid_annotation_type[v.annotation])) + params_function[v.annotation] = valid_annotation_type[v.annotation] else: raise Exception(f"Invalid type annotation {v.annotation} for function {func.__name__}") - res.append((func, params_function)) + + res.append( + FilterHandler( + func=func, + name=func.__name__, + params=params_function, + proto=proto + ) + ) for filter in filters: if not isinstance(filter, str): raise Exception("Invalid filter list: must be a list of strings") if filter in glob.keys(): add_func_to_list(glob[filter]) - elif filter in local.keys(): - add_func_to_list(local[filter]) else: raise Exception(f"Filter {filter} not found") - return res -def get_filters_info(code:str, proto:str): +def get_filters_info(code:str, proto:str) -> list[FilterHandler]: glob = {} - local = {} - exec(code, glob, local) - exec("import firegex.nfproxy", glob, local) - filters = eval("firegex.nfproxy.get_pyfilters()", glob, local) - return _generate_filter_structure(filters, proto, glob, local) + exec(code, glob, glob) + exec("import firegex.nfproxy", glob, glob) + filters = eval("firegex.nfproxy.get_pyfilters()", glob, glob) + try: + return generate_filter_structure(filters, proto, glob) + finally: + exec("firegex.nfproxy.clear_pyfilter_registry()", glob, glob) + -def get_filter_names(code:str, proto:str): - return [ele[0].__name__ for ele in get_filters_info(code, proto)] +def get_filter_names(code:str, proto:str) -> list[str]: + return [ele.name for ele in get_filters_info(code, proto)] -def compile(): - glob = globals() - local = locals() - filters = glob["__firegex_pyfilter_enabled"] - proto = glob["__firegex_proto"] - glob["__firegex_func_list"] = _generate_filter_structure(filters, proto, glob, local) - glob["__firegex_stream"] = [] - glob["__firegex_stream_size"] = 0 +def handle_packet() -> None: + cache_call = {} # Cache of the data handler calls - if "FGEX_STREAM_MAX_SIZE" in local and int(local["FGEX_STREAM_MAX_SIZE"]) > 0: - glob["__firegex_stream_max_size"] = int(local["FGEX_STREAM_MAX_SIZE"]) - elif "FGEX_STREAM_MAX_SIZE" in glob and int(glob["FGEX_STREAM_MAX_SIZE"]) > 0: - glob["__firegex_stream_max_size"] = int(glob["FGEX_STREAM_MAX_SIZE"]) - else: - glob["__firegex_stream_max_size"] = 1*8e20 # 1MB default value + pkt_info = RawPacket.fetch_from_global(internal_data.filter_glob) + cache_call[RawPacket] = pkt_info - if "FGEX_FULL_STREAM_ACTION" in local and local["FGEX_FULL_STREAM_ACTION"] in FULL_STREAM_ACTIONS: - glob["__firegex_full_stream_action"] = local["FGEX_FULL_STREAM_ACTION"] - else: - glob["__firegex_full_stream_action"] = "flush" + final_result = Action.ACCEPT + data_size = len(pkt_info.data) + + result = PacketHandlerResult(internal_data.filter_glob) + + if internal_data.stream_size+data_size > internal_data.stream_max_size: + match internal_data.full_stream_action: + case FullStreamAction.FLUSH: + internal_data.stream = [] + internal_data.stream_size = 0 + case FullStreamAction.ACCEPT: + result.action = Action.ACCEPT + return result.set_result() + case FullStreamAction.REJECT: + result.action = Action.REJECT + result.matched_by = "@MAX_STREAM_SIZE_REACHED" + return result.set_result() + case FullStreamAction.REJECT: + result.action = Action.DROP + result.matched_by = "@MAX_STREAM_SIZE_REACHED" + return result.set_result() + + internal_data.stream.append(pkt_info) + internal_data.stream_size += data_size - glob["__firegex_pyfilter_result"] = None - -def handle_packet(): - glob = globals() - func_list = glob["__firegex_func_list"] - final_result = ACCEPT - cache_call = {} - cache_call[RawPacket] = RawPacket.fetch_from_global() - data_size = len(cache_call[RawPacket].data) - if glob["__firegex_stream_size"]+data_size > glob["__firegex_stream_max_size"]: - match glob["__firegex_full_stream_action"]: - case "flush": - glob["__firegex_stream"] = [] - glob["__firegex_stream_size"] = 0 - case "accept": - glob["__firegex_pyfilter_result"] = { - "action": ACCEPT, - "matched_by": None, - "mangled_packet": None - } - return - case "reject": - glob["__firegex_pyfilter_result"] = { - "action": REJECT, - "matched_by": "@MAX_STREAM_SIZE_REACHED", - "mangled_packet": None - } - return - case "drop": - glob["__firegex_pyfilter_result"] = { - "action": DROP, - "matched_by": "@MAX_STREAM_SIZE_REACHED", - "mangled_packet": None - } - return - glob["__firegex_stream"].append(cache_call[RawPacket]) - glob["__firegex_stream_size"] += data_size func_name = None mangled_packet = None - for filter in func_list: + for filter in internal_data.filter_call_info: final_params = [] - for ele in filter[1]: - if ele[0] not in cache_call.keys(): + for data_type, data_func in filter.params.items(): + if data_type not in cache_call.keys(): try: - cache_call[ele[0]] = ele[1]() + cache_call[data_type] = data_func(internal_data.filter_glob) except NotReadyToRun: - cache_call[ele[0]] = None - if cache_call[ele[0]] is None: + cache_call[data_type] = None + if cache_call[data_type] is None: continue # Parsing raised NotReadyToRun, skip filter - final_params.append(cache_call[ele[0]]) - res = filter[0](*final_params) + final_params.append(cache_call[data_type]) + + res = context_call(filter.func, *final_params) + if res is None: continue #ACCEPTED - if res == MANGLE: - if RawPacket not in cache_call.keys(): - continue #Packet not modified - pkt:RawPacket = cache_call[RawPacket] - mangled_packet = pkt.raw_packet - break - elif res != ACCEPT: + if not isinstance(res, Action): + raise Exception(f"Invalid return type {type(res)} for function {filter.name}") + if res == Action.MANGLE: + mangled_packet = pkt_info.raw_packet + if res != Action.ACCEPT: + func_name = filter.name final_result = res - func_name = filter[0].__name__ break - glob["__firegex_pyfilter_result"] = { - "action": final_result, - "matched_by": func_name, - "mangled_packet": mangled_packet - } + + result.action = final_result + result.matched_by = func_name + result.mangled_packet = mangled_packet + + return result.set_result() + +def compile(glob:dict) -> None: + internal_data.filter_glob = glob + + filters = glob["__firegex_pyfilter_enabled"] + proto = glob["__firegex_proto"] + + internal_data.filter_call_info = generate_filter_structure(filters, proto, glob) + + if "FGEX_STREAM_MAX_SIZE" in glob and int(glob["FGEX_STREAM_MAX_SIZE"]) > 0: + internal_data.stream_max_size = int(glob["FGEX_STREAM_MAX_SIZE"]) + else: + internal_data.stream_max_size = 1*8e20 # 1MB default value + + if "FGEX_FULL_STREAM_ACTION" in glob and isinstance(glob["FGEX_FULL_STREAM_ACTION"], FullStreamAction): + internal_data.full_stream_action = glob["FGEX_FULL_STREAM_ACTION"] + else: + internal_data.full_stream_action = FullStreamAction.FLUSH + + PacketHandlerResult(glob).reset_result() diff --git a/fgex-lib/firegex/nfproxy/params.py b/fgex-lib/firegex/nfproxy/params.py index e2969b5..025ff3e 100644 --- a/fgex-lib/firegex/nfproxy/params.py +++ b/fgex-lib/firegex/nfproxy/params.py @@ -9,12 +9,15 @@ class RawPacket: is_input: bool, is_ipv6: bool, is_tcp: bool, + l4_size: int, ): self.__data = bytes(data) self.__raw_packet = bytes(raw_packet) self.__is_input = bool(is_input) self.__is_ipv6 = bool(is_ipv6) self.__is_tcp = bool(is_tcp) + self.__l4_size = int(l4_size) + self.__raw_packet_header_size = len(self.__raw_packet)-self.__l4_size @property def is_input(self) -> bool: @@ -33,19 +36,25 @@ class RawPacket: return self.__data @property - def proto_header(self) -> bytes: - return self.__raw_packet[:self.proto_header_len] + def l4_size(self) -> int: + return self.__l4_size @property - def proto_header_len(self) -> int: - return len(self.__raw_packet) - len(self.__data) + def raw_packet_header_len(self) -> int: + return self.__raw_packet_header_size - @data.setter - def data(self, v:bytes): + @property + def l4_data(self) -> bytes: + return self.__raw_packet[self.raw_packet_header_len:] + + @l4_data.setter + def l4_data(self, v:bytes): if not isinstance(v, bytes): raise Exception("Invalid data type, data MUST be of type bytes") - self.__raw_packet = self.proto_header + v - self.__data = v + #if len(v) != self.__l4_size: + # raise Exception("Invalid data size, must be equal to the original packet header size (due to a technical limitation)") + self.__raw_packet = self.__raw_packet[:self.raw_packet_header_len]+v + self.__l4_size = len(v) @property def raw_packet(self) -> bytes: @@ -55,17 +64,16 @@ class RawPacket: def raw_packet(self, v:bytes): if not isinstance(v, bytes): raise Exception("Invalid data type, data MUST be of type bytes") - if len(v) < self.proto_header_len: - raise Exception("Invalid packet length") - header_len = self.proto_header_len - self.__data = v[header_len:] + #if len(v) != len(self.__raw_packet): + # raise Exception("Invalid data size, must be equal to the original packet size (due to a technical limitation)") + if len(v) < self.raw_packet_header_len: + raise Exception("Invalid data size, must be greater than the original packet header size") self.__raw_packet = v + self.__l4_size = len(v)-self.raw_packet_header_len - @staticmethod - def fetch_from_global(): - glob = globals() + @classmethod + def fetch_from_global(cls, glob): if "__firegex_packet_info" not in glob.keys(): raise Exception("Packet info not found") - return RawPacket(**glob["__firegex_packet_info"]) - + return cls(**glob["__firegex_packet_info"]) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c13dd32..a988fdc 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -15,8 +15,21 @@ import { useQueryClient } from '@tanstack/react-query'; import NFProxy from './pages/NFProxy'; import ServiceDetailsNFProxy from './pages/NFProxy/ServiceDetails'; - -const socket = IS_DEV?io("ws://"+DEV_IP_BACKEND, {transports: ["websocket"], path:"/sock/socket.io" }):io({transports: ["websocket"], path:"/sock/socket.io"}); +export const socket = import.meta.env.DEV? + io("ws://"+DEV_IP_BACKEND, { + path:"/sock/socket.io", + transports: ['websocket'], + auth: { + token: localStorage.getItem("access_token") + } + }): + io({ + path:"/sock/socket.io", + transports: ['websocket'], + auth: { + token: localStorage.getItem("access_token") + } + }) function App() { @@ -25,33 +38,20 @@ function App() { const [reqError, setReqError] = useState() const [error, setError] = useState() const [loadinBtn, setLoadingBtn] = useState(false); - const queryClient = useQueryClient() + const getStatus = () =>{ - getstatus().then( res =>{ - setSystemStatus(res) - setReqError(undefined) - setLoading(false) - }).catch(err=>{ - setReqError(err.toString()) - setLoading(false) - setTimeout(getStatus, 500) - }) + getstatus().then( res =>{ + setSystemStatus(res) + setReqError(undefined) + }).catch(err=>{ + setReqError(err.toString()) + setTimeout(getStatus, 500) + }).finally( ()=>setLoading(false) ) } useEffect(()=>{ getStatus() - socket.on("update", (data) => { - queryClient.invalidateQueries({ queryKey: data }) - }) - socket.on("connect_error", (err) => { - errorNotify("Socket.Io connection failed! ",`Error message: [${err.message}]`) - getStatus() - }); - return () => { - socket.off("update") - socket.off("connect_error") - } },[]) const form = useForm({ @@ -145,19 +145,7 @@ function App() { :null} }else if (systemStatus.status === "run" && systemStatus.loggined){ - return - }> - } > - } /> - - } > - } /> - - } /> - } /> - } /> - - + return }else{ return Error launching Firegex! 🔥 @@ -167,4 +155,41 @@ function App() { } } +const PageRouting = ({ getStatus }:{ getStatus:()=>void }) => { + + const queryClient = useQueryClient() + + + useEffect(()=>{ + getStatus() + socket.on("update", (data) => { + queryClient.invalidateQueries({ queryKey: data }) + }) + socket.on("connect_error", (err) => { + errorNotify("Socket.Io connection failed! ",`Error message: [${err.message}]`) + getStatus() + }); + return () => { + socket.off("update") + socket.off("connect_error") + } +},[]) + + return + }> + } > + } /> + + } > + } /> + + } /> + } /> + } /> + + +} + + + export default App; diff --git a/frontend/src/components/NFProxy/AddEditService.tsx b/frontend/src/components/NFProxy/AddEditService.tsx index d9287e2..daa5585 100644 --- a/frontend/src/components/NFProxy/AddEditService.tsx +++ b/frontend/src/components/NFProxy/AddEditService.tsx @@ -26,7 +26,7 @@ function AddEditService({ opened, onClose, edit }:{ opened:boolean, onClose:()=> validate:{ name: (value) => edit? null : value !== "" ? null : "Service name is required", port: (value) => (value>0 && value<65536) ? null : "Invalid port", - proto: (value) => ["tcp","udp"].includes(value) ? null : "Invalid protocol", + proto: (value) => ["tcp","http"].includes(value) ? null : "Invalid protocol", ip_int: (value) => (value.match(regex_ipv6) || value.match(regex_ipv4)) ? null : "Invalid IP address", } }) @@ -50,7 +50,7 @@ function AddEditService({ opened, onClose, edit }:{ opened:boolean, onClose:()=> const submitRequest = ({ name, port, autostart, proto, ip_int, fail_open }:ServiceAddForm) =>{ setSubmitLoading(true) if (edit){ - nfproxy.settings(edit.service_id, { port, proto, ip_int, fail_open }).then( res => { + nfproxy.settings(edit.service_id, { port, ip_int, fail_open }).then( res => { if (!res){ setSubmitLoading(false) close(); @@ -111,13 +111,13 @@ function AddEditService({ opened, onClose, edit }:{ opened:boolean, onClose:()=> /> - + />} diff --git a/frontend/src/components/NFProxy/UploadFilterModal.tsx b/frontend/src/components/NFProxy/UploadFilterModal.tsx new file mode 100644 index 0000000..1208393 --- /dev/null +++ b/frontend/src/components/NFProxy/UploadFilterModal.tsx @@ -0,0 +1,54 @@ +import { Button, FileButton, Group, Modal, Notification, Space } from "@mantine/core"; +import { nfproxy, Service } from "./utils"; +import { useEffect, useState } from "react"; +import { ImCross } from "react-icons/im"; +import { okNotify } from "../../js/utils"; + +export const UploadFilterModal = ({ opened, onClose, service }: { opened: boolean, onClose: () => void, service?: Service }) => { + const close = () =>{ + onClose() + setError(null) + } + + const [submitLoading, setSubmitLoading] = useState(false) + const [error, setError] = useState(null) + const [file, setFile] = useState(null); + + useEffect(() => { + if (opened && file){ + file.bytes().then( code => { + console.log(code.toString()) + setSubmitLoading(true) + nfproxy.setpyfilterscode(service?.service_id??"",code.toString()).then( res => { + if (!res){ + setSubmitLoading(false) + close(); + okNotify(`Service ${name} code updated`, `Successfully updated code for service ${name}`) + } + }).catch( err => { + setSubmitLoading(false) + setError("Error: "+err) + }) + }) + } + }, [opened, file]) + + return + + + + {(props) => } + + + + {error?<> + + } color="red" onClose={()=>{setError(null)}}> + Error: {error} + + :null} + + + + +} \ No newline at end of file diff --git a/frontend/src/components/NFProxy/utils.ts b/frontend/src/components/NFProxy/utils.ts index 492f560..f5287ca 100644 --- a/frontend/src/components/NFProxy/utils.ts +++ b/frontend/src/components/NFProxy/utils.ts @@ -25,7 +25,6 @@ export type ServiceAddForm = { export type ServiceSettings = { port?:number, - proto?:string, ip_int?:string, fail_open?: boolean, } @@ -55,12 +54,12 @@ export const nfproxy = { serviceinfo: async (service_id:string) => { return await getapi(`nfproxy/services/${service_id}`) as Service; }, - pyfilterenable: async (regex_id:number) => { - const { status } = await postapi(`nfproxy/pyfilters/${regex_id}/enable`) as ServerResponse; + pyfilterenable: async (filter_name:string) => { + const { status } = await postapi(`nfproxy/pyfilters/${filter_name}/enable`) as ServerResponse; return status === "ok"?undefined:status }, - pyfilterdisable: async (regex_id:number) => { - const { status } = await postapi(`nfproxy/pyfilters/${regex_id}/disable`) as ServerResponse; + pyfilterdisable: async (filter_name:string) => { + const { status } = await postapi(`nfproxy/pyfilters/${filter_name}/disable`) as ServerResponse; return status === "ok"?undefined:status }, servicestart: async (service_id:string) => { diff --git a/frontend/src/components/PyFilterView/index.tsx b/frontend/src/components/PyFilterView/index.tsx index 9a16108..3602029 100644 --- a/frontend/src/components/PyFilterView/index.tsx +++ b/frontend/src/components/PyFilterView/index.tsx @@ -1,7 +1,7 @@ import { Text, Badge, Space, ActionIcon, Tooltip, Box } from '@mantine/core'; import { useState } from 'react'; import { PyFilter } from '../../js/models'; -import { errorNotify, okNotify } from '../../js/utils'; +import { errorNotify, isMediumScreen, okNotify } from '../../js/utils'; import { FaPause, FaPlay } from 'react-icons/fa'; import { FaFilter } from "react-icons/fa"; import { nfproxy } from '../NFProxy/utils'; @@ -9,42 +9,39 @@ import { FaPencilAlt } from 'react-icons/fa'; export default function PyFilterView({ filterInfo }:{ filterInfo:PyFilter }) { - const [deleteTooltipOpened, setDeleteTooltipOpened] = useState(false); const [statusTooltipOpened, setStatusTooltipOpened] = useState(false); + const isMedium = isMediumScreen() const changeRegexStatus = () => { - (filterInfo.active?nfproxy.pyfilterdisable:nfproxy.pyfilterenable)(filterInfo.filter_id).then(res => { + (filterInfo.active?nfproxy.pyfilterdisable:nfproxy.pyfilterenable)(filterInfo.name).then(res => { if(!res){ - okNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivated":"activated"} successfully!`,`Filter with id '${filterInfo.filter_id}' has been ${filterInfo.active?"deactivated":"activated"}!`) + okNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivated":"activated"} successfully!`,`Filter '${filterInfo.name}' has been ${filterInfo.active?"deactivated":"activated"}!`) }else{ errorNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivation":"activation"} failed!`,`Error: ${res}`) } }).catch( err => errorNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivation":"activation"} failed!`,`Error: ${err}`)) } - return - - - - {filterInfo.name} - - - - setStatusTooltipOpened(false)} onBlur={() => setStatusTooltipOpened(false)} - onMouseEnter={() => setStatusTooltipOpened(true)} onMouseLeave={() => setStatusTooltipOpened(false)} - >{filterInfo.active?:} - - - - {filterInfo.blocked_packets} - - {filterInfo.edited_packets} - - {filterInfo.active?"ACTIVE":"DISABLED"} - - - - + return + + + + {filterInfo.name} + + + {isMedium?<> + {filterInfo.blocked_packets} + + {filterInfo.edited_packets} + + :null} + + setStatusTooltipOpened(false)} onBlur={() => setStatusTooltipOpened(false)} + onMouseEnter={() => setStatusTooltipOpened(true)} onMouseLeave={() => setStatusTooltipOpened(false)} + >{filterInfo.active?:} + + + } diff --git a/frontend/src/index.css b/frontend/src/index.css index df910b8..192a8f7 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -96,6 +96,20 @@ body { opacity: 0.8; } +.firegex__regexview__pyfilter_text{ + padding: 6px; + padding-left: 15px; + padding-right: 15px; + background-color: var(--fourth_color); + border: 1px solid #444; + overflow-x: hidden; + border-radius: 8px; +} + +.firegex__regexview__pyfilter_text:hover{ + overflow-x: auto; +} + .firegex__porthijack__servicerow__row{ width: 95%; padding: 15px 0px; diff --git a/frontend/src/js/models.ts b/frontend/src/js/models.ts index 6f992f4..1c1a128 100644 --- a/frontend/src/js/models.ts +++ b/frontend/src/js/models.ts @@ -51,7 +51,6 @@ export type RegexAddForm = { } export type PyFilter = { - filter_id:number, name:string, blocked_packets:number, edited_packets:number, diff --git a/frontend/src/js/utils.tsx b/frontend/src/js/utils.tsx index 55b1493..39cbdc9 100644 --- a/frontend/src/js/utils.tsx +++ b/frontend/src/js/utils.tsx @@ -72,9 +72,14 @@ export async function genericapi(method:string,path:string,data:any = undefined, const errorDefault = res.statusText return res.json().then( res => reject(getErrorMessageFromServerResponse(res, errorDefault)) ).catch( _err => reject(errorDefault)) } - res.json().then( res => resolve(res) ).catch( err => reject(err)) - }) - .catch(err => { + res.text().then(t => { + try{ + resolve(JSON.parse(t)) + }catch(e){ + resolve(t) + } + }).catch( err => reject(err)) + }).catch(err => { reject(err) }) }); diff --git a/frontend/src/pages/NFProxy/ServiceDetails.tsx b/frontend/src/pages/NFProxy/ServiceDetails.tsx index 95117f8..855900e 100644 --- a/frontend/src/pages/NFProxy/ServiceDetails.tsx +++ b/frontend/src/pages/NFProxy/ServiceDetails.tsx @@ -162,22 +162,22 @@ export default function ServiceDetailsNFProxy() { + + {filterCode.data?<> <FaPython style={{ marginBottom: -3 }} size={30} /><Space w="xs" />Filter code : null} - + {(!filtersList.data || filtersList.data.length == 0)?<> + No filters found! Edit the proxy file Install the firegex client:<Space w="xs" /><Code mb={-4} >pip install fgex</Code> Then run the command:<Space w="xs" /><Code mb={-4} >fgex nfproxy</Code> - : - - {filtersList.data?.map( (filterInfo) => )} - + :<>{filtersList.data?.map( (filterInfo) => )} } { + if (files?.length??0 > 0) + setFile(files![0]) + } + }); + const [file, setFile] = useState(null); + useEffect(() => { + if (!srv) return + const service = services.data?.find(s => s.service_id === srv) + if (!service) return + if (file){ + console.log("Uploading code") + const notify_id = notifications.show( + { + title: "Uploading code", + message: `Uploading code for service ${service.name}`, + color: "blue", + icon: , + autoClose: false, + loading: true, + } + ) + file.text() + .then( code => nfproxy.setpyfilterscode(service?.service_id??"",code.toString())) + .then( res => { + if (!res){ + notifications.update({ + id: notify_id, + title: "Code uploaded", + message: `Successfully uploaded code for service ${service.name}`, + color: "green", + icon: , + autoClose: 5000, + loading: false, + }) + }else{ + notifications.update({ + id: notify_id, + title: "Code upload failed", + message: `Error: ${res}`, + color: "red", + icon: , + autoClose: 5000, + loading: false, + }) + } + }).catch( err => { + notifications.update({ + id: notify_id, + title: "Code upload failed", + message: `Error: ${err}`, + color: "red", + icon: , + autoClose: 5000, + loading: false, + }) + }).finally(()=>{setFile(null)}) + } + }, [file]) useEffect(()=> { if(services.isError) @@ -37,7 +102,7 @@ export default function NFProxy({ children }: { children: any }) { <ThemeIcon radius="md" size="md" variant='filled' color='lime' ><TbPlugConnected size={20} /></ThemeIcon><Space w="xs" />Netfilter Proxy {isMedium?:} - General stats: + {isMedium?"General stats:":null} Services: {services.isLoading?0:services.data?.length} @@ -50,8 +115,16 @@ export default function NFProxy({ children }: { children: any }) { {isMedium?null:} - {/* Will become the null a button to edit the source code? TODO */} - { srv?null + { srv? + + setTooltipAddOpened(false)} onBlur={() => setTooltipAddOpened(false)} + onMouseEnter={() => setTooltipAddOpened(true)} + onMouseLeave={() => setTooltipAddOpened(false)} onClick={fileDialog.open}> + + + : setOpen(true)} size="lg" radius="md" variant="filled" onFocus={() => setTooltipAddOpened(false)} onBlur={() => setTooltipAddOpened(false)} @@ -85,7 +158,9 @@ export default function NFProxy({ children }: { children: any }) { } {srv?children:null} - {!srv?:null} + {!srv? + :null + } } diff --git a/frontend/src/pages/NFRegex/index.tsx b/frontend/src/pages/NFRegex/index.tsx index 19b0905..83ec802 100644 --- a/frontend/src/pages/NFRegex/index.tsx +++ b/frontend/src/pages/NFRegex/index.tsx @@ -38,7 +38,7 @@ function NFRegex({ children }: { children: any }) { <ThemeIcon radius="md" size="md" variant='filled' color='grape' ><BsRegex size={20} /></ThemeIcon><Space w="xs" />Netfilter Regex {isMedium?:} - General stats: + {isMedium?"General stats:":null} Services: {services.isLoading?0:services.data?.length}