Merge pull request #12 from Pwnzer0tt1/dev-cpp

Implementing new cpp nfqueue with hyperscan an stream regex assembling TCP packets with libtis
This commit is contained in:
Domingo Dirutigliano
2025-02-05 12:29:54 +01:00
committed by GitHub
59 changed files with 1827 additions and 3180 deletions

View File

@@ -1,6 +1,10 @@
import uvicorn, secrets, utils
import os, asyncio, logging
from fastapi import FastAPI, HTTPException, Depends, APIRouter, Request
import uvicorn
import secrets
import utils
import os
import asyncio
import logging
from fastapi import FastAPI, HTTPException, Depends, APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt
from passlib.context import CryptContext
@@ -30,7 +34,14 @@ async def lifespan(app):
yield
await shutdown_main()
app = FastAPI(debug=DEBUG, redoc_url=None, lifespan=lifespan)
app = FastAPI(
debug=DEBUG,
redoc_url=None,
lifespan=lifespan,
docs_url="/api/docs",
title="Firegex API",
version=API_VERSION,
)
utils.socketio = SocketManager(app, "/sock", socketio_path="")
if DEBUG:
@@ -94,7 +105,8 @@ async def get_app_status(auth: bool = Depends(check_login)):
@app.post("/api/login")
async def login_api(form: OAuth2PasswordRequestForm = Depends()):
"""Get a login token to use the firegex api"""
if APP_STATUS() != "run": raise HTTPException(status_code=400)
if APP_STATUS() != "run":
raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
await asyncio.sleep(0.3) # No bruteforce :)
@@ -105,7 +117,8 @@ async def login_api(form: OAuth2PasswordRequestForm = Depends()):
@app.post('/api/set-password', response_model=ChangePasswordModel)
async def set_password(form: PasswordForm):
"""Set the password of firegex"""
if APP_STATUS() != "init": raise HTTPException(status_code=400)
if APP_STATUS() != "init":
raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
set_psw(form.password)
@@ -115,7 +128,8 @@ async def set_password(form: PasswordForm):
@api.post('/change-password', response_model=ChangePasswordModel)
async def change_password(form: PasswordChangeForm):
"""Change the password of firegex"""
if APP_STATUS() != "run": raise HTTPException(status_code=400)
if APP_STATUS() != "run":
raise HTTPException(status_code=400)
if form.password == "":
return {"status":"Cannot insert an empty password!"}
@@ -144,7 +158,8 @@ async def startup_main():
except Exception as e:
logging.error(f"Error setting sysctls: {e}")
await startup()
if not JWT_SECRET(): db.put("secret", secrets.token_hex(32))
if not JWT_SECRET():
db.put("secret", secrets.token_hex(32))
await refresh_frontend()
async def shutdown_main():
@@ -175,9 +190,9 @@ if __name__ == '__main__':
os.chdir(os.path.dirname(os.path.realpath(__file__)))
uvicorn.run(
"app:app",
host="::" if DEBUG else None,
host=None, #"::" if DEBUG else None,
port=FIREGEX_PORT,
reload=DEBUG,
reload=False,#DEBUG,
access_log=True,
workers=1, # Multiple workers will cause a crash due to the creation
# of multiple processes with separated memory

View File

@@ -0,0 +1,530 @@
#include <linux/netfilter/nfnetlink_queue.h>
#include <libnetfilter_queue/libnetfilter_queue.h>
#include <linux/netfilter/nfnetlink_conntrack.h>
#include <tins/tins.h>
#include <tins/tcp_ip/stream_follower.h>
#include <tins/tcp_ip/stream_identifier.h>
#include <libmnl/libmnl.h>
#include <linux/netfilter.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/types.h>
#include <stdexcept>
#include <thread>
#include <hs.h>
#include <iostream>
using Tins::TCPIP::Stream;
using Tins::TCPIP::StreamFollower;
using namespace std;
#ifndef NETFILTER_CLASSES_HPP
#define NETFILTER_CLASSES_HPP
typedef Tins::TCPIP::StreamIdentifier stream_id;
typedef map<stream_id, hs_stream_t*> matching_map;
/* Considering to use unorder_map using this hash of stream_id
namespace std {
template<>
struct hash<stream_id> {
size_t operator()(const stream_id& sid) const
{
return std::hash<std::uint32_t>()(sid.max_address[0] + sid.max_address[1] + sid.max_address[2] + sid.max_address[3] + sid.max_address_port + sid.min_address[0] + sid.min_address[1] + sid.min_address[2] + sid.min_address[3] + sid.min_address_port);
}
};
}
*/
#ifdef DEBUG
ostream& operator<<(ostream& os, const Tins::TCPIP::StreamIdentifier::address_type &sid){
bool first_print = false;
for (auto ele: sid){
if (first_print || ele){
first_print = true;
os << (int)ele << ".";
}
}
return os;
}
ostream& operator<<(ostream& os, const stream_id &sid){
os << sid.max_address << ":" << sid.max_address_port << " -> " << sid.min_address << ":" << sid.min_address_port;
return os;
}
#endif
struct packet_info;
struct tcp_stream_tmp {
bool matching_has_been_called = false;
bool result;
packet_info *pkt_info;
};
struct stream_ctx {
matching_map in_hs_streams;
matching_map out_hs_streams;
hs_scratch_t* in_scratch = nullptr;
hs_scratch_t* out_scratch = nullptr;
u_int16_t latest_config_ver = 0;
StreamFollower follower;
mnl_socket* nl;
tcp_stream_tmp tcp_match_util;
void clean_scratches(){
if (out_scratch != nullptr){
hs_free_scratch(out_scratch);
out_scratch = nullptr;
}
if (in_scratch != nullptr){
hs_free_scratch(in_scratch);
in_scratch = nullptr;
}
}
void clean_stream_by_id(stream_id sid){
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.clean_stream_by_id] Cleaning stream context of " << sid << endl;
#endif
auto stream_search = in_hs_streams.find(sid);
hs_stream_t* stream_match;
if (stream_search != in_hs_streams.end()){
stream_match = stream_search->second;
if (hs_close_stream(stream_match, in_scratch, nullptr, nullptr) != HS_SUCCESS) {
cerr << "[error] [NetfilterQueue.clean_stream_by_id] Error closing the stream matcher (hs)" << endl;
throw invalid_argument("Cannot close stream match on hyperscan");
}
in_hs_streams.erase(stream_search);
}
stream_search = out_hs_streams.find(sid);
if (stream_search != out_hs_streams.end()){
stream_match = stream_search->second;
if (hs_close_stream(stream_match, out_scratch, nullptr, nullptr) != HS_SUCCESS) {
cerr << "[error] [NetfilterQueue.clean_stream_by_id] Error closing the stream matcher (hs)" << endl;
throw invalid_argument("Cannot close stream match on hyperscan");
}
out_hs_streams.erase(stream_search);
}
}
void clean(){
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.clean] Cleaning stream context" << endl;
#endif
if (in_scratch){
for(auto ele: in_hs_streams){
if (hs_close_stream(ele.second, in_scratch, nullptr, nullptr) != HS_SUCCESS) {
cerr << "[error] [NetfilterQueue.clean_stream_by_id] Error closing the stream matcher (hs)" << endl;
throw invalid_argument("Cannot close stream match on hyperscan");
}
}
in_hs_streams.clear();
}
if (out_scratch){
for(auto ele: out_hs_streams){
if (hs_close_stream(ele.second, out_scratch, nullptr, nullptr) != HS_SUCCESS) {
cerr << "[error] [NetfilterQueue.clean_stream_by_id] Error closing the stream matcher (hs)" << endl;
throw invalid_argument("Cannot close stream match on hyperscan");
}
}
out_hs_streams.clear();
}
clean_scratches();
}
};
struct packet_info {
string packet;
string payload;
stream_id sid;
bool is_input;
bool is_tcp;
stream_ctx* sctx;
};
typedef bool NetFilterQueueCallback(packet_info &);
template <NetFilterQueueCallback callback_func>
class NetfilterQueue {
public:
size_t BUF_SIZE = 0xffff + (MNL_SOCKET_BUFFER_SIZE/2);
char *buf = nullptr;
unsigned int portid;
u_int16_t queue_num;
stream_ctx sctx;
NetfilterQueue(u_int16_t queue_num): queue_num(queue_num) {
sctx.nl = mnl_socket_open(NETLINK_NETFILTER);
if (sctx.nl == nullptr) { throw runtime_error( "mnl_socket_open" );}
if (mnl_socket_bind(sctx.nl, 0, MNL_SOCKET_AUTOPID) < 0) {
mnl_socket_close(sctx.nl);
throw runtime_error( "mnl_socket_bind" );
}
portid = mnl_socket_get_portid(sctx.nl);
buf = (char*) malloc(BUF_SIZE);
if (!buf) {
mnl_socket_close(sctx.nl);
throw runtime_error( "allocate receive buffer" );
}
if (send_config_cmd(NFQNL_CFG_CMD_BIND) < 0) {
_clear();
throw runtime_error( "mnl_socket_send" );
}
//TEST if BIND was successful
if (send_config_cmd(NFQNL_CFG_CMD_NONE) < 0) { // SEND A NONE cmmand to generate an error meessage
_clear();
throw runtime_error( "mnl_socket_send" );
}
if (recv_packet() == -1) { //RECV the error message
_clear();
throw runtime_error( "mnl_socket_recvfrom" );
}
struct nlmsghdr *nlh = (struct nlmsghdr *) buf;
if (nlh->nlmsg_type != NLMSG_ERROR) {
_clear();
throw runtime_error( "unexpected packet from kernel (expected NLMSG_ERROR packet)" );
}
//nfqnl_msg_config_cmd
nlmsgerr* error_msg = (nlmsgerr *)mnl_nlmsg_get_payload(nlh);
// error code taken from the linux kernel:
// https://elixir.bootlin.com/linux/v5.18.12/source/include/linux/errno.h#L27
#define ENOTSUPP 524 /* Operation is not supported */
if (error_msg->error != -ENOTSUPP) {
_clear();
throw invalid_argument( "queueid is already busy" );
}
//END TESTING BIND
nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
nfq_nlmsg_cfg_put_params(nlh, NFQNL_COPY_PACKET, 0xffff);
mnl_attr_put_u32(nlh, NFQA_CFG_FLAGS, htonl(NFQA_CFG_F_GSO));
mnl_attr_put_u32(nlh, NFQA_CFG_MASK, htonl(NFQA_CFG_F_GSO));
if (mnl_socket_sendto(sctx.nl, nlh, nlh->nlmsg_len) < 0) {
_clear();
throw runtime_error( "mnl_socket_send" );
}
}
static void on_data_recv(Stream& stream, stream_ctx* sctx, string data) {
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.on_data_recv] data: " << data << endl;
#endif
sctx->tcp_match_util.matching_has_been_called = true;
bool result = callback_func(*sctx->tcp_match_util.pkt_info);
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.on_data_recv] result: " << result << endl;
#endif
if (!result){
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.on_data_recv] Stream matched, removing all data about it" << endl;
#endif
sctx->clean_stream_by_id(sctx->tcp_match_util.pkt_info->sid);
stream.ignore_client_data();
stream.ignore_server_data();
}
sctx->tcp_match_util.result = result;
}
//Input data filtering
static void on_client_data(Stream& stream, stream_ctx* sctx) {
on_data_recv(stream, sctx, string(stream.client_payload().begin(), stream.client_payload().end()));
}
//Server data filtering
static void on_server_data(Stream& stream, stream_ctx* sctx) {
on_data_recv(stream, sctx, string(stream.server_payload().begin(), stream.server_payload().end()));
}
static void on_new_stream(Stream& stream, stream_ctx* sctx) {
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.on_new_stream] New stream detected" << endl;
#endif
if (stream.is_partial_stream()) {
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.on_new_stream] Partial stream detected, skipping" << endl;
#endif
return;
}
stream.auto_cleanup_payloads(true);
stream.client_data_callback(bind(on_client_data, placeholders::_1, sctx));
stream.server_data_callback(bind(on_server_data, placeholders::_1, sctx));
stream.stream_closed_callback(bind(on_stream_close, placeholders::_1, sctx));
}
// A stream was terminated. The second argument is the reason why it was terminated
static void on_stream_close(Stream& stream, stream_ctx* sctx) {
stream_id stream_id = stream_id::make_identifier(stream);
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.on_stream_close] Stream terminated, deleting all data" << endl;
#endif
sctx->clean_stream_by_id(stream_id);
}
void run(){
/*
* ENOBUFS is signalled to userspace when packets were lost
* on kernel side. In most cases, userspace isn't interested
* in this information, so turn it off.
*/
int ret = 1;
mnl_socket_setsockopt(sctx.nl, NETLINK_NO_ENOBUFS, &ret, sizeof(int));
sctx.follower.new_stream_callback(bind(on_new_stream, placeholders::_1, &sctx));
sctx.follower.stream_termination_callback(bind(on_stream_close, placeholders::_1, &sctx));
for (;;) {
ret = recv_packet();
if (ret == -1) {
throw runtime_error( "mnl_socket_recvfrom" );
}
ret = mnl_cb_run(buf, ret, 0, portid, queue_cb, &sctx);
if (ret < 0){
throw runtime_error( "mnl_cb_run" );
}
}
}
~NetfilterQueue() {
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.~NetfilterQueue] Destructor called" << endl;
#endif
send_config_cmd(NFQNL_CFG_CMD_UNBIND);
_clear();
}
private:
ssize_t send_config_cmd(nfqnl_msg_config_cmds cmd){
struct nlmsghdr *nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
nfq_nlmsg_cfg_put_cmd(nlh, AF_INET, cmd);
return mnl_socket_sendto(sctx.nl, nlh, nlh->nlmsg_len);
}
ssize_t recv_packet(){
return mnl_socket_recvfrom(sctx.nl, buf, BUF_SIZE);
}
void _clear(){
if (buf != nullptr) {
free(buf);
buf = nullptr;
}
mnl_socket_close(sctx.nl);
sctx.nl = nullptr;
sctx.clean();
}
template<typename T>
static void build_verdict(T packet, uint8_t *payload, uint16_t plen, nlmsghdr *nlh_verdict, nfqnl_msg_packet_hdr *ph, stream_ctx* sctx, bool is_input){
Tins::TCP* tcp = packet.template find_pdu<Tins::TCP>();
if (tcp){
Tins::PDU* application_layer = tcp->inner_pdu();
u_int16_t payload_size = 0;
if (application_layer != nullptr){
payload_size = application_layer->size();
}
packet_info pktinfo{
packet: string(payload, payload+plen),
payload: string(payload+plen - payload_size, payload+plen),
sid: stream_id::make_identifier(packet),
is_input: is_input,
is_tcp: true,
sctx: sctx,
};
sctx->tcp_match_util.matching_has_been_called = false;
sctx->tcp_match_util.pkt_info = &pktinfo;
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.build_verdict] TCP Packet received " << packet.src_addr() << ":" << tcp->sport() << " -> " << packet.dst_addr() << ":" << tcp->dport() << ", sending to libtins StreamFollower" << endl;
#endif
sctx->follower.process_packet(packet);
#ifdef DEBUG
if (sctx->tcp_match_util.matching_has_been_called){
cerr << "[DEBUG] [NetfilterQueue.build_verdict] StreamFollower has called matching functions" << endl;
}else{
cerr << "[DEBUG] [NetfilterQueue.build_verdict] StreamFollower has NOT called matching functions" << endl;
}
#endif
if (sctx->tcp_match_util.matching_has_been_called && !sctx->tcp_match_util.result){
Tins::PDU* data_layer = tcp->release_inner_pdu();
if (data_layer != nullptr){
delete data_layer;
}
tcp->set_flag(Tins::TCP::FIN,1);
tcp->set_flag(Tins::TCP::ACK,1);
tcp->set_flag(Tins::TCP::SYN,0);
nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size());
}
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT );
}else{
Tins::UDP* udp = packet.template find_pdu<Tins::UDP>();
if (!udp){
throw invalid_argument("Only TCP and UDP are supported");
}
Tins::PDU* application_layer = udp->inner_pdu();
u_int16_t payload_size = 0;
if (application_layer != nullptr){
payload_size = application_layer->size();
}
if((udp->inner_pdu() == nullptr)){
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT );
}
packet_info pktinfo{
packet: string(payload, payload+plen),
payload: string(payload+plen - payload_size, payload+plen),
sid: stream_id::make_identifier(packet),
is_input: is_input,
is_tcp: false,
sctx: sctx,
};
if (callback_func(pktinfo)){
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT );
}else{
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_DROP );
}
}
}
static int queue_cb(const nlmsghdr *nlh, void *data_ptr)
{
stream_ctx* sctx = (stream_ctx*)data_ptr;
//Extract attributes from the nlmsghdr
nlattr *attr[NFQA_MAX+1] = {};
if (nfq_nlmsg_parse(nlh, attr) < 0) {
perror("problems parsing");
return MNL_CB_ERROR;
}
if (attr[NFQA_PACKET_HDR] == nullptr) {
fputs("metaheader not set\n", stderr);
return MNL_CB_ERROR;
}
if (attr[NFQA_MARK] == nullptr) {
fputs("mark not set\n", stderr);
return MNL_CB_ERROR;
}
//Get Payload
uint16_t plen = mnl_attr_get_payload_len(attr[NFQA_PAYLOAD]);
uint8_t *payload = (uint8_t *)mnl_attr_get_payload(attr[NFQA_PAYLOAD]);
//Return result to the kernel
struct nfqnl_msg_packet_hdr *ph = (nfqnl_msg_packet_hdr*) mnl_attr_get_payload(attr[NFQA_PACKET_HDR]);
struct nfgenmsg *nfg = (nfgenmsg *)mnl_nlmsg_get_payload(nlh);
char buf[MNL_SOCKET_BUFFER_SIZE];
struct nlmsghdr *nlh_verdict;
struct nlattr *nest;
nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id));
bool is_input = ntohl(mnl_attr_get_u32(attr[NFQA_MARK])) & 0x1; // == 0x1337 that is odd
#ifdef DEBUG
cerr << "[DEBUG] [NetfilterQueue.queue_cb] Packet received" << endl;
cerr << "[DEBUG] [NetfilterQueue.queue_cb] Packet ID: " << ntohl(ph->packet_id) << endl;
cerr << "[DEBUG] [NetfilterQueue.queue_cb] Payload size: " << plen << endl;
cerr << "[DEBUG] [NetfilterQueue.queue_cb] Is input: " << is_input << endl;
#endif
// Check IP protocol version
if ( (payload[0] & 0xf0) == 0x40 ){
build_verdict(Tins::IP(payload, plen), payload, plen, nlh_verdict, ph, sctx, is_input);
}else{
build_verdict(Tins::IPv6(payload, plen), payload, plen, nlh_verdict, ph, sctx, is_input);
}
nest = mnl_attr_nest_start(nlh_verdict, NFQA_CT);
mnl_attr_put_u32(nlh_verdict, CTA_MARK, htonl(42));
mnl_attr_nest_end(nlh_verdict, nest);
if (mnl_socket_sendto(sctx->nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) {
throw runtime_error( "mnl_socket_send" );
}
return MNL_CB_OK;
}
};
template <NetFilterQueueCallback func>
class NFQueueSequence{
private:
vector<NetfilterQueue<func> *> nfq;
uint16_t _init;
uint16_t _end;
vector<thread> threads;
public:
static const int QUEUE_BASE_NUM = 1000;
NFQueueSequence(uint16_t seq_len){
if (seq_len <= 0) throw invalid_argument("seq_len <= 0");
nfq = vector<NetfilterQueue<func>*>(seq_len);
_init = QUEUE_BASE_NUM;
while(nfq[0] == nullptr){
if (_init+seq_len-1 >= 65536){
throw runtime_error("NFQueueSequence: too many queues!");
}
for (int i=0;i<seq_len;i++){
try{
nfq[i] = new NetfilterQueue<func>(_init+i);
}catch(const invalid_argument e){
for(int j = 0; j < i; j++) {
delete nfq[j];
nfq[j] = nullptr;
}
_init += seq_len - i;
break;
}
}
}
_end = _init + seq_len - 1;
}
void start(){
if (threads.size() != 0) throw runtime_error("NFQueueSequence: already started!");
for (int i=0;i<nfq.size();i++){
threads.push_back(thread(&NetfilterQueue<func>::run, nfq[i]));
}
}
void join(){
for (int i=0;i<nfq.size();i++){
threads[i].join();
}
threads.clear();
}
uint16_t init(){
return _init;
}
uint16_t end(){
return _end;
}
~NFQueueSequence(){
for (int i=0;i<nfq.size();i++){
delete nfq[i];
}
}
};
#endif // NETFILTER_CLASSES_HPP

