push: code changes

This commit is contained in:
Domingo Dirutigliano
2025-02-25 23:53:04 +01:00
parent 8652f40235
commit 6a11dd0d16
37 changed files with 1306 additions and 640 deletions

1
.gitignore vendored
View File

@@ -13,6 +13,7 @@
/frontend/coverage /frontend/coverage
/fgex-lib/firegex.egg-info /fgex-lib/firegex.egg-info
/fgex-lib/dist /fgex-lib/dist
/fgex-lib/build
/fgex-lib/fgex-pip/fgex.egg-info /fgex-lib/fgex-pip/fgex.egg-info
/fgex-lib/fgex-pip/dist /fgex-lib/fgex-pip/dist
/backend/db/ /backend/db/

View File

@@ -16,7 +16,7 @@ RUN bun run build
FROM --platform=$TARGETARCH registry.fedoraproject.org/fedora:latest FROM --platform=$TARGETARCH registry.fedoraproject.org/fedora:latest
RUN dnf -y update && dnf install -y python3.13-devel @development-tools gcc-c++ \ RUN dnf -y update && dnf install -y python3.13-devel @development-tools gcc-c++ \
libnetfilter_queue-devel libnfnetlink-devel libmnl-devel libcap-ng-utils nftables \ libnetfilter_queue-devel libnfnetlink-devel libmnl-devel libcap-ng-utils nftables \
vectorscan-devel libtins-devel python3-nftables libpcap-devel boost-devel uv vectorscan-devel libtins-devel python3-nftables libpcap-devel boost-devel uv redis
RUN mkdir -p /execute/modules RUN mkdir -p /execute/modules
WORKDIR /execute WORKDIR /execute

View File

