push: code changes

This commit is contained in:
Domingo Dirutigliano
2025-02-25 23:53:04 +01:00
parent 8652f40235
commit 6a11dd0d16
37 changed files with 1306 additions and 640 deletions

View File

@@ -9,12 +9,13 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt
from passlib.context import CryptContext
from utils.sqlite import SQLite
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager
from utils import API_VERSION, FIREGEX_PORT, JWT_ALGORITHM, get_interfaces, socketio_emit, DEBUG, SysctlManager, NORELOAD
from utils.loader import frontend_deploy, load_routers
from utils.models import ChangePasswordModel, IpInterface, PasswordChangeForm, PasswordForm, ResetRequest, StatusModel, StatusMessageModel
from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
import socketio
from socketio.exceptions import ConnectionRefusedError
# DB init
db = SQLite('db/firegex.db')
@@ -52,7 +53,6 @@ if DEBUG:
allow_headers=["*"],
)
utils.socketio = socketio.AsyncServer(
async_mode="asgi",
cors_allowed_origins=[],
@@ -69,9 +69,6 @@ def set_psw(psw: str):
hash_psw = crypto.hash(psw)
db.put("password",hash_psw)
@utils.socketio.on("update")
async def updater(): pass
def create_access_token(data: dict):
to_encode = data.copy()
encoded_jwt = jwt.encode(to_encode, JWT_SECRET(), algorithm=JWT_ALGORITHM)
@@ -90,6 +87,28 @@ async def check_login(token: str = Depends(oauth2_scheme)):
return False
return logged_in
@utils.socketio.on("connect")
async def sio_connect(sid, environ, auth):
if not auth or not await check_login(auth.get("token")):
raise ConnectionRefusedError("Unauthorized")
utils.sid_list.add(sid)
@utils.socketio.on("disconnect")
async def sio_disconnect(sid):
try:
utils.sid_list.remove(sid)
except KeyError:
pass
async def disconnect_all():
while True:
if len(utils.sid_list) == 0:
break
await utils.socketio.disconnect(utils.sid_list.pop())
@utils.socketio.on("update")
async def updater(): pass
async def is_loggined(auth: bool = Depends(check_login)):
if not auth:
raise HTTPException(
@@ -122,6 +141,7 @@ async def login_api(form: OAuth2PasswordRequestForm = Depends()):
return {"access_token": create_access_token({"logged_in": True}), "token_type": "bearer"}
raise HTTPException(406,"Wrong password!")
@app.post('/api/set-password', response_model=ChangePasswordModel)
async def set_password(form: PasswordForm):
"""Set the password of firegex"""
@@ -143,6 +163,7 @@ async def change_password(form: PasswordChangeForm):
return {"status":"Cannot insert an empty password!"}
if form.expire:
db.put("secret", secrets.token_hex(32))
await disconnect_all()
set_psw(form.password)
await refresh_frontend()
@@ -200,7 +221,7 @@ if __name__ == '__main__':
"app:app",
host="::" if DEBUG else None,
port=FIREGEX_PORT,
reload=DEBUG,
reload=DEBUG and not NORELOAD,
access_log=True,
workers=1, # Firewall module can't be replicated in multiple workers
# Later the firewall module will be moved to a separate process

View File

@@ -17,6 +17,7 @@ enum class FilterAction{ DROP, ACCEPT, MANGLE, NOACTION };
enum class L4Proto { TCP, UDP, RAW };
typedef Tins::TCPIP::StreamIdentifier stream_id;
//TODO DUBBIO: I PACCHETTI INVIATI A PYTHON SONO GIA' FIXATI?
template<typename T>
class PktRequest {
@@ -25,6 +26,9 @@ class PktRequest {
mnl_socket* nl = nullptr;
uint16_t res_id;
uint32_t packet_id;
size_t _original_size;
size_t _data_original_size;
bool need_tcp_fixing = false;
public:
bool is_ipv6;
Tins::IP* ipv4 = nullptr;
@@ -39,17 +43,27 @@ class PktRequest {
size_t data_size;
stream_id sid;
int64_t* tcp_in_offset = nullptr;
int64_t* tcp_out_offset = nullptr;
T* ctx;
private:
inline void fetch_data_size(Tins::PDU* pdu){
static size_t inner_data_size(Tins::PDU* pdu){
if (pdu == nullptr){
return 0;
}
auto inner = pdu->inner_pdu();
if (inner == nullptr){
data_size = 0;
}else{
data_size = inner->size();
return 0;
}
return inner->size();
}
inline void fetch_data_size(Tins::PDU* pdu){
data_size = inner_data_size(pdu);
_data_original_size = data_size;
}
L4Proto fill_l4_info(){
@@ -86,23 +100,92 @@ class PktRequest {
}
}
bool need_tcp_fix(){
return (tcp_in_offset != nullptr && *tcp_in_offset != 0) || (tcp_out_offset != nullptr && *tcp_out_offset != 0);
}
Tins::PDU::serialization_type reserialize_raw_data(const uint8_t* data, const size_t& data_size){
if (is_ipv6){
Tins::IPv6 ipv6_new = Tins::IPv6(data, data_size);
if (tcp){
Tins::TCP* tcp_new = ipv6_new.find_pdu<Tins::TCP>();
}
return ipv6_new.serialize();
}else{
Tins::IP ipv4_new = Tins::IP(data, data_size);
if (tcp){
Tins::TCP* tcp_new = ipv4_new.find_pdu<Tins::TCP>();
}
return ipv4_new.serialize();
}
}
void _fix_ack_seq_tcp(Tins::TCP* this_tcp){
need_tcp_fixing = need_tcp_fix();
#ifdef DEBUG
if (need_tcp_fixing){
cerr << "[DEBUG] Fixing ack_seq with offsets " << *tcp_in_offset << " " << *tcp_out_offset << endl;
}
#endif
if(this_tcp == nullptr){
return;
}
if (is_input){
if (tcp_in_offset != nullptr){
this_tcp->seq(this_tcp->seq() + *tcp_in_offset);
}
if (tcp_out_offset != nullptr){
this_tcp->ack_seq(this_tcp->ack_seq() - *tcp_out_offset);
}
}else{
if (tcp_in_offset != nullptr){
this_tcp->ack_seq(this_tcp->ack_seq() - *tcp_in_offset);
}
if (tcp_out_offset != nullptr){
this_tcp->seq(this_tcp->seq() + *tcp_out_offset);
}
}
#ifdef DEBUG
if (need_tcp_fixing){
size_t new_size = inner_data_size(this_tcp);
cerr << "[DEBUG] FIXED PKT " << (is_input?"-> IN ":"<- OUT") << " [SEQ: " << this_tcp->seq() << "] \t[ACK: " << this_tcp->ack_seq() << "] \t[SIZE: " << new_size << "]" << endl;
}
#endif
}
public:
PktRequest(const char* payload, size_t plen, T* ctx, mnl_socket* nl, nfgenmsg *nfg, nfqnl_msg_packet_hdr *ph, bool is_input):
ctx(ctx), nl(nl), res_id(nfg->res_id),
packet_id(ph->packet_id), is_input(is_input),
packet(string(payload, plen)),
is_ipv6((payload[0] & 0xf0) == 0x60){
if (is_ipv6){
ipv6 = new Tins::IPv6((uint8_t*)packet.c_str(), plen);
sid = stream_id::make_identifier(*ipv6);
}else{
ipv4 = new Tins::IP((uint8_t*)packet.c_str(), plen);
sid = stream_id::make_identifier(*ipv4);
}
l4_proto = fill_l4_info();
data = packet.data()+(plen-data_size);
action(FilterAction::NOACTION),
is_ipv6((payload[0] & 0xf0) == 0x60)
{
if (is_ipv6){
ipv6 = new Tins::IPv6((uint8_t*)packet.c_str(), plen);
sid = stream_id::make_identifier(*ipv6);
_original_size = ipv6->size();
}else{
ipv4 = new Tins::IP((uint8_t*)packet.data(), plen);
sid = stream_id::make_identifier(*ipv4);
_original_size = ipv4->size();
}
l4_proto = fill_l4_info();
data = packet.data()+(plen-data_size);
#ifdef DEBUG
if (tcp){
cerr << "[DEBUG] NEW_PACKET " << (is_input?"-> IN ":"<- OUT") << " [SEQ: " << tcp->seq() << "] \t[ACK: " << tcp->ack_seq() << "] \t[SIZE: " << data_size << "]" << endl;
}
#endif
}
void fix_tcp_ack(){
if (tcp){
_fix_ack_seq_tcp(tcp);
}
}
void drop(){
if (action == FilterAction::NOACTION){
@@ -113,6 +196,14 @@ class PktRequest {
}
}
size_t data_original_size(){
return _data_original_size;
}
size_t original_size(){
return _original_size;
}
void accept(){
if (action == FilterAction::NOACTION){
action = FilterAction::ACCEPT;
@@ -131,7 +222,26 @@ class PktRequest {
}
}
void mangle_custom_pkt(const uint8_t* pkt, size_t pkt_size){
void reject(){
if (tcp){
//If the packet has data, we have to remove it
delete tcp->release_inner_pdu();
//For the first matched data or only for data packets, we set FIN bit
//This only for client packets, because this will trigger server to close the connection
//Packets will be filtered anyway also if client don't send packets
if (_data_original_size != 0 && is_input){
tcp->set_flag(Tins::TCP::FIN,1);
tcp->set_flag(Tins::TCP::ACK,1);
tcp->set_flag(Tins::TCP::SYN,0);
}
//Send the edited packet to the kernel
mangle();
}else{
drop();
}
}
void mangle_custom_pkt(uint8_t* pkt, const size_t& pkt_size){
if (action == FilterAction::NOACTION){
action = FilterAction::MANGLE;
perfrom_action(pkt, pkt_size);
@@ -149,26 +259,58 @@ class PktRequest {
delete ipv6;
}
inline Tins::PDU::serialization_type serialize(){
if (is_ipv6){
return ipv6->serialize();
}else{
return ipv4->serialize();
}
}
private:
void perfrom_action(const uint8_t* custom_data = nullptr, size_t custom_data_size = 0){
void perfrom_action(uint8_t* custom_data = nullptr, size_t custom_data_size = 0){
char buf[MNL_SOCKET_BUFFER_SIZE];
struct nlmsghdr *nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(res_id));
switch (action)
{
case FilterAction::ACCEPT:
if (need_tcp_fixing){
Tins::PDU::serialization_type data = serialize();
nfq_nlmsg_verdict_put_pkt(nlh_verdict, data.data(), data.size());
}
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_ACCEPT );
break;
case FilterAction::DROP:
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_DROP );
break;
case FilterAction::MANGLE:{
if (custom_data != nullptr){
nfq_nlmsg_verdict_put_pkt(nlh_verdict, custom_data, custom_data_size);
}else if (is_ipv6){
nfq_nlmsg_verdict_put_pkt(nlh_verdict, ipv6->serialize().data(), ipv6->size());
//If not custom data, use the data in the packets
Tins::PDU::serialization_type data;
if (custom_data == nullptr){
data = serialize();
}else{
nfq_nlmsg_verdict_put_pkt(nlh_verdict, ipv4->serialize().data(), ipv4->size());
try{
data = reserialize_raw_data(custom_data, custom_data_size);
}catch(...){
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_DROP );
action = FilterAction::DROP;
break;
}
}
#ifdef DEBUG
size_t new_size = _data_original_size+((int64_t)custom_data_size) - ((int64_t)_original_size);
cerr << "[DEBUG] MANGLEDPKT " << (is_input?"-> IN ":"<- OUT") << " [SIZE: " << new_size << "]" << endl;
#endif
if (tcp && custom_data_size != _original_size){
int64_t delta = ((int64_t)custom_data_size) - ((int64_t)_original_size);
if (is_input && tcp_in_offset != nullptr){
*tcp_in_offset += delta;
}else if (!is_input && tcp_out_offset != nullptr){
*tcp_out_offset += delta;
}
}
nfq_nlmsg_verdict_put_pkt(nlh_verdict, data.data(), data.size());
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(packet_id), NF_ACCEPT );
break;
}

View File

@@ -4,11 +4,11 @@
#include "pyproxy/settings.cpp"
#include "pyproxy/pyproxy.cpp"
#include "classes/netfilter.cpp"
#include <syncstream>
#include <iostream>
#include <stdexcept>
#include <cstdlib>
#include <endian.h>
#include "utils.cpp"
using namespace std;
using namespace Firegex::PyProxy;
@@ -33,13 +33,13 @@ def invalid_curl_agent(http):
The code is now edited adding an intestation and a end statement:
```python
global __firegex_pyfilter_enabled, __firegex_proto
<user_code>
__firegex_pyfilter_enabled = ["invalid_curl_agent", "func3"] # This list is dynamically generated by firegex backend
__firegex_proto = "http"
import firegex.nfproxy.internals
<user_code>
firegex.nfproxy.internals.compile() # This function can save other global variables, to use by the packet handler and is used generally to check and optimize the code
firegex.nfproxy.internals.compile(globals(), locals()) # This function can save other global variables, to use by the packet handler and is used generally to check and optimize the code
````
(First lines are the same to keep line of code consistent on exceptions messages)
This code will be executed only once, and is needed to build the global and local context to use
The globals and locals generated here are copied for each connection, and are used to handle the packets
@@ -82,60 +82,53 @@ firegex lib will give you all the needed possibilities to do this is many ways
Final note: is not raccomanded to use variables that starts with __firegex_ in your code, because they may break the nfproxy
*/
ssize_t read_check(int __fd, void *__buf, size_t __nbytes){
ssize_t bytes = read(__fd, __buf, __nbytes);
if (bytes == 0){
cerr << "[fatal] [updater] read() returned EOF" << endl;
throw invalid_argument("read() returned EOF");
}
if (bytes < 0){
cerr << "[fatal] [updater] read() returned an error" << bytes << endl;
throw invalid_argument("read() returned an error");
}
return bytes;
}
void config_updater (){
while (true){
PyThreadState* state = PyEval_SaveThread(); // Release GIL while doing IO operation
uint32_t code_size;
read_check(STDIN_FILENO, &code_size, 4);
//Python will send number always in little endian
code_size = le32toh(code_size);
string code;
code.resize(code_size);
read_check(STDIN_FILENO, code.data(), code_size);
memcpy(&code_size, control_socket.recv(4).c_str(), 4);
code_size = be32toh(code_size);
string code = control_socket.recv(code_size);
#ifdef DEBUG
cerr << "[DEBUG] [updater] Received code: " << code << endl;
#endif
cerr << "[info] [updater] Updating configuration" << endl;
PyEval_AcquireThread(state); //Restore GIL before executing python code
try{
config.reset(new PyCodeConfig(code));
cerr << "[info] [updater] Config update done" << endl;
osyncstream(cout) << "ACK OK" << endl;
control_socket << "ACK OK" << endl;
}catch(const std::exception& e){
cerr << "[error] [updater] Failed to build new configuration!" << endl;
osyncstream(cout) << "ACK FAIL " << e.what() << endl;
control_socket << "ACK FAIL " << e.what() << endl;
}
}
}
int main(int argc, char *argv[]){
int main(int argc, char *argv[]) {
// Connect to the python backend using the unix socket
init_control_socket();
// Initialize the python interpreter
Py_Initialize();
atexit(Py_Finalize);
init_handle_packet_code(); //Compile the static code used to handle packets
if (freopen(nullptr, "rb", stdin) == nullptr){ // We need to read from stdin binary data
cerr << "[fatal] [main] Failed to reopen stdin in binary mode" << endl;
return 1;
}
int n_of_threads = 1;
char * n_threads_str = getenv("NTHREADS");
if (n_threads_str != nullptr) n_of_threads = ::atoi(n_threads_str);
if(n_of_threads <= 0) n_of_threads = 1;
config.reset(new PyCodeConfig());
MultiThreadQueue<PyProxyQueue> queue(n_of_threads);
osyncstream(cout) << "QUEUE " << queue.queue_num() << endl;
control_socket << "QUEUE " << queue.queue_num() << endl;
cerr << "[info] [main] Queue: " << queue.queue_num() << " threads assigned: " << n_of_threads << endl;
thread qthr([&](){

View File

@@ -33,7 +33,8 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
public:
stream_ctx sctx;
StreamFollower follower;
PyGILState_STATE gstate;
PyThreadState * gtstate = nullptr;
PyInterpreterConfig py_thread_config = {
.use_main_obmalloc = 0,
.allow_fork = 0,
@@ -44,24 +45,23 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
.gil = PyInterpreterConfig_OWN_GIL,
};
PyThreadState *tstate = NULL;
PyStatus pystatus;
struct {
bool matching_has_been_called = false;
bool already_closed = false;
bool rejected = true;
NfQueue::PktRequest<PyProxyQueue>* pkt;
} match_ctx;
NfQueue::PktRequest<PyProxyQueue>* pkt;
tcp_ack_seq_ctx* current_tcp_ack = nullptr;
void before_loop() override {
// Create thred structure for python
gstate = PyGILState_Ensure();
PyStatus pystatus;
// Create a new interpreter for the thread
gtstate = PyThreadState_New(PyInterpreterState_Main());
PyEval_AcquireThread(gtstate);
pystatus = Py_NewInterpreterFromConfig(&tstate, &py_thread_config);
if (PyStatus_Exception(pystatus)) {
Py_ExitStatusException(pystatus);
if(tstate == nullptr){
cerr << "[fatal] [main] Failed to create new interpreter" << endl;
exit(EXIT_FAILURE);
throw invalid_argument("Failed to create new interpreter (null tstate)");
}
if (PyStatus_Exception(pystatus)) {
cerr << "[fatal] [main] Failed to create new interpreter" << endl;
Py_ExitStatusException(pystatus);
throw invalid_argument("Failed to create new interpreter (pystatus exc)");
}
// Setting callbacks for the stream follower
follower.new_stream_callback(bind(on_new_stream, placeholders::_1, this));
@@ -69,21 +69,24 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
}
inline void print_blocked_reason(const string& func_name){
osyncstream(cout) << "BLOCKED " << func_name << endl;
control_socket << "BLOCKED " << func_name << endl;
}
inline void print_mangle_reason(const string& func_name){
osyncstream(cout) << "MANGLED " << func_name << endl;
control_socket << "MANGLED " << func_name << endl;
}
inline void print_exception_reason(){
osyncstream(cout) << "EXCEPTION" << endl;
control_socket << "EXCEPTION" << endl;
}
//If the stream has already been matched, drop all data, and try to close the connection
static void keep_fin_packet(PyProxyQueue* proxy_info){
proxy_info->match_ctx.matching_has_been_called = true;
proxy_info->match_ctx.already_closed = true;
static void keep_fin_packet(PyProxyQueue* pyq){
pyq->pkt->reject();// This is needed because the callback has to take the updated pkt pointer!
}
static void keep_dropped(PyProxyQueue* pyq){
pyq->pkt->drop();// This is needed because the callback has to take the updated pkt pointer!
}
void filter_action(NfQueue::PktRequest<PyProxyQueue>* pkt, Stream& stream){
@@ -92,36 +95,45 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
if (stream_search == sctx.streams_ctx.end()){
shared_ptr<PyCodeConfig> conf = config;
//If config is not set, ignore the stream
if (conf->glob == nullptr || conf->local == nullptr){
PyObject* compiled_code = conf->compiled_code();
if (compiled_code == nullptr){
stream.client_data_callback(nullptr);
stream.server_data_callback(nullptr);
return pkt->accept();
}
stream_match = new pyfilter_ctx(conf->glob, conf->local);
stream_match = new pyfilter_ctx(compiled_code);
Py_DECREF(compiled_code);
sctx.streams_ctx.insert_or_assign(pkt->sid, stream_match);
}else{
stream_match = stream_search->second;
}
}
auto result = stream_match->handle_packet(pkt);
switch(result.action){
case PyFilterResponse::ACCEPT:
pkt->accept();
return pkt->accept();
case PyFilterResponse::DROP:
print_blocked_reason(*result.filter_match_by);
sctx.clean_stream_by_id(pkt->sid);
stream.client_data_callback(nullptr);
stream.server_data_callback(nullptr);
break;
stream.client_data_callback(bind(keep_dropped, this));
stream.server_data_callback(bind(keep_dropped, this));
return pkt->drop();
case PyFilterResponse::REJECT:
print_blocked_reason(*result.filter_match_by);
sctx.clean_stream_by_id(pkt->sid);
stream.client_data_callback(bind(keep_fin_packet, this));
stream.server_data_callback(bind(keep_fin_packet, this));
pkt->ctx->match_ctx.rejected = true; //Handler will take care of the rest
break;
return pkt->reject();
case PyFilterResponse::MANGLE:
print_mangle_reason(*result.filter_match_by);
pkt->mangle_custom_pkt((uint8_t*)result.mangled_packet->c_str(), result.mangled_packet->size());
break;
pkt->mangle_custom_pkt((uint8_t*)result.mangled_packet->data(), result.mangled_packet->size());
if (pkt->get_action() == NfQueue::FilterAction::DROP){
cerr << "[error] [filter_action] Failed to mangle: the packet sent is not serializzable... the packet was dropped" << endl;
print_blocked_reason(*result.filter_match_by);
print_exception_reason();
}else{
print_mangle_reason(*result.filter_match_by);
}
return;
case PyFilterResponse::EXCEPTION:
case PyFilterResponse::INVALID:
print_exception_reason();
@@ -129,16 +141,15 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
//Free the packet data
stream.client_data_callback(nullptr);
stream.server_data_callback(nullptr);
pkt->accept();
break;
return pkt->accept();
}
}
static void on_data_recv(Stream& stream, PyProxyQueue* proxy_info, string data) {
proxy_info->match_ctx.matching_has_been_called = true;
proxy_info->match_ctx.already_closed = false;
proxy_info->filter_action(proxy_info->match_ctx.pkt, stream);
proxy_info->pkt->data = data.data();
proxy_info->pkt->data_size = data.size();
proxy_info->filter_action(proxy_info->pkt, stream);
}
//Input data filtering
@@ -152,77 +163,77 @@ class PyProxyQueue: public NfQueue::ThreadNfQueue<PyProxyQueue> {
}
// A stream was terminated. The second argument is the reason why it was terminated
static void on_stream_close(Stream& stream, PyProxyQueue* proxy_info) {
static void on_stream_close(Stream& stream, PyProxyQueue* pyq) {
stream_id stream_id = stream_id::make_identifier(stream);
proxy_info->sctx.clean_stream_by_id(stream_id);
pyq->sctx.clean_stream_by_id(stream_id);
pyq->sctx.clean_tcp_ack_by_id(stream_id);
}
static void on_new_stream(Stream& stream, PyProxyQueue* proxy_info) {
static void on_new_stream(Stream& stream, PyProxyQueue* pyq) {
stream.auto_cleanup_payloads(true);
if (stream.is_partial_stream()) {
stream.enable_recovery_mode(10 * 1024);
}
stream.client_data_callback(bind(on_client_data, placeholders::_1, proxy_info));
stream.server_data_callback(bind(on_server_data, placeholders::_1, proxy_info));
stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, proxy_info));
if (pyq->current_tcp_ack != nullptr){
pyq->current_tcp_ack->reset();
}else{
pyq->current_tcp_ack = new tcp_ack_seq_ctx();
pyq->sctx.tcp_ack_ctx.insert_or_assign(pyq->pkt->sid, pyq->current_tcp_ack);
pyq->pkt->tcp_in_offset = &pyq->current_tcp_ack->in_tcp_offset;
pyq->pkt->tcp_out_offset = &pyq->current_tcp_ack->out_tcp_offset;
}
//Should not happen, but with this we can be sure about this
auto tcp_ack_search = pyq->sctx.tcp_ack_ctx.find(pyq->pkt->sid);
if (tcp_ack_search != pyq->sctx.tcp_ack_ctx.end()){
tcp_ack_search->second->reset();
}
stream.client_data_callback(bind(on_client_data, placeholders::_1, pyq));
stream.server_data_callback(bind(on_server_data, placeholders::_1, pyq));
stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, pyq));
}
void handle_next_packet(NfQueue::PktRequest<PyProxyQueue>* _pkt) override{
pkt = _pkt; // Setting packet context
void handle_next_packet(NfQueue::PktRequest<PyProxyQueue>* pkt) override{
if (pkt->l4_proto != NfQueue::L4Proto::TCP){
throw invalid_argument("Only TCP and UDP are supported");
}
Tins::PDU* application_layer = pkt->tcp->inner_pdu();
u_int16_t payload_size = 0;
if (application_layer != nullptr){
payload_size = application_layer->size();
auto tcp_ack_search = sctx.tcp_ack_ctx.find(pkt->sid);
if (tcp_ack_search != sctx.tcp_ack_ctx.end()){
current_tcp_ack = tcp_ack_search->second;
pkt->tcp_in_offset = &current_tcp_ack->in_tcp_offset;
pkt->tcp_out_offset = &current_tcp_ack->out_tcp_offset;
}else{
current_tcp_ack = nullptr;
//If necessary will be created by libtis new_stream callback
}
match_ctx.matching_has_been_called = false;
match_ctx.pkt = pkt;
if (pkt->is_ipv6){
pkt->fix_tcp_ack();
follower.process_packet(*pkt->ipv6);
}else{
pkt->fix_tcp_ack();
follower.process_packet(*pkt->ipv4);
}
// Do an action only is an ordered packet has been received
if (match_ctx.matching_has_been_called){
bool empty_payload = payload_size == 0;
//In this 2 cases we have to remove all data about the stream
if (!match_ctx.rejected || match_ctx.already_closed){
sctx.clean_stream_by_id(pkt->sid);
//If the packet has data, we have to remove it
if (!empty_payload){
Tins::PDU* data_layer = pkt->tcp->release_inner_pdu();
if (data_layer != nullptr){
delete data_layer;
}
}
//For the first matched data or only for data packets, we set FIN bit
//This only for client packets, because this will trigger server to close the connection
//Packets will be filtered anyway also if client don't send packets
if ((!match_ctx.rejected || !empty_payload) && pkt->is_input){
pkt->tcp->set_flag(Tins::TCP::FIN,1);
pkt->tcp->set_flag(Tins::TCP::ACK,1);
pkt->tcp->set_flag(Tins::TCP::SYN,0);
}
//Send the edited packet to the kernel
return pkt->mangle();
}else{
//Fallback to the default action
if (pkt->get_action() == NfQueue::FilterAction::NOACTION){
return pkt->accept();
}
}
}else{
//Fallback to the default action
if (pkt->get_action() == NfQueue::FilterAction::NOACTION){
return pkt->accept();
}
}
~PyProxyQueue() {
// Closing first the interpreter
Py_EndInterpreter(tstate);
// Releasing the GIL and the thread data structure
PyGILState_Release(gstate);
PyEval_ReleaseThread(tstate);
PyThreadState_Clear(tstate);
PyThreadState_Delete(tstate);
sctx.clean();
}

View File

@@ -2,58 +2,73 @@
#define PROXY_TUNNEL_SETTINGS_CPP
#include <Python.h>
#include <marshal.h>
#include <vector>
#include <memory>
#include <iostream>
#include "../utils.cpp"
using namespace std;
namespace Firegex {
namespace PyProxy {
class PyCodeConfig;
shared_ptr<PyCodeConfig> config;
PyObject* py_handle_packet_code = nullptr;
UnixClientConnection control_socket;
class PyCodeConfig{
public:
PyObject* glob = nullptr;
PyObject* local = nullptr;
private:
void _clean(){
Py_XDECREF(glob);
Py_XDECREF(local);
}
public:
string encoded_code;
PyCodeConfig(const string& pycode){
PyObject* compiled_code = Py_CompileStringExFlags(pycode.c_str(), "<pyfilter>", Py_file_input, NULL, 2);
if (compiled_code == nullptr){
std::cerr << "[fatal] [main] Failed to compile the code" << endl;
_clean();
throw invalid_argument("Failed to compile the code");
}
glob = PyDict_New();
local = PyDict_New();
PyObject* result = PyEval_EvalCode(compiled_code, glob, local);
Py_XDECREF(compiled_code);
PyObject* glob = PyDict_New();
PyObject* result = PyEval_EvalCode(compiled_code, glob, glob);
Py_DECREF(glob);
if (!result){
PyErr_Print();
_clean();
Py_DECREF(compiled_code);
std::cerr << "[fatal] [main] Failed to execute the code" << endl;
throw invalid_argument("Failed to execute the code, maybe an invalid filter code has been provided");
}
Py_DECREF(result);
PyObject* code_dump = PyMarshal_WriteObjectToString(compiled_code, 4);
Py_DECREF(compiled_code);
if (code_dump == nullptr){
PyErr_Print();
std::cerr << "[fatal] [main] Failed to dump the code" << endl;
throw invalid_argument("Failed to dump the code");
}
if (!PyBytes_Check(code_dump)){
std::cerr << "[fatal] [main] Failed to dump the code" << endl;
throw invalid_argument("Failed to dump the code");
}
encoded_code = string(PyBytes_AsString(code_dump), PyBytes_Size(code_dump));
Py_DECREF(code_dump);
}
PyCodeConfig(){}
~PyCodeConfig(){
_clean();
PyObject* compiled_code(){
if (encoded_code.empty()) return nullptr;
return PyMarshal_ReadObjectFromString(encoded_code.c_str(), encoded_code.size());
}
PyCodeConfig(){}
};
shared_ptr<PyCodeConfig> config;
PyObject* py_handle_packet_code = nullptr;
void init_control_socket(){
char * socket_path = getenv("FIREGEX_NFPROXY_SOCK");
if (socket_path == nullptr) throw invalid_argument("FIREGEX_NFPROXY_SOCK not set");
if (strlen(socket_path) >= 108) throw invalid_argument("FIREGEX_NFPROXY_SOCK too long");
control_socket = UnixClientConnection(socket_path);
}
void init_handle_packet_code(){
py_handle_packet_code = Py_CompileStringExFlags(

View File

@@ -27,10 +27,21 @@ enum PyFilterResponse {
INVALID = 5
};
const PyFilterResponse VALID_PYTHON_RESPONSE[4] = {
PyFilterResponse::ACCEPT,
PyFilterResponse::DROP,
PyFilterResponse::REJECT,
PyFilterResponse::MANGLE
};
struct py_filter_response {
PyFilterResponse action;
string* filter_match_by = nullptr;
string* mangled_packet = nullptr;
py_filter_response(PyFilterResponse action, string* filter_match_by = nullptr, string* mangled_packet = nullptr):
action(action), filter_match_by(filter_match_by), mangled_packet(mangled_packet){}
~py_filter_response(){
delete mangled_packet;
delete filter_match_by;
@@ -39,34 +50,35 @@ struct py_filter_response {
typedef Tins::TCPIP::StreamIdentifier stream_id;
struct tcp_ack_seq_ctx{
//Can be negative, so we use int64_t (for a uint64_t value)
int64_t in_tcp_offset = 0;
int64_t out_tcp_offset = 0;
tcp_ack_seq_ctx(){}
void reset(){
in_tcp_offset = 0;
out_tcp_offset = 0;
}
};
struct pyfilter_ctx {
PyObject * glob = nullptr;
PyObject * local = nullptr;
pyfilter_ctx(PyObject * original_glob, PyObject * original_local){
PyObject *copy = PyImport_ImportModule("copy");
if (copy == nullptr){
pyfilter_ctx(PyObject * compiled_code){
glob = PyDict_New();
PyObject* result = PyEval_EvalCode(compiled_code, glob, glob);
if (!result){
PyErr_Print();
throw invalid_argument("Failed to import copy module");
Py_XDECREF(glob);
std::cerr << "[fatal] [main] Failed to compile the code" << endl;
throw invalid_argument("Failed to execute the code, maybe an invalid filter code has been provided");
}
PyObject *deepcopy = PyObject_GetAttrString(copy, "deepcopy");
glob = PyObject_CallFunctionObjArgs(deepcopy, original_glob, NULL);
if (glob == nullptr){
PyErr_Print();
throw invalid_argument("Failed to deepcopy the global dict");
}
local = PyObject_CallFunctionObjArgs(deepcopy, original_local, NULL);
if (local == nullptr){
PyErr_Print();
throw invalid_argument("Failed to deepcopy the local dict");
}
Py_DECREF(copy);
Py_XDECREF(result);
}
~pyfilter_ctx(){
Py_XDECREF(glob);
Py_XDECREF(local);
Py_DECREF(glob);
}
inline void set_item_to_glob(const char* key, PyObject* value){
@@ -84,15 +96,12 @@ struct pyfilter_ctx {
}
}
inline void set_item_to_local(const char* key, PyObject* value){
set_item_to_dict(local, key, value);
}
inline void set_item_to_dict(PyObject* dict, const char* key, PyObject* value){
if (PyDict_SetItemString(dict, key, value) != 0){
PyErr_Print();
throw invalid_argument("Failed to set item to dict");
}
Py_DECREF(value);
}
py_filter_response handle_packet(
@@ -101,6 +110,7 @@ struct pyfilter_ctx {
PyObject * packet_info = PyDict_New();
set_item_to_dict(packet_info, "data", PyBytes_FromStringAndSize(pkt->data, pkt->data_size));
set_item_to_dict(packet_info, "l4_size", PyLong_FromLong(pkt->data_original_size()));
set_item_to_dict(packet_info, "raw_packet", PyBytes_FromStringAndSize(pkt->packet.c_str(), pkt->packet.size()));
set_item_to_dict(packet_info, "is_input", PyBool_FromLong(pkt->is_input));
set_item_to_dict(packet_info, "is_ipv6", PyBool_FromLong(pkt->is_ipv6));
@@ -108,92 +118,156 @@ struct pyfilter_ctx {
// Set packet info to the global context
set_item_to_glob("__firegex_packet_info", packet_info);
PyObject * result = PyEval_EvalCode(py_handle_packet_code, glob, local);
PyObject * result = PyEval_EvalCode(py_handle_packet_code, glob, glob);
del_item_from_glob("__firegex_packet_info");
Py_DECREF(packet_info);
Py_DECREF(packet_info);
if (!result){
PyErr_Print();
return py_filter_response{PyFilterResponse::EXCEPTION, nullptr};
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Exception raised" << endl;
#endif
return py_filter_response(PyFilterResponse::EXCEPTION);
}
Py_DECREF(result);
result = get_item_from_glob("__firegex_pyfilter_result");
if (result == nullptr){
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] No result found" << endl;
#endif
return py_filter_response(PyFilterResponse::INVALID);
}
if (!PyDict_Check(result)){
PyErr_Print();
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Result is not a dict" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
return py_filter_response(PyFilterResponse::INVALID);
}
PyObject* action = PyDict_GetItemString(result, "action");
if (action == nullptr){
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] No result action found" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
return py_filter_response(PyFilterResponse::INVALID);
}
if (!PyLong_Check(action)){
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Action is not a long" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
return py_filter_response(PyFilterResponse::INVALID);
}
PyFilterResponse action_enum = (PyFilterResponse)PyLong_AsLong(action);
if (action_enum == PyFilterResponse::ACCEPT || action_enum == PyFilterResponse::EXCEPTION || action_enum == PyFilterResponse::INVALID){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{action_enum, nullptr, nullptr};
}else{
PyObject *func_name_py = PyDict_GetItemString(result, "matched_by");
if (func_name_py == nullptr){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
}
if (!PyUnicode_Check(func_name_py)){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
}
string* func_name = new string(PyUnicode_AsUTF8(func_name_py));
if (action_enum == PyFilterResponse::DROP || action_enum == PyFilterResponse::REJECT){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{action_enum, func_name, nullptr};
}
if (action_enum != PyFilterResponse::MANGLE){
PyObject* mangled_packet = PyDict_GetItemString(result, "mangled_packet");
if (mangled_packet == nullptr){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
}
if (!PyBytes_Check(mangled_packet)){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
}
string* pkt_str = new string(PyBytes_AsString(mangled_packet), PyBytes_Size(mangled_packet));
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::MANGLE, func_name, pkt_str};
//Check action_enum
bool valid = false;
for (auto valid_action: VALID_PYTHON_RESPONSE){
if (action_enum == valid_action){
valid = true;
break;
}
}
if (!valid){
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] Invalid action" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response(PyFilterResponse::INVALID);
}
if (action_enum == PyFilterResponse::ACCEPT){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response(action_enum);
}
PyObject *func_name_py = PyDict_GetItemString(result, "matched_by");
if (func_name_py == nullptr){
del_item_from_glob("__firegex_pyfilter_result");
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] No result matched_by found" << endl;
#endif
return py_filter_response(PyFilterResponse::INVALID);
}
if (!PyUnicode_Check(func_name_py)){
del_item_from_glob("__firegex_pyfilter_result");
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] matched_by is not a string" << endl;
#endif
return py_filter_response(PyFilterResponse::INVALID);
}
string* func_name = new string(PyUnicode_AsUTF8(func_name_py));
if (action_enum == PyFilterResponse::DROP || action_enum == PyFilterResponse::REJECT){
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response(action_enum, func_name);
}
if (action_enum == PyFilterResponse::MANGLE){
PyObject* mangled_packet = PyDict_GetItemString(result, "mangled_packet");
if (mangled_packet == nullptr){
del_item_from_glob("__firegex_pyfilter_result");
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] No result mangled_packet found" << endl;
#endif
return py_filter_response(PyFilterResponse::INVALID);
}
if (!PyBytes_Check(mangled_packet)){
#ifdef DEBUG
cerr << "[DEBUG] [handle_packet] mangled_packet is not a bytes" << endl;
#endif
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response(PyFilterResponse::INVALID);
}
string* pkt_str = new string(PyBytes_AsString(mangled_packet), PyBytes_Size(mangled_packet));
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response(PyFilterResponse::MANGLE, func_name, pkt_str);
}
//Should never reach this point, but just in case of new action not managed...
del_item_from_glob("__firegex_pyfilter_result");
return py_filter_response{PyFilterResponse::INVALID, nullptr, nullptr};
return py_filter_response(PyFilterResponse::INVALID);
}
};
typedef map<stream_id, pyfilter_ctx*> matching_map;
typedef map<stream_id, tcp_ack_seq_ctx*> tcp_ack_map;
struct stream_ctx {
matching_map streams_ctx;
tcp_ack_map tcp_ack_ctx;
void clean_stream_by_id(stream_id sid){
auto stream_search = streams_ctx.find(sid);
if (stream_search != streams_ctx.end()){
auto stream_match = stream_search->second;
delete stream_match;
streams_ctx.erase(stream_search->first);
}
}
void clean_tcp_ack_by_id(stream_id sid){
auto tcp_ack_search = tcp_ack_ctx.find(sid);
if (tcp_ack_search != tcp_ack_ctx.end()){
auto tcp_ack = tcp_ack_search->second;
delete tcp_ack;
tcp_ack_ctx.erase(tcp_ack_search->first);
}
}
void clean(){
for (auto ele: streams_ctx){
delete ele.second;
}
for (auto ele: tcp_ack_ctx){
delete ele.second;
}
tcp_ack_ctx.clear();
streams_ctx.clear();
}
};

View File

@@ -37,13 +37,7 @@ public:
stream_ctx sctx;
u_int16_t latest_config_ver = 0;
StreamFollower follower;
struct {
bool matching_has_been_called = false;
bool already_closed = false;
bool result;
NfQueue::PktRequest<RegexNfQueue>* pkt;
} match_ctx;
NfQueue::PktRequest<RegexNfQueue>* pkt;
bool filter_action(NfQueue::PktRequest<RegexNfQueue>* pkt){
shared_ptr<RegexRules> conf = regex_config;
@@ -119,49 +113,23 @@ public:
return true;
}
void handle_next_packet(NfQueue::PktRequest<RegexNfQueue>* pkt) override{
bool empty_payload = pkt->data_size == 0;
void handle_next_packet(NfQueue::PktRequest<RegexNfQueue>* _pkt) override{
pkt = _pkt; // Setting packet context
if (pkt->tcp){
match_ctx.matching_has_been_called = false;
match_ctx.pkt = pkt;
if (pkt->ipv4){
follower.process_packet(*pkt->ipv4);
}else{
follower.process_packet(*pkt->ipv6);
}
// Do an action only is an ordered packet has been received
if (match_ctx.matching_has_been_called){
//In this 2 cases we have to remove all data about the stream
if (!match_ctx.result || match_ctx.already_closed){
sctx.clean_stream_by_id(pkt->sid);
//If the packet has data, we have to remove it
if (!empty_payload){
Tins::PDU* data_layer = pkt->tcp->release_inner_pdu();
if (data_layer != nullptr){
delete data_layer;
}
}
//For the first matched data or only for data packets, we set FIN bit
//This only for client packets, because this will trigger server to close the connection
//Packets will be filtered anyway also if client don't send packets
if ((!match_ctx.result || !empty_payload) && pkt->is_input){
pkt->tcp->set_flag(Tins::TCP::FIN,1);
pkt->tcp->set_flag(Tins::TCP::ACK,1);
pkt->tcp->set_flag(Tins::TCP::SYN,0);
}
//Send the edited packet to the kernel
return pkt->mangle();
}
//Fallback to the default action
if (pkt->get_action() == NfQueue::FilterAction::NOACTION){
return pkt->accept();
}
return pkt->accept();
}else{
if (!pkt->udp){
throw invalid_argument("Only TCP and UDP are supported");
}
if(empty_payload){
if(pkt->data_size == 0){
return pkt->accept();
}else if (filter_action(pkt)){
return pkt->accept();
@@ -170,22 +138,21 @@ public:
}
}
}
//If the stream has already been matched, drop all data, and try to close the connection
static void keep_fin_packet(RegexNfQueue* nfq){
nfq->match_ctx.matching_has_been_called = true;
nfq->match_ctx.already_closed = true;
nfq->pkt->reject();// This is needed because the callback has to take the updated pkt pointer!
}
static void on_data_recv(Stream& stream, RegexNfQueue* nfq, string data) {
nfq->match_ctx.matching_has_been_called = true;
nfq->match_ctx.already_closed = false;
bool result = nfq->filter_action(nfq->match_ctx.pkt);
if (!result){
nfq->sctx.clean_stream_by_id(nfq->match_ctx.pkt->sid);
nfq->pkt->data = data.data();
nfq->pkt->data_size = data.size();
if (!nfq->filter_action(nfq->pkt)){
nfq->sctx.clean_stream_by_id(nfq->pkt->sid);
stream.client_data_callback(bind(keep_fin_packet, nfq));
stream.server_data_callback(bind(keep_fin_packet, nfq));
nfq->pkt->reject();
}
nfq->match_ctx.result = result;
}
//Input data filtering

View File

@@ -17,7 +17,6 @@ namespace Regex {
typedef Tins::TCPIP::StreamIdentifier stream_id;
typedef map<stream_id, hs_stream_t*> matching_map;
#ifdef DEBUG
ostream& operator<<(ostream& os, const Tins::TCPIP::StreamIdentifier::address_type &sid){
bool first_print = false;
for (auto ele: sid){
@@ -33,7 +32,6 @@ ostream& operator<<(ostream& os, const stream_id &sid){
os << sid.max_address << ":" << sid.max_address_port << " -> " << sid.min_address << ":" << sid.min_address_port;
return os;
}
#endif
struct stream_ctx {
matching_map in_hs_streams;

View File

@@ -1,10 +1,17 @@
#ifndef UTILS_CPP
#define UTILS_CPP
#include <string>
#include <unistd.h>
#include <queue>
#include <condition_variable>
#ifndef UTILS_CPP
#define UTILS_CPP
#include <sys/socket.h>
#include <sys/un.h>
#include <stdexcept>
#include <cstring>
#include <iostream>
#include <cerrno>
#include <sstream>
bool unhexlify(std::string const &hex, std::string &newString) {
try{
@@ -22,6 +29,113 @@ bool unhexlify(std::string const &hex, std::string &newString) {
}
}
class UnixClientConnection {
public:
int sockfd = -1;
struct sockaddr_un addr;
private:
// Internal buffer to accumulate the output until flush
std::ostringstream streamBuffer;
public:
UnixClientConnection(){};
UnixClientConnection(const char* path) {
sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sockfd == -1) {
throw std::runtime_error(std::string("socket error: ") + std::strerror(errno));
}
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
if (connect(sockfd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != 0) {
throw std::runtime_error(std::string("connect error: ") + std::strerror(errno));
}
}
// Delete copy constructor and assignment operator to avoid resource duplication
UnixClientConnection(const UnixClientConnection&) = delete;
UnixClientConnection& operator=(const UnixClientConnection&) = delete;
// Move constructor
UnixClientConnection(UnixClientConnection&& other) noexcept
: sockfd(other.sockfd), addr(other.addr) {
other.sockfd = -1;
}
// Move assignment operator
UnixClientConnection& operator=(UnixClientConnection&& other) noexcept {
if (this != &other) {
if (sockfd != -1) {
close(sockfd);
}
sockfd = other.sockfd;
addr = other.addr;
other.sockfd = -1;
}
return *this;
}
void send(const std::string& data) {
if (::write(sockfd, data.c_str(), data.size()) == -1) {
throw std::runtime_error(std::string("write error: ") + std::strerror(errno));
}
}
std::string recv(size_t size) {
std::string buffer(size, '\0');
ssize_t bytesRead = ::read(sockfd, &buffer[0], size);
if (bytesRead <= 0) {
throw std::runtime_error(std::string("read error: ") + std::strerror(errno));
}
buffer.resize(bytesRead); // resize to actual bytes read
return buffer;
}
// Template overload for generic types
template<typename T>
UnixClientConnection& operator<<(const T& data) {
streamBuffer << data;
return *this;
}
// Overload for manipulators (e.g., std::endl)
UnixClientConnection& operator<<(std::ostream& (*manip)(std::ostream&)) {
// Check if the manipulator is std::endl (or equivalent flush)
if (manip == static_cast<std::ostream& (*)(std::ostream&)>(std::endl)){
streamBuffer << '\n'; // Add a newline
std::string packet = streamBuffer.str();
streamBuffer.str(""); // Clear the buffer
// Send the accumulated data as one packet
send(packet);
}
if (static_cast<std::ostream& (*)(std::ostream&)>(std::flush)) {
std::string packet = streamBuffer.str();
streamBuffer.str(""); // Clear the buffer
// Send the accumulated data as one packet
send(packet);
} else {
// For other manipulators, simply pass them to the buffer
streamBuffer << manip;
}
return *this;
}
// Overload operator<< to allow printing connection info
friend std::ostream& operator<<(std::ostream& os, const UnixClientConnection& conn) {
os << "UnixClientConnection(sockfd=" << conn.sockfd
<< ", path=" << conn.addr.sun_path << ")";
return os;
}
~UnixClientConnection() {
if (sockfd != -1) {
close(sockfd);
}
}
};
#ifdef USE_PIPES_FOR_BLOKING_QUEUE
template<typename T>

View File

@@ -1,4 +1,4 @@
from modules.firewall.models import *
from modules.firewall.models import FirewallSettings, Action, Rule, Protocol, Mode, Table
from utils import nftables_int_to_json, ip_family, NFTableManager, is_ip_parse
import copy
@@ -9,7 +9,8 @@ class FiregexTables(NFTableManager):
filter_table = "filter"
mangle_table = "mangle"
def init_comands(self, policy:str=Action.ACCEPT, opt: FirewallSettings|None = None):
def init_comands(self, policy:str=Action.ACCEPT, opt:
FirewallSettings|None = None):
rules = [
{"add":{"table":{"name":self.filter_table,"family":"ip"}}},
{"add":{"table":{"name":self.filter_table,"family":"ip6"}}},
@@ -41,7 +42,8 @@ class FiregexTables(NFTableManager):
{"add":{"chain":{"family":"ip","table":self.mangle_table,"name":self.rules_chain_out}}},
{"add":{"chain":{"family":"ip6","table":self.mangle_table,"name":self.rules_chain_out}}},
]
if opt is None: return rules
if opt is None:
return rules
if opt.allow_loopback:
rules.extend([
@@ -194,13 +196,18 @@ class FiregexTables(NFTableManager):
def chain_to_firegex(self, chain:str, table:str):
if table == self.filter_table:
match chain:
case "INPUT": return self.rules_chain_in
case "OUTPUT": return self.rules_chain_out
case "FORWARD": return self.rules_chain_fwd
case "INPUT":
return self.rules_chain_in
case "OUTPUT":
return self.rules_chain_out
case "FORWARD":
return self.rules_chain_fwd
elif table == self.mangle_table:
match chain:
case "PREROUTING": return self.rules_chain_in
case "POSTROUTING": return self.rules_chain_out
case "PREROUTING":
return self.rules_chain_in
case "POSTROUTING":
return self.rules_chain_out
return None
def insert_firegex_chains(self):
@@ -214,7 +221,8 @@ class FiregexTables(NFTableManager):
if r.get("family") == family and r.get("table") == table and r.get("chain") == chain and r.get("expr") == rule_to_add:
found = True
break
if found: continue
if found:
continue
yield { "add":{ "rule": {
"family": family,
"table": table,
@@ -274,7 +282,7 @@ class FiregexTables(NFTableManager):
ip_filters.append({"match": { "op": "==", "left": { "meta": { "key": "oifname" } }, "right": srv.dst} })
port_filters = []
if not srv.proto in [Protocol.ANY, Protocol.BOTH]:
if srv.proto not in [Protocol.ANY, Protocol.BOTH]:
if srv.port_src_from != 1 or srv.port_src_to != 65535: #Any Port
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '>=', 'right': int(srv.port_src_from)}})
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '<=', 'right': int(srv.port_src_to)}})

View File

@@ -1,11 +1,10 @@
from modules.nfproxy.nftables import FiregexTables
from utils import run_func
from modules.nfproxy.models import Service, PyFilter
import os
import asyncio
from utils import DEBUG
import traceback
from fastapi import HTTPException
import time
nft = FiregexTables()
@@ -13,29 +12,37 @@ class FiregexInterceptor:
def __init__(self):
self.srv:Service
self._stats_updater_cb:callable
self.filter_map_lock:asyncio.Lock
self.filter_map: dict[str, PyFilter]
self.pyfilters: set[PyFilter]
self.update_config_lock:asyncio.Lock
self.process:asyncio.subprocess.Process
self.update_task: asyncio.Task
self.server_task: asyncio.Task
self.sock_path: str
self.unix_sock: asyncio.Server
self.ack_arrived = False
self.ack_status = None
self.ack_fail_what = "Unknown"
self.ack_fail_what = "Queue response timed-out"
self.ack_lock = asyncio.Lock()
async def _call_stats_updater_callback(self, filter: PyFilter):
if self._stats_updater_cb:
await run_func(self._stats_updater_cb(filter))
self.sock_reader:asyncio.StreamReader = None
self.sock_writer:asyncio.StreamWriter = None
self.sock_conn_lock:asyncio.Lock
self.last_time_exception = 0
@classmethod
async def start(cls, srv: Service, stats_updater_cb:callable):
async def start(cls, srv: Service):
self = cls()
self._stats_updater_cb = stats_updater_cb
self.srv = srv
self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock()
self.sock_conn_lock = asyncio.Lock()
if not self.sock_conn_lock.locked():
await self.sock_conn_lock.acquire()
self.sock_path = f"/tmp/firegex_nfproxy_{srv.id}.sock"
if os.path.exists(self.sock_path):
os.remove(self.sock_path)
self.unix_sock = await asyncio.start_unix_server(self._server_listener,path=self.sock_path)
self.server_task = asyncio.create_task(self.unix_sock.serve_forever())
queue_range = await self._start_binary()
self.update_task = asyncio.create_task(self.update_stats())
nft.add(self.srv, queue_range)
@@ -46,19 +53,20 @@ class FiregexInterceptor:
async def _start_binary(self):
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cpproxy")
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path,
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
proxy_binary_path, stdin=asyncio.subprocess.DEVNULL,
env={
"NTHREADS": os.getenv("NTHREADS","1"),
"FIREGEX_NFQUEUE_FAIL_OPEN": "1" if self.srv.fail_open else "0",
"FIREGEX_NFPROXY_SOCK": self.sock_path
},
)
line_fut = self.process.stdout.readuntil()
try:
line_fut = await asyncio.wait_for(line_fut, timeout=3)
async with asyncio.timeout(3):
await self.sock_conn_lock.acquire()
line_fut = await self.sock_reader.readuntil()
except asyncio.TimeoutError:
self.process.kill()
raise Exception("Invalid binary output")
raise Exception("Binary don't returned queue number until timeout")
line = line_fut.decode()
if line.startswith("QUEUE "):
params = line.split()
@@ -67,25 +75,45 @@ class FiregexInterceptor:
self.process.kill()
raise Exception("Invalid binary output")
async def _server_listener(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
if self.sock_reader or self.sock_writer:
writer.write_eof() # Technically never reached
writer.close()
reader.feed_eof()
return
self.sock_reader = reader
self.sock_writer = writer
self.sock_conn_lock.release()
async def update_stats(self):
try:
while True:
line = (await self.process.stdout.readuntil()).decode()
if DEBUG:
print(line)
try:
line = (await self.sock_reader.readuntil()).decode()
except Exception as e:
self.ack_arrived = False
self.ack_status = False
self.ack_fail_what = "Can't read from nfq client"
self.ack_lock.release()
await self.stop()
raise HTTPException(status_code=500, detail="Can't read from nfq client") from e
if line.startswith("BLOCKED "):
filter_id = line.split()[1]
filter_name = line.split()[1]
print("BLOCKED", filter_name)
async with self.filter_map_lock:
if filter_id in self.filter_map:
self.filter_map[filter_id].blocked_packets+=1
await self.filter_map[filter_id].update()
print("LOCKED MAP LOCK")
if filter_name in self.filter_map:
print("ADDING BLOCKED PACKET")
self.filter_map[filter_name].blocked_packets+=1
await self.filter_map[filter_name].update()
if line.startswith("MANGLED "):
filter_id = line.split()[1]
filter_name = line.split()[1]
async with self.filter_map_lock:
if filter_id in self.filter_map:
self.filter_map[filter_id].edited_packets+=1
await self.filter_map[filter_id].update()
if filter_name in self.filter_map:
self.filter_map[filter_name].edited_packets+=1
await self.filter_map[filter_name].update()
if line.startswith("EXCEPTION"):
self.last_time_exception = time.time()
print("TODO EXCEPTION HANDLING") # TODO
if line.startswith("ACK "):
self.ack_arrived = True
@@ -101,22 +129,29 @@ class FiregexInterceptor:
traceback.print_exc()
async def stop(self):
self.server_task.cancel()
self.update_task.cancel()
self.unix_sock.close()
if os.path.exists(self.sock_path):
os.remove(self.sock_path)
if self.process and self.process.returncode is None:
self.process.kill()
async def _update_config(self, code):
async with self.update_config_lock:
self.process.stdin.write(len(code).to_bytes(4, byteorder='big')+code.encode())
await self.process.stdin.drain()
try:
async with asyncio.timeout(3):
await self.ack_lock.acquire()
except TimeoutError:
pass
if not self.ack_arrived or not self.ack_status:
await self.stop()
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")
if self.sock_writer:
self.sock_writer.write(len(code).to_bytes(4, byteorder='big')+code.encode())
await self.sock_writer.drain()
try:
async with asyncio.timeout(3):
await self.ack_lock.acquire()
except TimeoutError:
self.ack_fail_what = "Queue response timed-out"
if not self.ack_arrived or not self.ack_status:
await self.stop()
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")
else:
raise HTTPException(status_code=400, detail="Socket not ready")
async def reload(self, filters:list[PyFilter]):
async with self.filter_map_lock:
@@ -125,12 +160,13 @@ class FiregexInterceptor:
filter_file = f.read()
else:
filter_file = ""
self.filter_map = {ele.name: ele for ele in filters}
await self._update_config(
"global __firegex_pyfilter_enabled\n" +
filter_file + "\n\n" +
"__firegex_pyfilter_enabled = [" + ", ".join([repr(f.name) for f in filters]) + "]\n" +
"__firegex_proto = " + repr(self.srv.proto) + "\n" +
"import firegex.nfproxy.internals\n\n" +
filter_file + "\n\n" +
"firegex.nfproxy.internals.compile()"
"import firegex.nfproxy.internals\n" +
"firegex.nfproxy.internals.compile(globals())\n"
)

View File

@@ -15,18 +15,18 @@ class ServiceManager:
self.srv = srv
self.db = db
self.status = STATUS.STOP
self.filters: dict[int, FiregexFilter] = {}
self.filters: dict[str, FiregexFilter] = {}
self.lock = asyncio.Lock()
self.interceptor = None
async def _update_filters_from_db(self):
pyfilters = [
PyFilter.from_dict(ele) for ele in
PyFilter.from_dict(ele, self.db) for ele in
self.db.query("SELECT * FROM pyfilter WHERE service_id = ? AND active=1;", self.srv.id)
]
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f.id for f in pyfilters])
new_filters = set([f.name for f in pyfilters])
#remove old filters
for f in old_filters:
if f not in new_filters:
@@ -34,7 +34,7 @@ class ServiceManager:
#add new filters
for f in new_filters:
if f not in old_filters:
self.filters[f] = [ele for ele in pyfilters if ele.id == f][0]
self.filters[f] = [ele for ele in pyfilters if ele.name == f][0]
if self.interceptor:
await self.interceptor.reload(self.filters.values())
@@ -43,16 +43,11 @@ class ServiceManager:
async def next(self,to):
async with self.lock:
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP):
if to == STATUS.STOP:
await self.stop()
self._set_status(to)
# PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
if to == STATUS.ACTIVE:
await self.restart()
def _stats_updater(self,filter:PyFilter):
self.db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE filter_id = ?;", filter.blocked_packets, filter.edited_packets, filter.id)
def _set_status(self,status):
self.status = status
self.__update_status_db(status)
@@ -60,7 +55,7 @@ class ServiceManager:
async def start(self):
if not self.interceptor:
nft.delete(self.srv)
self.interceptor = await FiregexInterceptor.start(self.srv, self._stats_updater)
self.interceptor = await FiregexInterceptor.start(self.srv)
await self._update_filters_from_db()
self._set_status(STATUS.ACTIVE)
@@ -69,6 +64,7 @@ class ServiceManager:
if self.interceptor:
await self.interceptor.stop()
self.interceptor = None
self._set_status(STATUS.STOP)
async def restart(self):
await self.stop()

View File

@@ -15,13 +15,19 @@ class Service:
class PyFilter:
def __init__(self, filter_id:int, name: str, blocked_packets: int, edited_packets: int, active: bool, **other):
self.id = filter_id
def __init__(self, name: str, blocked_packets: int, edited_packets: int, active: bool, db, **other):
self.name = name
self.blocked_packets = blocked_packets
self.edited_packets = edited_packets
self.active = active
self.__db = db
async def update(self):
self.__db.query("UPDATE pyfilter SET blocked_packets = ?, edited_packets = ? WHERE name = ?;", self.blocked_packets, self.edited_packets, self.name)
def __repr__(self):
return f"<PyFilter {self.name}>"
@classmethod
def from_dict(cls, var: dict):
return cls(**var)
def from_dict(cls, var: dict, db):
return cls(**var, db=db)

View File

@@ -1,6 +1,14 @@
from modules.nfproxy.models import Service
from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json
def convert_protocol_to_l4(proto:str):
if proto == "tcp":
return "tcp"
elif proto == "http":
return "tcp"
else:
raise Exception("Invalid protocol")
class FiregexFilter:
def __init__(self, proto:str, port:int, ip_int:str, target:str, id:int):
self.id = id
@@ -11,7 +19,7 @@ class FiregexFilter:
def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter) or isinstance(o, Service):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return self.port == o.port and self.proto == convert_protocol_to_l4(o.proto) and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return False
class FiregexTables(NFTableManager):
@@ -61,7 +69,7 @@ class FiregexTables(NFTableManager):
"chain": self.output_chain,
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}},
{'match': {"left": { "payload": {"protocol": convert_protocol_to_l4(str(srv.proto)), "field": "sport"}}, "op": "==", "right": int(srv.port)}},
{"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1338}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
]
@@ -72,7 +80,7 @@ class FiregexTables(NFTableManager):
"chain": self.input_chain,
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}},
{'match': {"left": { "payload": {"protocol": convert_protocol_to_l4(str(srv.proto)), "field": "dport"}}, "op": "==", "right": int(srv.port)}},
{"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1337}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
]

View File

@@ -79,7 +79,7 @@ class FiregexInterceptor:
self.update_task: asyncio.Task
self.ack_arrived = False
self.ack_status = None
self.ack_fail_what = "Unknown"
self.ack_fail_what = "Queue response timed-out"
self.ack_lock = asyncio.Lock()
@classmethod
@@ -158,7 +158,7 @@ class FiregexInterceptor:
async with asyncio.timeout(3):
await self.ack_lock.acquire()
except TimeoutError:
pass
self.ack_fail_what = "Queue response timed-out"
if not self.ack_arrived or not self.ack_status:
await self.stop()
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")

View File

@@ -45,11 +45,9 @@ class ServiceManager:
async def next(self,to):
async with self.lock:
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP):
if to == STATUS.STOP:
await self.stop()
self._set_status(to)
# PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
if to == STATUS.ACTIVE:
await self.restart()
def _stats_updater(self,filter:RegexFilter):
@@ -71,6 +69,7 @@ class ServiceManager:
if self.interceptor:
await self.interceptor.stop()
self.interceptor = None
self._set_status(STATUS.STOP)
async def restart(self):
await self.stop()

View File

@@ -10,6 +10,10 @@ from utils.models import ResetRequest, StatusMessageModel
import os
from firegex.nfproxy.internals import get_filter_names
from fastapi.responses import PlainTextResponse
from modules.nfproxy.nftables import convert_protocol_to_l4
import asyncio
import traceback
from utils import DEBUG
class ServiceModel(BaseModel):
service_id: str
@@ -28,12 +32,10 @@ class RenameForm(BaseModel):
class SettingsForm(BaseModel):
port: PortType|None = None
proto: str|None = None
ip_int: str|None = None
fail_open: bool|None = None
class PyFilterModel(BaseModel):
filter_id: int
name: str
blocked_packets: int
edited_packets: int
@@ -52,6 +54,7 @@ class ServiceAddResponse(BaseModel):
class SetPyFilterForm(BaseModel):
code: str
sid: str|None = None
app = APIRouter()
@@ -62,12 +65,12 @@ db = SQLite('db/nft-pyfilters.db', {
'port': 'INT NOT NULL CHECK(port > 0 and port < 65536)',
'name': 'VARCHAR(100) NOT NULL UNIQUE',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "http"))',
'l4_proto': 'VARCHAR(3) NOT NULL CHECK (l4_proto IN ("tcp", "udp"))',
'ip_int': 'VARCHAR(100) NOT NULL',
'fail_open': 'BOOLEAN NOT NULL CHECK (fail_open IN (0, 1)) DEFAULT 1',
},
'pyfilter': {
'filter_id': 'INTEGER PRIMARY KEY',
'name': 'VARCHAR(100) NOT NULL',
'name': 'VARCHAR(100) PRIMARY KEY',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'edited_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'service_id': 'VARCHAR(100) NOT NULL',
@@ -75,7 +78,7 @@ db = SQLite('db/nft-pyfilters.db', {
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (port, ip_int, proto);",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (port, ip_int, l4_proto);",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_pyfilter_service ON pyfilter (name, service_id);"
]
})
@@ -132,7 +135,7 @@ async def get_service_list():
s.proto proto,
s.ip_int ip_int,
s.fail_open fail_open,
COUNT(f.filter_id) n_filters,
COUNT(f.name) n_filters,
COALESCE(SUM(f.blocked_packets),0) blocked_packets,
COALESCE(SUM(f.edited_packets),0) edited_packets
FROM services s LEFT JOIN pyfilter f ON s.service_id = f.service_id
@@ -151,7 +154,7 @@ async def get_service_by_id(service_id: str):
s.proto proto,
s.ip_int ip_int,
s.fail_open fail_open,
COUNT(f.filter_id) n_filters,
COUNT(f.name) n_filters,
COALESCE(SUM(f.blocked_packets),0) blocked_packets,
COALESCE(SUM(f.edited_packets),0) edited_packets
FROM services s LEFT JOIN pyfilter f ON s.service_id = f.service_id
@@ -202,9 +205,6 @@ async def service_rename(service_id: str, form: RenameForm):
@app.put('/services/{service_id}/settings', response_model=StatusMessageModel)
async def service_settings(service_id: str, form: SettingsForm):
"""Request to change the settings of a specific service (will cause a restart)"""
if form.proto is not None and form.proto not in ["tcp", "udp"]:
raise HTTPException(status_code=400, detail="Invalid protocol")
if form.port is not None and (form.port < 1 or form.port > 65535):
raise HTTPException(status_code=400, detail="Invalid port")
@@ -245,38 +245,38 @@ async def get_service_pyfilter_list(service_id: str):
raise HTTPException(status_code=400, detail="This service does not exists!")
return db.query("""
SELECT
filter_id, name, blocked_packets, edited_packets, active
name, blocked_packets, edited_packets, active
FROM pyfilter WHERE service_id = ?;
""", service_id)
@app.get('/pyfilters/{filter_id}', response_model=PyFilterModel)
async def get_pyfilter_by_id(filter_id: int):
@app.get('/pyfilters/{filter_name}', response_model=PyFilterModel)
async def get_pyfilter_by_id(filter_name: str):
"""Get pyfilter info using his id"""
res = db.query("""
SELECT
filter_id, name, blocked_packets, edited_packets, active
FROM pyfilter WHERE filter_id = ?;
""", filter_id)
name, blocked_packets, edited_packets, active
FROM pyfilter WHERE name = ?;
""", filter_name)
if len(res) == 0:
raise HTTPException(status_code=400, detail="This filter does not exists!")
return res[0]
@app.post('/pyfilters/{filter_id}/enable', response_model=StatusMessageModel)
async def pyfilter_enable(filter_id: int):
@app.post('/pyfilters/{filter_name}/enable', response_model=StatusMessageModel)
async def pyfilter_enable(filter_name: str):
"""Request the enabling of a pyfilter"""
res = db.query('SELECT * FROM pyfilter WHERE filter_id = ?;', filter_id)
res = db.query('SELECT * FROM pyfilter WHERE name = ?;', filter_name)
if len(res) != 0:
db.query('UPDATE pyfilter SET active=1 WHERE filter_id = ?;', filter_id)
db.query('UPDATE pyfilter SET active=1 WHERE name = ?;', filter_name)
await firewall.get(res[0]["service_id"]).update_filters()
await refresh_frontend()
return {'status': 'ok'}
@app.post('/pyfilters/{filter_id}/disable', response_model=StatusMessageModel)
async def pyfilter_disable(filter_id: int):
@app.post('/pyfilters/{filter_name}/disable', response_model=StatusMessageModel)
async def pyfilter_disable(filter_name: str):
"""Request the deactivation of a pyfilter"""
res = db.query('SELECT * FROM pyfilter WHERE filter_id = ?;', filter_id)
res = db.query('SELECT * FROM pyfilter WHERE name = ?;', filter_name)
if len(res) != 0:
db.query('UPDATE pyfilter SET active=0 WHERE filter_id = ?;', filter_id)
db.query('UPDATE pyfilter SET active=0 WHERE name = ?;', filter_name)
await firewall.get(res[0]["service_id"]).update_filters()
await refresh_frontend()
return {'status': 'ok'}
@@ -293,8 +293,8 @@ async def add_new_service(form: ServiceAddForm):
srv_id = None
try:
srv_id = gen_service_id()
db.query("INSERT INTO services (service_id ,name, port, status, proto, ip_int, fail_open) VALUES (?, ?, ?, ?, ?, ?, ?)",
srv_id, refactor_name(form.name), form.port, STATUS.STOP, form.proto, form.ip_int, form.fail_open)
db.query("INSERT INTO services (service_id ,name, port, status, proto, ip_int, fail_open, l4_proto) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
srv_id, refactor_name(form.name), form.port, STATUS.STOP, form.proto, form.ip_int, form.fail_open, convert_protocol_to_l4(form.proto))
except sqlite3.IntegrityError:
raise HTTPException(status_code=400, detail="This type of service already exists")
await firewall.reload()
@@ -308,29 +308,41 @@ async def set_pyfilters(service_id: str, form: SetPyFilterForm):
if len(service) == 0:
raise HTTPException(status_code=400, detail="This service does not exists!")
service = service[0]
service_id = service["service_id"]
srv_proto = service["proto"]
try:
found_filters = get_filter_names(form.code, srv_proto)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
async with asyncio.timeout(8):
try:
found_filters = get_filter_names(form.code, srv_proto)
except Exception as e:
if DEBUG:
traceback.print_exc()
raise HTTPException(status_code=400, detail="Compile error: "+str(e))
# Remove filters that are not in the new code
existing_filters = db.query("SELECT name FROM pyfilter WHERE service_id = ?;", service_id)
existing_filters = [ele["name"] for ele in existing_filters]
for filter in existing_filters:
if filter not in found_filters:
db.query("DELETE FROM pyfilter WHERE name = ?;", filter)
# Add filters that are in the new code but not in the database
for filter in found_filters:
if not db.query("SELECT 1 FROM pyfilter WHERE service_id = ? AND name = ?;", service_id, filter):
db.query("INSERT INTO pyfilter (name, service_id) VALUES (?, ?);", filter, service["service_id"])
# Eventually edited filters will be reloaded
os.makedirs("db/nfproxy_filters", exist_ok=True)
with open(f"db/nfproxy_filters/{service_id}.py", "w") as f:
f.write(form.code)
await firewall.get(service_id).update_filters()
await refresh_frontend()
except asyncio.TimeoutError:
if DEBUG:
traceback.print_exc()
raise HTTPException(status_code=400, detail="The operation took too long")
# Remove filters that are not in the new code
existing_filters = db.query("SELECT filter_id FROM pyfilter WHERE service_id = ?;", service_id)
for filter in existing_filters:
if filter["name"] not in found_filters:
db.query("DELETE FROM pyfilter WHERE filter_id = ?;", filter["filter_id"])
# Add filters that are in the new code but not in the database
for filter in found_filters:
if not db.query("SELECT 1 FROM pyfilter WHERE service_id = ? AND name = ?;", service_id, filter):
db.query("INSERT INTO pyfilter (name, service_id) VALUES (?, ?);", filter, service["service_id"])
# Eventually edited filters will be reloaded
os.makedirs("db/nfproxy_filters", exist_ok=True)
with open(f"db/nfproxy_filters/{service_id}.py", "w") as f:
f.write(form.code)
await firewall.get(service_id).update_filters()
await refresh_frontend()
return {'status': 'ok'}
@app.get('/services/{service_id}/pyfilters/code', response_class=PlainTextResponse)
@@ -343,7 +355,3 @@ async def get_pyfilters(service_id: str):
return f.read()
except FileNotFoundError:
return ""
#TODO check all the APIs and add
# 1. API to change the python filter file (DONE)
# 2. a socketio mechanism to lock the previous feature

View File

@@ -8,15 +8,22 @@ import nftables
from socketio import AsyncServer
from fastapi import Path
from typing import Annotated
from functools import wraps
from pydantic import BaseModel, ValidationError
import traceback
from utils.models import StatusMessageModel
from typing import List
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
socketio:AsyncServer = None
sid_list:set = set()
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
ROUTERS_DIR = os.path.join(ROOT_DIR,"routers")
ON_DOCKER = "DOCKER" in sys.argv
DEBUG = "DEBUG" in sys.argv
NORELOAD = "NORELOAD" in sys.argv
FIREGEX_PORT = int(os.getenv("PORT","4444"))
JWT_ALGORITHM: str = "HS256"
API_VERSION = "{{VERSION_PLACEHOLDER}}" if "{" not in "{{VERSION_PLACEHOLDER}}" else "0.0.0"
@@ -153,4 +160,50 @@ class NFTableManager(Singleton):
def raw_list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"]
def _json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json"):
res = obj.model_dump(mode=mode, exclude_unset=not unset)
if convert_keys:
for from_k, to_k in convert_keys.items():
if from_k in res:
res[to_k] = res.pop(from_k)
if exclude:
for ele in exclude:
if ele in res:
del res[ele]
return res
def json_like(obj: BaseModel|List[BaseModel], unset=False, convert_keys:dict[str, str]=None, exclude:list[str]=None, mode:str="json") -> dict:
if isinstance(obj, list):
return [_json_like(ele, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode) for ele in obj]
return _json_like(obj, unset=unset, convert_keys=convert_keys, exclude=exclude, mode=mode)
def register_event(sio_server: AsyncServer, event_name: str, model: BaseModel, response_model: BaseModel|None = None):
def decorator(func):
@sio_server.on(event_name) # Automatically registers the event
@wraps(func)
async def wrapper(sid, data):
try:
# Parse and validate incoming data
parsed_data = model.model_validate(data)
except ValidationError:
return json_like(StatusMessageModel(status=f"Invalid {event_name} request"))
# Call the original function with the parsed data
result = await func(sid, parsed_data)
# If a response model is provided, validate the output
if response_model:
try:
parsed_result = response_model.model_validate(result)
except ValidationError:
traceback.print_exc()
return json_like(StatusMessageModel(status=f"SERVER ERROR: Invalid {event_name} response"))
else:
parsed_result = result
# Emit the validated result
if parsed_result:
if isinstance(parsed_result, BaseModel):
return json_like(parsed_result)
return parsed_result
return wrapper
return decorator

View File

@@ -7,6 +7,7 @@ from starlette.responses import StreamingResponse
from fastapi.responses import FileResponse
from utils import DEBUG, ON_DOCKER, ROUTERS_DIR, list_files, run_func
from utils.models import ResetRequest
import asyncio
REACT_BUILD_DIR: str = "../frontend/build/" if not ON_DOCKER else "frontend/"
REACT_HTML_PATH: str = os.path.join(REACT_BUILD_DIR,"index.html")
@@ -87,12 +88,9 @@ def load_routers(app):
if router.shutdown:
shutdowns.append(router.shutdown)
async def reset(reset_option:ResetRequest):
for func in resets:
await run_func(func, reset_option)
await asyncio.gather(*[run_func(func, reset_option) for func in resets])
async def startup():
for func in startups:
await run_func(func)
await asyncio.gather(*[run_func(func) for func in startups])
async def shutdown():
for func in shutdowns:
await run_func(func)
await asyncio.gather(*[run_func(func) for func in shutdowns])
return reset, startup, shutdown