View File

@@ -1,294 +0,0 @@
#include <linux/netfilter/nfnetlink_queue.h>
#include <libnetfilter_queue/libnetfilter_queue.h>
#include <linux/netfilter/nfnetlink_conntrack.h>
#include <tins/tins.h>
#include <libmnl/libmnl.h>
#include <linux/netfilter.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/types.h>
#include <stdexcept>
#include <thread>
#ifndef NETFILTER_CLASSES_HPP
#define NETFILTER_CLASSES_HPP
typedef bool NetFilterQueueCallback(const uint8_t*,uint32_t);
Tins::PDU * find_transport_layer(Tins::PDU* pkt){
while(pkt != NULL){
if (pkt->pdu_type() == Tins::PDU::TCP || pkt->pdu_type() == Tins::PDU::UDP) {
return pkt;
}
pkt = pkt->inner_pdu();
}
return pkt;
}
template <NetFilterQueueCallback callback_func>
class NetfilterQueue {
public:
size_t BUF_SIZE = 0xffff + (MNL_SOCKET_BUFFER_SIZE/2);
char *buf = NULL;
unsigned int portid;
u_int16_t queue_num;
struct mnl_socket* nl = NULL;
NetfilterQueue(u_int16_t queue_num): queue_num(queue_num) {
nl = mnl_socket_open(NETLINK_NETFILTER);
if (nl == NULL) { throw std::runtime_error( "mnl_socket_open" );}
if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
mnl_socket_close(nl);
throw std::runtime_error( "mnl_socket_bind" );
}
portid = mnl_socket_get_portid(nl);
buf = (char*) malloc(BUF_SIZE);
if (!buf) {
mnl_socket_close(nl);
throw std::runtime_error( "allocate receive buffer" );
}
if (send_config_cmd(NFQNL_CFG_CMD_BIND) < 0) {
_clear();
throw std::runtime_error( "mnl_socket_send" );
}
//TEST if BIND was successful
if (send_config_cmd(NFQNL_CFG_CMD_NONE) < 0) { // SEND A NONE cmmand to generate an error meessage
_clear();
throw std::runtime_error( "mnl_socket_send" );
}
if (recv_packet() == -1) { //RECV the error message
_clear();
throw std::runtime_error( "mnl_socket_recvfrom" );
}
struct nlmsghdr *nlh = (struct nlmsghdr *) buf;
if (nlh->nlmsg_type != NLMSG_ERROR) {
_clear();
throw std::runtime_error( "unexpected packet from kernel (expected NLMSG_ERROR packet)" );
}
//nfqnl_msg_config_cmd
nlmsgerr* error_msg = (nlmsgerr *)mnl_nlmsg_get_payload(nlh);
// error code taken from the linux kernel:
// https://elixir.bootlin.com/linux/v5.18.12/source/include/linux/errno.h#L27
#define ENOTSUPP 524 /* Operation is not supported */
if (error_msg->error != -ENOTSUPP) {
_clear();
throw std::invalid_argument( "queueid is already busy" );
}
//END TESTING BIND
nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
nfq_nlmsg_cfg_put_params(nlh, NFQNL_COPY_PACKET, 0xffff);
mnl_attr_put_u32(nlh, NFQA_CFG_FLAGS, htonl(NFQA_CFG_F_GSO));
mnl_attr_put_u32(nlh, NFQA_CFG_MASK, htonl(NFQA_CFG_F_GSO));
if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
_clear();
throw std::runtime_error( "mnl_socket_send" );
}
}
void run(){
/*
* ENOBUFS is signalled to userspace when packets were lost
* on kernel side. In most cases, userspace isn't interested
* in this information, so turn it off.
*/
int ret = 1;
mnl_socket_setsockopt(nl, NETLINK_NO_ENOBUFS, &ret, sizeof(int));
for (;;) {
ret = recv_packet();
if (ret == -1) {
throw std::runtime_error( "mnl_socket_recvfrom" );
}
ret = mnl_cb_run(buf, ret, 0, portid, queue_cb, nl);
if (ret < 0){
throw std::runtime_error( "mnl_cb_run" );
}
}
}
~NetfilterQueue() {
send_config_cmd(NFQNL_CFG_CMD_UNBIND);
_clear();
}
private:
ssize_t send_config_cmd(nfqnl_msg_config_cmds cmd){
struct nlmsghdr *nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
nfq_nlmsg_cfg_put_cmd(nlh, AF_INET, cmd);
return mnl_socket_sendto(nl, nlh, nlh->nlmsg_len);
}
ssize_t recv_packet(){
return mnl_socket_recvfrom(nl, buf, BUF_SIZE);
}
void _clear(){
if (buf != NULL) {
free(buf);
buf = NULL;
}
mnl_socket_close(nl);
}
static int queue_cb(const struct nlmsghdr *nlh, void *data)
{
struct mnl_socket* nl = (struct mnl_socket*)data;
//Extract attributes from the nlmsghdr
struct nlattr *attr[NFQA_MAX+1] = {};
if (nfq_nlmsg_parse(nlh, attr) < 0) {
perror("problems parsing");
return MNL_CB_ERROR;
}
if (attr[NFQA_PACKET_HDR] == NULL) {
fputs("metaheader not set\n", stderr);
return MNL_CB_ERROR;
}
//Get Payload
uint16_t plen = mnl_attr_get_payload_len(attr[NFQA_PAYLOAD]);
void *payload = mnl_attr_get_payload(attr[NFQA_PAYLOAD]);
//Return result to the kernel
struct nfqnl_msg_packet_hdr *ph = (nfqnl_msg_packet_hdr*) mnl_attr_get_payload(attr[NFQA_PACKET_HDR]);
struct nfgenmsg *nfg = (nfgenmsg *)mnl_nlmsg_get_payload(nlh);
char buf[MNL_SOCKET_BUFFER_SIZE];
struct nlmsghdr *nlh_verdict;
struct nlattr *nest;
nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id));
/*
This define allow to avoid to allocate new heap memory for each packet.
The code under this comment is replicated for ipv6 and ip
Better solutions are welcome. :)
*/
#define PKT_HANDLE \
Tins::PDU *transport_layer = find_transport_layer(&packet); \
if(transport_layer->inner_pdu() == nullptr || transport_layer == nullptr){ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
}else{ \
int size = transport_layer->inner_pdu()->size(); \
if(callback_func((const uint8_t*)payload+plen - size, size)){ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
} else{ \
if (transport_layer->pdu_type() == Tins::PDU::TCP){ \
((Tins::TCP *)transport_layer)->release_inner_pdu(); \
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::FIN,1); \
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::ACK,1); \
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::SYN,0); \
nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
}else{ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_DROP ); \
} \
} \
}
// Check IP protocol version
if ( (((uint8_t*)payload)[0] & 0xf0) == 0x40 ){
Tins::IP packet = Tins::IP((uint8_t*)payload,plen);
PKT_HANDLE
}else{
Tins::IPv6 packet = Tins::IPv6((uint8_t*)payload,plen);
PKT_HANDLE
}
/* example to set the connmark. First, start NFQA_CT section: */
nest = mnl_attr_nest_start(nlh_verdict, NFQA_CT);
/* then, add the connmark attribute: */
mnl_attr_put_u32(nlh_verdict, CTA_MARK, htonl(42));
/* more conntrack attributes, e.g. CTA_LABELS could be set here */
/* end conntrack section */
mnl_attr_nest_end(nlh_verdict, nest);
if (mnl_socket_sendto(nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) {
throw std::runtime_error( "mnl_socket_send" );
}
return MNL_CB_OK;
}
};
template <NetFilterQueueCallback func>
class NFQueueSequence{
private:
std::vector<NetfilterQueue<func> *> nfq;
uint16_t _init;
uint16_t _end;
std::vector<std::thread> threads;
public:
static const int QUEUE_BASE_NUM = 1000;
NFQueueSequence(uint16_t seq_len){
if (seq_len <= 0) throw std::invalid_argument("seq_len <= 0");
nfq = std::vector<NetfilterQueue<func>*>(seq_len);
_init = QUEUE_BASE_NUM;
while(nfq[0] == NULL){
if (_init+seq_len-1 >= 65536){
throw std::runtime_error("NFQueueSequence: too many queues!");
}
for (int i=0;i<seq_len;i++){
try{
nfq[i] = new NetfilterQueue<func>(_init+i);
}catch(const std::invalid_argument e){
for(int j = 0; j < i; j++) {
delete nfq[j];
nfq[j] = nullptr;
}
_init += seq_len - i;
break;
}
}
}
_end = _init + seq_len - 1;
}
void start(){
if (threads.size() != 0) throw std::runtime_error("NFQueueSequence: already started!");
for (int i=0;i<nfq.size();i++){
threads.push_back(std::thread(&NetfilterQueue<func>::run, nfq[i]));
}
}
void join(){
for (int i=0;i<nfq.size();i++){
threads[i].join();
}
threads.clear();
}
uint16_t init(){
return _init;
}
uint16_t end(){
return _end;
}
~NFQueueSequence(){
for (int i=0;i<nfq.size();i++){
delete nfq[i];
}
}
};
#endif // NETFILTER_CLASSES_HPP