@@ -9,12 +9,13 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt from jose import jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from utils.sqlite import SQLite from utils.sqlite import SQLite
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager, NORELOAD
from utils.loader import frontend_deploy, load_routers from utils.loader import frontend_deploy, load_routers
from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import socketio import socketio
from socketio.exceptions import ConnectionRefusedError
# DB init # DB init
db = SQLite('db/firegex.db') db = SQLite('db/firegex.db')
@@ -52,7 +53,6 @@ if DEBUG:
allow_headers=["*"], allow_headers=["*"],
) )
utils.socketio = socketio.AsyncServer( utils.socketio = socketio.AsyncServer(
async_mode="asgi", async_mode="asgi",
cors_allowed_origins=[], cors_allowed_origins=[],
@@ -69,9 +69,6 @@ def set_psw(psw: str):
hash_psw = crypto.hash(psw) hash_psw = crypto.hash(psw)
db.put("password",hash_psw) db.put("password",hash_psw)
@utils.socketio.on("update")
async def updater(): pass
def create_access_token(data: dict): def create_access_token(data: dict):
to_encode = data.copy() to_encode = data.copy()
encoded_jwt = jwt.encode(to_encode, JWT_SECRET(), algorithm=JWT_ALGORITHM) encoded_jwt = jwt.encode(to_encode, JWT_SECRET(), algorithm=JWT_ALGORITHM)
@@ -90,6 +87,28 @@ async def check_login(token: str = Depends(oauth2_scheme)):
return False return False
return logged_in return logged_in
@utils.socketio.on("connect")
async def sio_connect(sid, environ, auth):
if not auth or not await check_login(auth.get("token")):
raise ConnectionRefusedError("Unauthorized")
utils.sid_list.add(sid)
@utils.socketio.on("disconnect")
async def sio_disconnect(sid):
try:
utils.sid_list.remove(sid)
except KeyError:
pass
async def disconnect_all():
while True:
if len(utils.sid_list) == 0:
break
await utils.socketio.disconnect(utils.sid_list.pop())
@utils.socketio.on("update")
async def updater(): pass
async def is_loggined(auth: bool = Depends(check_login)): async def is_loggined(auth: bool = Depends(check_login)):
if not auth: if not auth:
raise HTTPException( raise HTTPException(
@@ -122,6 +141,7 @@ async def login_api(form: OAuth2PasswordRequestForm = Depends()):
return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"} return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"}
raise HTTPException(406,"Wrong password!") raise HTTPException(406,"Wrong password!")
@app.post('/api/set-password', response_model=ChangePasswordModel) @app.post('/api/set-password', response_model=ChangePasswordModel)
async def set_password(form: PasswordForm): async def set_password(form: PasswordForm):
"""Set the password of firegex""" """Set the password of firegex"""
@@ -143,6 +163,7 @@ async def change_password(form: PasswordChangeForm):
return {"status":"Cannot insert an empty password!"} return {"status":"Cannot insert an empty password!"}
if form.expire: if form.expire:
db.put("secret", secrets.token_hex(32)) db.put("secret", secrets.token_hex(32))
await disconnect_all()
set_psw(form.password) set_psw(form.password)
await refresh_frontend() await refresh_frontend()
@@ -200,7 +221,7 @@ if __name__ == '__main__':
"app:app", "app:app",
host="::" if DEBUG else None, host="::" if DEBUG else None,
port=FIREGEX_PORT, port=FIREGEX_PORT,
reload=DEBUG, reload=DEBUG and not NORELOAD,
access_log=True, access_log=True,
workers=1, # Firewall module can't be replicated in multiple workers workers=1, # Firewall module can't be replicated in multiple workers
# Later the firewall module will be moved to a separate process # Later the firewall module will be moved to a separate process

View File

@@ -17,6 +17,7 @@ enum class FilterAction{ DROP, ACCEPT, MANGLE, NOACTION };
enum class L4Proto { TCP, UDP, RAW }; enum class L4Proto { TCP, UDP, RAW };
typedef Tins::TCPIP::StreamIdentifier stream_id; typedef Tins::TCPIP::StreamIdentifier stream_id;
//TODO DUBBIO: I PACCHETTI INVIATI A PYTHON SONO GIA' FIXATI?
template<typename T> template<typename T>
class PktRequest { class PktRequest {
@@ -25,6 +26,9 @@ class PktRequest {
mnl_socket* nl = nullptr; mnl_socket* nl = nullptr;
uint16_t res_id; uint16_t res_id;
uint32_t packet_id; uint32_t packet_id;
size_t _original_size;
size_t _data_original_size;
bool need_tcp_fixing = false;
public: public:
bool is_ipv6; bool is_ipv6;
Tins::IP* ipv4 = nullptr; Tins::IP* ipv4 = nullptr;
@@ -39,17 +43,27 @@ class PktRequest {
size_t data_size; size_t data_size;
stream_id sid; stream_id sid;
int64_t* tcp_in_offset = nullptr;
int64_t* tcp_out_offset = nullptr;
T* ctx; T* ctx;
private: private:
inline void fetch_data_size(Tins::PDU* pdu){ static size_t inner_data_size(Tins::PDU* pdu){
if (pdu == nullptr){
return 0;
}
auto inner = pdu->inner_pdu(); auto inner = pdu->inner_pdu();
if (inner == nullptr){ if (inner == nullptr){
data_size = 0; return 0;
}else{
data_size = inner->size();
} }
return inner->size();
}
inline void fetch_data_size(Tins::PDU* pdu){
data_size = inner_data_size(pdu);
_data_original_size = data_size;
} }
L4Proto fill_l4_info(){ L4Proto fill_l4_info(){
@@ -86,22 +100,91 @@ class PktRequest {
} }
} }
bool need_tcp_fix(){
return (tcp_in_offset != nullptr && *tcp_in_offset != 0) || (tcp_out_offset != nullptr && *tcp_out_offset != 0);
}
Tins::PDU::serialization_type reserialize_raw_data(const uint8_t* data, const size_t& data_size){
if (is_ipv6){
Tins::IPv6 ipv6_new = Tins::IPv6(data, data_size);
if (tcp){
Tins::TCP* tcp_new = ipv6_new.find_pdu<Tins::TCP>();
}
return ipv6_new.serialize();
}else{
Tins::IP ipv4_new = Tins::IP(data, data_size);
if (tcp){
Tins::TCP* tcp_new = ipv4_new.find_pdu<Tins::TCP>();
}
return ipv4_new.serialize();
}
}
void _fix_ack_seq_tcp(Tins::TCP* this_tcp){
need_tcp_fixing = need_tcp_fix();
#ifdef DEBUG
if (need_tcp_fixing){
cerr << "[DEBUG] Fixing ack_seq with offsets " << *tcp_in_offset << " " << *tcp_out_offset << endl;
}
#endif
if(this_tcp == nullptr){
return;
}
if (is_input){
if (tcp_in_offset != nullptr){
this_tcp->seq(this_tcp->seq() + *tcp_in_offset);
}
if (tcp_out_offset != nullptr){
this_tcp->ack_seq(this_tcp->ack_seq() - *tcp_out_offset);
}
}else{
if (tcp_in_offset != nullptr){
this_tcp->ack_seq(this_tcp->ack_seq() - *tcp_in_offset);
}
if (tcp_out_offset != nullptr){
this_tcp->seq(this_tcp->seq() + *tcp_out_offset);
}
}
#ifdef DEBUG
if (need_tcp_fixing){
size_t new_size = inner_data_size(this_tcp);
cerr << "[DEBUG] FIXED PKT " << (is_input?"-> IN ":"<- OUT") << " [SEQ: " << this_tcp->seq() << "] \t[ACK: " << this_tcp->ack_seq() << "] \t[SIZE: " << new_size << "]" << endl;
}
#endif
}
public: public:
PktRequest(const char* payload, size_t plen, T* ctx, mnl_socket* nl, nfgenmsg *nfg, nfqnl_msg_packet_hdr *ph, bool is_input): PktRequest(const char* payload, size_t plen, T* ctx, mnl_socket* nl, nfgenmsg *nfg, nfqnl_msg_packet_hdr *ph, bool is_input):
ctx(ctx), nl(nl), res_id(nfg->res_id), ctx(ctx), nl(nl), res_id(nfg->res_id),
packet_id(ph->packet_id), is_input(is_input), packet_id(ph->packet_id), is_input(is_input),
packet(string(payload, plen)), packet(string(payload, plen)),
is_ipv6((payload[0] & 0xf0) == 0x60){ action(FilterAction::NOACTION),
is_ipv6((payload[0] & 0xf0) == 0x60)
{
if (is_ipv6){ if (is_ipv6){
ipv6 = new Tins::IPv6((uint8_t*)packet.c_str(), plen); ipv6 = new Tins::IPv6((uint8_t*)packet.c_str(), plen);
sid = stream_id::make_identifier(*ipv6); sid = stream_id::make_identifier(*ipv6);
_original_size = ipv6->size();
}else{ }else{
ipv4 = new Tins::IP((uint8_t*)packet.c_str(), plen); ipv4 = new Tins::IP((uint8_t*)packet.data(), plen);
sid = stream_id::make_identifier(*ipv4); sid = stream_id::make_identifier(*ipv4);
_original_size = ipv4->size();
} }
l4_proto = fill_l4_info(); l4_proto = fill_l4_info();
data = packet.data()+(plen-data_size); data = packet.data()+(plen-data_size);
#ifdef DEBUG
if (tcp){
cerr << "[DEBUG] NEW_PACKET " << (is_input?"-> IN ":"<- OUT") << " [SEQ: " << tcp->seq() << "] \t[ACK: " << tcp->ack_seq() << "] \t[SIZE: " << data_size << "]" << endl;
}
#endif
}
void fix_tcp_ack(){
if (tcp){
_fix_ack_seq_tcp(tcp);
}
} }
void drop(){ void drop(){
@@ -113,6 +196,14 @@ class PktRequest {
} }
} }
size_t data_original_size(){
return _data_original_size;
}
size_t original_size(){
return _original_size;
}
void accept(){ void accept(){
if (action == FilterAction::NOACTION){ if (action == FilterAction::NOACTION){
action = FilterAction::ACCEPT; action = FilterAction::ACCEPT;
@@ -131,7 +222,26 @@ class PktRequest {
} }
} }
void mangle_custom_pkt(const uint8_t* pkt, size_t pkt_size){ void reject(){
if (tcp){
//If the packet has data, we have to remove it
delete tcp->release_inner_pdu();
//For the first matched data or only for data packets, we set FIN bit
//This only for client packets, because this will trigger server to close the connection
//Packets will be filtered anyway also if client don't send packets
if (_data_original_size != 0 && is_input){
tcp->set_flag(Tins::TCP::FIN,1);
tcp->set_flag(Tins::TCP::ACK,1);
tcp->set_flag(Tins::TCP::SYN,0);
}
//Send the edited packet to the kernel
mangle();
}else{
drop();
}
}
void mangle_custom_pkt(uint8_t* pkt, const size_t& pkt_size){
if (action == FilterAction::NOACTION){ if (action == FilterAction::NOACTION){
action = FilterAction::MANGLE; action = FilterAction::MANGLE;
perfrom_action(pkt, pkt_size); perfrom_action(pkt, pkt_size);
@@ -149,26 +259,58 @@ class PktRequest {
delete ipv6; delete ipv6;
} }
inline Tins::PDU::serialization_type serialize(){
if (is_ipv6){
return ipv6->serialize();
}else{
return ipv4->serialize();
}
}
private: private:
void perfrom_action(const uint8_t* custom_data = nullptr, size_t custom_data_size = 0){ void perfrom_action(uint8_t* custom_data = nullptr, size_t custom_data_size = 0){
char buf[MNL_SOCKET_BUFFER_SIZE]; char buf[MNL_SOCKET_BUFFER_SIZE];
struct nlmsghdr *nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(res_id)); struct nlmsghdr *nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(res_id));
switch (action) switch (action)
{ {
case FilterAction::ACCEPT: case FilterAction::ACCEPT:
if (need_tcp_fixing){
Tins::PDU::serialization_type data = serialize();
nfq_nlmsg_verdict_put_pkt(nlh_verdict, data.data(), data.size());
}
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_ACCEPT ); nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_ACCEPT );
break; break;
case FilterAction::DROP: case FilterAction::DROP:
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_DROP ); nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_DROP );
break; break;
case FilterAction::MANGLE:{ case FilterAction::MANGLE:{
if (custom_data != nullptr){ //If not custom data, use the data in the packets
nfq_nlmsg_verdict_put_pkt(nlh_verdict, custom_data, custom_data_size); Tins::PDU::serialization_type data;
}else if (is_ipv6){ if (custom_data == nullptr){
nfq_nlmsg_verdict_put_pkt(nlh_verdict, ipv6->serialize().data(), ipv6->size()); data = serialize();
}else{ }else{
nfq_nlmsg_verdict_put_pkt(nlh_verdict, ipv4->serialize().data(), ipv4->size()); try{
data = reserialize_raw_data(custom_data, custom_data_size);
}catch(...){
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_DROP );
action = FilterAction::DROP;
break;
} }
}
#ifdef DEBUG
size_t new_size = _data_original_size+((int64_t)custom_data_size) - ((int64_t)_original_size);
cerr << "[DEBUG] MANGLEDPKT " << (is_input?"-> IN ":"<- OUT") << " [SIZE: " << new_size << "]" << endl;
#endif
if (tcp && custom_data_size != _original_size){
int64_t delta = ((int64_t)custom_data_size) - ((int64_t)_original_size);
if (is_input && tcp_in_offset != nullptr){
*tcp_in_offset += delta;
}else if (!is_input && tcp_out_offset != nullptr){
*tcp_out_offset += delta;
}
}
nfq_nlmsg_verdict_put_pkt(nlh_verdict, data.data(), data.size());
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_ACCEPT ); nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_ACCEPT );
break; break;
} }

View File

@@ -4,11 +4,11 @@
#include "pyproxy/settings.cpp" #include "pyproxy/settings.cpp"
#include "pyproxy/pyproxy.cpp" #include "pyproxy/pyproxy.cpp"
#include "classes/netfilter.cpp" #include "classes/netfilter.cpp"
#include <syncstream>
#include <iostream> #include <iostream>
#include <stdexcept> #include <stdexcept>
#include <cstdlib> #include <cstdlib>
#include <endian.h> #include <endian.h>
#include "utils.cpp"
using namespace std; using namespace std;
using namespace Firegex::PyProxy; using namespace Firegex::PyProxy;
@@ -33,13 +33,13 @@ def invalid_curl_agent(http):
The code is now edited adding an intestation and a end statement: The code is now edited adding an intestation and a end statement:
```python ```python
global __firegex_pyfilter_enabled, __firegex_proto <user_code>
__firegex_pyfilter_enabled = ["invalid_curl_agent", "func3"] # This list is dynamically generated by firegex backend __firegex_pyfilter_enabled = ["invalid_curl_agent", "func3"] # This list is dynamically generated by firegex backend
__firegex_proto = "http" __firegex_proto = "http"
import firegex.nfproxy.internals import firegex.nfproxy.internals
<user_code> firegex.nfproxy.internals.compile(globals(), locals()) # This function can save other global variables, to use by the packet handler and is used generally to check and optimize the code
firegex.nfproxy.internals.compile() # This function can save other global variables, to use by the packet handler and is used generally to check and optimize the code
```` ````
(First lines are the same to keep line of code consistent on exceptions messages)
This code will be executed only once, and is needed to build the global and local context to use This code will be executed only once, and is needed to build the global and local context to use
The globals and locals generated here are copied for each connection, and are used to handle the packets The globals and locals generated here are copied for each connection, and are used to handle the packets
@@ -82,60 +82,53 @@ firegex lib will give you all the needed possibilities to do this is many ways
Final note: is not raccomanded to use variables that starts with __firegex_ in your code, because they may break the nfproxy Final note: is not raccomanded to use variables that starts with __firegex_ in your code, because they may break the nfproxy
*/ */
ssize_t read_check(int __fd, void *__buf, size_t __nbytes){
ssize_t bytes = read(__fd, __buf, __nbytes);
if (bytes == 0){
cerr << "[fatal] [updater] read() returned EOF" << endl;
throw invalid_argument("read() returned EOF");
}
if (bytes < 0){
cerr << "[fatal] [updater] read() returned an error" << bytes << endl;
throw invalid_argument("read() returned an error");
}
return bytes;
}
void config_updater (){ void config_updater (){
while (true){ while (true){
PyThreadState* state = PyEval_SaveThread(); // Release GIL while doing IO operation
uint32_t code_size; uint32_t code_size;
read_check(STDIN_FILENO, &code_size, 4); memcpy(&code_size, control_socket.recv(4).c_str(), 4);
//Python will send number always in little endian code_size = be32toh(code_size);
code_size = le32toh(code_size); string code = control_socket.recv(code_size);
string code; #ifdef DEBUG
code.resize(code_size); cerr << "[DEBUG] [updater] Received code: " << code << endl;
read_check(STDIN_FILENO, code.data(), code_size); #endif
cerr << "[info] [updater] Updating configuration" << endl; cerr << "[info] [updater] Updating configuration" << endl;
PyEval_AcquireThread(state); //Restore GIL before executing python code
try{ try{
config.reset(new PyCodeConfig(code)); config.reset(new PyCodeConfig(code));
cerr << "[info] [updater] Config update done" << endl; cerr << "[info] [updater] Config update done" << endl;
osyncstream(cout) << "ACK OK" << endl; control_socket << "ACK OK" << endl;
}catch(const std::exception& e){ }catch(const std::exception& e){
cerr << "[error] [updater] Failed to build new configuration!" << endl; cerr << "[error] [updater] Failed to build new configuration!" << endl;
osyncstream(cout) << "ACK FAIL " << e.what() << endl; control_socket << "ACK FAIL " << e.what() << endl;
} }
} }
} }
int main(int argc, char *argv[]){
int main(int argc, char *argv[]) {
// Connect to the python backend using the unix socket
init_control_socket();
// Initialize the python interpreter
Py_Initialize(); Py_Initialize();
atexit(Py_Finalize); atexit(Py_Finalize);
init_handle_packet_code(); //Compile the static code used to handle packets init_handle_packet_code(); //Compile the static code used to handle packets
if (freopen(nullptr, "rb", stdin) == nullptr){ // We need to read from stdin binary data
cerr << "[fatal] [main] Failed to reopen stdin in binary mode" << endl;
return 1;
}
int n_of_threads = 1; int n_of_threads = 1;
char * n_threads_str = getenv("NTHREADS"); char * n_threads_str = getenv("NTHREADS");
if (n_threads_str != nullptr) n_of_threads = ::atoi(n_threads_str); if (n_threads_str != nullptr) n_of_threads = ::atoi(n_threads_str);
if(n_of_threads <= 0) n_of_threads = 1; if(n_of_threads <= 0) n_of_threads = 1;
config.reset(new PyCodeConfig()); config.reset(new PyCodeConfig());
MultiThreadQueue<PyProxyQueue> queue(n_of_threads); MultiThreadQueue<PyProxyQueue> queue(n_of_threads);
osyncstream(cout) << "QUEUE " << queue.queue_num() << endl; control_socket << "QUEUE " << queue.queue_num() << endl;
cerr << "[info] [main] Queue: " << queue.queue_num() << " threads assigned: " << n_of_threads << endl; cerr << "[info] [main] Queue: " << queue.queue_num() << " threads assigned: " << n_of_threads << endl;
thread qthr([&](){ thread qthr([&](){

View File

@@ -33,7 +33,8 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
public: public:
stream_ctx sctx; stream_ctx sctx;
StreamFollower follower; StreamFollower follower;
PyGILState_STATE gstate; PyThreadState * gtstate = nullptr;
PyInterpreterConfig py_thread_config = { PyInterpreterConfig py_thread_config = {
.use_main_obmalloc = 0, .use_main_obmalloc = 0,
.allow_fork = 0, .allow_fork = 0,
@@ -44,24 +45,23 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
.gil = PyInterpreterConfig_OWN_GIL, .gil = PyInterpreterConfig_OWN_GIL,
}; };
PyThreadState *tstate = NULL; PyThreadState *tstate = NULL;
PyStatus pystatus;
struct {
bool matching_has_been_called = false;
bool already_closed = false;
bool rejected = true;
NfQueue::PktRequest<PyProxyQueue>* pkt; NfQueue::PktRequest<PyProxyQueue>* pkt;
} match_ctx; tcp_ack_seq_ctx* current_tcp_ack = nullptr;
void before_loop() override { void before_loop() override {
// Create thred structure for python PyStatus pystatus;
gstate = PyGILState_Ensure();
// Create a new interpreter for the thread // Create a new interpreter for the thread
gtstate = PyThreadState_New(PyInterpreterState_Main());
PyEval_AcquireThread(gtstate);
pystatus = Py_NewInterpreterFromConfig(&tstate, &py_thread_config); pystatus = Py_NewInterpreterFromConfig(&tstate, &py_thread_config);
if (PyStatus_Exception(pystatus)) { if(tstate == nullptr){
Py_ExitStatusException(pystatus);
cerr << "[fatal] [main] Failed to create new interpreter" << endl; cerr << "[fatal] [main] Failed to create new interpreter" << endl;
exit(EXIT_FAILURE); throw invalid_argument("Failed to create new interpreter (null tstate)");
}
if (PyStatus_Exception(pystatus)) {
cerr << "[fatal] [main] Failed to create new interpreter" << endl;
Py_ExitStatusException(pystatus);
throw invalid_argument("Failed to create new interpreter (pystatus exc)");
} }
// Setting callbacks for the stream follower // Setting callbacks for the stream follower
follower.new_stream_callback(bind(on_new_stream, placeholders::_1, this)); follower.new_stream_callback(bind(on_new_stream, placeholders::_1, this));
@@ -69,21 +69,24 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
} }
inline void print_blocked_reason(const string& func_name){ inline void print_blocked_reason(const string& func_name){
osyncstream(cout) << "BLOCKED " << func_name << endl; control_socket << "BLOCKED " << func_name << endl;
} }
inline void print_mangle_reason(const string& func_name){ inline void print_mangle_reason(const string& func_name){
osyncstream(cout) << "MANGLED " << func_name << endl; control_socket << "MANGLED " << func_name << endl;
} }
inline void print_exception_reason(){ inline void print_exception_reason(){
osyncstream(cout) << "EXCEPTION" << endl; control_socket << "EXCEPTION" << endl;
} }
//If the stream has already been matched, drop all data, and try to close the connection //If the stream has already been matched, drop all data, and try to close the connection
static void keep_fin_packet(PyProxyQueue* proxy_info){ static void keep_fin_packet(PyProxyQueue* pyq){
proxy_info->match_ctx.matching_has_been_called = true; pyq->pkt->reject();// This is needed because the callback has to take the updated pkt pointer!
proxy_info->match_ctx.already_closed = true; }
static void keep_dropped(PyProxyQueue* pyq){
pyq->pkt->drop();// This is needed because the callback has to take the updated pkt pointer!
} }
void filter_action(NfQueue::PktRequest<PyProxyQueue>* pkt, Stream& stream){ void filter_action(NfQueue::PktRequest<PyProxyQueue>* pkt, Stream& stream){
@@ -92,36 +95,45 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
if (stream_search == sctx.streams_ctx.end()){ if (stream_search == sctx.streams_ctx.end()){
shared_ptr<PyCodeConfig> conf = config; shared_ptr<PyCodeConfig> conf = config;
//If config is not set, ignore the stream //If config is not set, ignore the stream
if (conf->glob == nullptr || conf->local == nullptr){ PyObject* compiled_code = conf->compiled_code();
if (compiled_code == nullptr){
stream.client_data_callback(nullptr); stream.client_data_callback(nullptr);
stream.server_data_callback(nullptr); stream.server_data_callback(nullptr);
return pkt->accept(); return pkt->accept();
} }
stream_match = new pyfilter_ctx(conf->glob, conf->local); stream_match = new pyfilter_ctx(compiled_code);
Py_DECREF(compiled_code);
sctx.streams_ctx.insert_or_assign(pkt->sid, stream_match); sctx.streams_ctx.insert_or_assign(pkt->sid, stream_match);
}else{ }else{
stream_match = stream_search->second; stream_match = stream_search->second;
} }
auto result = stream_match->handle_packet(pkt); auto result = stream_match->handle_packet(pkt);
switch(result.action){ switch(result.action){
case PyFilterResponse::ACCEPT: case PyFilterResponse::ACCEPT:
pkt->accept(); return pkt->accept();
case PyFilterResponse::DROP: case PyFilterResponse::DROP:
print_blocked_reason(*result.filter_match_by); print_blocked_reason(*result.filter_match_by);
sctx.clean_stream_by_id(pkt->sid); sctx.clean_stream_by_id(pkt->sid);
stream.client_data_callback(nullptr); stream.client_data_callback(bind(keep_dropped, this));
stream.server_data_callback(nullptr); stream.server_data_callback(bind(keep_dropped, this));
break; return pkt->drop();
case PyFilterResponse::REJECT: case PyFilterResponse::REJECT:
print_blocked_reason(*result.filter_match_by);
sctx.clean_stream_by_id(pkt->sid); sctx.clean_stream_by_id(pkt->sid);
stream.client_data_callback(bind(keep_fin_packet, this)); stream.client_data_callback(bind(keep_fin_packet, this));
stream.server_data_callback(bind(keep_fin_packet, this)); stream.server_data_callback(bind(keep_fin_packet, this));
pkt->ctx->match_ctx.rejected = true; //Handler will take care of the rest return pkt->reject();
break;
case PyFilterResponse::MANGLE: case PyFilterResponse::MANGLE:
pkt->mangle_custom_pkt((uint8_t*)result.mangled_packet->data(), result.mangled_packet->size());
if (pkt->get_action() == NfQueue::FilterAction::DROP){
cerr << "[error] [filter_action] Failed to mangle: the packet sent is not serializzable... the packet was dropped" << endl;
print_blocked_reason(*result.filter_match_by);
print_exception_reason();
}else{
print_mangle_reason(*result.filter_match_by); print_mangle_reason(*result.filter_match_by);
pkt->mangle_custom_pkt((uint8_t*)result.mangled_packet->c_str(), result.mangled_packet->size()); }
break; return;
case PyFilterResponse::EXCEPTION: case PyFilterResponse::EXCEPTION:
case PyFilterResponse::INVALID: case PyFilterResponse::INVALID:
print_exception_reason(); print_exception_reason();
@@ -129,16 +141,15 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
//Free the packet data //Free the packet data
stream.client_data_callback(nullptr); stream.client_data_callback(nullptr);
stream.server_data_callback(nullptr); stream.server_data_callback(nullptr);
pkt->accept(); return pkt->accept();
break;
} }
} }
static void on_data_recv(Stream& stream, PyProxyQueue* proxy_info, string data) { static void on_data_recv(Stream& stream, PyProxyQueue* proxy_info, string data) {
proxy_info->match_ctx.matching_has_been_called = true; proxy_info->pkt->data = data.data();
proxy_info->match_ctx.already_closed = false; proxy_info->pkt->data_size = data.size();
proxy_info->filter_action(proxy_info->match_ctx.pkt, stream); proxy_info->filter_action(proxy_info->pkt, stream);
} }
//Input data filtering //Input data filtering
@@ -152,77 +163,77 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
} }
// A stream was terminated. The second argument is the reason why it was terminated // A stream was terminated. The second argument is the reason why it was terminated
static void on_stream_close(Stream& stream, PyProxyQueue* proxy_info) { static void on_stream_close(Stream& stream, PyProxyQueue* pyq) {
stream_id stream_id = stream_id::make_identifier(stream); stream_id stream_id = stream_id::make_identifier(stream);
proxy_info->sctx.clean_stream_by_id(stream_id); pyq->sctx.clean_stream_by_id(stream_id);
pyq->sctx.clean_tcp_ack_by_id(stream_id);
} }
static void on_new_stream(Stream& stream, PyProxyQueue* proxy_info) { static void on_new_stream(Stream& stream, PyProxyQueue* pyq) {
stream.auto_cleanup_payloads(true); stream.auto_cleanup_payloads(true);
if (stream.is_partial_stream()) { if (stream.is_partial_stream()) {
stream.enable_recovery_mode(10 * 1024); stream.enable_recovery_mode(10 * 1024);
} }
stream.client_data_callback(bind(on_client_data, placeholders::_1, proxy_info));
stream.server_data_callback(bind(on_server_data, placeholders::_1, proxy_info)); if (pyq->current_tcp_ack != nullptr){
stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, proxy_info)); pyq->current_tcp_ack->reset();
}else{
pyq->current_tcp_ack = new tcp_ack_seq_ctx();
pyq->sctx.tcp_ack_ctx.insert_or_assign(pyq->pkt->sid, pyq->current_tcp_ack);
pyq->pkt->tcp_in_offset = &pyq->current_tcp_ack->in_tcp_offset;
pyq->pkt->tcp_out_offset = &pyq->current_tcp_ack->out_tcp_offset;
} }
//Should not happen, but with this we can be sure about this
auto tcp_ack_search = pyq->sctx.tcp_ack_ctx.find(pyq->pkt->sid);
if (tcp_ack_search != pyq->sctx.tcp_ack_ctx.end()){
tcp_ack_search->second->reset();
}
stream.client_data_callback(bind(on_client_data, placeholders::_1, pyq));
stream.server_data_callback(bind(on_server_data, placeholders::_1, pyq));
stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, pyq));
}
void handle_next_packet(NfQueue::PktRequest<PyProxyQueue>* _pkt) override{
pkt = _pkt; // Setting packet context
void handle_next_packet(NfQueue::PktRequest<PyProxyQueue>* pkt) override{
if (pkt->l4_proto != NfQueue::L4Proto::TCP){ if (pkt->l4_proto != NfQueue::L4Proto::TCP){
throw invalid_argument("Only TCP and UDP are supported"); throw invalid_argument("Only TCP and UDP are supported");
} }
Tins::PDU* application_layer = pkt->tcp->inner_pdu();
u_int16_t payload_size = 0; auto tcp_ack_search = sctx.tcp_ack_ctx.find(pkt->sid);
if (application_layer != nullptr){ if (tcp_ack_search != sctx.tcp_ack_ctx.end()){
payload_size = application_layer->size(); current_tcp_ack = tcp_ack_search->second;
pkt->tcp_in_offset = &current_tcp_ack->in_tcp_offset;
pkt->tcp_out_offset = &current_tcp_ack->out_tcp_offset;
}else{
current_tcp_ack = nullptr;
//If necessary will be created by libtis new_stream callback
} }
match_ctx.matching_has_been_called = false;
match_ctx.pkt = pkt;
if (pkt->is_ipv6){ if (pkt->is_ipv6){
pkt->fix_tcp_ack();
follower.process_packet(*pkt->ipv6); follower.process_packet(*pkt->ipv6);
}else{ }else{
pkt->fix_tcp_ack();
follower.process_packet(*pkt->ipv4); follower.process_packet(*pkt->ipv4);
} }
// Do an action only is an ordered packet has been received
if (match_ctx.matching_has_been_called){
bool empty_payload = payload_size == 0;
//In this 2 cases we have to remove all data about the stream
if (!match_ctx.rejected || match_ctx.already_closed){
sctx.clean_stream_by_id(pkt->sid);
//If the packet has data, we have to remove it
if (!empty_payload){
Tins::PDU* data_layer = pkt->tcp->release_inner_pdu();
if (data_layer != nullptr){
delete data_layer;
}
}
//For the first matched data or only for data packets, we set FIN bit
//This only for client packets, because this will trigger server to close the connection
//Packets will be filtered anyway also if client don't send packets
if ((!match_ctx.rejected || !empty_payload) && pkt->is_input){
pkt->tcp->set_flag(Tins::TCP::FIN,1);
pkt->tcp->set_flag(Tins::TCP::ACK,1);
pkt->tcp->set_flag(Tins::TCP::SYN,0);
}
//Send the edited packet to the kernel
return pkt->mangle();
}else{
//Fallback to the default action //Fallback to the default action
if (pkt->get_action() == NfQueue::FilterAction::NOACTION){ if (pkt->get_action() == NfQueue::FilterAction::NOACTION){
return pkt->accept(); return pkt->accept();
} }
} }
}else{
return pkt->accept();
}
}
~PyProxyQueue() { ~PyProxyQueue() {
// Closing first the interpreter // Closing first the interpreter
Py_EndInterpreter(tstate); Py_EndInterpreter(tstate);
// Releasing the GIL and the thread data structure PyEval_ReleaseThread(tstate);
PyGILState_Release(gstate); PyThreadState_Clear(tstate);
PyThreadState_Delete(tstate);
sctx.clean(); sctx.clean();
} }

View File

@@ -2,58 +2,73 @@
#define PROXY_TUNNEL_SETTINGS_CPP #define PROXY_TUNNEL_SETTINGS_CPP
#include <Python.h> #include <Python.h>
#include <marshal.h>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <iostream> #include <iostream>
#include "../utils.cpp"
using namespace std; using namespace std;
namespace Firegex { namespace Firegex {
namespace PyProxy { namespace PyProxy {
class PyCodeConfig;
shared_ptr<PyCodeConfig> config;
PyObject* py_handle_packet_code = nullptr;
UnixClientConnection control_socket;
class PyCodeConfig{ class PyCodeConfig{
public: public:
PyObject* glob = nullptr; string encoded_code;
PyObject* local = nullptr;
private:
void _clean(){
Py_XDECREF(glob);
Py_XDECREF(local);
}
public:
PyCodeConfig(const string& pycode){ PyCodeConfig(const string& pycode){
PyObject* compiled_code = Py_CompileStringExFlags(pycode.c_str(), "<pyfilter>", Py_file_input, NULL, 2); PyObject* compiled_code = Py_CompileStringExFlags(pycode.c_str(), "<pyfilter>", Py_file_input, NULL, 2);
if (compiled_code == nullptr){ if (compiled_code == nullptr){
std::cerr << "[fatal] [main] Failed to compile the code" << endl; std::cerr << "[fatal] [main] Failed to compile the code" << endl;
_clean();
throw invalid_argument("Failed to compile the code"); throw invalid_argument("Failed to compile the code");
} }
glob = PyDict_New(); PyObject* glob = PyDict_New();
local = PyDict_New(); PyObject* result = PyEval_EvalCode(compiled_code, glob, glob);
PyObject* result = PyEval_EvalCode(compiled_code, glob, local); Py_DECREF(glob);
Py_XDECREF(compiled_code);
if (!result){ if (!result){
PyErr_Print(); PyErr_Print();
_clean(); Py_DECREF(compiled_code);
std::cerr << "[fatal] [main] Failed to execute the code" << endl; std::cerr << "[fatal] [main] Failed to execute the code" << endl;
throw invalid_argument("Failed to execute the code, maybe an invalid filter code has been provided"); throw invalid_argument("Failed to execute the code, maybe an invalid filter code has been provided");
} }
Py_DECREF(result); Py_DECREF(result);
PyObject* code_dump = PyMarshal_WriteObjectToString(compiled_code, 4);
Py_DECREF(compiled_code);
if (code_dump == nullptr){
PyErr_Print();
std::cerr << "[fatal] [main] Failed to dump the code" << endl;
throw invalid_argument("Failed to dump the code");
}
if (!PyBytes_Check(code_dump)){
std::cerr << "[fatal] [main] Failed to dump the code" << endl;
throw invalid_argument("Failed to dump the code");
}
encoded_code = string(PyBytes_AsString(code_dump), PyBytes_Size(code_dump));
Py_DECREF(code_dump);
} }
PyCodeConfig(){}
~PyCodeConfig(){ PyObject* compiled_code(){
_clean(); if (encoded_code.empty()) return nullptr;
return PyMarshal_ReadObjectFromString(encoded_code.c_str(), encoded_code.size());
} }
PyCodeConfig(){}
}; };
shared_ptr<PyCodeConfig> config; void init_control_socket(){
PyObject* py_handle_packet_code = nullptr; char * socket_path = getenv("FIREGEX_NFPROXY_SOCK");
if (socket_path == nullptr) throw invalid_argument("FIREGEX_NFPROXY_SOCK not set");
if (strlen(socket_path) >= 108) throw invalid_argument("FIREGEX_NFPROXY_SOCK too long");
control_socket = UnixClientConnection(socket_path);
}
void init_handle_packet_code(){ void init_handle_packet_code(){
py_handle_packet_code = Py_CompileStringExFlags( py_handle_packet_code = Py_CompileStringExFlags(

View File

@@ -27,10 +27,21 @@ enum PyFilterResponse {
INVALID = 5 INVALID = 5
}; };
const PyFilterResponse VALID_PYTHON_RESPONSE[4] = {
PyFilterResponse::ACCEPT,
PyFilterResponse::DROP,
PyFilterResponse::REJECT,
PyFilterResponse::MANGLE
};
struct py_filter_response { struct py_filter_response {
PyFilterResponse action; PyFilterResponse action;
string* filter_match_by = nullptr; string* filter_match_by = nullptr;
string* mangled_packet = nullptr; string* mangled_packet = nullptr;
py_filter_response(PyFilterResponse action, string* filter_match_by = nullptr, string* mangled_packet = nullptr):
action(action), filter_match_by(filter_match_by), mangled_packet(mangled_packet){}
~py_filter_response(){ ~py_filter_response(){
delete mangled_packet; delete mangled_packet;
delete filter_match_by; delete filter_match_by;
@@ -39,34 +50,35 @@ struct py_filter_response {
typedef Tins::TCPIP::StreamIdentifier stream_id; typedef Tins::TCPIP::StreamIdentifier stream_id;
struct tcp_ack_seq_ctx{
//Can be negative, so we use int64_t (for a uint64_t value)
int64_t in_tcp_offset = 0;
int64_t out_tcp_offset = 0;
tcp_ack_seq_ctx(){}
void reset(){
in_tcp_offset = 0;
out_tcp_offset = 0;
}
};
struct pyfilter_ctx { struct pyfilter_ctx {
PyObject * glob = nullptr; PyObject * glob = nullptr;
PyObject * local = nullptr;
pyfilter_ctx(PyObject * original_glob, PyObject * original_local){ pyfilter_ctx(PyObject * compiled_code){
PyObject *copy = PyImport_ImportModule("copy"); glob = PyDict_New();
if (copy == nullptr){ PyObject* result = PyEval_EvalCode(compiled_code, glob, glob);
if (!result){
PyErr_Print(); PyErr_Print();
throw invalid_argument("Failed to import copy module"); Py_XDECREF(glob);
std::cerr << "[fatal] [main] Failed to compile the code" << endl;
throw invalid_argument("Failed to execute the code, maybe an invalid filter code has been provided");
} }
PyObject *deepcopy = PyObject_GetAttrString(copy, "deepcopy"); Py_XDECREF(result);
glob = PyObject_CallFunctionObjArgs(deepcopy, original_glob, NULL);
if (glob == nullptr){
PyErr_Print();
throw invalid_argument("Failed to deepcopy the global dict");
}
local = PyObject_CallFunctionObjArgs(deepcopy, original_local, NULL);
if (local == nullptr){
PyErr_Print();
throw invalid_argument("Failed to deepcopy the local dict");
}
Py_DECREF(copy);
} }
~pyfilter_ctx(){ ~pyfilter_ctx(){
Py_XDECREF(glob); Py_DECREF(glob);
Py_XDECREF(local);
} }
inline void set_item_to_glob(const char* key, PyObject* value){ inline void set_item_to_glob(const char* key, PyObject* value){
@@ -84,15 +96,12 @@ struct pyfilter_ctx {
} }
} }
inline void set_item_to_local(const char* key, PyObject* value){
set_item_to_dict(local, key, value);
}
inline void set_item_to_dict(PyObject* dict, const char* key, PyObject* value){ inline void set_item_to_dict(PyObject* dict, const char* key, PyObject* value){
if (PyDict_SetItemString(dict, key, value) != 0){ if (PyDict_SetItemString(dict, key, value) != 0){
PyErr_Print(); PyErr_Print();
throw invalid_argument("Failed to set item to dict"); throw invalid_argument("Failed to set item to dict");
} }
Py_DECREF(value);
} }
py_filter_response handle_packet( py_filter_response handle_packet(
@@ -101,6 +110,7 @@ struct pyfilter_ctx {
PyObject * packet_info = PyDict_New(); PyObject * packet_info = PyDict_New();
set_item_to_dict(packet_info, "data", PyBytes_FromStringAndSize(pkt->data, pkt->data_size)); set_item_to_dict(packet_info, "data", PyBytes_FromStringAndSize(pkt->data, pkt->data_size));
set_item_to_dict(packet_info, "l4_size", PyLong_FromLong(pkt->data_original_size()));
set_item_to_dict(packet_info, "raw_packet", PyBytes_FromStringAndSize(pkt->packet.c_str(), pkt->packet.size())); set_item_to_dict(packet_info, "raw_packet", PyBytes_FromStringAndSize(pkt->packet.c_str(), pkt->packet.size()));
set_item_to_dict(packet_info, "is_input", PyBool_FromLong(pkt->is_input)); set_item_to_dict(packet_info, "is_input", PyBool_FromLong(pkt->is_input));
set_item_to_dict(packet_info, "is_ipv6", PyBool_FromLong(pkt->is_ipv6)); set_item_to_dict(packet_info, "is_ipv6", PyBool_FromLong(pkt->is_ipv6));
@@ -108,92 +118,156 @@ struct pyfilter_ctx {
// Set packet info to the global context // Set packet info to the global context
set_item_to_glob("__firegex_packet_info", packet_info); set_item_to_glob("__firegex_packet_info", packet_info);
PyObject * result = PyEval_EvalCode(py_handle_packet_code, glob, local); PyObject * result = PyEval_EvalCode(py_handle_packet_code, glob, glob);
del_item_from_glob("__firegex_packet_info"); del_item_from_glob("__firegex_packet_info");
Py_DECREF(packet_info);
Py_DECREF(packet_info);
if (!result){ if (!result){
PyErr_Print(); PyErr_Print();
return py_filter_response{PyFilterResponse::EXCEPTION, nullptr}; #ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Exception raised" << endl;
#endif
return py_filter_response(PyFilterResponse::EXCEPTION);
} }
Py_DECREF(result); Py_DECREF(result);
result = get_item_from_glob("__firegex_pyfilter_result"); result = get_item_from_glob("__firegex_pyfilter_result");
if (result == nullptr){ if (result == nullptr){
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; #ifdef DEBUG
cerr << "[DEBUG] [handle_packet] No result found" << endl;
#endif
return py_filter_response(PyFilterResponse::INVALID);
} }
if (!PyDict_Check(result)){ if (!PyDict_Check(result)){
PyErr_Print(); PyErr_Print();
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Result is not a dict" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; return py_filter_response(PyFilterResponse::INVALID);
} }
PyObject* action = PyDict_GetItemString(result, "action"); PyObject* action = PyDict_GetItemString(result, "action");
if (action == nullptr){ if (action == nullptr){
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] No result action found" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; return py_filter_response(PyFilterResponse::INVALID);
} }
if (!PyLong_Check(action)){ if (!PyLong_Check(action)){
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Action is not a long" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; return py_filter_response(PyFilterResponse::INVALID);
} }
PyFilterResponse action_enum = (PyFilterResponse)PyLong_AsLong(action); PyFilterResponse action_enum = (PyFilterResponse)PyLong_AsLong(action);
if (action_enum == PyFilterResponse::ACCEPT || action_enum == PyFilterResponse::EXCEPTION || action_enum == PyFilterResponse::INVALID){ //Check action_enum
bool valid = false;
for (auto valid_action: VALID_PYTHON_RESPONSE){
if (action_enum == valid_action){
valid = true;
break;
}
}
if (!valid){
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Invalid action" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{action_enum, nullptr, nullptr}; return py_filter_response(PyFilterResponse::INVALID);
}else{ }
if (action_enum == PyFilterResponse::ACCEPT){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response(action_enum);
}
PyObject *func_name_py = PyDict_GetItemString(result, "matched_by"); PyObject *func_name_py = PyDict_GetItemString(result, "matched_by");
if (func_name_py == nullptr){ if (func_name_py == nullptr){
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; #ifdef DEBUG
cerr << "[DEBUG] [handle_packet] No result matched_by found" << endl;
#endif
return py_filter_response(PyFilterResponse::INVALID);
} }
if (!PyUnicode_Check(func_name_py)){ if (!PyUnicode_Check(func_name_py)){
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; #ifdef DEBUG
cerr << "[DEBUG] [handle_packet] matched_by is not a string" << endl;
#endif
return py_filter_response(PyFilterResponse::INVALID);
} }
string* func_name = new string(PyUnicode_AsUTF8(func_name_py)); string* func_name = new string(PyUnicode_AsUTF8(func_name_py));
if (action_enum == PyFilterResponse::DROP || action_enum == PyFilterResponse::REJECT){ if (action_enum == PyFilterResponse::DROP || action_enum == PyFilterResponse::REJECT){
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{action_enum, func_name, nullptr}; return py_filter_response(action_enum, func_name);
} }
if (action_enum != PyFilterResponse::MANGLE){ if (action_enum == PyFilterResponse::MANGLE){
PyObject* mangled_packet = PyDict_GetItemString(result, "mangled_packet"); PyObject* mangled_packet = PyDict_GetItemString(result, "mangled_packet");
if (mangled_packet == nullptr){ if (mangled_packet == nullptr){
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; #ifdef DEBUG
cerr << "[DEBUG] [handle_packet] No result mangled_packet found" << endl;
#endif
return py_filter_response(PyFilterResponse::INVALID);
} }
if (!PyBytes_Check(mangled_packet)){ if (!PyBytes_Check(mangled_packet)){
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] mangled_packet is not a bytes" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; return py_filter_response(PyFilterResponse::INVALID);
} }
string* pkt_str = new string(PyBytes_AsString(mangled_packet), PyBytes_Size(mangled_packet)); string* pkt_str = new string(PyBytes_AsString(mangled_packet), PyBytes_Size(mangled_packet));
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::MANGLE, func_name, pkt_str}; return py_filter_response(PyFilterResponse::MANGLE, func_name, pkt_str);
}
} }
//Should never reach this point, but just in case of new action not managed...
del_item_from_glob("__firegex_pyfilter_result"); del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr}; return py_filter_response(PyFilterResponse::INVALID);
} }
}; };
typedef map<stream_id, pyfilter_ctx*> matching_map; typedef map<stream_id, pyfilter_ctx*> matching_map;
typedef map<stream_id, tcp_ack_seq_ctx*> tcp_ack_map;
struct stream_ctx { struct stream_ctx {
matching_map streams_ctx; matching_map streams_ctx;
tcp_ack_map tcp_ack_ctx;
void clean_stream_by_id(stream_id sid){ void clean_stream_by_id(stream_id sid){
auto stream_search = streams_ctx.find(sid); auto stream_search = streams_ctx.find(sid);
if (stream_search != streams_ctx.end()){ if (stream_search != streams_ctx.end()){
auto stream_match = stream_search->second; auto stream_match = stream_search->second;
delete stream_match; delete stream_match;
streams_ctx.erase(stream_search->first);
} }
} }
void clean_tcp_ack_by_id(stream_id sid){
auto tcp_ack_search = tcp_ack_ctx.find(sid);
if (tcp_ack_search != tcp_ack_ctx.end()){
auto tcp_ack = tcp_ack_search->second;
delete tcp_ack;
tcp_ack_ctx.erase(tcp_ack_search->first);
}
}
void clean(){ void clean(){
for (auto ele: streams_ctx){ for (auto ele: streams_ctx){
delete ele.second; delete ele.second;
} }
for (auto ele: tcp_ack_ctx){
delete ele.second;
}
tcp_ack_ctx.clear();
streams_ctx.clear();
} }
}; };

View File

@@ -37,13 +37,7 @@ public:
stream_ctx sctx; stream_ctx sctx;
u_int16_t latest_config_ver = 0; u_int16_t latest_config_ver = 0;
StreamFollower follower; StreamFollower follower;
struct {
bool matching_has_been_called = false;
bool already_closed = false;
bool result;
NfQueue::PktRequest<RegexNfQueue>* pkt; NfQueue::PktRequest<RegexNfQueue>* pkt;
} match_ctx;
bool filter_action(NfQueue::PktRequest<RegexNfQueue>* pkt){ bool filter_action(NfQueue::PktRequest<RegexNfQueue>* pkt){
shared_ptr<RegexRules> conf = regex_config; shared_ptr<RegexRules> conf = regex_config;
@@ -119,49 +113,23 @@ public:
return true; return true;
} }
void handle_next_packet(NfQueue::PktRequest<RegexNfQueue>* pkt) override{ void handle_next_packet(NfQueue::PktRequest<RegexNfQueue>* _pkt) override{
bool empty_payload = pkt->data_size == 0; pkt = _pkt; // Setting packet context
if (pkt->tcp){ if (pkt->tcp){
match_ctx.matching_has_been_called = false;
match_ctx.pkt = pkt;
if (pkt->ipv4){ if (pkt->ipv4){
follower.process_packet(*pkt->ipv4); follower.process_packet(*pkt->ipv4);
}else{ }else{
follower.process_packet(*pkt->ipv6); follower.process_packet(*pkt->ipv6);
} }
//Fallback to the default action
// Do an action only is an ordered packet has been received if (pkt->get_action() == NfQueue::FilterAction::NOACTION){
if (match_ctx.matching_has_been_called){
//In this 2 cases we have to remove all data about the stream
if (!match_ctx.result || match_ctx.already_closed){
sctx.clean_stream_by_id(pkt->sid);
//If the packet has data, we have to remove it
if (!empty_payload){
Tins::PDU* data_layer = pkt->tcp->release_inner_pdu();
if (data_layer != nullptr){
delete data_layer;
}
}
//For the first matched data or only for data packets, we set FIN bit
//This only for client packets, because this will trigger server to close the connection
//Packets will be filtered anyway also if client don't send packets
if ((!match_ctx.result || !empty_payload) && pkt->is_input){
pkt->tcp->set_flag(Tins::TCP::FIN,1);
pkt->tcp->set_flag(Tins::TCP::ACK,1);
pkt->tcp->set_flag(Tins::TCP::SYN,0);
}
//Send the edited packet to the kernel
return pkt->mangle();
}
}
return pkt->accept(); return pkt->accept();
}
}else{ }else{
if (!pkt->udp){ if (!pkt->udp){
throw invalid_argument("Only TCP and UDP are supported"); throw invalid_argument("Only TCP and UDP are supported");
} }
if(empty_payload){ if(pkt->data_size == 0){
return pkt->accept(); return pkt->accept();
}else if (filter_action(pkt)){ }else if (filter_action(pkt)){
return pkt->accept(); return pkt->accept();
@@ -170,22 +138,21 @@ public:
} }
} }
} }
//If the stream has already been matched, drop all data, and try to close the connection //If the stream has already been matched, drop all data, and try to close the connection
static void keep_fin_packet(RegexNfQueue* nfq){ static void keep_fin_packet(RegexNfQueue* nfq){
nfq->match_ctx.matching_has_been_called = true; nfq->pkt->reject();// This is needed because the callback has to take the updated pkt pointer!
nfq->match_ctx.already_closed = true;
} }
static void on_data_recv(Stream& stream, RegexNfQueue* nfq, string data) { static void on_data_recv(Stream& stream, RegexNfQueue* nfq, string data) {
nfq->match_ctx.matching_has_been_called = true; nfq->pkt->data = data.data();
nfq->match_ctx.already_closed = false; nfq->pkt->data_size = data.size();
bool result = nfq->filter_action(nfq->match_ctx.pkt); if (!nfq->filter_action(nfq->pkt)){
if (!result){ nfq->sctx.clean_stream_by_id(nfq->pkt->sid);
nfq->sctx.clean_stream_by_id(nfq->match_ctx.pkt->sid);
stream.client_data_callback(bind(keep_fin_packet, nfq)); stream.client_data_callback(bind(keep_fin_packet, nfq));
stream.server_data_callback(bind(keep_fin_packet, nfq)); stream.server_data_callback(bind(keep_fin_packet, nfq));
nfq->pkt->reject();
} }
nfq->match_ctx.result = result;
} }
//Input data filtering //Input data filtering

View File

@@ -17,7 +17,6 @@ namespace Regex {
typedef Tins::TCPIP::StreamIdentifier stream_id; typedef Tins::TCPIP::StreamIdentifier stream_id;
typedef map<stream_id, hs_stream_t*> matching_map; typedef map<stream_id, hs_stream_t*> matching_map;
#ifdef DEBUG
ostream& operator<<(ostream& os, const Tins::TCPIP::StreamIdentifier::address_type &sid){ ostream& operator<<(ostream& os, const Tins::TCPIP::StreamIdentifier::address_type &sid){
bool first_print = false; bool first_print = false;
for (auto ele: sid){ for (auto ele: sid){
@@ -33,7 +32,6 @@ ostream& operator<<(ostream& os, const stream_id &sid){
os << sid.max_address << ":" << sid.max_address_port << " -> " << sid.min_address << ":" << sid.min_address_port; os << sid.max_address << ":" << sid.max_address_port << " -> " << sid.min_address << ":" << sid.min_address_port;
return os; return os;
} }
#endif
struct stream_ctx { struct stream_ctx {
matching_map in_hs_streams; matching_map in_hs_streams;

View File

@@ -1,10 +1,17 @@
#ifndef UTILS_CPP
#define UTILS_CPP
#include <string> #include <string>
#include <unistd.h> #include <unistd.h>
#include <queue> #include <queue>
#include <condition_variable> #include <condition_variable>
#include <sys/socket.h>
#ifndef UTILS_CPP #include <sys/un.h>
#define UTILS_CPP #include <stdexcept>
#include <cstring>
#include <iostream>
#include <cerrno>
#include <sstream>
bool unhexlify(std::string const &hex, std::string &newString) { bool unhexlify(std::string const &hex, std::string &newString) {
try{ try{
@@ -22,6 +29,113 @@ bool unhexlify(std::string const &hex, std::string &newString) {
} }
} }
class UnixClientConnection {
public:
int sockfd = -1;
struct sockaddr_un addr;
private:
// Internal buffer to accumulate the output until flush
std::ostringstream streamBuffer;
public:
UnixClientConnection(){};
UnixClientConnection(const char* path) {
sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sockfd == -1) {
throw std::runtime_error(std::string("socket error: ") + std::strerror(errno));
}
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
if (connect(sockfd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != 0) {
throw std::runtime_error(std::string("connect error: ") + std::strerror(errno));
}
}
// Delete copy constructor and assignment operator to avoid resource duplication
UnixClientConnection(const UnixClientConnection&) = delete;
UnixClientConnection& operator=(const UnixClientConnection&) = delete;
// Move constructor
UnixClientConnection(UnixClientConnection&& other) noexcept
: sockfd(other.sockfd), addr(other.addr) {
other.sockfd = -1;
}
// Move assignment operator
UnixClientConnection& operator=(UnixClientConnection&& other) noexcept {
if (this != &other) {
if (sockfd != -1) {
close(sockfd);
}
sockfd = other.sockfd;
addr = other.addr;
other.sockfd = -1;
}
return *this;
}
void send(const std::string& data) {
if (::write(sockfd, data.c_str(), data.size()) == -1) {
throw std::runtime_error(std::string("write error: ") + std::strerror(errno));
}
}
std::string recv(size_t size) {
std::string buffer(size, '\0');
ssize_t bytesRead = ::read(sockfd, &buffer[0], size);
if (bytesRead <= 0) {
throw std::runtime_error(std::string("read error: ") + std::strerror(errno));
}
buffer.resize(bytesRead); // resize to actual bytes read
return buffer;
}
// Template overload for generic types
template<typename T>
UnixClientConnection& operator<<(const T& data) {
streamBuffer << data;
return *this;
}
// Overload for manipulators (e.g., std::endl)
UnixClientConnection& operator<<(std::ostream& (*manip)(std::ostream&)) {
// Check if the manipulator is std::endl (or equivalent flush)
if (manip == static_cast<std::ostream& (*)(std::ostream&)>(std::endl)){
streamBuffer << '\n'; // Add a newline
std::string packet = streamBuffer.str();
streamBuffer.str(""); // Clear the buffer
// Send the accumulated data as one packet
send(packet);
}
if (static_cast<std::ostream& (*)(std::ostream&)>(std::flush)) {
std::string packet = streamBuffer.str();
streamBuffer.str(""); // Clear the buffer
// Send the accumulated data as one packet
send(packet);
} else {
// For other manipulators, simply pass them to the buffer
streamBuffer << manip;
}
return *this;
}
// Overload operator<< to allow printing connection info
friend std::ostream& operator<<(std::ostream& os, const UnixClientConnection& conn) {
os << "UnixClientConnection(sockfd=" << conn.sockfd
<< ", path=" << conn.addr.sun_path << ")";
return os;
}
~UnixClientConnection() {
if (sockfd != -1) {
close(sockfd);
}
}
};
#ifdef USE_PIPES_FOR_BLOKING_QUEUE #ifdef USE_PIPES_FOR_BLOKING_QUEUE
template<typename T> template<typename T>

View File

@@ -1,4 +1,4 @@
from modules.firewall.models import * from modules.firewall.models import FirewallSettings, Action, Rule, Protocol, Mode, Table
from utils import nftables_int_to_json, ip_family, NFTableManager, is_ip_parse from utils import nftables_int_to_json, ip_family, NFTableManager, is_ip_parse
import copy import copy
@@ -9,7 +9,8 @@ class FiregexTables(NFTableManager):
filter_table = "filter" filter_table = "filter"
mangle_table = "mangle" mangle_table = "mangle"
def init_comands(self, policy:str=Action.ACCEPT, opt: FirewallSettings|None = None): def init_comands(self, policy:str=Action.ACCEPT, opt:
FirewallSettings|None = None):
rules = [ rules = [
{"add":{"table":{"name":self.filter_table,"family":"ip"}}}, {"add":{"table":{"name":self.filter_table,"family":"ip"}}},
{"add":{"table":{"name":self.filter_table,"family":"ip6"}}}, {"add":{"table":{"name":self.filter_table,"family":"ip6"}}},
@@ -41,7 +42,8 @@ class FiregexTables(NFTableManager):
{"add":{"chain":{"family":"ip","table":self.mangle_table,"name":self.rules_chain_out}}}, {"add":{"chain":{"family":"ip","table":self.mangle_table,"name":self.rules_chain_out}}},
{"add":{"chain":{"family":"ip6","table":self.mangle_table,"name":self.rules_chain_out}}}, {"add":{"chain":{"family":"ip6","table":self.mangle_table,"name":self.rules_chain_out}}},
] ]
if opt is None: return rules if opt is None:
return rules
if opt.allow_loopback: if opt.allow_loopback:
rules.extend([ rules.extend([
@@ -194,13 +196,18 @@ class FiregexTables(NFTableManager):
def chain_to_firegex(self, chain:str, table:str): def chain_to_firegex(self, chain:str, table:str):
if table == self.filter_table: if table == self.filter_table:
match chain: match chain:
case "INPUT": return self.rules_chain_in case "INPUT":
case "OUTPUT": return self.rules_chain_out return self.rules_chain_in
case "FORWARD": return self.rules_chain_fwd case "OUTPUT":
return self.rules_chain_out
case "FORWARD":
return self.rules_chain_fwd
elif table == self.mangle_table: elif table == self.mangle_table:
match chain: match chain:
case "PREROUTING": return self.rules_chain_in case "PREROUTING":
case "POSTROUTING": return self.rules_chain_out return self.rules_chain_in
case "POSTROUTING":
return self.rules_chain_out
return None return None
def insert_firegex_chains(self): def insert_firegex_chains(self):
@@ -214,7 +221,8 @@ class FiregexTables(NFTableManager):
if r.get("family") == family and r.get("table") == table and r.get("chain") == chain and r.get("expr") == rule_to_add: if r.get("family") == family and r.get("table") == table and r.get("chain") == chain and r.get("expr") == rule_to_add:
found = True found = True
break break
if found: continue if found:
continue
yield { "add":{ "rule": { yield { "add":{ "rule": {
"family": family, "family": family,
"table": table, "table": table,
@@ -274,7 +282,7 @@ class FiregexTables(NFTableManager):
ip_filters.append({"match": { "op": "==", "left": { "meta": { "key": "oifname" } }, "right": srv.dst} }) ip_filters.append({"match": { "op": "==", "left": { "meta": { "key": "oifname" } }, "right": srv.dst} })
port_filters = [] port_filters = []
if not srv.proto in [Protocol.ANY, Protocol.BOTH]: if srv.proto not in [Protocol.ANY, Protocol.BOTH]:
if srv.port_src_from != 1 or srv.port_src_to != 65535: #Any Port if srv.port_src_from != 1 or srv.port_src_to != 65535: #Any Port
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '>=', 'right': int(srv.port_src_from)}}) port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '>=', 'right': int(srv.port_src_from)}})
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '<=', 'right': int(srv.port_src_to)}}) port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '<=', 'right': int(srv.port_src_to)}})

View File

@@ -1,11 +1,10 @@
from modules.nfproxy.nftables import FiregexTables from modules.nfproxy.nftables import FiregexTables
from utils import run_func
from modules.nfproxy.models import Service, PyFilter from modules.nfproxy.models import Service, PyFilter
import os import os
import asyncio import asyncio
from utils import DEBUG
import traceback import traceback
from fastapi import HTTPException from fastapi import HTTPException
import time
nft = FiregexTables() nft = FiregexTables()
@@ -13,29 +12,37 @@ class FiregexInterceptor:
def __init__(self): def __init__(self):
self.srv:Service self.srv:Service
self._stats_updater_cb:callable
self.filter_map_lock:asyncio.Lock self.filter_map_lock:asyncio.Lock
self.filter_map: dict[str, PyFilter] self.filter_map: dict[str, PyFilter]
self.pyfilters: set[PyFilter]
self.update_config_lock:asyncio.Lock self.update_config_lock:asyncio.Lock
self.process:asyncio.subprocess.Process self.process:asyncio.subprocess.Process
self.update_task: asyncio.Task self.update_task: asyncio.Task
self.server_task: asyncio.Task
self.sock_path: str
self.unix_sock: asyncio.Server
self.ack_arrived = False self.ack_arrived = False
self.ack_status = None self.ack_status = None
self.ack_fail_what = "Unknown" self.ack_fail_what = "Queue response timed-out"
self.ack_lock = asyncio.Lock() self.ack_lock = asyncio.Lock()
self.sock_reader:asyncio.StreamReader = None
async def _call_stats_updater_callback(self, filter: PyFilter): self.sock_writer:asyncio.StreamWriter = None
if self._stats_updater_cb: self.sock_conn_lock:asyncio.Lock
await run_func(self._stats_updater_cb(filter)) self.last_time_exception = 0
@classmethod @classmethod
async def start(cls, srv: Service, stats_updater_cb:callable): async def start(cls, srv: Service):
self = cls() self = cls()
self._stats_updater_cb = stats_updater_cb
self.srv = srv self.srv = srv
self.filter_map_lock = asyncio.Lock() self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock() self.update_config_lock = asyncio.Lock()
self.sock_conn_lock = asyncio.Lock()
if not self.sock_conn_lock.locked():
await self.sock_conn_lock.acquire()
self.sock_path = f"/tmp/firegex_nfproxy_{srv.id}.sock"
if os.path.exists(self.sock_path):
os.remove(self.sock_path)
self.unix_sock = await asyncio.start_unix_server(self._server_listener,path=self.sock_path)
self.server_task = asyncio.create_task(self.unix_sock.serve_forever())
queue_range = await self._start_binary() queue_range = await self._start_binary()
self.update_task = asyncio.create_task(self.update_stats()) self.update_task = asyncio.create_task(self.update_stats())
nft.add(self.srv, queue_range) nft.add(self.srv, queue_range)
@@ -46,19 +53,20 @@ class FiregexInterceptor:
async def _start_binary(self): async def _start_binary(self):
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cpproxy") proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cpproxy")
self.process = await asyncio.create_subprocess_exec( self.process = await asyncio.create_subprocess_exec(
proxy_binary_path, proxy_binary_path, stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
env={ env={
"NTHREADS": os.getenv("NTHREADS","1"), "NTHREADS": os.getenv("NTHREADS","1"),
"FIREGEX_NFQUEUE_FAIL_OPEN": "1" if self.srv.fail_open else "0", "FIREGEX_NFQUEUE_FAIL_OPEN": "1" if self.srv.fail_open else "0",
"FIREGEX_NFPROXY_SOCK": self.sock_path
}, },
) )
line_fut = self.process.stdout.readuntil()
try: try:
line_fut = await asyncio.wait_for(line_fut, timeout=3) async with asyncio.timeout(3):
await self.sock_conn_lock.acquire()
line_fut = await self.sock_reader.readuntil()
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.process.kill() self.process.kill()
raise Exception("Invalid binary output") raise Exception("Binary don't returned queue number until timeout")
line = line_fut.decode() line = line_fut.decode()
if line.startswith("QUEUE "): if line.startswith("QUEUE "):
params = line.split() params = line.split()
@@ -67,25 +75,45 @@ class FiregexInterceptor:
self.process.kill() self.process.kill()
raise Exception("Invalid binary output") raise Exception("Invalid binary output")
async def _server_listener(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
if self.sock_reader or self.sock_writer:
writer.write_eof() # Technically never reached
writer.close()
reader.feed_eof()
return
self.sock_reader = reader
self.sock_writer = writer
self.sock_conn_lock.release()
async def update_stats(self): async def update_stats(self):
try: try:
while True: while True:
line = (await self.process.stdout.readuntil()).decode() try:
if DEBUG: line = (await self.sock_reader.readuntil()).decode()
print(line) except Exception as e:
self.ack_arrived = False
self.ack_status = False
self.ack_fail_what = "Can't read from nfq client"
self.ack_lock.release()
await self.stop()
raise HTTPException(status_code=500, detail="Can't read from nfq client") from e
if line.startswith("BLOCKED "): if line.startswith("BLOCKED "):
filter_id = line.split()[1] filter_name = line.split()[1]
print("BLOCKED", filter_name)
async with self.filter_map_lock: async with self.filter_map_lock:
if filter_id in self.filter_map: print("LOCKED MAP LOCK")
self.filter_map[filter_id].blocked_packets+=1 if filter_name in self.filter_map:
await self.filter_map[filter_id].update() print("ADDING BLOCKED PACKET")
self.filter_map[filter_name].blocked_packets+=1
await self.filter_map[filter_name].update()
if line.startswith("MANGLED "): if line.startswith("MANGLED "):
filter_id = line.split()[1] filter_name = line.split()[1]
async with self.filter_map_lock: async with self.filter_map_lock:
if filter_id in self.filter_map: if filter_name in self.filter_map:
self.filter_map[filter_id].edited_packets+=1 self.filter_map[filter_name].edited_packets+=1
await self.filter_map[filter_id].update() await self.filter_map[filter_name].update()
if line.startswith("EXCEPTION"): if line.startswith("EXCEPTION"):
self.last_time_exception = time.time()
print("TODO EXCEPTION HANDLING") # TODO print("TODO EXCEPTION HANDLING") # TODO
if line.startswith("ACK "): if line.startswith("ACK "):
self.ack_arrived = True self.ack_arrived = True
@@ -101,22 +129,29 @@ class FiregexInterceptor:
traceback.print_exc() traceback.print_exc()
async def stop(self): async def stop(self):
self.server_task.cancel()
self.update_task.cancel() self.update_task.cancel()
self.unix_sock.close()
if os.path.exists(self.sock_path):
os.remove(self.sock_path)
if self.process and self.process.returncode is None: if self.process and self.process.returncode is None:
self.process.kill() self.process.kill()
async def _update_config(self, code): async def _update_config(self, code):
async with self.update_config_lock: async with self.update_config_lock:
self.process.stdin.write(len(code).to_bytes(4, byteorder='big')+code.encode()) if self.sock_writer:
await self.process.stdin.drain() self.sock_writer.write(len(code).to_bytes(4, byteorder='big')+code.encode())
await self.sock_writer.drain()
try: try:
async with asyncio.timeout(3): async with asyncio.timeout(3):
await self.ack_lock.acquire() await self.ack_lock.acquire()
except TimeoutError: except TimeoutError:
pass self.ack_fail_what = "Queue response timed-out"
if not self.ack_arrived or not self.ack_status: if not self.ack_arrived or not self.ack_status:
await self.stop() await self.stop()
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}") raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")
else:
raise HTTPException(status_code=400, detail="Socket not ready")
async def reload(self, filters:list[PyFilter]): async def reload(self, filters:list[PyFilter]):
async with self.filter_map_lock: async with self.filter_map_lock:
@@ -125,12 +160,13 @@ class FiregexInterceptor:
filter_file = f.read() filter_file = f.read()
else: else:
filter_file = "" filter_file = ""
self.filter_map = {ele.name: ele for ele in filters}
await self._update_config( await self._update_config(
"global __firegex_pyfilter_enabled\n" +
filter_file + "\n\n" +
"__firegex_pyfilter_enabled = [" + ", ".join([repr(f.name) for f in filters]) + "]\n" + "__firegex_pyfilter_enabled = [" + ", ".join([repr(f.name) for f in filters]) + "]\n" +
"__firegex_proto = " + repr(self.srv.proto) + "\n" + "__firegex_proto = " + repr(self.srv.proto) + "\n" +
"import firegex.nfproxy.internals\n\n" + "import firegex.nfproxy.internals\n" +
filter_file + "\n\n" + "firegex.nfproxy.internals.compile(globals())\n"
"firegex.nfproxy.internals.compile()"
) )

View File

@@ -15,18 +15,18 @@ class ServiceManager:
self.srv = srv self.srv = srv
self.db = db self.db = db
self.status = STATUS.STOP self.status = STATUS.STOP
self.filters: dict[int, FiregexFilter] = {} self.filters: dict[str, FiregexFilter] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.interceptor = None self.interceptor = None
async def _update_filters_from_db(self): async def _update_filters_from_db(self):
pyfilters = [ pyfilters = [
PyFilter.from_dict(ele) for ele in PyFilter.from_dict(ele, self.db) for ele in
self.db.query("SELECT * FROM pyfilter WHERE service_id = ? AND active=1;", self.srv.id) self.db.query("SELECT * FROM pyfilter WHERE service_id = ? AND active=1;", self.srv.id)
] ]
#Filter check #Filter check
old_filters = set(self.filters.keys()) old_filters = set(self.filters.keys())
new_filters = set([f.id for f in pyfilters]) new_filters = set([f.name for f in pyfilters])
#remove old filters #remove old filters
for f in old_filters: for f in old_filters:
if f not in new_filters: if f not in new_filters:
@@ -34,7 +34,7 @@ class ServiceManager:
#add new filters #add new filters
for f in new_filters: for f in new_filters:
if f not in old_filters: if f not in old_filters:
self.filters[f] = [ele for ele in pyfilters if ele.id == f][0] self.filters[f] = [ele for ele in pyfilters if ele.name == f][0]
if self.interceptor: if self.interceptor:
await self.interceptor.reload(self.filters.values()) await self.interceptor.reload(self.filters.values())
@@ -43,16 +43,11 @@ class ServiceManager:
async def next(self,to): async def next(self,to):
async with self.lock: async with self.lock:
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): if to == STATUS.STOP:
await self.stop() await self.stop()
self._set_status(to) if to == STATUS.ACTIVE:
# PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
await self.restart() await self.restart()
def _stats_updater(self,filter:PyFilter):
self.db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE filter_id = ?;", filter.blocked_packets, filter.edited_packets, filter.id)
def _set_status(self,status): def _set_status(self,status):
self.status = status self.status = status
self.__update_status_db(status) self.__update_status_db(status)
@@ -60,7 +55,7 @@ class ServiceManager:
async def start(self): async def start(self):
if not self.interceptor: if not self.interceptor:
nft.delete(self.srv) nft.delete(self.srv)
self.interceptor = await FiregexInterceptor.start(self.srv, self._stats_updater) self.interceptor = await FiregexInterceptor.start(self.srv)
await self._update_filters_from_db() await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE) self._set_status(STATUS.ACTIVE)
@@ -69,6 +64,7 @@ class ServiceManager:
if self.interceptor: if self.interceptor:
await self.interceptor.stop() await self.interceptor.stop()
self.interceptor = None self.interceptor = None
self._set_status(STATUS.STOP)
async def restart(self): async def restart(self):
await self.stop() await self.stop()

View File

@@ -15,13 +15,19 @@ class Service:
class PyFilter: class PyFilter:
def __init__(self, filter_id:int, name: str, blocked_packets: int, edited_packets: int, active: bool, **other): def __init__(self, name: str, blocked_packets: int, edited_packets: int, active: bool, db, **other):
self.id = filter_id
self.name = name self.name = name
self.blocked_packets = blocked_packets self.blocked_packets = blocked_packets
self.edited_packets = edited_packets self.edited_packets = edited_packets
self.active = active self.active = active
self.__db = db
async def update(self):
self.__db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE name = ?;", self.blocked_packets, self.edited_packets, self.name)
def __repr__(self):
return f"<PyFilter {self.name}>"
@classmethod @classmethod
def from_dict(cls, var: dict): def from_dict(cls, var: dict, db):
return cls(**var) return cls(**var, db=db)

View File

@@ -1,6 +1,14 @@
from modules.nfproxy.models import Service from modules.nfproxy.models import Service
from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json
def convert_protocol_to_l4(proto:str):
if proto == "tcp":
return "tcp"
elif proto == "http":
return "tcp"
else:
raise Exception("Invalid protocol")
class FiregexFilter: class FiregexFilter:
def __init__(self, proto:str, port:int, ip_int:str, target:str, id:int): def __init__(self, proto:str, port:int, ip_int:str, target:str, id:int):
self.id = id self.id = id
@@ -11,7 +19,7 @@ class FiregexFilter:
def __eq__(self, o: object) -> bool: def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter) or isinstance(o, Service): if isinstance(o, FiregexFilter) or isinstance(o, Service):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) return self.port == o.port and self.proto == convert_protocol_to_l4(o.proto) and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return False return False
class FiregexTables(NFTableManager): class FiregexTables(NFTableManager):
@@ -61,7 +69,7 @@ class FiregexTables(NFTableManager):
"chain": self.output_chain, "chain": self.output_chain,
"expr": [ "expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}}, {'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}}, {'match': {"left": { "payload": {"protocol": convert_protocol_to_l4(str(srv.proto)), "field": "sport"}}, "op": "==", "right": int(srv.port)}},
{"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1338}}, {"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1338}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}} {"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
] ]
@@ -72,7 +80,7 @@ class FiregexTables(NFTableManager):
"chain": self.input_chain, "chain": self.input_chain,
"expr": [ "expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}}, {'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}}, {'match': {"left": { "payload": {"protocol": convert_protocol_to_l4(str(srv.proto)), "field": "dport"}}, "op": "==", "right": int(srv.port)}},
{"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1337}}, {"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1337}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}} {"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
] ]

View File

@@ -79,7 +79,7 @@ class FiregexInterceptor:
self.update_task: asyncio.Task self.update_task: asyncio.Task
self.ack_arrived = False self.ack_arrived = False
self.ack_status = None self.ack_status = None
self.ack_fail_what = "Unknown" self.ack_fail_what = "Queue response timed-out"
self.ack_lock = asyncio.Lock() self.ack_lock = asyncio.Lock()
@classmethod @classmethod
@@ -158,7 +158,7 @@ class FiregexInterceptor:
async with asyncio.timeout(3): async with asyncio.timeout(3):
await self.ack_lock.acquire() await self.ack_lock.acquire()
except TimeoutError: except TimeoutError:
pass self.ack_fail_what = "Queue response timed-out"
if not self.ack_arrived or not self.ack_status: if not self.ack_arrived or not self.ack_status:
await self.stop() await self.stop()
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}") raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")

View File

@@ -45,11 +45,9 @@ class ServiceManager:
async def next(self,to): async def next(self,to):
async with self.lock: async with self.lock:
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): if to == STATUS.STOP:
await self.stop() await self.stop()
self._set_status(to) if to == STATUS.ACTIVE:
# PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
await self.restart() await self.restart()
def _stats_updater(self,filter:RegexFilter): def _stats_updater(self,filter:RegexFilter):
@@ -71,6 +69,7 @@ class ServiceManager:
if self.interceptor: if self.interceptor:
await self.interceptor.stop() await self.interceptor.stop()
self.interceptor = None self.interceptor = None
self._set_status(STATUS.STOP)
async def restart(self): async def restart(self):
await self.stop() await self.stop()

View File

@@ -10,6 +10,10 @@ from utils.models import ResetRequest, StatusMessageModel
import os import os
from firegex.nfproxy.internals import get_filter_names from firegex.nfproxy.internals import get_filter_names
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from modules.nfproxy.nftables import convert_protocol_to_l4
import asyncio
import traceback
from utils import DEBUG
class ServiceModel(BaseModel): class ServiceModel(BaseModel):
service_id: str service_id: str
@@ -28,12 +32,10 @@ class RenameForm(BaseModel):
class SettingsForm(BaseModel): class SettingsForm(BaseModel):
port: PortType|None = None port: PortType|None = None
proto: str|None = None
ip_int: str|None = None ip_int: str|None = None
fail_open: bool|None = None fail_open: bool|None = None
class PyFilterModel(BaseModel): class PyFilterModel(BaseModel):
filter_id: int
name: str name: str
blocked_packets: int blocked_packets: int
edited_packets: int edited_packets: int
@@ -52,6 +54,7 @@ class ServiceAddResponse(BaseModel):
class SetPyFilterForm(BaseModel): class SetPyFilterForm(BaseModel):
code: str code: str
sid: str|None = None
app = APIRouter() app = APIRouter()
@@ -62,12 +65,12 @@ db = SQLite('db/nft-pyfilters.db', {
'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)', 'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)',
'name': 'VARCHAR(100) NOT NULL UNIQUE', 'name': 'VARCHAR(100) NOT NULL UNIQUE',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "http"))', 'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "http"))',
'l4_proto': 'VARCHAR(3) NOT NULL CHECK (l4_proto IN ("tcp", "udp"))',
'ip_int': 'VARCHAR(100) NOT NULL', 'ip_int': 'VARCHAR(100) NOT NULL',
'fail_open': 'BOOLEAN NOT NULL CHECK (fail_open IN (0, 1)) DEFAULT 1', 'fail_open': 'BOOLEAN NOT NULL CHECK (fail_open IN (0, 1)) DEFAULT 1',
}, },
'pyfilter': { 'pyfilter': {
'filter_id': 'INTEGER PRIMARY KEY', 'name': 'VARCHAR(100) PRIMARY KEY',
'name': 'VARCHAR(100) NOT NULL',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', 'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'edited_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', 'edited_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'service_id': 'VARCHAR(100) NOT NULL', 'service_id': 'VARCHAR(100) NOT NULL',
@@ -75,7 +78,7 @@ db = SQLite('db/nft-pyfilters.db', {
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)', 'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
}, },
'QUERY':[ 'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (port, ip_int, proto);", "CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (port, ip_int, l4_proto);",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_pyfilter_service ON pyfilter (name, service_id);" "CREATE UNIQUE INDEX IF NOT EXISTS unique_pyfilter_service ON pyfilter (name, service_id);"
] ]
}) })
@@ -132,7 +135,7 @@ async def get_service_list():
s.proto proto, s.proto proto,
s.ip_int ip_int, s.ip_int ip_int,
s.fail_open fail_open, s.fail_open fail_open,
COUNT(f.filter_id) n_filters, COUNT(f.name) n_filters,
COALESCE(SUM(f.blocked_packets),0) blocked_packets, COALESCE(SUM(f.blocked_packets),0) blocked_packets,
COALESCE(SUM(f.edited_packets),0) edited_packets COALESCE(SUM(f.edited_packets),0) edited_packets
FROM services s LEFT JOIN pyfilter f ON s.service_id = f.service_id FROM services s LEFT JOIN pyfilter f ON s.service_id = f.service_id
@@ -151,7 +154,7 @@ async def get_service_by_id(service_id: str):
s.proto proto, s.proto proto,
s.ip_int ip_int, s.ip_int ip_int,
s.fail_open fail_open, s.fail_open fail_open,
COUNT(f.filter_id) n_filters, COUNT(f.name) n_filters,
COALESCE(SUM(f.blocked_packets),0) blocked_packets, COALESCE(SUM(f.blocked_packets),0) blocked_packets,
COALESCE(SUM(f.edited_packets),0) edited_packets COALESCE(SUM(f.edited_packets),0) edited_packets
FROM services s LEFT JOIN pyfilter f ON s.service_id = f.service_id FROM services s LEFT JOIN pyfilter f ON s.service_id = f.service_id
@@ -203,9 +206,6 @@ async def service_rename(service_id: str, form: RenameForm):
async def service_settings(service_id: str, form: SettingsForm): async def service_settings(service_id: str, form: SettingsForm):
"""Request to change the settings of a specific service (will cause a restart)""" """Request to change the settings of a specific service (will cause a restart)"""
if form.proto is not None and form.proto not in ["tcp", "udp"]:
raise HTTPException(status_code=400, detail="Invalid protocol")
if form.port is not None and (form.port < 1 or form.port > 65535): if form.port is not None and (form.port < 1 or form.port > 65535):
raise HTTPException(status_code=400, detail="Invalid port") raise HTTPException(status_code=400, detail="Invalid port")
@@ -245,38 +245,38 @@ async def get_service_pyfilter_list(service_id: str):
raise HTTPException(status_code=400, detail="This service does not exists!") raise HTTPException(status_code=400, detail="This service does not exists!")
return db.query(""" return db.query("""
SELECT SELECT
filter_id, name, blocked_packets, edited_packets, active name, blocked_packets, edited_packets, active
FROM pyfilter WHERE service_id = ?; FROM pyfilter WHERE service_id = ?;
""", service_id) """, service_id)
@app.get('/pyfilters/{filter_id}', response_model=PyFilterModel) @app.get('/pyfilters/{filter_name}', response_model=PyFilterModel)
async def get_pyfilter_by_id(filter_id: int): async def get_pyfilter_by_id(filter_name: str):
"""Get pyfilter info using his id""" """Get pyfilter info using his id"""
res = db.query(""" res = db.query("""
SELECT SELECT
filter_id, name, blocked_packets, edited_packets, active name, blocked_packets, edited_packets, active
FROM pyfilter WHERE filter_id = ?; FROM pyfilter WHERE name = ?;
""", filter_id) """, filter_name)
if len(res) == 0: if len(res) == 0:
raise HTTPException(status_code=400, detail="This filter does not exists!") raise HTTPException(status_code=400, detail="This filter does not exists!")
return res[0] return res[0]
@app.post('/pyfilters/{filter_id}/enable', response_model=StatusMessageModel) @app.post('/pyfilters/{filter_name}/enable', response_model=StatusMessageModel)
async def pyfilter_enable(filter_id: int): async def pyfilter_enable(filter_name: str):
"""Request the enabling of a pyfilter""" """Request the enabling of a pyfilter"""
res = db.query('SELECT * FROM pyfilter WHERE filter_id = ?;', filter_id) res = db.query('SELECT * FROM pyfilter WHERE name = ?;', filter_name)
if len(res) != 0: if len(res) != 0:
db.query('UPDATE pyfilter SET active=1 WHERE filter_id = ?;', filter_id) db.query('UPDATE pyfilter SET active=1 WHERE name = ?;', filter_name)
await firewall.get(res[0]["service_id"]).update_filters() await firewall.get(res[0]["service_id"]).update_filters()
await refresh_frontend() await refresh_frontend()
return {'status': 'ok'} return {'status': 'ok'}
@app.post('/pyfilters/{filter_id}/disable', response_model=StatusMessageModel) @app.post('/pyfilters/{filter_name}/disable', response_model=StatusMessageModel)
async def pyfilter_disable(filter_id: int): async def pyfilter_disable(filter_name: str):
"""Request the deactivation of a pyfilter""" """Request the deactivation of a pyfilter"""
res = db.query('SELECT * FROM pyfilter WHERE filter_id = ?;', filter_id) res = db.query('SELECT * FROM pyfilter WHERE name = ?;', filter_name)
if len(res) != 0: if len(res) != 0:
db.query('UPDATE pyfilter SET active=0 WHERE filter_id = ?;', filter_id) db.query('UPDATE pyfilter SET active=0 WHERE name = ?;', filter_name)
await firewall.get(res[0]["service_id"]).update_filters() await firewall.get(res[0]["service_id"]).update_filters()
await refresh_frontend() await refresh_frontend()
return {'status': 'ok'} return {'status': 'ok'}
@@ -293,8 +293,8 @@ async def add_new_service(form: ServiceAddForm):
srv_id = None srv_id = None
try: try:
srv_id = gen_service_id() srv_id = gen_service_id()
db.query("INSERT INTO services (service_id ,name, port, status, proto, ip_int, fail_open) VALUES (?, ?, ?, ?, ?, ?, ?)", db.query("INSERT INTO services (service_id ,name, port, status, proto, ip_int, fail_open, l4_proto) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
srv_id, refactor_name(form.name), form.port, STATUS.STOP, form.proto, form.ip_int, form.fail_open) srv_id, refactor_name(form.name), form.port, STATUS.STOP, form.proto, form.ip_int, form.fail_open, convert_protocol_to_l4(form.proto))
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
raise HTTPException(status_code=400, detail="This type of service already exists") raise HTTPException(status_code=400, detail="This type of service already exists")
await firewall.reload() await firewall.reload()
@@ -308,17 +308,24 @@ async def set_pyfilters(service_id: str, form: SetPyFilterForm):
if len(service) == 0: if len(service) == 0:
raise HTTPException(status_code=400, detail="This service does not exists!") raise HTTPException(status_code=400, detail="This service does not exists!")
service = service[0] service = service[0]
service_id = service["service_id"]
srv_proto = service["proto"] srv_proto = service["proto"]
try:
async with asyncio.timeout(8):
try: try:
found_filters = get_filter_names(form.code, srv_proto) found_filters = get_filter_names(form.code, srv_proto)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) if DEBUG:
traceback.print_exc()
raise HTTPException(status_code=400, detail="Compile error: "+str(e))
# Remove filters that are not in the new code # Remove filters that are not in the new code
existing_filters = db.query("SELECT filter_id FROM pyfilter WHERE service_id = ?;", service_id) existing_filters = db.query("SELECT name FROM pyfilter WHERE service_id = ?;", service_id)
existing_filters = [ele["name"] for ele in existing_filters]
for filter in existing_filters: for filter in existing_filters:
if filter["name"] not in found_filters: if filter not in found_filters:
db.query("DELETE FROM pyfilter WHERE filter_id = ?;", filter["filter_id"]) db.query("DELETE FROM pyfilter WHERE name = ?;", filter)
# Add filters that are in the new code but not in the database # Add filters that are in the new code but not in the database
for filter in found_filters: for filter in found_filters:
@@ -331,6 +338,11 @@ async def set_pyfilters(service_id: str, form: SetPyFilterForm):
f.write(form.code) f.write(form.code)
await firewall.get(service_id).update_filters() await firewall.get(service_id).update_filters()
await refresh_frontend() await refresh_frontend()
except asyncio.TimeoutError:
if DEBUG:
traceback.print_exc()
raise HTTPException(status_code=400, detail="The operation took too long")
return {'status': 'ok'} return {'status': 'ok'}
@app.get('/services/{service_id}/pyfilters/code', response_class=PlainTextResponse) @app.get('/services/{service_id}/pyfilters/code', response_class=PlainTextResponse)
@@ -343,7 +355,3 @@ async def get_pyfilters(service_id: str):
return f.read() return f.read()
except FileNotFoundError: except FileNotFoundError:
return "" return ""
#TODO check all the APIs and add
# 1. API to change the python filter file (DONE)
# 2. a socketio mechanism to lock the previous feature

View File

@@ -8,15 +8,22 @@ import nftables
from socketio import AsyncServer from socketio import AsyncServer
from fastapi import Path from fastapi import Path
from typing import Annotated from typing import Annotated
from functools import wraps
from pydantic import BaseModel, ValidationError
import traceback
from utils.models import StatusMessageModel
from typing import List
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1")) LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
socketio:AsyncServer = None socketio:AsyncServer = None
sid_list:set = set()
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
ROUTERS_DIR = os.path.join(ROOT_DIR,"routers") ROUTERS_DIR = os.path.join(ROOT_DIR,"routers")
ON_DOCKER = "DOCKER" in sys.argv ON_DOCKER = "DOCKER" in sys.argv
DEBUG = "DEBUG" in sys.argv DEBUG = "DEBUG" in sys.argv
NORELOAD = "NORELOAD" in sys.argv
FIREGEX_PORT = int(os.getenv("PORT","4444")) FIREGEX_PORT = int(os.getenv("PORT","4444"))
JWT_ALGORITHM: str = "HS256" JWT_ALGORITHM: str = "HS256"
API_VERSION = "{{VERSION_PLACEHOLDER}}" if "{" not in "{{VERSION_PLACEHOLDER}}" else "0.0.0" API_VERSION = "{{VERSION_PLACEHOLDER}}" if "{" not in "{{VERSION_PLACEHOLDER}}" else "0.0.0"
@@ -153,4 +160,50 @@ class NFTableManager(Singleton):
def raw_list(self): def raw_list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"] return self.cmd({"list": {"ruleset": None}})["nftables"]
def _json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json"):
res = obj.model_dump(mode=mode, exclude_unset=not unset)
if convert_keys:
for from_k, to_k in convert_keys.items():
if from_k in res:
res[to_k] = res.pop(from_k)
if exclude:
for ele in exclude:
if ele in res:
del res[ele]
return res
def json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json") -> dict:
if isinstance(obj, list):
return [_json_like(ele, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode) for ele in obj]
return _json_like(obj, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode)
def register_event(sio_server: AsyncServer, event_name: str, model: BaseModel, response_model: BaseModel|None = None):
def decorator(func):
@sio_server.on(event_name) # Automatically registers the event
@wraps(func)
async def wrapper(sid, data):
try:
# Parse and validate incoming data
parsed_data = model.model_validate(data)
except ValidationError:
return json_like(StatusMessageModel(status=f"Invalid {event_name} request"))
# Call the original function with the parsed data
result = await func(sid, parsed_data)
# If a response model is provided, validate the output
if response_model:
try:
parsed_result = response_model.model_validate(result)
except ValidationError:
traceback.print_exc()
return json_like(StatusMessageModel(status=f"SERVER ERROR: Invalid {event_name} response"))
else:
parsed_result = result
# Emit the validated result
if parsed_result:
if isinstance(parsed_result, BaseModel):
return json_like(parsed_result)
return parsed_result
return wrapper
return decorator