View File

@@ -1,95 +0,0 @@
#include <iostream>
#include <cstring>
#include <jpcre2.hpp>
#include <sstream>
#include "../utils.hpp"
#ifndef REGEX_FILTER_HPP
#define REGEX_FILTER_HPP
typedef jpcre2::select<char> jp;
typedef std::pair<std::string,jp::Regex> regex_rule_pair;
typedef std::vector<regex_rule_pair> regex_rule_vector;
struct regex_rules{
regex_rule_vector output_whitelist, input_whitelist, output_blacklist, input_blacklist;
regex_rule_vector* getByCode(char code){
switch(code){
case 'C': // Client to server Blacklist
return &input_blacklist; break;
case 'c': // Client to server Whitelist
return &input_whitelist; break;
case 'S': // Server to client Blacklist
return &output_blacklist; break;
case 's': // Server to client Whitelist
return &output_whitelist; break;
}
throw std::invalid_argument( "Expected 'C' 'c' 'S' or 's'" );
}
int add(const char* arg){
//Integrity checks
size_t arg_len = strlen(arg);
if (arg_len < 2 || arg_len%2 != 0){
std::cerr << "[warning] [regex_rules.add] invalid arg passed (" << arg << "), skipping..." << std::endl;
return -1;
}
if (arg[0] != '0' && arg[0] != '1'){
std::cerr << "[warning] [regex_rules.add] invalid is_case_sensitive (" << arg[0] << ") in '" << arg << "', must be '1' or '0', skipping..." << std::endl;
return -1;
}
if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's'){
std::cerr << "[warning] [regex_rules.add] invalid filter_type (" << arg[1] << ") in '" << arg << "', must be 'C', 'c', 'S' or 's', skipping..." << std::endl;
return -1;
}
std::string hex(arg+2), expr;
if (!unhexlify(hex, expr)){
std::cerr << "[warning] [regex_rules.add] invalid hex regex value (" << hex << "), skipping..." << std::endl;
return -1;
}
//Push regex
jp::Regex regex(expr,arg[0] == '1'?"gS":"giS");
if (regex){
std::cerr << "[info] [regex_rules.add] adding new regex filter: '" << expr << "'" << std::endl;
getByCode(arg[1])->push_back(std::make_pair(std::string(arg), regex));
} else {
std::cerr << "[warning] [regex_rules.add] compiling of '" << expr << "' regex failed, skipping..." << std::endl;
return -1;
}
return 0;
}
bool check(unsigned char* data, const size_t& bytes_transferred, const bool in_input){
std::string str_data((char *) data, bytes_transferred);
for (regex_rule_pair ele:(in_input?input_blacklist:output_blacklist)){
try{
if(ele.second.match(str_data)){
std::stringstream msg;
msg << "BLOCKED " << ele.first << "\n";
std::cout << msg.str() << std::flush;
return false;
}
} catch(...){
std::cerr << "[info] [regex_rules.check] Error while matching blacklist regex: " << ele.first << std::endl;
}
}
for (regex_rule_pair ele:(in_input?input_whitelist:output_whitelist)){
try{
std::cerr << "[debug] [regex_rules.check] regex whitelist match " << ele.second.getPattern() << std::endl;
if(!ele.second.match(str_data)){
std::stringstream msg;
msg << "BLOCKED " << ele.first << "\n";
std::cout << msg.str() << std::flush;
return false;
}
} catch(...){
std::cerr << "[info] [regex_rules.check] Error while matching whitelist regex: " << ele.first << std::endl;
}
}
return true;
}
};
#endif // REGEX_FILTER_HPP

View File

@@ -0,0 +1,174 @@
#include <iostream>
#include <cstring>
#include <sstream>
#include "../utils.hpp"
#include <vector>
#include <hs.h>
using namespace std;
#ifndef REGEX_FILTER_HPP
#define REGEX_FILTER_HPP
enum FilterDirection{ CTOS, STOC };
struct decoded_regex {
string regex;
FilterDirection direction;
bool is_case_sensitive;
};
struct regex_ruleset {
hs_database_t* hs_db = nullptr;
vector<string> regexes;
};
decoded_regex decode_regex(string regex){
size_t arg_len = regex.size();
if (arg_len < 2 || arg_len%2 != 0){
cerr << "[warning] [decode_regex] invalid arg passed (" << regex << "), skipping..." << endl;
throw runtime_error( "Invalid expression len (too small)" );
}
if (regex[0] != '0' && regex[0] != '1'){
cerr << "[warning] [decode_regex] invalid is_case_sensitive (" << regex[0] << ") in '" << regex << "', must be '1' or '0', skipping..." << endl;
throw runtime_error( "Invalid is_case_sensitive" );
}
if (regex[1] != 'C' && regex[1] != 'S'){
cerr << "[warning] [decode_regex] invalid filter_direction (" << regex[1] << ") in '" << regex << "', must be 'C', 'S', skipping..." << endl;
throw runtime_error( "Invalid filter_direction" );
}
string hex(regex.c_str()+2), expr;
if (!unhexlify(hex, expr)){
cerr << "[warning] [decode_regex] invalid hex regex value (" << hex << "), skipping..." << endl;
throw runtime_error( "Invalid hex regex encoded value" );
}
decoded_regex ruleset{
regex: expr,
direction: regex[1] == 'C' ? CTOS : STOC,
is_case_sensitive: regex[0] == '1'
};
return ruleset;
}
class RegexRules{
public:
regex_ruleset output_ruleset, input_ruleset;
private:
static inline u_int16_t glob_seq = 0;
u_int16_t version;
vector<pair<string, decoded_regex>> decoded_input_rules;
vector<pair<string, decoded_regex>> decoded_output_rules;
bool is_stream = true;
void free_dbs(){
if (output_ruleset.hs_db != nullptr){
hs_free_database(output_ruleset.hs_db);
output_ruleset.hs_db = nullptr;
}
if (input_ruleset.hs_db != nullptr){
hs_free_database(input_ruleset.hs_db);
input_ruleset.hs_db = nullptr;
}
}
void fill_ruleset(vector<pair<string, decoded_regex>> & decoded, regex_ruleset & ruleset){
size_t n_of_regex = decoded.size();
if (n_of_regex == 0){
return;
}
vector<const char*> regex_match_rules(n_of_regex);
vector<unsigned int> regex_array_ids(n_of_regex);
vector<unsigned int> regex_flags(n_of_regex);
for(int i = 0; i < n_of_regex; i++){
regex_match_rules[i] = decoded[i].second.regex.c_str();
regex_array_ids[i] = i;
regex_flags[i] = HS_FLAG_SINGLEMATCH | HS_FLAG_ALLOWEMPTY;
if (!decoded[i].second.is_case_sensitive){
regex_flags[i] |= HS_FLAG_CASELESS;
}
}
#ifdef DEBUG
cerr << "[DEBUG] [RegexRules.fill_ruleset] compiling " << n_of_regex << " regexes..." << endl;
for (int i = 0; i < n_of_regex; i++){
cerr << "[DEBUG] [RegexRules.fill_ruleset] regex[" << i << "]: " << decoded[i].first << " " << decoded[i].second.regex << endl;
cerr << "[DEBUG] [RegexRules.fill_ruleset] regex_match_rules[" << i << "]: " << regex_match_rules[i] << endl;
cerr << "[DEBUG] [RegexRules.fill_ruleset] regex_flags[" << i << "]: " << regex_flags[i] << endl;
cerr << "[DEBUG] [RegexRules.fill_ruleset] regex_array_ids[" << i << "]: " << regex_array_ids[i] << endl;
}
#endif
hs_database_t* rebuilt_db = nullptr;
hs_compile_error_t *compile_err = nullptr;
if (
hs_compile_multi(
regex_match_rules.data(),
regex_flags.data(),
regex_array_ids.data(),
n_of_regex,
is_stream?HS_MODE_STREAM:HS_MODE_BLOCK,
nullptr, &rebuilt_db, &compile_err
) != HS_SUCCESS
) {
cerr << "[warning] [RegexRules.fill_ruleset] hs_db failed to compile: '" << compile_err->message << "' skipping..." << endl;
hs_free_compile_error(compile_err);
throw runtime_error( "Failed to compile hyperscan db" );
}
ruleset.hs_db = rebuilt_db;
ruleset.regexes = vector<string>(n_of_regex);
for(int i = 0; i < n_of_regex; i++){
ruleset.regexes[i] = decoded[i].first;
}
}
public:
RegexRules(vector<string> raw_rules, bool is_stream){
this->is_stream = is_stream;
this->version = ++glob_seq; // 0 version is a invalid version (useful for some logics)
for(string ele : raw_rules){
try{
decoded_regex rule = decode_regex(ele);
if (rule.direction == FilterDirection::CTOS){
decoded_input_rules.push_back(make_pair(ele, rule));
}else{
decoded_output_rules.push_back(make_pair(ele, rule));
}
}catch(...){
throw current_exception();
}
}
fill_ruleset(decoded_input_rules, input_ruleset);
try{
fill_ruleset(decoded_output_rules, output_ruleset);
}catch(...){
free_dbs();
throw current_exception();
}
}
u_int16_t ver(){
return version;
}
RegexRules(bool is_stream){
vector<string> no_rules;
RegexRules(no_rules, is_stream);
}
bool stream_mode(){
return is_stream;
}
RegexRules(){
RegexRules(true);
}
~RegexRules(){
free_dbs();
}
};
#endif // REGEX_FILTER_HPP