View File

@@ -7,6 +7,7 @@ from starlette.responses import StreamingResponse
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from utils import DEBUG, ON_DOCKER, ROUTERS_DIR, list_files, run_func from utils import DEBUG, ON_DOCKER, ROUTERS_DIR, list_files, run_func
from utils.models import ResetRequest from utils.models import ResetRequest
import asyncio
REACT_BUILD_DIR: str = "../frontend/build/" if not ON_DOCKER else "frontend/" REACT_BUILD_DIR: str = "../frontend/build/" if not ON_DOCKER else "frontend/"
REACT_HTML_PATH: str = os.path.join(REACT_BUILD_DIR,"index.html") REACT_HTML_PATH: str = os.path.join(REACT_BUILD_DIR,"index.html")
@@ -87,12 +88,9 @@ def load_routers(app):
if router.shutdown: if router.shutdown:
shutdowns.append(router.shutdown) shutdowns.append(router.shutdown)
async def reset(reset_option:ResetRequest): async def reset(reset_option:ResetRequest):
for func in resets: await asyncio.gather(*[run_func(func, reset_option) for func in resets])
await run_func(func, reset_option)
async def startup(): async def startup():
for func in startups: await asyncio.gather(*[run_func(func) for func in startups])
await run_func(func)
async def shutdown(): async def shutdown():
for func in shutdowns: await asyncio.gather(*[run_func(func) for func in shutdowns])
await run_func(func)
return reset, startup, shutdown return reset, startup, shutdown