View File

@@ -1,11 +1,11 @@
#include "classes/regex_filter.hpp"
#include "classes/netfilter.hpp"
#include "classes/regex_rules.cpp"
#include "classes/netfilter.cpp"
#include "utils.hpp"
#include <iostream>
using namespace std;
shared_ptr<regex_rules> regex_config;
shared_ptr<RegexRules> regex_config;
void config_updater (){
string line;
@@ -21,44 +21,158 @@ void config_updater (){
}
cerr << "[info] [updater] Updating configuration with line " << line << endl;
istringstream config_stream(line);
regex_rules *regex_new_config = new regex_rules();
vector<string> raw_rules;
while(!config_stream.eof()){
string data;
config_stream >> data;
if (data != "" && data != "\n"){
regex_new_config->add(data.c_str());
raw_rules.push_back(data);
}
}
regex_config.reset(regex_new_config);
cerr << "[info] [updater] Config update done" << endl;
try{
regex_config.reset(new RegexRules(raw_rules, regex_config->stream_mode()));
cerr << "[info] [updater] Config update done to ver "<< regex_config->ver() << endl;
cout << "ACK OK" << endl;
}catch(const std::exception& e){
cerr << "[error] [updater] Failed to build new configuration!" << endl;
cout << "ACK FAIL " << e.what() << endl;
}
}
}
template <bool is_input>
bool filter_callback(const uint8_t *data, uint32_t len){
shared_ptr<regex_rules> current_config = regex_config;
return current_config->check((unsigned char *)data, len, is_input);
void inline scratch_setup(regex_ruleset &conf, hs_scratch_t* & scratch){
if (scratch == nullptr && conf.hs_db != nullptr){
if (hs_alloc_scratch(conf.hs_db, &scratch) != HS_SUCCESS) {
throw invalid_argument("Cannot alloc scratch");
}
}
}
int main(int argc, char *argv[])
{
struct matched_data{
unsigned int matched = 0;
bool has_matched = false;
};
bool filter_callback(packet_info& info){
shared_ptr<RegexRules> conf = regex_config;
auto current_version = conf->ver();
if (current_version != info.sctx->latest_config_ver){
#ifdef DEBUG
cerr << "[DEBUG] [filter_callback] Configuration has changed (" << current_version << "!=" << info.sctx->latest_config_ver << "), cleaning scratch spaces" << endl;
#endif
info.sctx->clean();
info.sctx->latest_config_ver = current_version;
}
scratch_setup(conf->input_ruleset, info.sctx->in_scratch);
scratch_setup(conf->output_ruleset, info.sctx->out_scratch);
hs_database_t* regex_matcher = info.is_input ? conf->input_ruleset.hs_db : conf->output_ruleset.hs_db;
if (regex_matcher == nullptr){
return true;
}
#ifdef DEBUG
cerr << "[DEBUG] [filter_callback] Matching packet with " << (info.is_input ? "input" : "output") << " ruleset" << endl;
if (info.payload.size() <= 30){
cerr << "[DEBUG] [filter_callback] Packet: " << info.payload << endl;
}
#endif
matched_data match_res;
hs_error_t err;
hs_scratch_t* scratch_space = info.is_input ? info.sctx->in_scratch: info.sctx->out_scratch;
auto match_func = [](unsigned int id, auto from, auto to, auto flags, auto ctx){
auto res = (matched_data*)ctx;
res->has_matched = true;
res->matched = id;
return -1; // Stop matching
};
hs_stream_t* stream_match;
if (conf->stream_mode()){
matching_map* match_map = info.is_input ? &info.sctx->in_hs_streams : &info.sctx->out_hs_streams;
#ifdef DEBUG
cerr << "[DEBUG] [filter_callback] Dumping match_map " << match_map << endl;
for (auto ele: *match_map){
cerr << "[DEBUG] [filter_callback] " << ele.first << " -> " << ele.second << endl;
}
cerr << "[DEBUG] [filter_callback] End of match_map" << endl;
#endif
auto stream_search = match_map->find(info.sid);
if (stream_search == match_map->end()){
#ifdef DEBUG
cerr << "[DEBUG] [filter_callback] Creating new stream matcher for " << info.sid << endl;
#endif
if (hs_open_stream(regex_matcher, 0, &stream_match) != HS_SUCCESS) {
cerr << "[error] [filter_callback] Error opening the stream matcher (hs)" << endl;
throw invalid_argument("Cannot open stream match on hyperscan");
}
if (info.is_tcp){
match_map->insert_or_assign(info.sid, stream_match);
}
}else{
stream_match = stream_search->second;
}
#ifdef DEBUG
cerr << "[DEBUG] [filter_callback] Matching as a stream" << endl;
#endif
err = hs_scan_stream(
stream_match,info.payload.c_str(), info.payload.length(),
0, scratch_space, match_func, &match_res
);
}else{
#ifdef DEBUG
cerr << "[DEBUG] [filter_callback] Matching as a block" << endl;
#endif
err = hs_scan(
regex_matcher,info.payload.c_str(), info.payload.length(),
0, scratch_space, match_func, &match_res
);
}
if (
!info.is_tcp && conf->stream_mode() &&
hs_close_stream(stream_match, scratch_space, nullptr, nullptr) != HS_SUCCESS
){
cerr << "[error] [filter_callback] Error closing the stream matcher (hs)" << endl;
throw invalid_argument("Cannot close stream match on hyperscan");
}
if (err != HS_SUCCESS && err != HS_SCAN_TERMINATED) {
cerr << "[error] [filter_callback] Error while matching the stream (hs)" << endl;
throw invalid_argument("Error while matching the stream with hyperscan");
}
if (match_res.has_matched){
auto rules_vector = info.is_input ? conf->input_ruleset.regexes : conf->output_ruleset.regexes;
stringstream msg;
msg << "BLOCKED " << rules_vector[match_res.matched] << "\n";
cout << msg.str() << flush;
return false;
}
return true;
}
int main(int argc, char *argv[]){
int n_of_threads = 1;
char * n_threads_str = getenv("NTHREADS");
if (n_threads_str != NULL) n_of_threads = ::atoi(n_threads_str);
if (n_threads_str != nullptr) n_of_threads = ::atoi(n_threads_str);
if(n_of_threads <= 0) n_of_threads = 1;
if (n_of_threads % 2 != 0 ) n_of_threads++;
cerr << "[info] [main] Using " << n_of_threads << " threads" << endl;
regex_config.reset(new regex_rules());
NFQueueSequence<filter_callback<true>> input_queues(n_of_threads/2);
input_queues.start();
NFQueueSequence<filter_callback<false>> output_queues(n_of_threads/2);
output_queues.start();
cout << "QUEUES INPUT " << input_queues.init() << " " << input_queues.end() << " OUTPUT " << output_queues.init() << " " << output_queues.end() << endl;
cerr << "[info] [main] Input queues: " << input_queues.init() << ":" << input_queues.end() << " threads assigned: " << n_of_threads/2 << endl;
cerr << "[info] [main] Output queues: " << output_queues.init() << ":" << output_queues.end() << " threads assigned: " << n_of_threads/2 << endl;
char * matchmode = getenv("MATCH_MODE");
bool stream_mode = true;
if (matchmode != nullptr && strcmp(matchmode, "block") == 0){
stream_mode = false;
}
regex_config.reset(new RegexRules(stream_mode));
NFQueueSequence<filter_callback> queues(n_of_threads);
queues.start();
cout << "QUEUES " << queues.init() << " " << queues.end() << endl;
cerr << "[info] [main] Queues: " << queues.init() << ":" << queues.end() << " threads assigned: " << n_of_threads << " stream mode: " << stream_mode << endl;
config_updater();
}

View File

@@ -1,32 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "atomic_refcell"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c"
[[package]]
name = "libc"
version = "0.2.153"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
[[package]]
name = "nfq"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9c8f4c88952507d9df9400a6a2e48640fb460e21dcb2b4716eb3ff156d6db9e"
dependencies = [
"libc",
]
[[package]]
name = "nfqueue_regex"
version = "0.1.0"
dependencies = [
"atomic_refcell",
"nfq",
]

View File

@@ -1,11 +0,0 @@
[package]
name = "nfqueue_regex"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
atomic_refcell = "0.1.13"
nfq = "0.2.5"
#hyperscan = "0.3.2"

View File

@@ -1,150 +0,0 @@
use atomic_refcell::AtomicRefCell;
use nfq::{Queue, Verdict};
use std::cell::{Cell, RefCell};
use std::env;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::atomic::{AtomicPtr, AtomicU32};
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::Arc;
use std::thread::{self, sleep, sleep_ms, JoinHandle};
enum WorkerMessage {
Error(String),
Dropped(usize),
}
impl ToString for WorkerMessage {
fn to_string(&self) -> String {
match self {
WorkerMessage::Error(e) => format!("E{}", e),
WorkerMessage::Dropped(d) => format!("D{}", d),
}
}
}
struct Pool {
_workers: Vec<Worker>,
pub start: u16,
pub end: u16,
}
const QUEUE_BASE_NUM: u16 = 1000;
impl Pool {
fn new(threads: u16, tx: Sender<WorkerMessage>, db: RefCell<&str>) -> Self {
// Find free queues
let mut start = QUEUE_BASE_NUM;
let mut queues: Vec<(Queue, u16)> = vec![];
while queues.len() != threads.into() {
for queue_num in
(start..start.checked_add(threads + 1).expect("No more queues left")).rev()
{
let mut queue = Queue::open().unwrap();
if queue.bind(queue_num).is_err() {
start = queue_num;
while let Some((mut q, num)) = queues.pop() {
let _ = q.unbind(num);
}
break;
};
queues.push((queue, queue_num));
}
}
Pool {
_workers: queues
.into_iter()
.map(|(queue, queue_num)| Worker::new(queue, queue_num, tx.clone()))
.collect(),
start,
end: (start + threads),
}
}
// fn join(self) {
// for worker in self._workers {
// let _ = worker.join();
// }
// }
}
struct Worker {
_inner: JoinHandle<()>,
}
impl Worker {
fn new(mut queue: Queue, _queue_num: u16, tx: Sender<WorkerMessage>) -> Self {
Worker {
_inner: thread::spawn(move || loop {
let mut msg = queue.recv().unwrap_or_else(|_| {
let _ = tx.send(WorkerMessage::Error("Fuck".to_string()));
panic!("");
});
msg.set_verdict(Verdict::Accept);
queue.verdict(msg).unwrap();
}),
}
}
}
struct InputOuputPools {
pub output_queue: Pool,
pub input_queue: Pool,
rx: Receiver<WorkerMessage>,
}
impl InputOuputPools {
fn new(threads: u16) -> InputOuputPools {
let (tx, rx) = mpsc::channel();
InputOuputPools {
output_queue: Pool::new(threads / 2, tx.clone(), RefCell::new("ciao")),
input_queue: Pool::new(threads / 2, tx, RefCell::new("miao")),
rx,
}
}
fn poll_events(&self) {
loop {
let event = self.rx.recv().expect("Channel has hung up");
println!("{}", event.to_string());
}
}
}
static mut DB: AtomicPtr<Arc<u32>> = AtomicPtr::new(std::ptr::null_mut() as *mut Arc<u32>);
fn main() -> std::io::Result<()> {
let mut my_x: Arc<u32> = Arc::new(0);
let my_x_ptr: *mut Arc<u32> = std::ptr::addr_of_mut!(my_x);
unsafe { DB.store(my_x_ptr, std::sync::atomic::Ordering::SeqCst) };
thread::spawn(|| loop {
let x_ptr = unsafe { DB.load(std::sync::atomic::Ordering::SeqCst) };
let x = unsafe { (*x_ptr).clone() };
dbg!(x);
//sleep_ms(1000);
});
for i in 0..1000000000 {
let mut my_x: Arc<u32> = Arc::new(i);
let my_x_ptr: *mut Arc<u32> = std::ptr::addr_of_mut!(my_x);
unsafe { DB.store(my_x_ptr, std::sync::atomic::Ordering::SeqCst) };
//sleep_ms(100);
}
let mut threads = env::var("NPROCS").unwrap_or_default().parse().unwrap_or(2);
if threads % 2 != 0 {
threads += 1;
}
let in_out_pools = InputOuputPools::new(threads);
eprintln!(
"[info] [main] Input queues: {}:{}",
in_out_pools.input_queue.start, in_out_pools.input_queue.end
);
eprintln!(
"[info] [main] Output queues: {}:{}",
in_out_pools.output_queue.start, in_out_pools.output_queue.end
);
in_out_pools.poll_events();
Ok(())
}

View File

@@ -1,493 +0,0 @@
/*
Copyright (c) 2007 Arash Partow (http://www.partow.net)
URL: http://www.partow.net/programming/tcpproxy/index.html
Modified and adapted by Pwnzer0tt1
*/
#include <cstdlib>
#include <cstddef>
#include <iostream>
#include <string>
#include <mutex>
#include <boost/thread.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/enable_shared_from_this.hpp>
#include <boost/bind/bind.hpp>
#include <boost/asio.hpp>
#include <boost/thread/mutex.hpp>
#include <jpcre2.hpp>
typedef jpcre2::select<char> jp;
using namespace std;
bool unhexlify(string const &hex, string &newString) {
try{
int len = hex.length();
for(int i=0; i< len; i+=2)
{
std::string byte = hex.substr(i,2);
char chr = (char) (int)strtol(byte.c_str(), NULL, 16);
newString.push_back(chr);
}
return true;
}
catch (...){
return false;
}
}
typedef pair<string,jp::Regex> regex_rule_pair;
typedef vector<regex_rule_pair> regex_rule_vector;
struct regex_rules{
regex_rule_vector regex_s_c_w, regex_c_s_w, regex_s_c_b, regex_c_s_b;
regex_rule_vector* getByCode(char code){
switch(code){
case 'C': // Client to server Blacklist
return &regex_c_s_b; break;
case 'c': // Client to server Whitelist
return &regex_c_s_w; break;
case 'S': // Server to client Blacklist
return &regex_s_c_b; break;
case 's': // Server to client Whitelist
return &regex_s_c_w; break;
}
throw invalid_argument( "Expected 'C' 'c' 'S' or 's'" );
}
void add(const char* arg){
//Integrity checks
size_t arg_len = strlen(arg);
if (arg_len < 2 || arg_len%2 != 0) return;
if (arg[0] != '0' && arg[0] != '1') return;
if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's') return;
string hex(arg+2), expr;
if (!unhexlify(hex, expr)) return;
//Push regex
jp::Regex regex(expr,arg[0] == '1'?"gS":"giS");
if (regex){
#ifdef DEBUG
cerr << "Added regex " << expr << " " << arg << endl;
#endif
getByCode(arg[1])->push_back(make_pair(string(arg), regex));
} else {
cerr << "Regex " << arg << " was not compiled successfully" << endl;
}
}
};
shared_ptr<regex_rules> regex_config;
mutex update_mutex;
bool filter_data(unsigned char* data, const size_t& bytes_transferred, regex_rule_vector const &blacklist, regex_rule_vector const &whitelist){
#ifdef DEBUG_PACKET
cerr << "---------------- Packet ----------------" << endl;
for(int i=0;i<bytes_transferred;i++) cerr << data[i];
cerr << endl;
for(int i=0;i<bytes_transferred;i++) fprintf(stderr, "%x", data[i]);
cerr << endl;
cerr << "---------------- End Packet ----------------" << endl;
#endif
string str_data((char *) data, bytes_transferred);
for (regex_rule_pair ele:blacklist){
try{
if(ele.second.match(str_data)){
stringstream msg;
msg << "BLOCKED " << ele.first << endl;
cout << msg.str() << std::flush;
return false;
}
} catch(...){
cerr << "Error while matching regex: " << ele.first << endl;
}
}
for (regex_rule_pair ele:whitelist){
try{
if(!ele.second.match(str_data)){
stringstream msg;
msg << "BLOCKED " << ele.first << endl;
cout << msg.str() << std::flush;
return false;
}
} catch(...){
cerr << "Error while matching regex: " << ele.first << endl;
}
}
#ifdef DEBUG
cerr << "Packet Accepted!" << endl;
#endif
return true;
}
namespace tcp_proxy
{
namespace ip = boost::asio::ip;
class bridge : public boost::enable_shared_from_this<bridge>
{
public:
typedef ip::tcp::socket socket_type;
typedef boost::shared_ptr<bridge> ptr_type;
bridge(boost::asio::io_context& ios)
: downstream_socket_(ios),
upstream_socket_ (ios),
thread_safety(ios)
{}
socket_type& downstream_socket()
{
// Client socket
return downstream_socket_;
}
socket_type& upstream_socket()
{
// Remote server socket
return upstream_socket_;
}
void start(const string& upstream_host, unsigned short upstream_port)
{
// Attempt connection to remote server (upstream side)
upstream_socket_.async_connect(
ip::tcp::endpoint(
boost::asio::ip::address::from_string(upstream_host),
upstream_port),
boost::asio::bind_executor(thread_safety,
boost::bind(
&bridge::handle_upstream_connect,
shared_from_this(),
boost::asio::placeholders::error)));
}
void handle_upstream_connect(const boost::system::error_code& error)
{
if (!error)
{
// Setup async read from remote server (upstream)
upstream_socket_.async_read_some(
boost::asio::buffer(upstream_data_,max_data_length),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_upstream_read,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
// Setup async read from client (downstream)
downstream_socket_.async_read_some(
boost::asio::buffer(downstream_data_,max_data_length),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_downstream_read,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
close();
}
private:
/*
Section A: Remote Server --> Proxy --> Client
Process data recieved from remote sever then send to client.
*/
// Read from remote server complete, now send data to client
void handle_upstream_read(const boost::system::error_code& error,
const size_t& bytes_transferred) // Da Server a Client
{
if (!error)
{
shared_ptr<regex_rules> regex_old_config = regex_config;
if (filter_data(upstream_data_, bytes_transferred, regex_old_config->regex_s_c_b, regex_old_config->regex_s_c_w)){
async_write(downstream_socket_,
boost::asio::buffer(upstream_data_,bytes_transferred),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_downstream_write,
shared_from_this(),
boost::asio::placeholders::error)));
}else{
close();
}
}
else
close();
}
// Write to client complete, Async read from remote server
void handle_downstream_write(const boost::system::error_code& error)
{
if (!error)
{
upstream_socket_.async_read_some(
boost::asio::buffer(upstream_data_,max_data_length),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_upstream_read,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
close();
}
// *** End Of Section A ***
/*
Section B: Client --> Proxy --> Remove Server
Process data recieved from client then write to remove server.
*/
// Read from client complete, now send data to remote server
void handle_downstream_read(const boost::system::error_code& error,
const size_t& bytes_transferred) // Da Client a Server
{
if (!error)
{
shared_ptr<regex_rules> regex_old_config = regex_config;
if (filter_data(downstream_data_, bytes_transferred, regex_old_config->regex_c_s_b, regex_old_config->regex_c_s_w)){
async_write(upstream_socket_,
boost::asio::buffer(downstream_data_,bytes_transferred),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_upstream_write,
shared_from_this(),
boost::asio::placeholders::error)));
}else{
close();
}
}
else
close();
}
// Write to remote server complete, Async read from client
void handle_upstream_write(const boost::system::error_code& error)
{
if (!error)
{
downstream_socket_.async_read_some(
boost::asio::buffer(downstream_data_,max_data_length),
boost::asio::bind_executor(thread_safety,
boost::bind(&bridge::handle_downstream_read,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
close();
}
// *** End Of Section B ***
void close()
{
boost::mutex::scoped_lock lock(mutex_);
if (downstream_socket_.is_open())
{
downstream_socket_.close();
}
if (upstream_socket_.is_open())
{
upstream_socket_.close();
}
}
socket_type downstream_socket_;
socket_type upstream_socket_;
enum { max_data_length = 8192 }; //8KB
unsigned char downstream_data_[max_data_length];
unsigned char upstream_data_ [max_data_length];
boost::asio::io_context::strand thread_safety;
boost::mutex mutex_;
public:
class acceptor
{
public:
acceptor(boost::asio::io_context& io_context,
const string& local_host, unsigned short local_port,
const string& upstream_host, unsigned short upstream_port)
: io_context_(io_context),
localhost_address(boost::asio::ip::address_v4::from_string(local_host)),
acceptor_(io_context_,ip::tcp::endpoint(localhost_address,local_port)),
upstream_port_(upstream_port),
upstream_host_(upstream_host)
{}
bool accept_connections()
{
try
{
session_ = boost::shared_ptr<bridge>(new bridge(io_context_));
acceptor_.async_accept(session_->downstream_socket(),
boost::asio::bind_executor(session_->thread_safety,
boost::bind(&acceptor::handle_accept,
this,
boost::asio::placeholders::error)));
}
catch(exception& e)
{
cerr << "acceptor exception: " << e.what() << endl;
return false;
}
return true;
}
private:
void handle_accept(const boost::system::error_code& error)
{
if (!error)
{
session_->start(upstream_host_,upstream_port_);
if (!accept_connections())
{
cerr << "Failure during call to accept." << endl;
}
}
else
{
cerr << "Error: " << error.message() << endl;
}
}
boost::asio::io_context& io_context_;
ip::address_v4 localhost_address;
ip::tcp::acceptor acceptor_;
ptr_type session_;
unsigned short upstream_port_;
string upstream_host_;
};
};
}
void update_config (boost::asio::streambuf &input_buffer){
#ifdef DEBUG
cerr << "Updating configuration" << endl;
#endif
std::istream config_stream(&input_buffer);
std::unique_lock<std::mutex> lck(update_mutex);
regex_rules *regex_new_config = new regex_rules();
string data;
while(true){
config_stream >> data;
if (config_stream.eof()) break;
regex_new_config->add(data.c_str());
}
regex_config.reset(regex_new_config);
}
class async_updater
{
public:
async_updater(boost::asio::io_context& io_context) : input_(io_context, ::dup(STDIN_FILENO)), thread_safety(io_context)
{
boost::asio::async_read_until(input_, input_buffer_, '\n',
boost::asio::bind_executor(thread_safety,
boost::bind(&async_updater::on_update, this,
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
void on_update(const boost::system::error_code& error, std::size_t length)
{
if (!error)
{
update_config(input_buffer_);
boost::asio::async_read_until(input_, input_buffer_, '\n',
boost::asio::bind_executor(thread_safety,
boost::bind(&async_updater::on_update, this,
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
}
else
{
close();
}
}
void close()
{
input_.close();
}
private:
boost::asio::posix::stream_descriptor input_;
boost::asio::io_context::strand thread_safety;
boost::asio::streambuf input_buffer_;
};
int main(int argc, char* argv[])
{
if (argc < 5)
{
cerr << "usage: tcpproxy_server <local host ip> <local port> <forward host ip> <forward port>" << endl;
return 1;
}
const unsigned short local_port = static_cast<unsigned short>(::atoi(argv[2]));
const unsigned short forward_port = static_cast<unsigned short>(::atoi(argv[4]));
const string local_host = argv[1];
const string forward_host = argv[3];
int threads = 1;
char * n_threads_str = getenv("NTHREADS");
if (n_threads_str != NULL) threads = ::atoi(n_threads_str);
boost::asio::io_context ios;
boost::asio::streambuf buf;
boost::asio::posix::stream_descriptor cin_in(ios, ::dup(STDIN_FILENO));
boost::asio::read_until(cin_in, buf,'\n');
update_config(buf);
async_updater updater(ios);
#ifdef DEBUG
cerr << "Starting Proxy" << endl;
#endif
try
{
tcp_proxy::bridge::acceptor acceptor(ios,
local_host, local_port,
forward_host, forward_port);
acceptor.accept_connections();
if (threads > 1){
boost::thread_group tg;
for (unsigned i = 0; i < threads; ++i)
tg.create_thread(boost::bind(&boost::asio::io_context::run, &ios));
tg.join_all();
}else{
ios.run();
}
}
catch(exception& e)
{
cerr << "Error: " << e.what() << endl;
return 1;
}
#ifdef DEBUG
cerr << "Proxy stopped!" << endl;
#endif
return 0;
}

View File

@@ -10,7 +10,7 @@ bool unhexlify(std::string const &hex, std::string &newString) {
for(int i=0; i< len; i+=2)
{
std::string byte = hex.substr(i,2);
char chr = (char) (int)strtol(byte.c_str(), NULL, 16);
char chr = (char) (int)strtol(byte.c_str(), nullptr, 16);
newString.push_back(chr);
}
return true;

View File

@@ -1,6 +1,6 @@
import asyncio
from modules.firewall.nftables import FiregexTables
from modules.firewall.models import *
from modules.firewall.models import Rule, FirewallSettings
from utils.sqlite import SQLite
from modules.firewall.models import Action
@@ -131,5 +131,5 @@ class FirewallManager:
return self.db.get("allow_dhcp", "1") == "1"
@drop_invalid.setter
def allow_dhcp(self, value):
def allow_dhcp_set(self, value):
self.db.set("allow_dhcp", "1" if value else "0")

View File

@@ -1,8 +1,12 @@
from modules.nfregex.nftables import FiregexTables
from utils import ip_parse, run_func
from utils import run_func
from modules.nfregex.models import Service, Regex
import re, os, asyncio
import re
import os
import asyncio
import traceback
from utils import DEBUG
from fastapi import HTTPException
nft = FiregexTables()
@@ -10,7 +14,6 @@ class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
is_blacklist=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
@@ -19,8 +22,8 @@ class RegexFilter:
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True)
if input_mode == output_mode:
input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
@@ -32,19 +35,21 @@ class RegexFilter:
def from_regex(cls, regex:Regex, update_func = None):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets,
blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"],
update_func = update_func
)
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
if isinstance(self.regex, str):
self.regex = self.regex.encode()
if not isinstance(self.regex, bytes):
raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if it's invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.input_mode:
yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex()
yield case_sensitive + "C" + self.regex.hex()
if self.output_mode:
yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex()
yield case_sensitive + "S" + self.regex.hex()
async def update(self):
if self.update_func:
@@ -60,6 +65,10 @@ class FiregexInterceptor:
self.update_config_lock:asyncio.Lock
self.process:asyncio.subprocess.Process
self.update_task: asyncio.Task
self.ack_arrived = False
self.ack_status = None
self.ack_fail_what = ""
self.ack_lock = asyncio.Lock()
@classmethod
async def start(cls, srv: Service):
@@ -67,16 +76,19 @@ class FiregexInterceptor:
self.srv = srv
self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock()
input_range, output_range = await self._start_binary()
queue_range = await self._start_binary()
self.update_task = asyncio.create_task(self.update_blocked())
nft.add(self.srv, input_range, output_range)
nft.add(self.srv, queue_range)
if not self.ack_lock.locked():
await self.ack_lock.acquire()
return self
async def _start_binary(self):
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../cppqueue")
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path,
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
env={"MATCH_MODE": "stream" if self.srv.proto == "tcp" else "block", "NTHREADS": os.getenv("NTHREADS","1")},
)
line_fut = self.process.stdout.readuntil()
try:
@@ -87,7 +99,7 @@ class FiregexInterceptor:
line = line_fut.decode()
if line.startswith("QUEUES "):
params = line.split()
return (int(params[2]), int(params[3])), (int(params[5]), int(params[6]))
return (int(params[1]), int(params[2]))
else:
self.process.kill()
raise Exception("Invalid binary output")
@@ -96,14 +108,24 @@ class FiregexInterceptor:
try:
while True:
line = (await self.process.stdout.readuntil()).decode()
if line.startswith("BLOCKED"):
if DEBUG:
print(line)
if line.startswith("BLOCKED "):
regex_id = line.split()[1]
async with self.filter_map_lock:
if regex_id in self.filter_map:
self.filter_map[regex_id].blocked+=1
await self.filter_map[regex_id].update()
except asyncio.CancelledError: pass
except asyncio.IncompleteReadError: pass
if line.startswith("ACK "):
self.ack_arrived = True
self.ack_status = line.split()[1].upper() == "OK"
if not self.ack_status:
self.ack_fail_what = " ".join(line.split()[2:])
self.ack_lock.release()
except asyncio.CancelledError:
pass
except asyncio.IncompleteReadError:
pass
except Exception:
traceback.print_exc()
@@ -116,6 +138,14 @@ class FiregexInterceptor:
async with self.update_config_lock:
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain()
try:
async with asyncio.timeout(3):
await self.ack_lock.acquire()
except TimeoutError:
pass
if not self.ack_arrived or not self.ack_status:
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")
async def reload(self, filters:list[RegexFilter]):
async with self.filter_map_lock:
@@ -135,6 +165,7 @@ class FiregexInterceptor:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception: pass
except Exception:
pass
return res

View File

@@ -30,14 +30,15 @@ class ServiceManager:
new_filters = set([f.id for f in regexes])
#remove old filters
for f in old_filters:
if not f in new_filters:
if f not in new_filters:
del self.filters[f]
#add new filters
for f in new_filters:
if not f in old_filters:
if f not in old_filters:
filter = [ele for ele in regexes if ele.id == f][0]
self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater)
if self.interceptor: await self.interceptor.reload(self.filters.values())
if self.interceptor:
await self.interceptor.reload(self.filters.values())
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id)
@@ -114,4 +115,5 @@ class FirewallManager:
else:
raise ServiceNotFoundException()
class ServiceNotFoundException(Exception): pass
class ServiceNotFoundException(Exception):
pass

View File

@@ -15,11 +15,10 @@ class Service:
class Regex:
def __init__(self, regex_id: int, regex: bytes, mode: str, service_id: str, is_blacklist: bool, blocked_packets: int, is_case_sensitive: bool, active: bool, **other):
def __init__(self, regex_id: int, regex: bytes, mode: str, service_id: str, blocked_packets: int, is_case_sensitive: bool, active: bool, **other):
self.regex = regex
self.mode = mode
self.service_id = service_id
self.is_blacklist = is_blacklist
self.blocked_packets = blocked_packets
self.id = regex_id
self.is_case_sensitive = is_case_sensitive

View File

@@ -45,36 +45,37 @@ class FiregexTables(NFTableManager):
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.output_chain}}},
])
def add(self, srv:Service, queue_range_input, queue_range_output):
def add(self, srv:Service, queue_range):
for ele in self.get():
if ele.__eq__(srv): return
init, end = queue_range_output
init, end = queue_range
if init > end: init, end = end, init
self.cmd({ "insert":{ "rule": {
"family": "inet",
"table": self.table_name,
"chain": self.output_chain,
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
self.cmd(
{ "insert":{ "rule": {
"family": "inet",
"table": self.table_name,
"chain": self.output_chain,
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}},
{"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1338}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
]
}}})
init, end = queue_range_input
if init > end: init, end = end, init
self.cmd({"insert":{"rule":{
"family": "inet",
"table": self.table_name,
"chain": self.input_chain,
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
]
}}})
}}},
{"insert":{"rule":{
"family": "inet",
"table": self.table_name,
"chain": self.input_chain,
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}},
{"mangle": {"key": {"meta": {"key": "mark"}},"value": 0x1337}},
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
]
}}}
)
def get(self) -> list[FiregexFilter]:

View File

@@ -5,7 +5,8 @@ from utils.sqlite import SQLite
nft = FiregexTables()
class ServiceNotFoundException(Exception): pass
class ServiceNotFoundException(Exception):
pass
class ServiceManager:
def __init__(self, srv: Service, db):
@@ -29,7 +30,8 @@ class ServiceManager:
async def refresh(self, srv:Service):
self.srv = srv
if self.active: await self.restart()
if self.active:
await self.restart()
def _set_status(self,active):
self.active = active

View File

@@ -50,7 +50,8 @@ class FiregexTables(NFTableManager):
def add(self, srv:Service):
for ele in self.get():
if ele.__eq__(srv): return
if ele.__eq__(srv):
return
self.cmd({ "insert":{ "rule": {
"family": "inet",

View File

@@ -1,116 +0,0 @@
import re, os, asyncio
class Filter:
def __init__(self, regex, is_case_sensitive=True, is_blacklist=True, c_to_s=False, s_to_c=False, blocked_packets=0, code=None):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if c_to_s == s_to_c: c_to_s = s_to_c = True # (False, False) == (True, True)
self.c_to_s = c_to_s
self.s_to_c = s_to_c
self.blocked = blocked_packets
self.code = code
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if is invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.c_to_s:
yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex()
if self.s_to_c:
yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex()
class Proxy:
def __init__(self, internal_port=0, public_port=0, callback_blocked_update=None, filters=None, public_host="0.0.0.0", internal_host="127.0.0.1"):
self.filter_map = {}
self.filter_map_lock = asyncio.Lock()
self.update_config_lock = asyncio.Lock()
self.status_change = asyncio.Lock()
self.public_host = public_host
self.public_port = public_port
self.internal_host = internal_host
self.internal_port = internal_port
self.filters = set(filters) if filters else set([])
self.process = None
self.callback_blocked_update = callback_blocked_update
async def start(self, in_pause=False):
await self.status_change.acquire()
if not self.isactive():
try:
self.filter_map = self.compile_filters()
filters_codes = self.get_filter_codes() if not in_pause else []
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../proxy")
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port),
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE
)
await self.update_config(filters_codes)
finally:
self.status_change.release()
try:
while True:
buff = await self.process.stdout.readuntil()
stdout_line = buff.decode()
if stdout_line.startswith("BLOCKED"):
regex_id = stdout_line.split()[1]
async with self.filter_map_lock:
if regex_id in self.filter_map:
self.filter_map[regex_id].blocked+=1
if self.callback_blocked_update: self.callback_blocked_update(self.filter_map[regex_id])
except Exception:
return await self.process.wait()
else:
self.status_change.release()
async def stop(self):
async with self.status_change:
if self.isactive():
self.process.kill()
return False
return True
async def restart(self, in_pause=False):
status = await self.stop()
await self.start(in_pause=in_pause)
return status
async def update_config(self, filters_codes):
async with self.update_config_lock:
if (self.isactive()):
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain()
async def reload(self):
if self.isactive():
async with self.filter_map_lock:
self.filter_map = self.compile_filters()
filters_codes = self.get_filter_codes()
await self.update_config(filters_codes)
def get_filter_codes(self):
filters_codes = list(self.filter_map.keys())
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
return filters_codes
def isactive(self):
return self.process and self.process.returncode is None
async def pause(self):
if self.isactive():
await self.update_config([])
else:
await self.start(in_pause=True)
def compile_filters(self):
res = {}
for filter_obj in self.filters:
try:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception: pass
return res

View File

@@ -1,199 +0,0 @@
import secrets
from modules.regexproxy.proxy import Filter, Proxy
import random, socket, asyncio
from base64 import b64decode
from utils.sqlite import SQLite
from utils import socketio_emit
class STATUS:
WAIT = "wait"
STOP = "stop"
PAUSE = "pause"
ACTIVE = "active"
class ServiceNotFoundException(Exception): pass
class ServiceManager:
def __init__(self, id, db):
self.id = id
self.db = db
self.proxy = Proxy(
internal_host="127.0.0.1",
callback_blocked_update=self._stats_updater
)
self.status = STATUS.STOP
self.wanted_status = STATUS.STOP
self.filters = {}
self._update_port_from_db()
self._update_filters_from_db()
self.lock = asyncio.Lock()
self.starter = None
def _update_port_from_db(self):
res = self.db.query("""
SELECT
public_port,
internal_port
FROM services WHERE service_id = ?;
""", self.id)
if len(res) == 0: raise ServiceNotFoundException()
self.proxy.internal_port = res[0]["internal_port"]
self.proxy.public_port = res[0]["public_port"]
def _update_filters_from_db(self):
res = self.db.query("""
SELECT
regex, mode, regex_id `id`, is_blacklist,
blocked_packets n_packets, is_case_sensitive
FROM regexes WHERE service_id = ? AND active=1;
""", self.id)
#Filter check
old_filters = set(self.filters.keys())
new_filters = set([f["id"] for f in res])
#remove old filters
for f in old_filters:
if not f in new_filters:
del self.filters[f]
for f in new_filters:
if not f in old_filters:
filter_info = [ele for ele in res if ele["id"] == f][0]
self.filters[f] = Filter(
is_case_sensitive=filter_info["is_case_sensitive"],
c_to_s=filter_info["mode"] in ["C","B"],
s_to_c=filter_info["mode"] in ["S","B"],
is_blacklist=filter_info["is_blacklist"],
regex=b64decode(filter_info["regex"]),
blocked_packets=filter_info["n_packets"],
code=f
)
self.proxy.filters = list(self.filters.values())
def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.id)
async def next(self,to):
async with self.lock:
return await self._next(to)
async def _next(self, to):
if self.status != to:
# ACTIVE -> PAUSE or PAUSE -> ACTIVE
if (self.status, to) in [(STATUS.ACTIVE, STATUS.PAUSE)]:
await self.proxy.pause()
self._set_status(to)
elif (self.status, to) in [(STATUS.PAUSE, STATUS.ACTIVE)]:
await self.proxy.reload()
self._set_status(to)
# ACTIVE -> STOP
elif (self.status,to) in [(STATUS.ACTIVE, STATUS.STOP), (STATUS.WAIT, STATUS.STOP), (STATUS.PAUSE, STATUS.STOP)]: #Stop proxy
if self.starter: self.starter.cancel()
await self.proxy.stop()
self._set_status(to)
# STOP -> ACTIVE or STOP -> PAUSE
elif (self.status, to) in [(STATUS.STOP, STATUS.ACTIVE), (STATUS.STOP, STATUS.PAUSE)]:
self.wanted_status = to
self._set_status(STATUS.WAIT)
self.__proxy_starter(to)
def _stats_updater(self,filter:Filter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
async def update_port(self):
async with self.lock:
self._update_port_from_db()
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
next_status = self.status if self.status != STATUS.WAIT else self.wanted_status
await self._next(STATUS.STOP)
await self._next(next_status)
def _set_status(self,status):
self.status = status
self.__update_status_db(status)
async def update_filters(self):
async with self.lock:
self._update_filters_from_db()
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
await self.proxy.reload()
def __proxy_starter(self,to):
async def func():
try:
while True:
if check_port_is_open(self.proxy.public_port):
self._set_status(to)
await socketio_emit(["regexproxy"])
await self.proxy.start(in_pause=(to==STATUS.PAUSE))
self._set_status(STATUS.STOP)
await socketio_emit(["regexproxy"])
return
else:
await asyncio.sleep(.5)
except asyncio.CancelledError:
self._set_status(STATUS.STOP)
await self.proxy.stop()
self.starter = asyncio.create_task(func())
class ProxyManager:
def __init__(self, db:SQLite):
self.db = db
self.proxy_table: dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
async def close(self):
for key in list(self.proxy_table.keys()):
await self.remove(key)
async def remove(self,id):
async with self.lock:
if id in self.proxy_table:
await self.proxy_table[id].next(STATUS.STOP)
del self.proxy_table[id]
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT service_id, status FROM services;'):
srv_id, req_status = srv["service_id"], srv["status"]
if srv_id in self.proxy_table:
continue
self.proxy_table[srv_id] = ServiceManager(srv_id,self.db)
await self.proxy_table[srv_id].next(req_status)
def get(self,id) -> ServiceManager:
if id in self.proxy_table:
return self.proxy_table[id]
else:
raise ServiceNotFoundException()
def check_port_is_open(port):
try:
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('0.0.0.0',port))
sock.close()
return True
except Exception:
return False
def gen_service_id(db):
while True:
res = secrets.token_hex(8)
if len(db.query('SELECT 1 FROM services WHERE service_id = ?;', res)) == 0:
break
return res
def gen_internal_port(db):
while True:
res = random.randint(30000, 45000)
if len(db.query('SELECT 1 FROM services WHERE internal_port = ?;', res)) == 0:
break
return res

View File

@@ -5,7 +5,7 @@ from utils import ip_parse, ip_family, socketio_emit
from utils.models import ResetRequest, StatusMessageModel
from modules.firewall.nftables import FiregexTables
from modules.firewall.firewall import FirewallManager
from modules.firewall.models import *
from modules.firewall.models import FirewallSettings, RuleInfo, RuleModel, RuleFormAdd, Mode, Table
db = SQLite('db/firewall-rules.db', {
'rules': {

View File

@@ -28,7 +28,6 @@ class RegexModel(BaseModel):
mode:str
id:int
service_id:str
is_blacklist: bool
n_packets:int
is_case_sensitive:bool
active:bool
@@ -38,7 +37,6 @@ class RegexAddForm(BaseModel):
regex: str
mode: str
active: bool|None = None
is_blacklist: bool
is_case_sensitive: bool
class ServiceAddForm(BaseModel):
@@ -66,7 +64,6 @@ db = SQLite('db/nft-regex.db', {
'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL CHECK (mode IN ("C", "S", "B"))', # C = to the client, S = to the server, B = both
'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'regex_id': 'INTEGER PRIMARY KEY',
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
@@ -75,7 +72,7 @@ db = SQLite('db/nft-regex.db', {
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_services ON services (port, ip_int, proto);",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);"
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,mode,is_case_sensitive);"
]
})
@@ -92,12 +89,18 @@ async def reset(params: ResetRequest):
db.init()
else:
db.restore()
await firewall.init()
try:
await firewall.init()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def startup():
db.init()
await firewall.init()
try:
await firewall.init()
except Exception as e:
print("WARNING cannot start firewall:", e)
async def shutdown():
db.backup()
@@ -147,7 +150,8 @@ async def get_service_by_id(service_id: str):
FROM services s LEFT JOIN regexes r ON s.service_id = r.service_id
WHERE s.service_id = ? GROUP BY s.service_id;
""", service_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
if len(res) == 0:
raise HTTPException(status_code=400, detail="This service does not exists!")
return res[0]
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
@@ -177,7 +181,8 @@ async def service_delete(service_id: str):
async def service_rename(service_id: str, form: RenameForm):
"""Request to change the name of a specific service"""
form.name = refactor_name(form.name)
if not form.name: raise HTTPException(status_code=400, detail="The name cannot be empty!")
if not form.name:
raise HTTPException(status_code=400, detail="The name cannot be empty!")
try:
db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id)
except sqlite3.IntegrityError:
@@ -188,10 +193,11 @@ async def service_rename(service_id: str, form: RenameForm):
@app.get('/service/{service_id}/regexes', response_model=list[RegexModel])
async def get_service_regexe_list(service_id: str):
"""Get the list of the regexes of a service"""
if not db.query("SELECT 1 FROM services s WHERE s.service_id = ?;", service_id): raise HTTPException(status_code=400, detail="This service does not exists!")
if not db.query("SELECT 1 FROM services s WHERE s.service_id = ?;", service_id):
raise HTTPException(status_code=400, detail="This service does not exists!")
return db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
regex, mode, regex_id `id`, service_id,
blocked_packets n_packets, is_case_sensitive, active
FROM regexes WHERE service_id = ?;
""", service_id)
@@ -201,11 +207,12 @@ async def get_regex_by_id(regex_id: int):
"""Get regex info using his id"""
res = db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
regex, mode, regex_id `id`, service_id,
blocked_packets n_packets, is_case_sensitive, active
FROM regexes WHERE `id` = ?;
""", regex_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This regex does not exists!")
if len(res) == 0:
raise HTTPException(status_code=400, detail="This regex does not exists!")
return res[0]
@app.get('/regex/{regex_id}/delete', response_model=StatusMessageModel)
@@ -247,8 +254,8 @@ async def add_new_regex(form: RegexAddForm):
except Exception:
raise HTTPException(status_code=400, detail="Invalid regex")
try:
db.query("INSERT INTO regexes (service_id, regex, is_blacklist, mode, is_case_sensitive, active ) VALUES (?, ?, ?, ?, ?, ?);",
form.service_id, form.regex, form.is_blacklist, form.mode, form.is_case_sensitive, True if form.active is None else form.active )
db.query("INSERT INTO regexes (service_id, regex, mode, is_case_sensitive, active ) VALUES (?, ?, ?, ?, ?);",
form.service_id, form.regex, form.mode, form.is_case_sensitive, True if form.active is None else form.active )
except sqlite3.IntegrityError:
raise HTTPException(status_code=400, detail="An identical regex already exists")

View File

@@ -96,7 +96,8 @@ async def get_service_list():
async def get_service_by_id(service_id: str):
"""Get info about a specific service using his id"""
res = db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_src, ip_dst FROM services WHERE service_id = ?;", service_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
if len(res) == 0:
raise HTTPException(status_code=400, detail="This service does not exists!")
return res[0]
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
@@ -125,7 +126,8 @@ async def service_delete(service_id: str):
async def service_rename(service_id: str, form: RenameForm):
"""Request to change the name of a specific service"""
form.name = refactor_name(form.name)
if not form.name: raise HTTPException(status_code=400, detail="The name cannot be empty!")
if not form.name:
raise HTTPException(status_code=400, detail="The name cannot be empty!")
try:
db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id)
except sqlite3.IntegrityError:

View File

@@ -1,311 +0,0 @@
from base64 import b64decode
import sqlite3, re
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from modules.regexproxy.utils import STATUS, ProxyManager, gen_internal_port, gen_service_id
from utils.sqlite import SQLite
from utils.models import ResetRequest, StatusMessageModel
from utils import refactor_name, socketio_emit, PortType
app = APIRouter()
db = SQLite("db/regextcpproxy.db",{
'services': {
'status': 'VARCHAR(100) NOT NULL',
'service_id': 'VARCHAR(100) PRIMARY KEY',
'internal_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536)',
'public_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE',
'name': 'VARCHAR(100) NOT NULL UNIQUE'
},
'regexes': {
'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL',
'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
'regex_id': 'INTEGER PRIMARY KEY',
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
'active' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1)) DEFAULT 1',
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);"
]
})
firewall = ProxyManager(db)
async def reset(params: ResetRequest):
if not params.delete:
db.backup()
await firewall.close()
if params.delete:
db.delete()
db.init()
else:
db.restore()
await firewall.reload()
async def startup():
db.init()
await firewall.reload()
async def shutdown():
db.backup()
await firewall.close()
db.disconnect()
db.restore()
async def refresh_frontend(additional:list[str]=[]):
await socketio_emit(["regexproxy"]+additional)
class GeneralStatModel(BaseModel):
closed:int
regexes: int
services: int
@app.get('/stats', response_model=GeneralStatModel)
async def get_general_stats():
"""Get firegex general status about services"""
return db.query("""
SELECT
(SELECT COALESCE(SUM(blocked_packets),0) FROM regexes) closed,
(SELECT COUNT(*) FROM regexes) regexes,
(SELECT COUNT(*) FROM services) services
""")[0]
class ServiceModel(BaseModel):
id:str
status: str
public_port: PortType
internal_port: PortType
name: str
n_regex: int
n_packets: int
@app.get('/services', response_model=list[ServiceModel])
async def get_service_list():
"""Get the list of existent firegex services"""
return db.query("""
SELECT
s.service_id `id`,
s.status status,
s.public_port public_port,
s.internal_port internal_port,
s.name name,
COUNT(r.regex_id) n_regex,
COALESCE(SUM(r.blocked_packets),0) n_packets
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id
GROUP BY s.service_id;
""")
@app.get('/service/{service_id}', response_model=ServiceModel)
async def get_service_by_id(service_id: str):
"""Get info about a specific service using his id"""
res = db.query("""
SELECT
s.service_id `id`,
s.status status,
s.public_port public_port,
s.internal_port internal_port,
s.name name,
COUNT(r.regex_id) n_regex,
COALESCE(SUM(r.blocked_packets),0) n_packets
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id WHERE s.service_id = ?
GROUP BY s.service_id;
""", service_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
return res[0]
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
async def service_stop(service_id: str):
"""Request the stop of a specific service"""
await firewall.get(service_id).next(STATUS.STOP)
await refresh_frontend()
return {'status': 'ok'}
@app.get('/service/{service_id}/pause', response_model=StatusMessageModel)
async def service_pause(service_id: str):
"""Request the pause of a specific service"""
await firewall.get(service_id).next(STATUS.PAUSE)
await refresh_frontend()
return {'status': 'ok'}
@app.get('/service/{service_id}/start', response_model=StatusMessageModel)
async def service_start(service_id: str):
"""Request the start of a specific service"""
await firewall.get(service_id).next(STATUS.ACTIVE)
await refresh_frontend()
return {'status': 'ok'}
@app.get('/service/{service_id}/delete', response_model=StatusMessageModel)
async def service_delete(service_id: str):
"""Request the deletion of a specific service"""
db.query('DELETE FROM services WHERE service_id = ?;', service_id)
db.query('DELETE FROM regexes WHERE service_id = ?;', service_id)
await firewall.remove(service_id)
await refresh_frontend()
return {'status': 'ok'}
@app.get('/service/{service_id}/regen-port', response_model=StatusMessageModel)
async def regen_service_port(service_id: str):
"""Request the regeneration of a the internal proxy port of a specific service"""
db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', gen_internal_port(db), service_id)
await firewall.get(service_id).update_port()
await refresh_frontend()
return {'status': 'ok'}
class ChangePortForm(BaseModel):
port: int|None = None
internalPort: int|None = None
@app.post('/service/{service_id}/change-ports', response_model=StatusMessageModel)
async def change_service_ports(service_id: str, change_port:ChangePortForm):
"""Choose and change the ports of the service"""
if change_port.port is None and change_port.internalPort is None:
raise HTTPException(status_code=400, detail="Invalid Request!")
try:
sql_inj = ""
query:list[str|int] = []
if not change_port.port is None:
sql_inj+=" public_port = ? "
query.append(change_port.port)
if not change_port.port is None and not change_port.internalPort is None:
sql_inj += ","
if not change_port.internalPort is None:
sql_inj+=" internal_port = ? "
query.append(change_port.internalPort)
query.append(service_id)
db.query(f'UPDATE services SET {sql_inj} WHERE service_id = ?;', *query)
except sqlite3.IntegrityError:
raise HTTPException(status_code=400, detail="Port of the service has been already assigned to another service")
await firewall.get(service_id).update_port()
await refresh_frontend()
return {'status': 'ok'}
class RegexModel(BaseModel):
regex:str
mode:str
id:int
service_id:str
is_blacklist: bool
n_packets:int
is_case_sensitive:bool
active:bool
@app.get('/service/{service_id}/regexes', response_model=list[RegexModel])
async def get_service_regexe_list(service_id: str):
"""Get the list of the regexes of a service"""
if not db.query("SELECT 1 FROM services s WHERE s.service_id = ?;", service_id): raise HTTPException(status_code=400, detail="This service does not exists!")
return db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
blocked_packets n_packets, is_case_sensitive, active
FROM regexes WHERE service_id = ?;
""", service_id)
@app.get('/regex/{regex_id}', response_model=RegexModel)
async def get_regex_by_id(regex_id: int):
"""Get regex info using his id"""
res = db.query("""
SELECT
regex, mode, regex_id `id`, service_id, is_blacklist,
blocked_packets n_packets, is_case_sensitive, active
FROM regexes WHERE `id` = ?;
""", regex_id)
if len(res) == 0: raise HTTPException(status_code=400, detail="This regex does not exists!")
return res[0]
@app.get('/regex/{regex_id}/delete', response_model=StatusMessageModel)
async def regex_delete(regex_id: int):
"""Delete a regex using his id"""
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
db.query('DELETE FROM regexes WHERE regex_id = ?;', regex_id)
await firewall.get(res[0]["service_id"]).update_filters()
await refresh_frontend()
return {'status': 'ok'}
@app.get('/regex/{regex_id}/enable', response_model=StatusMessageModel)
async def regex_enable(regex_id: int):
"""Request the enabling of a regex"""
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
db.query('UPDATE regexes SET active=1 WHERE regex_id = ?;', regex_id)
await firewall.get(res[0]["service_id"]).update_filters()
await refresh_frontend()
return {'status': 'ok'}
@app.get('/regex/{regex_id}/disable', response_model=StatusMessageModel)
async def regex_disable(regex_id: int):
"""Request the deactivation of a regex"""
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
if len(res) != 0:
db.query('UPDATE regexes SET active=0 WHERE regex_id = ?;', regex_id)
await firewall.get(res[0]["service_id"]).update_filters()
await refresh_frontend()
return {'status': 'ok'}
class RegexAddForm(BaseModel):
service_id: str
regex: str
mode: str
active: bool|None = None
is_blacklist: bool
is_case_sensitive: bool
@app.post('/regexes/add', response_model=StatusMessageModel)
async def add_new_regex(form: RegexAddForm):
"""Add a new regex"""
try:
re.compile(b64decode(form.regex))
except Exception:
raise HTTPException(status_code=400, detail="Invalid regex")
try:
db.query("INSERT INTO regexes (service_id, regex, is_blacklist, mode, is_case_sensitive, active ) VALUES (?, ?, ?, ?, ?, ?);",
form.service_id, form.regex, form.is_blacklist, form.mode, form.is_case_sensitive, True if form.active is None else form.active )
except sqlite3.IntegrityError:
raise HTTPException(status_code=400, detail="An identical regex already exists")
await firewall.get(form.service_id).update_filters()
await refresh_frontend()
return {'status': 'ok'}
class ServiceAddForm(BaseModel):
name: str
port: PortType
internalPort: int|None = None
class ServiceAddStatus(BaseModel):
status:str
id: str|None = None
class RenameForm(BaseModel):
name:str
@app.post('/service/{service_id}/rename', response_model=StatusMessageModel)
async def service_rename(service_id: str, form: RenameForm):
"""Request to change the name of a specific service"""
form.name = refactor_name(form.name)
if not form.name: raise HTTPException(status_code=400, detail="The name cannot be empty!")
try:
db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id)
except sqlite3.IntegrityError:
raise HTTPException(status_code=400, detail="The name is already used!")
await refresh_frontend()
return {'status': 'ok'}
@app.post('/services/add', response_model=ServiceAddStatus)
async def add_new_service(form: ServiceAddForm):
"""Add a new service"""
serv_id = gen_service_id(db)
form.name = refactor_name(form.name)
try:
internal_port = form.internalPort if form.internalPort else gen_internal_port(db)
db.query("INSERT INTO services (name, service_id, internal_port, public_port, status) VALUES (?, ?, ?, ?, ?)",
form.name, serv_id, internal_port, form.port, 'stop')
except sqlite3.IntegrityError:
raise HTTPException(status_code=400, detail="Name or/and ports of the service has been already assigned to another service")
await firewall.reload()
await refresh_frontend()
return {'status': 'ok', "id": serv_id }

View File

@@ -1,10 +1,13 @@
import asyncio
from ipaddress import ip_address, ip_interface
import os, socket, psutil, sys, nftables
import os
import socket
import psutil
import sys
import nftables
from fastapi_socketio import SocketManager
from fastapi import Path
from typing import Annotated
import json
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
@@ -16,7 +19,7 @@ ON_DOCKER = "DOCKER" in sys.argv
DEBUG = "DEBUG" in sys.argv
FIREGEX_PORT = int(os.getenv("PORT","4444"))
JWT_ALGORITHM: str = "HS256"
API_VERSION = "2.2.0"
API_VERSION = "3.0.0"
PortType = Annotated[int, Path(gt=0, lt=65536)]
@@ -31,7 +34,8 @@ async def socketio_emit(elements:list[str]):
def refactor_name(name:str):
name = name.strip()
while " " in name: name = name.replace(" "," ")
while " " in name:
name = name.replace(" "," ")
return name
class SysctlManager:
@@ -125,8 +129,10 @@ class NFTableManager(Singleton):
def cmd(self, *cmds):
code, out, err = self.raw_cmd(*cmds)
if code == 0: return out
else: raise Exception(err)
if code == 0:
return out
else:
raise Exception(err)
def init(self):
self.reset()
@@ -138,8 +144,10 @@ class NFTableManager(Singleton):
def list_rules(self, tables = None, chains = None):
for filter in [ele["rule"] for ele in self.raw_list() if "rule" in ele ]:
if tables and filter["table"] not in tables: continue
if chains and filter["chain"] not in chains: continue
if tables and filter["table"] not in tables:
continue
if chains and filter["chain"] not in chains:
continue
yield filter
def raw_list(self):

View File

@@ -1,5 +1,6 @@
import os, httpx
import os
import httpx
from typing import Callable
from fastapi import APIRouter
from starlette.responses import StreamingResponse
@@ -31,7 +32,8 @@ def frontend_deploy(app):
return await frontend_debug_proxy(full_path)
except Exception:
return {"details":"Frontend not started at "+f"http://127.0.0.1:{os.getenv('F_PORT','5173')}"}
else: return await react_deploy(full_path)
else:
return await react_deploy(full_path)
def list_routers():
return [ele[:-3] for ele in list_files(ROUTERS_DIR) if ele != "__init__.py" and " " not in ele and ele.endswith(".py")]
@@ -79,9 +81,12 @@ def load_routers(app):
if router.shutdown:
shutdowns.append(router.shutdown)
async def reset(reset_option:ResetRequest):
for func in resets: await run_func(func, reset_option)
for func in resets:
await run_func(func, reset_option)
async def startup():
for func in startups: await run_func(func)
for func in startups:
await run_func(func)
async def shutdown():
for func in shutdowns: await run_func(func)
for func in shutdowns:
await run_func(func)
return reset, startup, shutdown

View File

@@ -1,4 +1,6 @@
import json, sqlite3, os
import json
import sqlite3
import os
from hashlib import md5
class SQLite():
@@ -15,8 +17,10 @@ class SQLite():
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
except Exception:
path_name = os.path.dirname(self.db_name)
if not os.path.exists(path_name): os.makedirs(path_name)
with open(self.db_name, 'x'): pass
if not os.path.exists(path_name):
os.makedirs(path_name)
with open(self.db_name, 'x'):
pass
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
def dict_factory(cursor, row):
d = {}
@@ -36,13 +40,15 @@ class SQLite():
with open(self.db_name, "wb") as f:
f.write(self.__backup)
self.__backup = None
if were_active: self.connect()
if were_active:
self.connect()
def delete_backup(self):
self.__backup = None
def disconnect(self) -> None:
if self.conn: self.conn.close()
if self.conn:
self.conn.close()
self.conn = None
def create_schema(self, tables = {}) -> None:
@@ -50,9 +56,11 @@ class SQLite():
cur = self.conn.cursor()
cur.execute("CREATE TABLE IF NOT EXISTS main.keys_values(key VARCHAR(100) PRIMARY KEY, value VARCHAR(100) NOT NULL);")
for t in tables:
if t == "QUERY": continue
if t == "QUERY":
continue
cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2]))
if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]]
if "QUERY" in tables:
[cur.execute(qry) for qry in tables["QUERY"]]
cur.close()
def query(self, query, *values):
@@ -82,8 +90,10 @@ class SQLite():
raise e
finally:
cur.close()
try: self.conn.commit()
except Exception: pass
try:
self.conn.commit()
except Exception:
pass
def delete(self):
self.disconnect()
@@ -92,7 +102,8 @@ class SQLite():
def init(self):
self.connect()
try:
if self.get('DB_VERSION') != self.DB_VER: raise Exception("DB_VERSION is not correct")
if self.get('DB_VERSION') != self.DB_VER:
raise Exception("DB_VERSION is not correct")
except Exception:
self.delete()
self.connect()