Binary file not shown.

Before

Width:  |  Height:  |  Size: 116 KiB

After

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 322 KiB

After

Width:  |  Height:  |  Size: 327 KiB

View File

@@ -1,11 +1,23 @@
import functools import functools
from firegex.nfproxy.params import RawPacket
from enum import Enum
ACCEPT = 0 class Action(Enum):
DROP = 1 ACCEPT = 0
REJECT = 2 DROP = 1
MANGLE = 3 REJECT = 2
EXCEPTION = 4 MANGLE = 3
INVALID = 5
class FullStreamAction(Enum):
FLUSH = 0
ACCEPT = 1
REJECT = 2
DROP = 3
ACCEPT = Action.ACCEPT
DROP = Action.DROP
REJECT = Action.REJECT
MANGLE = Action.MANGLE
def pyfilter(func): def pyfilter(func):
""" """
@@ -27,12 +39,14 @@ def get_pyfilters():
"""Returns the list of functions marked with @pyfilter.""" """Returns the list of functions marked with @pyfilter."""
return list(pyfilter.registry) return list(pyfilter.registry)
def clear_pyfilter_registry():
"""Clears the pyfilter registry."""
if hasattr(pyfilter, "registry"):
pyfilter.registry.clear()
__all__ = [
"ACCEPT", "DROP", "REJECT", "MANGLE", "EXCEPTION", "INVALID",
"Action", "FullStreamAction",
"pyfilter",
"RawPacket"
]

View File

@@ -1,21 +1,7 @@
from inspect import signature from inspect import signature
from firegex.nfproxy.params import RawPacket, NotReadyToRun from firegex.nfproxy.params import RawPacket, NotReadyToRun
from firegex.nfproxy import ACCEPT, DROP, REJECT, MANGLE, EXCEPTION, INVALID from firegex.nfproxy import Action, FullStreamAction
from dataclasses import dataclass, field
RESULTS = [
ACCEPT,
DROP,
REJECT,
MANGLE,
EXCEPTION,
INVALID
]
FULL_STREAM_ACTIONS = [
"flush"
"accept",
"reject",
"drop"
]
type_annotations_associations = { type_annotations_associations = {
"tcp": { "tcp": {
@@ -26,136 +12,178 @@ type_annotations_associations = {
} }
} }
def _generate_filter_structure(filters: list[str], proto:str, glob:dict, local:dict): @dataclass
class FilterHandler:
func: callable
name: str
params: dict[type, callable]
proto: str
class internal_data:
filter_call_info: list[FilterHandler] = []
stream: list[RawPacket] = []
stream_size: int = 0
stream_max_size: int = 1*8e20
full_stream_action: str = "flush"
filter_glob: dict = {}
@dataclass
class PacketHandlerResult:
glob: dict = field(repr=False)
action: Action = Action.ACCEPT
matched_by: str = None
mangled_packet: bytes = None
def set_result(self) -> None:
self.glob["__firegex_pyfilter_result"] = {
"action": self.action.value,
"matched_by": self.matched_by,
"mangled_packet": self.mangled_packet
}
def reset_result(self) -> None:
self.glob["__firegex_pyfilter_result"] = None
def context_call(func, *args, **kargs):
internal_data.filter_glob["__firegex_tmp_args"] = args
internal_data.filter_glob["__firegex_tmp_kargs"] = kargs
internal_data.filter_glob["__firege_tmp_call"] = func
res = eval("__firege_tmp_call(*__firegex_tmp_args, **__firegex_tmp_kargs)", internal_data.filter_glob, internal_data.filter_glob)
del internal_data.filter_glob["__firegex_tmp_args"]
del internal_data.filter_glob["__firegex_tmp_kargs"]
del internal_data.filter_glob["__firege_tmp_call"]
return res
def generate_filter_structure(filters: list[str], proto:str, glob:dict) -> list[FilterHandler]:
if proto not in type_annotations_associations.keys(): if proto not in type_annotations_associations.keys():
raise Exception("Invalid protocol") raise Exception("Invalid protocol")
res = [] res = []
valid_annotation_type = type_annotations_associations[proto] valid_annotation_type = type_annotations_associations[proto]
def add_func_to_list(func): def add_func_to_list(func):
if not callable(func): if not callable(func):
raise Exception(f"{func} is not a function") raise Exception(f"{func} is not a function")
sig = signature(func) sig = signature(func)
params_function = [] params_function = {}
for k, v in sig.parameters.items(): for k, v in sig.parameters.items():
if v.annotation in valid_annotation_type.keys(): if v.annotation in valid_annotation_type.keys():
params_function.append((v.annotation, valid_annotation_type[v.annotation])) params_function[v.annotation] = valid_annotation_type[v.annotation]
else: else:
raise Exception(f"Invalid type annotation {v.annotation} for function {func.__name__}") raise Exception(f"Invalid type annotation {v.annotation} for function {func.__name__}")
res.append((func, params_function))
res.append(
FilterHandler(
func=func,
name=func.__name__,
params=params_function,
proto=proto
)
)
for filter in filters: for filter in filters:
if not isinstance(filter, str): if not isinstance(filter, str):
raise Exception("Invalid filter list: must be a list of strings") raise Exception("Invalid filter list: must be a list of strings")
if filter in glob.keys(): if filter in glob.keys():
add_func_to_list(glob[filter]) add_func_to_list(glob[filter])
elif filter in local.keys():
add_func_to_list(local[filter])
else: else:
raise Exception(f"Filter {filter} not found") raise Exception(f"Filter {filter} not found")
return res return res
def get_filters_info(code:str, proto:str): def get_filters_info(code:str, proto:str) -> list[FilterHandler]:
glob = {} glob = {}
local = {} exec(code, glob, glob)
exec(code, glob, local) exec("import firegex.nfproxy", glob, glob)
exec("import firegex.nfproxy", glob, local) filters = eval("firegex.nfproxy.get_pyfilters()", glob, glob)
filters = eval("firegex.nfproxy.get_pyfilters()", glob, local) try:
return _generate_filter_structure(filters, proto, glob, local) return generate_filter_structure(filters, proto, glob)
finally:
exec("firegex.nfproxy.clear_pyfilter_registry()", glob, glob)
def get_filter_names(code:str, proto:str):
return [ele[0].__name__ for ele in get_filters_info(code, proto)]
def compile(): def get_filter_names(code:str, proto:str) -> list[str]:
glob = globals() return [ele.name for ele in get_filters_info(code, proto)]
local = locals()
filters = glob["__firegex_pyfilter_enabled"]
proto = glob["__firegex_proto"]
glob["__firegex_func_list"] = _generate_filter_structure(filters, proto, glob, local)
glob["__firegex_stream"] = []
glob["__firegex_stream_size"] = 0
if "FGEX_STREAM_MAX_SIZE" in local and int(local["FGEX_STREAM_MAX_SIZE"]) > 0: def handle_packet() -> None:
glob["__firegex_stream_max_size"] = int(local["FGEX_STREAM_MAX_SIZE"]) cache_call = {} # Cache of the data handler calls
elif "FGEX_STREAM_MAX_SIZE" in glob and int(glob["FGEX_STREAM_MAX_SIZE"]) > 0:
glob["__firegex_stream_max_size"] = int(glob["FGEX_STREAM_MAX_SIZE"])
else:
glob["__firegex_stream_max_size"] = 1*8e20 # 1MB default value
if "FGEX_FULL_STREAM_ACTION" in local and local["FGEX_FULL_STREAM_ACTION"] in FULL_STREAM_ACTIONS: pkt_info = RawPacket.fetch_from_global(internal_data.filter_glob)
glob["__firegex_full_stream_action"] = local["FGEX_FULL_STREAM_ACTION"] cache_call[RawPacket] = pkt_info
else:
glob["__firegex_full_stream_action"] = "flush"
glob["__firegex_pyfilter_result"] = None final_result = Action.ACCEPT
data_size = len(pkt_info.data)
result = PacketHandlerResult(internal_data.filter_glob)
if internal_data.stream_size+data_size > internal_data.stream_max_size:
match internal_data.full_stream_action:
case FullStreamAction.FLUSH:
internal_data.stream = []
internal_data.stream_size = 0
case FullStreamAction.ACCEPT:
result.action = Action.ACCEPT
return result.set_result()
case FullStreamAction.REJECT:
result.action = Action.REJECT
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
case FullStreamAction.REJECT:
result.action = Action.DROP
result.matched_by = "@MAX_STREAM_SIZE_REACHED"
return result.set_result()
internal_data.stream.append(pkt_info)
internal_data.stream_size += data_size
def handle_packet():
glob = globals()
func_list = glob["__firegex_func_list"]
final_result = ACCEPT
cache_call = {}
cache_call[RawPacket] = RawPacket.fetch_from_global()
data_size = len(cache_call[RawPacket].data)
if glob["__firegex_stream_size"]+data_size > glob["__firegex_stream_max_size"]:
match glob["__firegex_full_stream_action"]:
case "flush":
glob["__firegex_stream"] = []
glob["__firegex_stream_size"] = 0
case "accept":
glob["__firegex_pyfilter_result"] = {
"action": ACCEPT,
"matched_by": None,
"mangled_packet": None
}
return
case "reject":
glob["__firegex_pyfilter_result"] = {
"action": REJECT,
"matched_by": "@MAX_STREAM_SIZE_REACHED",
"mangled_packet": None
}
return
case "drop":
glob["__firegex_pyfilter_result"] = {
"action": DROP,
"matched_by": "@MAX_STREAM_SIZE_REACHED",
"mangled_packet": None
}
return
glob["__firegex_stream"].append(cache_call[RawPacket])
glob["__firegex_stream_size"] += data_size
func_name = None func_name = None
mangled_packet = None mangled_packet = None
for filter in func_list: for filter in internal_data.filter_call_info:
final_params = [] final_params = []
for ele in filter[1]: for data_type, data_func in filter.params.items():
if ele[0] not in cache_call.keys(): if data_type not in cache_call.keys():
try: try:
cache_call[ele[0]] = ele[1]() cache_call[data_type] = data_func(internal_data.filter_glob)
except NotReadyToRun: except NotReadyToRun:
cache_call[ele[0]] = None cache_call[data_type] = None
if cache_call[ele[0]] is None: if cache_call[data_type] is None:
continue # Parsing raised NotReadyToRun, skip filter continue # Parsing raised NotReadyToRun, skip filter
final_params.append(cache_call[ele[0]]) final_params.append(cache_call[data_type])
res = filter[0](*final_params)
res = context_call(filter.func, *final_params)
if res is None: if res is None:
continue #ACCEPTED continue #ACCEPTED
if res == MANGLE: if not isinstance(res, Action):
if RawPacket not in cache_call.keys(): raise Exception(f"Invalid return type {type(res)} for function {filter.name}")
continue #Packet not modified if res == Action.MANGLE:
pkt:RawPacket = cache_call[RawPacket] mangled_packet = pkt_info.raw_packet
mangled_packet = pkt.raw_packet if res != Action.ACCEPT:
break func_name = filter.name
elif res != ACCEPT:
final_result = res final_result = res
func_name = filter[0].__name__
break break
glob["__firegex_pyfilter_result"] = {
"action": final_result,
"matched_by": func_name,
"mangled_packet": mangled_packet
}
result.action = final_result
result.matched_by = func_name
result.mangled_packet = mangled_packet
return result.set_result()
def compile(glob:dict) -> None:
internal_data.filter_glob = glob
filters = glob["__firegex_pyfilter_enabled"]
proto = glob["__firegex_proto"]
internal_data.filter_call_info = generate_filter_structure(filters, proto, glob)
if "FGEX_STREAM_MAX_SIZE" in glob and int(glob["FGEX_STREAM_MAX_SIZE"]) > 0:
internal_data.stream_max_size = int(glob["FGEX_STREAM_MAX_SIZE"])
else:
internal_data.stream_max_size = 1*8e20 # 1MB default value
if "FGEX_FULL_STREAM_ACTION" in glob and isinstance(glob["FGEX_FULL_STREAM_ACTION"], FullStreamAction):
internal_data.full_stream_action = glob["FGEX_FULL_STREAM_ACTION"]
else:
internal_data.full_stream_action = FullStreamAction.FLUSH
PacketHandlerResult(glob).reset_result()

View File

@@ -9,12 +9,15 @@ class RawPacket:
is_input: bool, is_input: bool,
is_ipv6: bool, is_ipv6: bool,
is_tcp: bool, is_tcp: bool,
l4_size: int,
): ):
self.__data = bytes(data) self.__data = bytes(data)
self.__raw_packet = bytes(raw_packet) self.__raw_packet = bytes(raw_packet)
self.__is_input = bool(is_input) self.__is_input = bool(is_input)
self.__is_ipv6 = bool(is_ipv6) self.__is_ipv6 = bool(is_ipv6)
self.__is_tcp = bool(is_tcp) self.__is_tcp = bool(is_tcp)
self.__l4_size = int(l4_size)
self.__raw_packet_header_size = len(self.__raw_packet)-self.__l4_size
@property @property
def is_input(self) -> bool: def is_input(self) -> bool:
@@ -33,19 +36,25 @@ class RawPacket:
return self.__data return self.__data
@property @property
def proto_header(self) -> bytes: def l4_size(self) -> int:
return self.__raw_packet[:self.proto_header_len] return self.__l4_size
@property @property
def proto_header_len(self) -> int: def raw_packet_header_len(self) -> int:
return len(self.__raw_packet) - len(self.__data) return self.__raw_packet_header_size
@data.setter @property
def data(self, v:bytes): def l4_data(self) -> bytes:
return self.__raw_packet[self.raw_packet_header_len:]
@l4_data.setter
def l4_data(self, v:bytes):
if not isinstance(v, bytes): if not isinstance(v, bytes):
raise Exception("Invalid data type, data MUST be of type bytes") raise Exception("Invalid data type, data MUST be of type bytes")
self.__raw_packet = self.proto_header + v #if len(v) != self.__l4_size:
self.__data = v # raise Exception("Invalid data size, must be equal to the original packet header size (due to a technical limitation)")
self.__raw_packet = self.__raw_packet[:self.raw_packet_header_len]+v
self.__l4_size = len(v)
@property @property
def raw_packet(self) -> bytes: def raw_packet(self) -> bytes:
@@ -55,17 +64,16 @@ class RawPacket:
def raw_packet(self, v:bytes): def raw_packet(self, v:bytes):
if not isinstance(v, bytes): if not isinstance(v, bytes):
raise Exception("Invalid data type, data MUST be of type bytes") raise Exception("Invalid data type, data MUST be of type bytes")
if len(v) < self.proto_header_len: #if len(v) != len(self.__raw_packet):
raise Exception("Invalid packet length") # raise Exception("Invalid data size, must be equal to the original packet size (due to a technical limitation)")
header_len = self.proto_header_len if len(v) < self.raw_packet_header_len:
self.__data = v[header_len:] raise Exception("Invalid data size, must be greater than the original packet header size")
self.__raw_packet = v self.__raw_packet = v
self.__l4_size = len(v)-self.raw_packet_header_len
@staticmethod @classmethod
def fetch_from_global(): def fetch_from_global(cls, glob):
glob = globals()
if "__firegex_packet_info" not in glob.keys(): if "__firegex_packet_info" not in glob.keys():
raise Exception("Packet info not found") raise Exception("Packet info not found")
return RawPacket(**glob["__firegex_packet_info"]) return cls(**glob["__firegex_packet_info"])

View File

@@ -15,8 +15,21 @@ import { useQueryClient } from '@tanstack/react-query';
import NFProxy from './pages/NFProxy'; import NFProxy from './pages/NFProxy';
import ServiceDetailsNFProxy from './pages/NFProxy/ServiceDetails'; import ServiceDetailsNFProxy from './pages/NFProxy/ServiceDetails';
export const socket = import.meta.env.DEV?
const socket = IS_DEV?io("ws://"+DEV_IP_BACKEND, {transports: ["websocket"], path:"/sock/socket.io" }):io({transports: ["websocket"], path:"/sock/socket.io"}); io("ws://"+DEV_IP_BACKEND, {
path:"/sock/socket.io",
transports: ['websocket'],
auth: {
token: localStorage.getItem("access_token")
}
}):
io({
path:"/sock/socket.io",
transports: ['websocket'],
auth: {
token: localStorage.getItem("access_token")
}
})
function App() { function App() {
@@ -25,33 +38,20 @@ function App() {
const [reqError, setReqError] = useState<undefined|string>() const [reqError, setReqError] = useState<undefined|string>()
const [error, setError] = useState<string|null>() const [error, setError] = useState<string|null>()
const [loadinBtn, setLoadingBtn] = useState(false); const [loadinBtn, setLoadingBtn] = useState(false);
const queryClient = useQueryClient()
const getStatus = () =>{ const getStatus = () =>{
getstatus().then( res =>{ getstatus().then( res =>{
setSystemStatus(res) setSystemStatus(res)
setReqError(undefined) setReqError(undefined)
setLoading(false)
}).catch(err=>{ }).catch(err=>{
setReqError(err.toString()) setReqError(err.toString())
setLoading(false)
setTimeout(getStatus, 500) setTimeout(getStatus, 500)
}) }).finally( ()=>setLoading(false) )
} }
useEffect(()=>{ useEffect(()=>{
getStatus() getStatus()
socket.on("update", (data) => {
queryClient.invalidateQueries({ queryKey: data })
})
socket.on("connect_error", (err) => {
errorNotify("Socket.Io connection failed! ",`Error message: [${err.message}]`)
getStatus()
});
return () => {
socket.off("update")
socket.off("connect_error")
}
},[]) },[])
const form = useForm({ const form = useForm({
@@ -145,6 +145,36 @@ function App() {
</Notification><Space h="md" /></>:null} </Notification><Space h="md" /></>:null}
</Box> </Box>
}else if (systemStatus.status === "run" && systemStatus.loggined){ }else if (systemStatus.status === "run" && systemStatus.loggined){
return <PageRouting getStatus={getStatus} />
}else{
return <Box className='center-flex-row' style={{padding:"100px"}}>
<Title order={1} style={{textAlign:"center"}}>Error launching Firegex! 🔥</Title>
<Space h="md" />
<Title order={4} style={{textAlign:"center"}}>Error communicating with backend</Title>
</Box>
}
}
const PageRouting = ({ getStatus }:{ getStatus:()=>void }) => {
const queryClient = useQueryClient()
useEffect(()=>{
getStatus()
socket.on("update", (data) => {
queryClient.invalidateQueries({ queryKey: data })
})
socket.on("connect_error", (err) => {
errorNotify("Socket.Io connection failed! ",`Error message: [${err.message}]`)
getStatus()
});
return () => {
socket.off("update")
socket.off("connect_error")
}
},[])
return <Routes> return <Routes>
<Route element={<MainLayout><Outlet /></MainLayout>}> <Route element={<MainLayout><Outlet /></MainLayout>}>
<Route path="nfregex" element={<NFRegex><Outlet /></NFRegex>} > <Route path="nfregex" element={<NFRegex><Outlet /></NFRegex>} >
@@ -157,14 +187,9 @@ function App() {
<Route path="porthijack" element={<PortHijack />} /> <Route path="porthijack" element={<PortHijack />} />
<Route path="*" element={<HomeRedirector />} /> <Route path="*" element={<HomeRedirector />} />
</Route> </Route>
</Routes> </Routes>
}else{
return <Box className='center-flex-row' style={{padding:"100px"}}>
<Title order={1} style={{textAlign:"center"}}>Error launching Firegex! 🔥</Title>
<Space h="md" />
<Title order={4} style={{textAlign:"center"}}>Error communicating with backend</Title>
</Box>
}
} }
export default App; export default App;

View File

@@ -26,7 +26,7 @@ function AddEditService({ opened, onClose, edit }:{ opened:boolean, onClose:()=>
validate:{ validate:{
name: (value) => edit? null : value !== "" ? null : "Service name is required", name: (value) => edit? null : value !== "" ? null : "Service name is required",
port: (value) => (value>0 && value<65536) ? null : "Invalid port", port: (value) => (value>0 && value<65536) ? null : "Invalid port",
proto: (value) => ["tcp","udp"].includes(value) ? null : "Invalid protocol", proto: (value) => ["tcp","http"].includes(value) ? null : "Invalid protocol",
ip_int: (value) => (value.match(regex_ipv6) || value.match(regex_ipv4)) ? null : "Invalid IP address", ip_int: (value) => (value.match(regex_ipv6) || value.match(regex_ipv4)) ? null : "Invalid IP address",
} }
}) })
@@ -50,7 +50,7 @@ function AddEditService({ opened, onClose, edit }:{ opened:boolean, onClose:()=>
const submitRequest = ({ name, port, autostart, proto, ip_int, fail_open }:ServiceAddForm) =>{ const submitRequest = ({ name, port, autostart, proto, ip_int, fail_open }:ServiceAddForm) =>{
setSubmitLoading(true) setSubmitLoading(true)
if (edit){ if (edit){
nfproxy.settings(edit.service_id, { port, proto, ip_int, fail_open }).then( res => { nfproxy.settings(edit.service_id, { port, ip_int, fail_open }).then( res => {
if (!res){ if (!res){
setSubmitLoading(false) setSubmitLoading(false)
close(); close();
@@ -111,13 +111,13 @@ function AddEditService({ opened, onClose, edit }:{ opened:boolean, onClose:()=>
/> />
</Box> </Box>
<Box className="flex-spacer"></Box> <Box className="flex-spacer"></Box>
<SegmentedControl {edit?null:<SegmentedControl
data={[ data={[
{ label: 'TCP', value: 'tcp' }, { label: 'TCP', value: 'tcp' },
{ label: 'UDP', value: 'udp' }, { label: 'HTTP', value: 'http' },
]} ]}
{...form.getInputProps('proto')} {...form.getInputProps('proto')}
/> />}
</Box> </Box>
<Group justify='flex-end' mt="md" mb="sm"> <Group justify='flex-end' mt="md" mb="sm">

View File

@@ -0,0 +1,54 @@
import { Button, FileButton, Group, Modal, Notification, Space } from "@mantine/core";
import { nfproxy, Service } from "./utils";
import { useEffect, useState } from "react";
import { ImCross } from "react-icons/im";
import { okNotify } from "../../js/utils";
export const UploadFilterModal = ({ opened, onClose, service }: { opened: boolean, onClose: () => void, service?: Service }) => {
const close = () =>{
onClose()
setError(null)
}
const [submitLoading, setSubmitLoading] = useState(false)
const [error, setError] = useState<string|null>(null)
const [file, setFile] = useState<File | null>(null);
useEffect(() => {
if (opened && file){
file.bytes().then( code => {
console.log(code.toString())
setSubmitLoading(true)
nfproxy.setpyfilterscode(service?.service_id??"",code.toString()).then( res => {
if (!res){
setSubmitLoading(false)
close();
okNotify(`Service ${name} code updated`, `Successfully updated code for service ${name}`)
}
}).catch( err => {
setSubmitLoading(false)
setError("Error: "+err)
})
})
}
}, [opened, file])
return <Modal opened={opened && service != null} onClose={onClose} title="Upload filter Code" size="xl" closeOnClickOutside={false} centered>
<Space h="md" />
<Group justify="center">
<FileButton onChange={setFile} accept=".py" multiple={false}>
{(props) => <Button {...props}>Upload filter python code</Button>}
</FileButton>
</Group>
{error?<>
<Space h="md" />
<Notification icon={<ImCross size={14} />} color="red" onClose={()=>{setError(null)}}>
Error: {error}
</Notification>
</>:null}
<Space h="md" />
</Modal>
}

View File

@@ -25,7 +25,6 @@ export type ServiceAddForm = {
export type ServiceSettings = { export type ServiceSettings = {
port?:number, port?:number,
proto?:string,
ip_int?:string, ip_int?:string,
fail_open?: boolean, fail_open?: boolean,
} }
@@ -55,12 +54,12 @@ export const nfproxy = {
serviceinfo: async (service_id:string) => { serviceinfo: async (service_id:string) => {
return await getapi(`nfproxy/services/${service_id}`) as Service; return await getapi(`nfproxy/services/${service_id}`) as Service;
}, },
pyfilterenable: async (regex_id:number) => { pyfilterenable: async (filter_name:string) => {
const { status } = await postapi(`nfproxy/pyfilters/${regex_id}/enable`) as ServerResponse; const { status } = await postapi(`nfproxy/pyfilters/${filter_name}/enable`) as ServerResponse;
return status === "ok"?undefined:status return status === "ok"?undefined:status
}, },
pyfilterdisable: async (regex_id:number) => { pyfilterdisable: async (filter_name:string) => {
const { status } = await postapi(`nfproxy/pyfilters/${regex_id}/disable`) as ServerResponse; const { status } = await postapi(`nfproxy/pyfilters/${filter_name}/disable`) as ServerResponse;
return status === "ok"?undefined:status return status === "ok"?undefined:status
}, },
servicestart: async (service_id:string) => { servicestart: async (service_id:string) => {

View File

@@ -1,7 +1,7 @@
import { Text, Badge, Space, ActionIcon, Tooltip, Box } from '@mantine/core'; import { Text, Badge, Space, ActionIcon, Tooltip, Box } from '@mantine/core';
import { useState } from 'react'; import { useState } from 'react';
import { PyFilter } from '../../js/models'; import { PyFilter } from '../../js/models';
import { errorNotify, okNotify } from '../../js/utils'; import { errorNotify, isMediumScreen, okNotify } from '../../js/utils';
import { FaPause, FaPlay } from 'react-icons/fa'; import { FaPause, FaPlay } from 'react-icons/fa';
import { FaFilter } from "react-icons/fa"; import { FaFilter } from "react-icons/fa";
import { nfproxy } from '../NFProxy/utils'; import { nfproxy } from '../NFProxy/utils';
@@ -9,42 +9,39 @@ import { FaPencilAlt } from 'react-icons/fa';
export default function PyFilterView({ filterInfo }:{ filterInfo:PyFilter }) { export default function PyFilterView({ filterInfo }:{ filterInfo:PyFilter }) {
const [deleteTooltipOpened, setDeleteTooltipOpened] = useState(false);
const [statusTooltipOpened, setStatusTooltipOpened] = useState(false); const [statusTooltipOpened, setStatusTooltipOpened] = useState(false);
const isMedium = isMediumScreen()
const changeRegexStatus = () => { const changeRegexStatus = () => {
(filterInfo.active?nfproxy.pyfilterdisable:nfproxy.pyfilterenable)(filterInfo.filter_id).then(res => { (filterInfo.active?nfproxy.pyfilterdisable:nfproxy.pyfilterenable)(filterInfo.name).then(res => {
if(!res){ if(!res){
okNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivated":"activated"} successfully!`,`Filter with id '${filterInfo.filter_id}' has been ${filterInfo.active?"deactivated":"activated"}!`) okNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivated":"activated"} successfully!`,`Filter '${filterInfo.name}' has been ${filterInfo.active?"deactivated":"activated"}!`)
}else{ }else{
errorNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivation":"activation"} failed!`,`Error: ${res}`) errorNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivation":"activation"} failed!`,`Error: ${res}`)
} }
}).catch( err => errorNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivation":"activation"} failed!`,`Error: ${err}`)) }).catch( err => errorNotify(`Filter ${filterInfo.name} ${filterInfo.active?"deactivation":"activation"} failed!`,`Error: ${err}`))
} }
return <Box className="firegex__regexview__box"> return <Box my="sm" display="flex" style={{alignItems:"center"}}>
<Box>
<Box className='center-flex' style={{width: "100%"}}> <Text className="firegex__regexview__pyfilter_text" style={{ width: "100%", alignItems: "center"}} display="flex" >
<Box className="firegex__regexview__outer_regex_text"> <Badge size="sm" radius="lg" mr="sm" color={filterInfo.active?"lime":"red"} variant="filled" />
<Text className="firegex__regexview__regex_text">{filterInfo.name}</Text> {filterInfo.name}
</Box> <Box className='flex-spacer' />
<Space w="xs" /> <Space w="xs" />
{isMedium?<>
<Badge size="md" radius="md" color="yellow" variant="filled"><FaFilter style={{ marginBottom: -2, marginRight: 2}} /> {filterInfo.blocked_packets}</Badge>
<Space w="xs" />
<Badge size="md" radius="md" color="orange" variant="filled"><FaPencilAlt style={{ marginBottom: -1, marginRight: 2}} /> {filterInfo.edited_packets}</Badge>
<Space w="lg" />
</>:null}
<Tooltip label={filterInfo.active?"Deactivate":"Activate"} zIndex={0} color={filterInfo.active?"orange":"teal"} opened={statusTooltipOpened}> <Tooltip label={filterInfo.active?"Deactivate":"Activate"} zIndex={0} color={filterInfo.active?"orange":"teal"} opened={statusTooltipOpened}>
<ActionIcon color={filterInfo.active?"orange":"teal"} onClick={changeRegexStatus} size="xl" radius="md" variant="filled" <ActionIcon color={filterInfo.active?"orange":"teal"} onClick={changeRegexStatus} size="lg" radius="md" variant="filled"
onFocus={() => setStatusTooltipOpened(false)} onBlur={() => setStatusTooltipOpened(false)} onFocus={() => setStatusTooltipOpened(false)} onBlur={() => setStatusTooltipOpened(false)}
onMouseEnter={() => setStatusTooltipOpened(true)} onMouseLeave={() => setStatusTooltipOpened(false)} onMouseEnter={() => setStatusTooltipOpened(true)} onMouseLeave={() => setStatusTooltipOpened(false)}
>{filterInfo.active?<FaPause size="20px" />:<FaPlay size="20px" />}</ActionIcon> >{filterInfo.active?<FaPause size="20px" />:<FaPlay size="20px" />}</ActionIcon>
</Tooltip> </Tooltip>
</Box> </Text>
<Box display="flex" mt="sm" ml="xs">
<Badge size="md" color="yellow" variant="filled"><FaFilter style={{ marginBottom: -2}} /> {filterInfo.blocked_packets}</Badge>
<Space w="xs" />
<Badge size="md" color="orange" variant="filled"><FaPencilAlt size={18} /> {filterInfo.edited_packets}</Badge>
<Space w="xs" />
<Badge size="md" color={filterInfo.active?"lime":"red"} variant="filled">{filterInfo.active?"ACTIVE":"DISABLED"}</Badge>
</Box>
</Box>
</Box> </Box>
} }

View File

@@ -96,6 +96,20 @@ body {
opacity: 0.8; opacity: 0.8;
} }
.firegex__regexview__pyfilter_text{
padding: 6px;
padding-left: 15px;
padding-right: 15px;
background-color: var(--fourth_color);
border: 1px solid #444;
overflow-x: hidden;
border-radius: 8px;
}
.firegex__regexview__pyfilter_text:hover{
overflow-x: auto;
}
.firegex__porthijack__servicerow__row{ .firegex__porthijack__servicerow__row{
width: 95%; width: 95%;
padding: 15px 0px; padding: 15px 0px;

View File

@@ -51,7 +51,6 @@ export type RegexAddForm = {
} }
export type PyFilter = { export type PyFilter = {
filter_id:number,
name:string, name:string,
blocked_packets:number, blocked_packets:number,
edited_packets:number, edited_packets:number,

View File

@@ -72,9 +72,14 @@ export async function genericapi(method:string,path:string,data:any = undefined,
const errorDefault = res.statusText const errorDefault = res.statusText
return res.json().then( res => reject(getErrorMessageFromServerResponse(res, errorDefault)) ).catch( _err => reject(errorDefault)) return res.json().then( res => reject(getErrorMessageFromServerResponse(res, errorDefault)) ).catch( _err => reject(errorDefault))
} }
res.json().then( res => resolve(res) ).catch( err => reject(err)) res.text().then(t => {
}) try{
.catch(err => { resolve(JSON.parse(t))
}catch(e){
resolve(t)
}
}).catch( err => reject(err))
}).catch(err => {
reject(err) reject(err)
}) })
}); });

View File

@@ -162,22 +162,22 @@ export default function ServiceDetailsNFProxy() {
</Tooltip> </Tooltip>
</Box> </Box>
</Box> </Box>
<Divider my="xl" /> <Divider my="xl" />
{filterCode.data?<> {filterCode.data?<>
<Title order={3} style={{textAlign:"center"}} className="center-flex"><FaPython style={{ marginBottom: -3 }} size={30} /><Space w="xs" />Filter code</Title> <Title order={3} style={{textAlign:"center"}} className="center-flex"><FaPython style={{ marginBottom: -3 }} size={30} /><Space w="xs" />Filter code</Title>
<CodeHighlight code={filterCode.data} language="python" mt="lg" /> <CodeHighlight code={filterCode.data} language="python" mt="lg" />
</>: null} </>: null}
<Space h="xl" />
{(!filtersList.data || filtersList.data.length == 0)?<> {(!filtersList.data || filtersList.data.length == 0)?<>
<Space h="xl" />
<Title className='center-flex' style={{textAlign:"center"}} order={3}>No filters found! Edit the proxy file</Title> <Title className='center-flex' style={{textAlign:"center"}} order={3}>No filters found! Edit the proxy file</Title>
<Space h="xs" /> <Space h="xs" />
<Title className='center-flex' style={{textAlign:"center"}} order={3}>Install the firegex client:<Space w="xs" /><Code mb={-4} >pip install fgex</Code></Title> <Title className='center-flex' style={{textAlign:"center"}} order={3}>Install the firegex client:<Space w="xs" /><Code mb={-4} >pip install fgex</Code></Title>
<Space h="xs" /> <Space h="xs" />
<Title className='center-flex' style={{textAlign:"center"}} order={3}>Then run the command:<Space w="xs" /><Code mb={-4} >fgex nfproxy</Code></Title> <Title className='center-flex' style={{textAlign:"center"}} order={3}>Then run the command:<Space w="xs" /><Code mb={-4} >fgex nfproxy</Code></Title>
</>: </>:<>{filtersList.data?.map( (filterInfo) => <PyFilterView filterInfo={filterInfo} />)}</>
<Grid>
{filtersList.data?.map( (filterInfo) => <Grid.Col key={filterInfo.filter_id} span={{ lg:6, xs: 12 }}><PyFilterView filterInfo={filterInfo} /></Grid.Col>)}
</Grid>
} }
<YesNoModal <YesNoModal
title='Are you sure to delete this service?' title='Are you sure to delete this service?'

View File

@@ -1,4 +1,4 @@
import { ActionIcon, Badge, Box, LoadingOverlay, Space, ThemeIcon, Title, Tooltip } from '@mantine/core'; import { ActionIcon, Badge, Box, FileButton, LoadingOverlay, Space, ThemeIcon, Title, Tooltip } from '@mantine/core';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { BsPlusLg } from "react-icons/bs"; import { BsPlusLg } from "react-icons/bs";
import { useNavigate, useParams } from 'react-router-dom'; import { useNavigate, useParams } from 'react-router-dom';
@@ -7,9 +7,11 @@ import { errorNotify, getErrorMessage, isMediumScreen } from '../../js/utils';
import AddEditService from '../../components/NFProxy/AddEditService'; import AddEditService from '../../components/NFProxy/AddEditService';
import { useQueryClient } from '@tanstack/react-query'; import { useQueryClient } from '@tanstack/react-query';
import { TbPlugConnected, TbReload } from 'react-icons/tb'; import { TbPlugConnected, TbReload } from 'react-icons/tb';
import { nfproxyServiceQuery } from '../../components/NFProxy/utils'; import { nfproxy, nfproxyServiceQuery } from '../../components/NFProxy/utils';
import { FaFilter, FaPencilAlt, FaServer } from 'react-icons/fa'; import { FaFilter, FaPencilAlt, FaServer } from 'react-icons/fa';
import { VscRegex } from 'react-icons/vsc'; import { MdUploadFile } from "react-icons/md";
import { notifications } from '@mantine/notifications';
import { useFileDialog } from '@mantine/hooks';
export default function NFProxy({ children }: { children: any }) { export default function NFProxy({ children }: { children: any }) {
@@ -23,6 +25,69 @@ export default function NFProxy({ children }: { children: any }) {
const [tooltipAddOpened, setTooltipAddOpened] = useState(false); const [tooltipAddOpened, setTooltipAddOpened] = useState(false);
const isMedium = isMediumScreen() const isMedium = isMediumScreen()
const services = nfproxyServiceQuery() const services = nfproxyServiceQuery()
const fileDialog = useFileDialog({
accept: ".py",
multiple: false,
resetOnOpen: true,
onChange: (files) => {
if (files?.length??0 > 0)
setFile(files![0])
}
});
const [file, setFile] = useState<File | null>(null);
useEffect(() => {
if (!srv) return
const service = services.data?.find(s => s.service_id === srv)
if (!service) return
if (file){
console.log("Uploading code")
const notify_id = notifications.show(
{
title: "Uploading code",
message: `Uploading code for service ${service.name}`,
color: "blue",
icon: <MdUploadFile size={20} />,
autoClose: false,
loading: true,
}
)
file.text()
.then( code => nfproxy.setpyfilterscode(service?.service_id??"",code.toString()))
.then( res => {
if (!res){
notifications.update({
id: notify_id,
title: "Code uploaded",
message: `Successfully uploaded code for service ${service.name}`,
color: "green",
icon: <MdUploadFile size={20} />,
autoClose: 5000,
loading: false,
})
}else{
notifications.update({
id: notify_id,
title: "Code upload failed",
message: `Error: ${res}`,
color: "red",
icon: <MdUploadFile size={20} />,
autoClose: 5000,
loading: false,
})
}
}).catch( err => {
notifications.update({
id: notify_id,
title: "Code upload failed",
message: `Error: ${err}`,
color: "red",
icon: <MdUploadFile size={20} />,
autoClose: 5000,
loading: false,
})
}).finally(()=>{setFile(null)})
}
}, [file])
useEffect(()=> { useEffect(()=> {
if(services.isError) if(services.isError)
@@ -37,7 +102,7 @@ export default function NFProxy({ children }: { children: any }) {
<Title order={5} className="center-flex"><ThemeIcon radius="md" size="md" variant='filled' color='lime' ><TbPlugConnected size={20} /></ThemeIcon><Space w="xs" />Netfilter Proxy</Title> <Title order={5} className="center-flex"><ThemeIcon radius="md" size="md" variant='filled' color='lime' ><TbPlugConnected size={20} /></ThemeIcon><Space w="xs" />Netfilter Proxy</Title>
{isMedium?<Box className='flex-spacer' />:<Space h="sm" />} {isMedium?<Box className='flex-spacer' />:<Space h="sm" />}
<Box className='center-flex' > <Box className='center-flex' >
General stats: {isMedium?"General stats:":null}
<Space w="xs" /> <Space w="xs" />
<Badge size="md" radius="sm" color="green" variant="filled"><FaServer style={{ marginBottom: -1, marginRight: 4}} />Services: {services.isLoading?0:services.data?.length}</Badge> <Badge size="md" radius="sm" color="green" variant="filled"><FaServer style={{ marginBottom: -1, marginRight: 4}} />Services: {services.isLoading?0:services.data?.length}</Badge>
<Space w="xs" /> <Space w="xs" />
@@ -50,8 +115,16 @@ export default function NFProxy({ children }: { children: any }) {
</Box> </Box>
{isMedium?null:<Space h="md" />} {isMedium?null:<Space h="md" />}
<Box className='center-flex' > <Box className='center-flex' >
{/* Will become the null a button to edit the source code? TODO */} { srv?
{ srv?null <Tooltip label="Upload a new filter code" position='bottom' color="blue" opened={tooltipAddOpened}>
<ActionIcon
color="blue" size="lg" radius="md" variant="filled"
onFocus={() => setTooltipAddOpened(false)} onBlur={() => setTooltipAddOpened(false)}
onMouseEnter={() => setTooltipAddOpened(true)}
onMouseLeave={() => setTooltipAddOpened(false)} onClick={fileDialog.open}>
<MdUploadFile size={18} />
</ActionIcon>
</Tooltip>
: <Tooltip label="Add a new service" position='bottom' color="blue" opened={tooltipAddOpened}> : <Tooltip label="Add a new service" position='bottom' color="blue" opened={tooltipAddOpened}>
<ActionIcon color="blue" onClick={()=>setOpen(true)} size="lg" radius="md" variant="filled" <ActionIcon color="blue" onClick={()=>setOpen(true)} size="lg" radius="md" variant="filled"
onFocus={() => setTooltipAddOpened(false)} onBlur={() => setTooltipAddOpened(false)} onFocus={() => setTooltipAddOpened(false)} onBlur={() => setTooltipAddOpened(false)}
@@ -85,7 +158,9 @@ export default function NFProxy({ children }: { children: any }) {
</>} </>}
</Box> </Box>
{srv?children:null} {srv?children:null}
{!srv?<AddEditService opened={open} onClose={closeModal} />:null} {!srv?
<AddEditService opened={open} onClose={closeModal} />:null
}
</> </>
} }

View File

@@ -38,7 +38,7 @@ function NFRegex({ children }: { children: any }) {
<Title order={5} className="center-flex"><ThemeIcon radius="md" size="md" variant='filled' color='grape' ><BsRegex size={20} /></ThemeIcon><Space w="xs" />Netfilter Regex</Title> <Title order={5} className="center-flex"><ThemeIcon radius="md" size="md" variant='filled' color='grape' ><BsRegex size={20} /></ThemeIcon><Space w="xs" />Netfilter Regex</Title>
{isMedium?<Box className='flex-spacer' />:<Space h="sm" />} {isMedium?<Box className='flex-spacer' />:<Space h="sm" />}
<Box className='center-flex' > <Box className='center-flex' >
General stats: {isMedium?"General stats:":null}
<Space w="xs" /> <Space w="xs" />
<Badge size="md" radius="sm" color="green" variant="filled"><FaServer style={{ marginBottom: -1, marginRight: 4}} />Services: {services.isLoading?0:services.data?.length}</Badge> <Badge size="md" radius="sm" color="green" variant="filled"><FaServer style={{ marginBottom: -1, marginRight: 4}} />Services: {services.isLoading?0:services.data?.length}</Badge>
<Space w="xs" /> <Space w="xs" />