nfqueue to hyperscan and stream match, removed proxyregex
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
import uvicorn, secrets, utils
|
||||
import os, asyncio, logging
|
||||
from fastapi import FastAPI, HTTPException, Depends, APIRouter, Request
|
||||
import uvicorn
|
||||
import secrets
|
||||
import utils
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from fastapi import FastAPI, HTTPException, Depends, APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
@@ -94,7 +98,8 @@ async def get_app_status(auth: bool = Depends(check_login)):
|
||||
@app.post("/api/login")
|
||||
async def login_api(form: OAuth2PasswordRequestForm = Depends()):
|
||||
"""Get a login token to use the firegex api"""
|
||||
if APP_STATUS() != "run": raise HTTPException(status_code=400)
|
||||
if APP_STATUS() != "run":
|
||||
raise HTTPException(status_code=400)
|
||||
if form.password == "":
|
||||
return {"status":"Cannot insert an empty password!"}
|
||||
await asyncio.sleep(0.3) # No bruteforce :)
|
||||
@@ -105,7 +110,8 @@ async def login_api(form: OAuth2PasswordRequestForm = Depends()):
|
||||
@app.post('/api/set-password', response_model=ChangePasswordModel)
|
||||
async def set_password(form: PasswordForm):
|
||||
"""Set the password of firegex"""
|
||||
if APP_STATUS() != "init": raise HTTPException(status_code=400)
|
||||
if APP_STATUS() != "init":
|
||||
raise HTTPException(status_code=400)
|
||||
if form.password == "":
|
||||
return {"status":"Cannot insert an empty password!"}
|
||||
set_psw(form.password)
|
||||
@@ -115,7 +121,8 @@ async def set_password(form: PasswordForm):
|
||||
@api.post('/change-password', response_model=ChangePasswordModel)
|
||||
async def change_password(form: PasswordChangeForm):
|
||||
"""Change the password of firegex"""
|
||||
if APP_STATUS() != "run": raise HTTPException(status_code=400)
|
||||
if APP_STATUS() != "run":
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
if form.password == "":
|
||||
return {"status":"Cannot insert an empty password!"}
|
||||
@@ -144,7 +151,8 @@ async def startup_main():
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting sysctls: {e}")
|
||||
await startup()
|
||||
if not JWT_SECRET(): db.put("secret", secrets.token_hex(32))
|
||||
if not JWT_SECRET():
|
||||
db.put("secret", secrets.token_hex(32))
|
||||
await refresh_frontend()
|
||||
|
||||
async def shutdown_main():
|
||||
|
||||
485
backend/binsrc/classes/netfilter.cpp
Normal file
485
backend/binsrc/classes/netfilter.cpp
Normal file
@@ -0,0 +1,485 @@
|
||||
#include <linux/netfilter/nfnetlink_queue.h>
|
||||
#include <libnetfilter_queue/libnetfilter_queue.h>
|
||||
#include <linux/netfilter/nfnetlink_conntrack.h>
|
||||
#include <tins/tins.h>
|
||||
#include <tins/tcp_ip/stream_follower.h>
|
||||
#include <libmnl/libmnl.h>
|
||||
#include <linux/netfilter.h>
|
||||
#include <linux/netfilter/nfnetlink.h>
|
||||
#include <linux/types.h>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <hs.h>
|
||||
#include <iostream>
|
||||
|
||||
using Tins::TCPIP::Stream;
|
||||
using Tins::TCPIP::StreamFollower;
|
||||
using namespace std;
|
||||
|
||||
|
||||
#ifndef NETFILTER_CLASSES_HPP
|
||||
#define NETFILTER_CLASSES_HPP
|
||||
|
||||
string inline client_endpoint(const Stream& stream) {
|
||||
ostringstream output;
|
||||
// Use the IPv4 or IPv6 address depending on which protocol the
|
||||
// connection uses
|
||||
if (stream.is_v6()) {
|
||||
output << stream.client_addr_v6();
|
||||
}
|
||||
else {
|
||||
output << stream.client_addr_v4();
|
||||
}
|
||||
output << ":" << stream.client_port();
|
||||
return output.str();
|
||||
}
|
||||
|
||||
// Convert the server endpoint to a readable string
|
||||
string inline server_endpoint(const Stream& stream) {
|
||||
ostringstream output;
|
||||
if (stream.is_v6()) {
|
||||
output << stream.server_addr_v6();
|
||||
}
|
||||
else {
|
||||
output << stream.server_addr_v4();
|
||||
}
|
||||
output << ":" << stream.server_port();
|
||||
return output.str();
|
||||
}
|
||||
|
||||
// Concat both endpoints to get a readable stream identifier
|
||||
string inline stream_identifier(const Stream& stream) {
|
||||
ostringstream output;
|
||||
output << client_endpoint(stream) << " - " << server_endpoint(stream);
|
||||
return output.str();
|
||||
}
|
||||
|
||||
typedef unordered_map<string, hs_stream_t*> matching_map;
|
||||
|
||||
struct packet_info;
|
||||
|
||||
struct tcp_stream_tmp {
|
||||
bool matching_has_been_called = false;
|
||||
bool result;
|
||||
packet_info *pkt_info;
|
||||
};
|
||||
|
||||
struct stream_ctx {
|
||||
matching_map in_hs_streams;
|
||||
matching_map out_hs_streams;
|
||||
hs_scratch_t* in_scratch = nullptr;
|
||||
hs_scratch_t* out_scratch = nullptr;
|
||||
u_int16_t latest_config_ver = 0;
|
||||
StreamFollower follower;
|
||||
mnl_socket* nl;
|
||||
tcp_stream_tmp tcp_match_util;
|
||||
|
||||
void clean_scratches(){
|
||||
if (out_scratch != nullptr){
|
||||
hs_free_scratch(out_scratch);
|
||||
out_scratch = nullptr;
|
||||
}
|
||||
if (in_scratch != nullptr){
|
||||
hs_free_scratch(in_scratch);
|
||||
in_scratch = nullptr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct packet_info {
|
||||
string packet;
|
||||
string payload;
|
||||
string stream_id;
|
||||
bool is_input;
|
||||
bool is_tcp;
|
||||
stream_ctx* sctx;
|
||||
};
|
||||
|
||||
typedef bool NetFilterQueueCallback(packet_info &);
|
||||
|
||||
|
||||
Tins::PDU * find_transport_layer(Tins::PDU* pkt){
|
||||
while(pkt != nullptr){
|
||||
if (pkt->pdu_type() == Tins::PDU::TCP || pkt->pdu_type() == Tins::PDU::UDP) {
|
||||
return pkt;
|
||||
}
|
||||
pkt = pkt->inner_pdu();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <NetFilterQueueCallback callback_func>
|
||||
class NetfilterQueue {
|
||||
public:
|
||||
|
||||
size_t BUF_SIZE = 0xffff + (MNL_SOCKET_BUFFER_SIZE/2);
|
||||
char *buf = nullptr;
|
||||
unsigned int portid;
|
||||
u_int16_t queue_num;
|
||||
stream_ctx sctx;
|
||||
|
||||
NetfilterQueue(u_int16_t queue_num): queue_num(queue_num) {
|
||||
sctx.nl = mnl_socket_open(NETLINK_NETFILTER);
|
||||
|
||||
if (sctx.nl == nullptr) { throw runtime_error( "mnl_socket_open" );}
|
||||
|
||||
if (mnl_socket_bind(sctx.nl, 0, MNL_SOCKET_AUTOPID) < 0) {
|
||||
mnl_socket_close(sctx.nl);
|
||||
throw runtime_error( "mnl_socket_bind" );
|
||||
}
|
||||
portid = mnl_socket_get_portid(sctx.nl);
|
||||
|
||||
buf = (char*) malloc(BUF_SIZE);
|
||||
|
||||
if (!buf) {
|
||||
mnl_socket_close(sctx.nl);
|
||||
throw runtime_error( "allocate receive buffer" );
|
||||
}
|
||||
|
||||
if (send_config_cmd(NFQNL_CFG_CMD_BIND) < 0) {
|
||||
_clear();
|
||||
throw runtime_error( "mnl_socket_send" );
|
||||
}
|
||||
//TEST if BIND was successful
|
||||
if (send_config_cmd(NFQNL_CFG_CMD_NONE) < 0) { // SEND A NONE cmmand to generate an error meessage
|
||||
_clear();
|
||||
throw runtime_error( "mnl_socket_send" );
|
||||
}
|
||||
if (recv_packet() == -1) { //RECV the error message
|
||||
_clear();
|
||||
throw runtime_error( "mnl_socket_recvfrom" );
|
||||
}
|
||||
|
||||
struct nlmsghdr *nlh = (struct nlmsghdr *) buf;
|
||||
|
||||
if (nlh->nlmsg_type != NLMSG_ERROR) {
|
||||
_clear();
|
||||
throw runtime_error( "unexpected packet from kernel (expected NLMSG_ERROR packet)" );
|
||||
}
|
||||
//nfqnl_msg_config_cmd
|
||||
nlmsgerr* error_msg = (nlmsgerr *)mnl_nlmsg_get_payload(nlh);
|
||||
|
||||
// error code taken from the linux kernel:
|
||||
// https://elixir.bootlin.com/linux/v5.18.12/source/include/linux/errno.h#L27
|
||||
#define ENOTSUPP 524 /* Operation is not supported */
|
||||
|
||||
if (error_msg->error != -ENOTSUPP) {
|
||||
_clear();
|
||||
throw invalid_argument( "queueid is already busy" );
|
||||
}
|
||||
|
||||
//END TESTING BIND
|
||||
nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
|
||||
nfq_nlmsg_cfg_put_params(nlh, NFQNL_COPY_PACKET, 0xffff);
|
||||
|
||||
mnl_attr_put_u32(nlh, NFQA_CFG_FLAGS, htonl(NFQA_CFG_F_GSO));
|
||||
mnl_attr_put_u32(nlh, NFQA_CFG_MASK, htonl(NFQA_CFG_F_GSO));
|
||||
|
||||
if (mnl_socket_sendto(sctx.nl, nlh, nlh->nlmsg_len) < 0) {
|
||||
_clear();
|
||||
throw runtime_error( "mnl_socket_send" );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//Input data filtering
|
||||
void on_client_data(Stream& stream) {
|
||||
string data(stream.client_payload().begin(), stream.client_payload().end());
|
||||
string stream_id = stream_identifier(stream);
|
||||
this->sctx.tcp_match_util.pkt_info->is_input = true;
|
||||
this->sctx.tcp_match_util.pkt_info->stream_id = stream_id;
|
||||
this->sctx.tcp_match_util.matching_has_been_called = true;
|
||||
bool result = callback_func(*sctx.tcp_match_util.pkt_info);
|
||||
if (result){
|
||||
this->clean_stream_by_id(stream_id);
|
||||
stream.ignore_client_data();
|
||||
stream.ignore_server_data();
|
||||
}
|
||||
this->sctx.tcp_match_util.result = result;
|
||||
}
|
||||
|
||||
//Server data filtering
|
||||
void on_server_data(Stream& stream) {
|
||||
string data(stream.server_payload().begin(), stream.server_payload().end());
|
||||
string stream_id = stream_identifier(stream);
|
||||
this->sctx.tcp_match_util.pkt_info->is_input = false;
|
||||
this->sctx.tcp_match_util.pkt_info->stream_id = stream_id;
|
||||
this->sctx.tcp_match_util.matching_has_been_called = true;
|
||||
bool result = callback_func(*sctx.tcp_match_util.pkt_info);
|
||||
if (result){
|
||||
this->clean_stream_by_id(stream_id);
|
||||
stream.ignore_client_data();
|
||||
stream.ignore_server_data();
|
||||
}
|
||||
this->sctx.tcp_match_util.result = result;
|
||||
}
|
||||
|
||||
void on_new_stream(Stream& stream) {
|
||||
string stream_id = stream_identifier(stream);
|
||||
if (stream.is_partial_stream()) {
|
||||
return;
|
||||
}
|
||||
cout << "[+] New connection " << stream_id << endl;
|
||||
stream.auto_cleanup_payloads(true);
|
||||
stream.client_data_callback(
|
||||
[&](auto a){this->on_client_data(a);}
|
||||
);
|
||||
stream.server_data_callback(
|
||||
[&](auto a){this->on_server_data(a);}
|
||||
);
|
||||
}
|
||||
|
||||
void clean_stream_by_id(string stream_id){
|
||||
auto stream_search = this->sctx.in_hs_streams.find(stream_id);
|
||||
hs_stream_t* stream_match;
|
||||
if (stream_search != this->sctx.in_hs_streams.end()){
|
||||
stream_match = stream_search->second;
|
||||
if (hs_close_stream(stream_match, sctx.in_scratch, nullptr, nullptr) != HS_SUCCESS) {
|
||||
cerr << "[error] [NetfilterQueue.clean_stream_by_id] Error closing the stream matcher (hs)" << endl;
|
||||
throw invalid_argument("Cannot close stream match on hyperscan");
|
||||
}
|
||||
this->sctx.in_hs_streams.erase(stream_search);
|
||||
}
|
||||
|
||||
stream_search = this->sctx.out_hs_streams.find(stream_id);
|
||||
if (stream_search != this->sctx.out_hs_streams.end()){
|
||||
stream_match = stream_search->second;
|
||||
if (hs_close_stream(stream_match, sctx.out_scratch, nullptr, nullptr) != HS_SUCCESS) {
|
||||
cerr << "[error] [NetfilterQueue.clean_stream_by_id] Error closing the stream matcher (hs)" << endl;
|
||||
throw invalid_argument("Cannot close stream match on hyperscan");
|
||||
}
|
||||
this->sctx.out_hs_streams.erase(stream_search);
|
||||
}
|
||||
}
|
||||
|
||||
// A stream was terminated. The second argument is the reason why it was terminated
|
||||
void on_stream_terminated(Stream& stream, StreamFollower::TerminationReason reason) {
|
||||
string stream_id = stream_identifier(stream);
|
||||
cout << "[+] Connection closed: " << stream_id << endl;
|
||||
this->clean_stream_by_id(stream_id);
|
||||
}
|
||||
|
||||
|
||||
void run(){
|
||||
/*
|
||||
* ENOBUFS is signalled to userspace when packets were lost
|
||||
* on kernel side. In most cases, userspace isn't interested
|
||||
* in this information, so turn it off.
|
||||
*/
|
||||
int ret = 1;
|
||||
mnl_socket_setsockopt(sctx.nl, NETLINK_NO_ENOBUFS, &ret, sizeof(int));
|
||||
sctx.follower.new_stream_callback(
|
||||
[&](auto a){this->on_new_stream(a);}
|
||||
);
|
||||
sctx.follower.stream_termination_callback(
|
||||
[&](auto a, auto b){this->on_stream_terminated(a, b);}
|
||||
);
|
||||
for (;;) {
|
||||
ret = recv_packet();
|
||||
if (ret == -1) {
|
||||
throw runtime_error( "mnl_socket_recvfrom" );
|
||||
}
|
||||
|
||||
ret = mnl_cb_run(buf, ret, 0, portid, queue_cb, &sctx);
|
||||
if (ret < 0){
|
||||
throw runtime_error( "mnl_cb_run" );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
~NetfilterQueue() {
|
||||
send_config_cmd(NFQNL_CFG_CMD_UNBIND);
|
||||
_clear();
|
||||
}
|
||||
private:
|
||||
|
||||
ssize_t send_config_cmd(nfqnl_msg_config_cmds cmd){
|
||||
struct nlmsghdr *nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
|
||||
nfq_nlmsg_cfg_put_cmd(nlh, AF_INET, cmd);
|
||||
return mnl_socket_sendto(sctx.nl, nlh, nlh->nlmsg_len);
|
||||
}
|
||||
|
||||
ssize_t recv_packet(){
|
||||
return mnl_socket_recvfrom(sctx.nl, buf, BUF_SIZE);
|
||||
}
|
||||
|
||||
void _clear(){
|
||||
if (buf != nullptr) {
|
||||
free(buf);
|
||||
buf = nullptr;
|
||||
}
|
||||
mnl_socket_close(sctx.nl);
|
||||
sctx.nl = nullptr;
|
||||
sctx.clean_scratches();
|
||||
|
||||
for(auto ele: sctx.in_hs_streams){
|
||||
if (hs_close_stream(ele.second, sctx.in_scratch, nullptr, nullptr) != HS_SUCCESS) {
|
||||
cerr << "[error] [NetfilterQueue.clean_stream_by_id] Error closing the stream matcher (hs)" << endl;
|
||||
throw invalid_argument("Cannot close stream match on hyperscan");
|
||||
}
|
||||
}
|
||||
sctx.in_hs_streams.clear();
|
||||
for(auto ele: sctx.out_hs_streams){
|
||||
if (hs_close_stream(ele.second, sctx.out_scratch, nullptr, nullptr) != HS_SUCCESS) {
|
||||
cerr << "[error] [NetfilterQueue.clean_stream_by_id] Error closing the stream matcher (hs)" << endl;
|
||||
throw invalid_argument("Cannot close stream match on hyperscan");
|
||||
}
|
||||
}
|
||||
sctx.out_hs_streams.clear();
|
||||
}
|
||||
|
||||
static int queue_cb(const nlmsghdr *nlh, void *data_ptr)
|
||||
{
|
||||
stream_ctx* sctx = (stream_ctx*)data_ptr;
|
||||
|
||||
//Extract attributes from the nlmsghdr
|
||||
nlattr *attr[NFQA_MAX+1] = {};
|
||||
|
||||
if (nfq_nlmsg_parse(nlh, attr) < 0) {
|
||||
perror("problems parsing");
|
||||
return MNL_CB_ERROR;
|
||||
}
|
||||
if (attr[NFQA_PACKET_HDR] == nullptr) {
|
||||
fputs("metaheader not set\n", stderr);
|
||||
return MNL_CB_ERROR;
|
||||
}
|
||||
//Get Payload
|
||||
uint16_t plen = mnl_attr_get_payload_len(attr[NFQA_PAYLOAD]);
|
||||
uint8_t *payload = (uint8_t *)mnl_attr_get_payload(attr[NFQA_PAYLOAD]);
|
||||
|
||||
//Return result to the kernel
|
||||
struct nfqnl_msg_packet_hdr *ph = (nfqnl_msg_packet_hdr*) mnl_attr_get_payload(attr[NFQA_PACKET_HDR]);
|
||||
struct nfgenmsg *nfg = (nfgenmsg *)mnl_nlmsg_get_payload(nlh);
|
||||
char buf[MNL_SOCKET_BUFFER_SIZE];
|
||||
struct nlmsghdr *nlh_verdict;
|
||||
struct nlattr *nest;
|
||||
|
||||
nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id));
|
||||
|
||||
// Check IP protocol version
|
||||
Tins::PDU *packet;
|
||||
if ( ((payload)[0] & 0xf0) == 0x40 ){
|
||||
Tins::IP parsed = Tins::IP(payload, plen);
|
||||
packet = &parsed;
|
||||
}else{
|
||||
Tins::IPv6 parsed = Tins::IPv6(payload, plen);
|
||||
packet = &parsed;
|
||||
}
|
||||
Tins::PDU *transport_layer = find_transport_layer(packet);
|
||||
if(transport_layer == nullptr || transport_layer->inner_pdu() == nullptr){
|
||||
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT );
|
||||
}else{
|
||||
bool is_tcp = transport_layer->pdu_type() == Tins::PDU::TCP;
|
||||
int size = transport_layer->inner_pdu()->size();
|
||||
packet_info pktinfo{
|
||||
packet: string(payload, payload+plen),
|
||||
payload: string(payload+plen - size, payload+plen),
|
||||
stream_id: "", // TODO We need to calculate this
|
||||
is_input: true, // TODO We need to detect this
|
||||
is_tcp: is_tcp,
|
||||
sctx: sctx,
|
||||
};
|
||||
if (is_tcp){
|
||||
sctx->tcp_match_util.matching_has_been_called = false;
|
||||
sctx->tcp_match_util.pkt_info = &pktinfo;
|
||||
sctx->follower.process_packet(*packet);
|
||||
if (sctx->tcp_match_util.matching_has_been_called && !sctx->tcp_match_util.result){
|
||||
auto tcp_layer = (Tins::TCP *)transport_layer;
|
||||
tcp_layer->release_inner_pdu();
|
||||
tcp_layer->set_flag(Tins::TCP::FIN,1);
|
||||
tcp_layer->set_flag(Tins::TCP::ACK,1);
|
||||
tcp_layer->set_flag(Tins::TCP::SYN,0);
|
||||
nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet->serialize().data(), packet->size());
|
||||
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT );
|
||||
delete tcp_layer;
|
||||
}
|
||||
}else if(callback_func(pktinfo)){
|
||||
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT );
|
||||
} else{
|
||||
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_DROP );
|
||||
}
|
||||
}
|
||||
|
||||
/* example to set the connmark. First, start NFQA_CT section: */
|
||||
nest = mnl_attr_nest_start(nlh_verdict, NFQA_CT);
|
||||
|
||||
/* then, add the connmark attribute: */
|
||||
mnl_attr_put_u32(nlh_verdict, CTA_MARK, htonl(42));
|
||||
/* more conntrack attributes, e.g. CTA_LABELS could be set here */
|
||||
|
||||
/* end conntrack section */
|
||||
mnl_attr_nest_end(nlh_verdict, nest);
|
||||
|
||||
if (mnl_socket_sendto(sctx->nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) {
|
||||
throw runtime_error( "mnl_socket_send" );
|
||||
}
|
||||
|
||||
return MNL_CB_OK;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <NetFilterQueueCallback func>
|
||||
class NFQueueSequence{
|
||||
private:
|
||||
vector<NetfilterQueue<func> *> nfq;
|
||||
uint16_t _init;
|
||||
uint16_t _end;
|
||||
vector<thread> threads;
|
||||
public:
|
||||
static const int QUEUE_BASE_NUM = 1000;
|
||||
|
||||
NFQueueSequence(uint16_t seq_len){
|
||||
if (seq_len <= 0) throw invalid_argument("seq_len <= 0");
|
||||
nfq = vector<NetfilterQueue<func>*>(seq_len);
|
||||
_init = QUEUE_BASE_NUM;
|
||||
while(nfq[0] == nullptr){
|
||||
if (_init+seq_len-1 >= 65536){
|
||||
throw runtime_error("NFQueueSequence: too many queues!");
|
||||
}
|
||||
for (int i=0;i<seq_len;i++){
|
||||
try{
|
||||
nfq[i] = new NetfilterQueue<func>(_init+i);
|
||||
}catch(const invalid_argument e){
|
||||
for(int j = 0; j < i; j++) {
|
||||
delete nfq[j];
|
||||
nfq[j] = nullptr;
|
||||
}
|
||||
_init += seq_len - i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_end = _init + seq_len - 1;
|
||||
}
|
||||
|
||||
void start(){
|
||||
if (threads.size() != 0) throw runtime_error("NFQueueSequence: already started!");
|
||||
for (int i=0;i<nfq.size();i++){
|
||||
threads.push_back(thread(&NetfilterQueue<func>::run, nfq[i]));
|
||||
}
|
||||
}
|
||||
|
||||
void join(){
|
||||
for (int i=0;i<nfq.size();i++){
|
||||
threads[i].join();
|
||||
}
|
||||
threads.clear();
|
||||
}
|
||||
|
||||
uint16_t init(){
|
||||
return _init;
|
||||
}
|
||||
uint16_t end(){
|
||||
return _end;
|
||||
}
|
||||
|
||||
~NFQueueSequence(){
|
||||
for (int i=0;i<nfq.size();i++){
|
||||
delete nfq[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // NETFILTER_CLASSES_HPP
|
||||
@@ -1,294 +0,0 @@
|
||||
#include <linux/netfilter/nfnetlink_queue.h>
|
||||
#include <libnetfilter_queue/libnetfilter_queue.h>
|
||||
#include <linux/netfilter/nfnetlink_conntrack.h>
|
||||
#include <tins/tins.h>
|
||||
#include <libmnl/libmnl.h>
|
||||
#include <linux/netfilter.h>
|
||||
#include <linux/netfilter/nfnetlink.h>
|
||||
#include <linux/types.h>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
|
||||
#ifndef NETFILTER_CLASSES_HPP
|
||||
#define NETFILTER_CLASSES_HPP
|
||||
|
||||
typedef bool NetFilterQueueCallback(const uint8_t*,uint32_t);
|
||||
|
||||
Tins::PDU * find_transport_layer(Tins::PDU* pkt){
|
||||
while(pkt != NULL){
|
||||
if (pkt->pdu_type() == Tins::PDU::TCP || pkt->pdu_type() == Tins::PDU::UDP) {
|
||||
return pkt;
|
||||
}
|
||||
pkt = pkt->inner_pdu();
|
||||
}
|
||||
return pkt;
|
||||
}
|
||||
|
||||
template <NetFilterQueueCallback callback_func>
|
||||
class NetfilterQueue {
|
||||
public:
|
||||
size_t BUF_SIZE = 0xffff + (MNL_SOCKET_BUFFER_SIZE/2);
|
||||
char *buf = NULL;
|
||||
unsigned int portid;
|
||||
u_int16_t queue_num;
|
||||
struct mnl_socket* nl = NULL;
|
||||
|
||||
NetfilterQueue(u_int16_t queue_num): queue_num(queue_num) {
|
||||
|
||||
nl = mnl_socket_open(NETLINK_NETFILTER);
|
||||
|
||||
if (nl == NULL) { throw std::runtime_error( "mnl_socket_open" );}
|
||||
|
||||
if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
|
||||
mnl_socket_close(nl);
|
||||
throw std::runtime_error( "mnl_socket_bind" );
|
||||
}
|
||||
portid = mnl_socket_get_portid(nl);
|
||||
|
||||
buf = (char*) malloc(BUF_SIZE);
|
||||
|
||||
if (!buf) {
|
||||
mnl_socket_close(nl);
|
||||
throw std::runtime_error( "allocate receive buffer" );
|
||||
}
|
||||
|
||||
if (send_config_cmd(NFQNL_CFG_CMD_BIND) < 0) {
|
||||
_clear();
|
||||
throw std::runtime_error( "mnl_socket_send" );
|
||||
}
|
||||
//TEST if BIND was successful
|
||||
if (send_config_cmd(NFQNL_CFG_CMD_NONE) < 0) { // SEND A NONE cmmand to generate an error meessage
|
||||
_clear();
|
||||
throw std::runtime_error( "mnl_socket_send" );
|
||||
}
|
||||
if (recv_packet() == -1) { //RECV the error message
|
||||
_clear();
|
||||
throw std::runtime_error( "mnl_socket_recvfrom" );
|
||||
}
|
||||
|
||||
struct nlmsghdr *nlh = (struct nlmsghdr *) buf;
|
||||
|
||||
if (nlh->nlmsg_type != NLMSG_ERROR) {
|
||||
_clear();
|
||||
throw std::runtime_error( "unexpected packet from kernel (expected NLMSG_ERROR packet)" );
|
||||
}
|
||||
//nfqnl_msg_config_cmd
|
||||
nlmsgerr* error_msg = (nlmsgerr *)mnl_nlmsg_get_payload(nlh);
|
||||
|
||||
// error code taken from the linux kernel:
|
||||
// https://elixir.bootlin.com/linux/v5.18.12/source/include/linux/errno.h#L27
|
||||
#define ENOTSUPP 524 /* Operation is not supported */
|
||||
|
||||
if (error_msg->error != -ENOTSUPP) {
|
||||
_clear();
|
||||
throw std::invalid_argument( "queueid is already busy" );
|
||||
}
|
||||
|
||||
//END TESTING BIND
|
||||
nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
|
||||
nfq_nlmsg_cfg_put_params(nlh, NFQNL_COPY_PACKET, 0xffff);
|
||||
|
||||
|
||||
mnl_attr_put_u32(nlh, NFQA_CFG_FLAGS, htonl(NFQA_CFG_F_GSO));
|
||||
mnl_attr_put_u32(nlh, NFQA_CFG_MASK, htonl(NFQA_CFG_F_GSO));
|
||||
|
||||
if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
|
||||
_clear();
|
||||
throw std::runtime_error( "mnl_socket_send" );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
void run(){
|
||||
/*
|
||||
* ENOBUFS is signalled to userspace when packets were lost
|
||||
* on kernel side. In most cases, userspace isn't interested
|
||||
* in this information, so turn it off.
|
||||
*/
|
||||
int ret = 1;
|
||||
mnl_socket_setsockopt(nl, NETLINK_NO_ENOBUFS, &ret, sizeof(int));
|
||||
|
||||
for (;;) {
|
||||
ret = recv_packet();
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error( "mnl_socket_recvfrom" );
|
||||
}
|
||||
|
||||
ret = mnl_cb_run(buf, ret, 0, portid, queue_cb, nl);
|
||||
if (ret < 0){
|
||||
throw std::runtime_error( "mnl_cb_run" );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~NetfilterQueue() {
|
||||
send_config_cmd(NFQNL_CFG_CMD_UNBIND);
|
||||
_clear();
|
||||
}
|
||||
private:
|
||||
|
||||
ssize_t send_config_cmd(nfqnl_msg_config_cmds cmd){
|
||||
struct nlmsghdr *nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
|
||||
nfq_nlmsg_cfg_put_cmd(nlh, AF_INET, cmd);
|
||||
return mnl_socket_sendto(nl, nlh, nlh->nlmsg_len);
|
||||
}
|
||||
|
||||
ssize_t recv_packet(){
|
||||
return mnl_socket_recvfrom(nl, buf, BUF_SIZE);
|
||||
}
|
||||
|
||||
void _clear(){
|
||||
if (buf != NULL) {
|
||||
free(buf);
|
||||
buf = NULL;
|
||||
}
|
||||
mnl_socket_close(nl);
|
||||
}
|
||||
|
||||
static int queue_cb(const struct nlmsghdr *nlh, void *data)
|
||||
{
|
||||
struct mnl_socket* nl = (struct mnl_socket*)data;
|
||||
//Extract attributes from the nlmsghdr
|
||||
struct nlattr *attr[NFQA_MAX+1] = {};
|
||||
|
||||
if (nfq_nlmsg_parse(nlh, attr) < 0) {
|
||||
perror("problems parsing");
|
||||
return MNL_CB_ERROR;
|
||||
}
|
||||
if (attr[NFQA_PACKET_HDR] == NULL) {
|
||||
fputs("metaheader not set\n", stderr);
|
||||
return MNL_CB_ERROR;
|
||||
}
|
||||
//Get Payload
|
||||
uint16_t plen = mnl_attr_get_payload_len(attr[NFQA_PAYLOAD]);
|
||||
void *payload = mnl_attr_get_payload(attr[NFQA_PAYLOAD]);
|
||||
|
||||
//Return result to the kernel
|
||||
struct nfqnl_msg_packet_hdr *ph = (nfqnl_msg_packet_hdr*) mnl_attr_get_payload(attr[NFQA_PACKET_HDR]);
|
||||
struct nfgenmsg *nfg = (nfgenmsg *)mnl_nlmsg_get_payload(nlh);
|
||||
char buf[MNL_SOCKET_BUFFER_SIZE];
|
||||
struct nlmsghdr *nlh_verdict;
|
||||
struct nlattr *nest;
|
||||
|
||||
nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id));
|
||||
|
||||
/*
|
||||
This define allow to avoid to allocate new heap memory for each packet.
|
||||
The code under this comment is replicated for ipv6 and ip
|
||||
Better solutions are welcome. :)
|
||||
*/
|
||||
#define PKT_HANDLE \
|
||||
Tins::PDU *transport_layer = find_transport_layer(&packet); \
|
||||
if(transport_layer->inner_pdu() == nullptr || transport_layer == nullptr){ \
|
||||
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
|
||||
}else{ \
|
||||
int size = transport_layer->inner_pdu()->size(); \
|
||||
if(callback_func((const uint8_t*)payload+plen - size, size)){ \
|
||||
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
|
||||
} else{ \
|
||||
if (transport_layer->pdu_type() == Tins::PDU::TCP){ \
|
||||
((Tins::TCP *)transport_layer)->release_inner_pdu(); \
|
||||
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::FIN,1); \
|
||||
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::ACK,1); \
|
||||
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::SYN,0); \
|
||||
nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); \
|
||||
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
|
||||
}else{ \
|
||||
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_DROP ); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
// Check IP protocol version
|
||||
if ( (((uint8_t*)payload)[0] & 0xf0) == 0x40 ){
|
||||
Tins::IP packet = Tins::IP((uint8_t*)payload,plen);
|
||||
PKT_HANDLE
|
||||
}else{
|
||||
Tins::IPv6 packet = Tins::IPv6((uint8_t*)payload,plen);
|
||||
PKT_HANDLE
|
||||
}
|
||||
|
||||
/* example to set the connmark. First, start NFQA_CT section: */
|
||||
nest = mnl_attr_nest_start(nlh_verdict, NFQA_CT);
|
||||
|
||||
/* then, add the connmark attribute: */
|
||||
mnl_attr_put_u32(nlh_verdict, CTA_MARK, htonl(42));
|
||||
/* more conntrack attributes, e.g. CTA_LABELS could be set here */
|
||||
|
||||
/* end conntrack section */
|
||||
mnl_attr_nest_end(nlh_verdict, nest);
|
||||
|
||||
if (mnl_socket_sendto(nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) {
|
||||
throw std::runtime_error( "mnl_socket_send" );
|
||||
}
|
||||
|
||||
return MNL_CB_OK;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <NetFilterQueueCallback func>
|
||||
class NFQueueSequence{
|
||||
private:
|
||||
std::vector<NetfilterQueue<func> *> nfq;
|
||||
uint16_t _init;
|
||||
uint16_t _end;
|
||||
std::vector<std::thread> threads;
|
||||
public:
|
||||
static const int QUEUE_BASE_NUM = 1000;
|
||||
|
||||
NFQueueSequence(uint16_t seq_len){
|
||||
if (seq_len <= 0) throw std::invalid_argument("seq_len <= 0");
|
||||
nfq = std::vector<NetfilterQueue<func>*>(seq_len);
|
||||
_init = QUEUE_BASE_NUM;
|
||||
while(nfq[0] == NULL){
|
||||
if (_init+seq_len-1 >= 65536){
|
||||
throw std::runtime_error("NFQueueSequence: too many queues!");
|
||||
}
|
||||
for (int i=0;i<seq_len;i++){
|
||||
try{
|
||||
nfq[i] = new NetfilterQueue<func>(_init+i);
|
||||
}catch(const std::invalid_argument e){
|
||||
for(int j = 0; j < i; j++) {
|
||||
delete nfq[j];
|
||||
nfq[j] = nullptr;
|
||||
}
|
||||
_init += seq_len - i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_end = _init + seq_len - 1;
|
||||
}
|
||||
|
||||
void start(){
|
||||
if (threads.size() != 0) throw std::runtime_error("NFQueueSequence: already started!");
|
||||
for (int i=0;i<nfq.size();i++){
|
||||
threads.push_back(std::thread(&NetfilterQueue<func>::run, nfq[i]));
|
||||
}
|
||||
}
|
||||
|
||||
void join(){
|
||||
for (int i=0;i<nfq.size();i++){
|
||||
threads[i].join();
|
||||
}
|
||||
threads.clear();
|
||||
}
|
||||
|
||||
uint16_t init(){
|
||||
return _init;
|
||||
}
|
||||
uint16_t end(){
|
||||
return _end;
|
||||
}
|
||||
|
||||
~NFQueueSequence(){
|
||||
for (int i=0;i<nfq.size();i++){
|
||||
delete nfq[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // NETFILTER_CLASSES_HPP
|
||||
@@ -1,95 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <jpcre2.hpp>
|
||||
#include <sstream>
|
||||
#include "../utils.hpp"
|
||||
|
||||
|
||||
#ifndef REGEX_FILTER_HPP
|
||||
#define REGEX_FILTER_HPP
|
||||
|
||||
typedef jpcre2::select<char> jp;
|
||||
typedef std::pair<std::string,jp::Regex> regex_rule_pair;
|
||||
typedef std::vector<regex_rule_pair> regex_rule_vector;
|
||||
struct regex_rules{
|
||||
regex_rule_vector output_whitelist, input_whitelist, output_blacklist, input_blacklist;
|
||||
|
||||
regex_rule_vector* getByCode(char code){
|
||||
switch(code){
|
||||
case 'C': // Client to server Blacklist
|
||||
return &input_blacklist; break;
|
||||
case 'c': // Client to server Whitelist
|
||||
return &input_whitelist; break;
|
||||
case 'S': // Server to client Blacklist
|
||||
return &output_blacklist; break;
|
||||
case 's': // Server to client Whitelist
|
||||
return &output_whitelist; break;
|
||||
}
|
||||
throw std::invalid_argument( "Expected 'C' 'c' 'S' or 's'" );
|
||||
}
|
||||
|
||||
int add(const char* arg){
|
||||
//Integrity checks
|
||||
size_t arg_len = strlen(arg);
|
||||
if (arg_len < 2 || arg_len%2 != 0){
|
||||
std::cerr << "[warning] [regex_rules.add] invalid arg passed (" << arg << "), skipping..." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (arg[0] != '0' && arg[0] != '1'){
|
||||
std::cerr << "[warning] [regex_rules.add] invalid is_case_sensitive (" << arg[0] << ") in '" << arg << "', must be '1' or '0', skipping..." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's'){
|
||||
std::cerr << "[warning] [regex_rules.add] invalid filter_type (" << arg[1] << ") in '" << arg << "', must be 'C', 'c', 'S' or 's', skipping..." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::string hex(arg+2), expr;
|
||||
if (!unhexlify(hex, expr)){
|
||||
std::cerr << "[warning] [regex_rules.add] invalid hex regex value (" << hex << "), skipping..." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
//Push regex
|
||||
jp::Regex regex(expr,arg[0] == '1'?"gS":"giS");
|
||||
if (regex){
|
||||
std::cerr << "[info] [regex_rules.add] adding new regex filter: '" << expr << "'" << std::endl;
|
||||
getByCode(arg[1])->push_back(std::make_pair(std::string(arg), regex));
|
||||
} else {
|
||||
std::cerr << "[warning] [regex_rules.add] compiling of '" << expr << "' regex failed, skipping..." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool check(unsigned char* data, const size_t& bytes_transferred, const bool in_input){
|
||||
std::string str_data((char *) data, bytes_transferred);
|
||||
for (regex_rule_pair ele:(in_input?input_blacklist:output_blacklist)){
|
||||
try{
|
||||
if(ele.second.match(str_data)){
|
||||
std::stringstream msg;
|
||||
msg << "BLOCKED " << ele.first << "\n";
|
||||
std::cout << msg.str() << std::flush;
|
||||
return false;
|
||||
}
|
||||
} catch(...){
|
||||
std::cerr << "[info] [regex_rules.check] Error while matching blacklist regex: " << ele.first << std::endl;
|
||||
}
|
||||
}
|
||||
for (regex_rule_pair ele:(in_input?input_whitelist:output_whitelist)){
|
||||
try{
|
||||
std::cerr << "[debug] [regex_rules.check] regex whitelist match " << ele.second.getPattern() << std::endl;
|
||||
if(!ele.second.match(str_data)){
|
||||
std::stringstream msg;
|
||||
msg << "BLOCKED " << ele.first << "\n";
|
||||
std::cout << msg.str() << std::flush;
|
||||
return false;
|
||||
}
|
||||
} catch(...){
|
||||
std::cerr << "[info] [regex_rules.check] Error while matching whitelist regex: " << ele.first << std::endl;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif // REGEX_FILTER_HPP
|
||||
161
backend/binsrc/classes/regex_rules.cpp
Normal file
161
backend/binsrc/classes/regex_rules.cpp
Normal file
@@ -0,0 +1,161 @@
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include "../utils.hpp"
|
||||
#include <vector>
|
||||
#include <hs.h>
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
#ifndef REGEX_FILTER_HPP
|
||||
#define REGEX_FILTER_HPP
|
||||
|
||||
enum FilterDirection{ CTOS, STOC };
|
||||
|
||||
struct decoded_regex {
|
||||
string regex;
|
||||
FilterDirection direction;
|
||||
bool is_case_sensitive;
|
||||
};
|
||||
|
||||
struct regex_ruleset {
|
||||
hs_database_t* hs_db;
|
||||
char** regexes;
|
||||
};
|
||||
|
||||
decoded_regex decode_regex(string regex){
|
||||
|
||||
size_t arg_len = regex.size();
|
||||
if (arg_len < 2 || arg_len%2 != 0){
|
||||
cerr << "[warning] [decode_regex] invalid arg passed (" << regex << "), skipping..." << endl;
|
||||
throw runtime_error( "Invalid expression len (too small)" );
|
||||
}
|
||||
if (regex[0] != '0' && regex[0] != '1'){
|
||||
cerr << "[warning] [decode_regex] invalid is_case_sensitive (" << regex[0] << ") in '" << regex << "', must be '1' or '0', skipping..." << endl;
|
||||
throw runtime_error( "Invalid is_case_sensitive" );
|
||||
}
|
||||
if (regex[1] != 'C' && regex[1] != 'S'){
|
||||
cerr << "[warning] [decode_regex] invalid filter_direction (" << regex[1] << ") in '" << regex << "', must be 'C', 'S', skipping..." << endl;
|
||||
throw runtime_error( "Invalid filter_direction" );
|
||||
}
|
||||
string hex(regex.c_str()+2), expr;
|
||||
if (!unhexlify(hex, expr)){
|
||||
cerr << "[warning] [decode_regex] invalid hex regex value (" << hex << "), skipping..." << endl;
|
||||
throw runtime_error( "Invalid hex regex encoded value" );
|
||||
}
|
||||
decoded_regex ruleset{
|
||||
regex: expr,
|
||||
direction: regex[1] == 'C'? CTOS : STOC,
|
||||
is_case_sensitive: regex[0] == '1'
|
||||
};
|
||||
return ruleset;
|
||||
}
|
||||
|
||||
class RegexRules{
|
||||
public:
|
||||
regex_ruleset output_ruleset, input_ruleset;
|
||||
|
||||
private:
|
||||
static inline u_int16_t glob_seq = 0;
|
||||
u_int16_t version;
|
||||
vector<pair<string, decoded_regex>> decoded_input_rules;
|
||||
vector<pair<string, decoded_regex>> decoded_output_rules;
|
||||
bool is_stream = true;
|
||||
|
||||
void free_dbs(){
|
||||
if (output_ruleset.hs_db != nullptr){
|
||||
hs_free_database(output_ruleset.hs_db);
|
||||
}
|
||||
if (input_ruleset.hs_db != nullptr){
|
||||
hs_free_database(input_ruleset.hs_db);
|
||||
}
|
||||
}
|
||||
|
||||
void fill_ruleset(vector<pair<string, decoded_regex>> & decoded, regex_ruleset & ruleset){
|
||||
size_t n_of_regex = decoded.size();
|
||||
if (n_of_regex == 0){
|
||||
return;
|
||||
}
|
||||
const char* regex_match_rules[n_of_regex];
|
||||
unsigned int regex_array_ids[n_of_regex];
|
||||
unsigned int regex_flags[n_of_regex];
|
||||
for(int i = 0; i < n_of_regex; i++){
|
||||
regex_match_rules[i] = decoded[i].second.regex.c_str();
|
||||
regex_array_ids[i] = i;
|
||||
regex_flags[i] = HS_FLAG_SINGLEMATCH | HS_FLAG_ALLOWEMPTY;
|
||||
if (!decoded[i].second.is_case_sensitive){
|
||||
regex_flags[i] |= HS_FLAG_CASELESS;
|
||||
}
|
||||
}
|
||||
|
||||
hs_database_t* rebuilt_db;
|
||||
hs_compile_error_t *compile_err;
|
||||
if (
|
||||
hs_compile_multi(
|
||||
regex_match_rules,
|
||||
regex_flags,
|
||||
regex_array_ids,
|
||||
n_of_regex,
|
||||
is_stream?HS_MODE_STREAM:HS_MODE_BLOCK,
|
||||
nullptr,&rebuilt_db, &compile_err
|
||||
) != HS_SUCCESS
|
||||
) {
|
||||
cerr << "[warning] [RegexRules.fill_ruleset] hs_db failed to compile: '" << compile_err->message << "' skipping..." << endl;
|
||||
hs_free_compile_error(compile_err);
|
||||
throw runtime_error( "Failed to compile hyperscan db" );
|
||||
}
|
||||
ruleset.hs_db = rebuilt_db;
|
||||
}
|
||||
|
||||
public:
|
||||
RegexRules(vector<string> raw_rules, bool is_stream){
|
||||
this->is_stream = is_stream;
|
||||
this->version = ++glob_seq; // 0 version is a invalid version (useful for some logics)
|
||||
for(string ele : raw_rules){
|
||||
try{
|
||||
decoded_regex rule = decode_regex(ele);
|
||||
if (rule.direction == FilterDirection::CTOS){
|
||||
decoded_input_rules.push_back(make_pair(ele, rule));
|
||||
}else{
|
||||
decoded_output_rules.push_back(make_pair(ele, rule));
|
||||
}
|
||||
}catch(...){
|
||||
throw current_exception();
|
||||
}
|
||||
}
|
||||
fill_ruleset(decoded_input_rules, input_ruleset);
|
||||
try{
|
||||
fill_ruleset(decoded_output_rules, output_ruleset);
|
||||
}catch(...){
|
||||
free_dbs();
|
||||
throw current_exception();
|
||||
}
|
||||
}
|
||||
|
||||
u_int16_t ver(){
|
||||
return version;
|
||||
}
|
||||
|
||||
RegexRules(bool is_stream){
|
||||
vector<string> no_rules;
|
||||
RegexRules(no_rules, is_stream);
|
||||
}
|
||||
|
||||
bool stream_mode(){
|
||||
return is_stream;
|
||||
}
|
||||
|
||||
|
||||
|
||||
RegexRules(){
|
||||
RegexRules(true);
|
||||
}
|
||||
|
||||
~RegexRules(){
|
||||
free_dbs();
|
||||
}
|
||||
};
|
||||
|
||||
#endif // REGEX_FILTER_HPP
|
||||
|
||||
BIN
backend/binsrc/cppqueue
Executable file
BIN
backend/binsrc/cppqueue
Executable file
Binary file not shown.
@@ -1,11 +1,11 @@
|
||||
#include "classes/regex_filter.hpp"
|
||||
#include "classes/netfilter.hpp"
|
||||
#include "classes/regex_rules.cpp"
|
||||
#include "classes/netfilter.cpp"
|
||||
#include "utils.hpp"
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
shared_ptr<regex_rules> regex_config;
|
||||
shared_ptr<RegexRules> regex_config;
|
||||
|
||||
void config_updater (){
|
||||
string line;
|
||||
@@ -21,44 +21,116 @@ void config_updater (){
|
||||
}
|
||||
cerr << "[info] [updater] Updating configuration with line " << line << endl;
|
||||
istringstream config_stream(line);
|
||||
regex_rules *regex_new_config = new regex_rules();
|
||||
vector<string> raw_rules;
|
||||
|
||||
while(!config_stream.eof()){
|
||||
string data;
|
||||
config_stream >> data;
|
||||
if (data != "" && data != "\n"){
|
||||
regex_new_config->add(data.c_str());
|
||||
raw_rules.push_back(data);
|
||||
}
|
||||
}
|
||||
regex_config.reset(regex_new_config);
|
||||
cerr << "[info] [updater] Config update done" << endl;
|
||||
|
||||
try{
|
||||
regex_config.reset(new RegexRules(raw_rules, regex_config->stream_mode()));
|
||||
cerr << "[info] [updater] Config update done" << endl;
|
||||
}catch(...){
|
||||
cerr << "[error] [updater] Failed to build new configuration!" << endl;
|
||||
// TODO send a row on stdout for this error
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <bool is_input>
|
||||
bool filter_callback(const uint8_t *data, uint32_t len){
|
||||
shared_ptr<regex_rules> current_config = regex_config;
|
||||
return current_config->check((unsigned char *)data, len, is_input);
|
||||
void inline scratch_setup(regex_ruleset &conf, hs_scratch_t* & scratch){
|
||||
if (scratch == nullptr){
|
||||
if (hs_alloc_scratch(conf.hs_db, &scratch) != HS_SUCCESS) {
|
||||
throw invalid_argument("Cannot alloc scratch");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
struct matched_data{
|
||||
unsigned int matched = 0;
|
||||
bool has_matched = false;
|
||||
};
|
||||
|
||||
bool filter_callback(packet_info & info){
|
||||
shared_ptr<RegexRules> conf = regex_config;
|
||||
if (conf->ver() != info.sctx->latest_config_ver){
|
||||
info.sctx->clean_scratches();
|
||||
}
|
||||
scratch_setup(conf->input_ruleset, info.sctx->in_scratch);
|
||||
scratch_setup(conf->output_ruleset, info.sctx->out_scratch);
|
||||
|
||||
hs_database_t* regex_matcher = info.is_input ? conf->input_ruleset.hs_db : conf->output_ruleset.hs_db;
|
||||
if (regex_matcher == nullptr){
|
||||
return true;
|
||||
}
|
||||
matched_data match_res;
|
||||
hs_error_t err;
|
||||
hs_scratch_t* scratch_space = info.is_input ? info.sctx->in_scratch: info.sctx->out_scratch;
|
||||
auto match_func = [](unsigned int id, auto from, auto to, auto flags, auto ctx){
|
||||
auto res = (matched_data*)ctx;
|
||||
res->has_matched = true;
|
||||
res->matched = id;
|
||||
return 1; // Stop matching
|
||||
};
|
||||
if (conf->stream_mode()){
|
||||
matching_map match_map = info.is_input ? info.sctx->in_hs_streams : info.sctx->out_hs_streams;
|
||||
auto stream_search = match_map.find(info.stream_id);
|
||||
hs_stream_t* stream_match;
|
||||
if (stream_search == match_map.end()){
|
||||
if (hs_open_stream(regex_matcher, 0, &stream_match) != HS_SUCCESS) {
|
||||
cerr << "[error] [filter_callback] Error opening the stream matcher (hs)" << endl;
|
||||
throw invalid_argument("Cannot open stream match on hyperscan");
|
||||
}
|
||||
match_map[info.stream_id] = stream_match;
|
||||
}else{
|
||||
stream_match = stream_search->second;
|
||||
}
|
||||
err = hs_scan_stream(
|
||||
stream_match,info.payload.c_str(), info.payload.length(),
|
||||
0, scratch_space, match_func, &match_res
|
||||
);
|
||||
}else{
|
||||
err = hs_scan(
|
||||
regex_matcher,info.payload.c_str(), info.payload.length(),
|
||||
0, scratch_space, match_func, &match_res
|
||||
);
|
||||
}
|
||||
if (err != HS_SUCCESS) {
|
||||
cerr << "[error] [filter_callback] Error while matching the stream (hs)" << endl;
|
||||
throw invalid_argument("Error while matching the stream with hyperscan");
|
||||
}
|
||||
if (match_res.has_matched){
|
||||
auto rules_vector = info.is_input ? conf->input_ruleset.regexes : conf->output_ruleset.regexes;
|
||||
stringstream msg;
|
||||
msg << "BLOCKED " << rules_vector[match_res.matched] << "\n";
|
||||
cout << msg.str() << flush;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]){
|
||||
int n_of_threads = 1;
|
||||
char * n_threads_str = getenv("NTHREADS");
|
||||
if (n_threads_str != NULL) n_of_threads = ::atoi(n_threads_str);
|
||||
if (n_threads_str != nullptr) n_of_threads = ::atoi(n_threads_str);
|
||||
if(n_of_threads <= 0) n_of_threads = 1;
|
||||
if (n_of_threads % 2 != 0 ) n_of_threads++;
|
||||
cerr << "[info] [main] Using " << n_of_threads << " threads" << endl;
|
||||
regex_config.reset(new regex_rules());
|
||||
NFQueueSequence<filter_callback<true>> input_queues(n_of_threads/2);
|
||||
input_queues.start();
|
||||
NFQueueSequence<filter_callback<false>> output_queues(n_of_threads/2);
|
||||
output_queues.start();
|
||||
|
||||
cout << "QUEUES INPUT " << input_queues.init() << " " << input_queues.end() << " OUTPUT " << output_queues.init() << " " << output_queues.end() << endl;
|
||||
cerr << "[info] [main] Input queues: " << input_queues.init() << ":" << input_queues.end() << " threads assigned: " << n_of_threads/2 << endl;
|
||||
cerr << "[info] [main] Output queues: " << output_queues.init() << ":" << output_queues.end() << " threads assigned: " << n_of_threads/2 << endl;
|
||||
char * matchmode = getenv("MATCH_MODE");
|
||||
bool stream_mode = true;
|
||||
if (matchmode != nullptr && strcmp(matchmode, "block") == 0){
|
||||
stream_mode = false;
|
||||
}
|
||||
cerr << "[info] [main] Using " << n_of_threads << " threads" << endl;
|
||||
regex_config.reset(new RegexRules(stream_mode));
|
||||
|
||||
NFQueueSequence<filter_callback> queues(n_of_threads);
|
||||
queues.start();
|
||||
|
||||
cout << "QUEUES " << queues.init() << " " << queues.end() << endl;
|
||||
cerr << "[info] [main] Queues: " << queues.init() << ":" << queues.end() << " threads assigned: " << n_of_threads << endl;
|
||||
|
||||
config_updater();
|
||||
}
|
||||
|
||||
32
backend/binsrc/nfqueue_regex/Cargo.lock
generated
32
backend/binsrc/nfqueue_regex/Cargo.lock
generated
@@ -1,32 +0,0 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "atomic_refcell"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.153"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
|
||||
|
||||
[[package]]
|
||||
name = "nfq"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9c8f4c88952507d9df9400a6a2e48640fb460e21dcb2b4716eb3ff156d6db9e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nfqueue_regex"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"atomic_refcell",
|
||||
"nfq",
|
||||
]
|
||||
@@ -1,11 +0,0 @@
|
||||
[package]
|
||||
name = "nfqueue_regex"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
atomic_refcell = "0.1.13"
|
||||
nfq = "0.2.5"
|
||||
#hyperscan = "0.3.2"
|
||||
@@ -1,150 +0,0 @@
|
||||
use atomic_refcell::AtomicRefCell;
|
||||
use nfq::{Queue, Verdict};
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::env;
|
||||
use std::pin::Pin;
|
||||
use std::rc::Rc;
|
||||
use std::sync::atomic::{AtomicPtr, AtomicU32};
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::sync::Arc;
|
||||
use std::thread::{self, sleep, sleep_ms, JoinHandle};
|
||||
|
||||
enum WorkerMessage {
|
||||
Error(String),
|
||||
Dropped(usize),
|
||||
}
|
||||
|
||||
impl ToString for WorkerMessage {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
WorkerMessage::Error(e) => format!("E{}", e),
|
||||
WorkerMessage::Dropped(d) => format!("D{}", d),
|
||||
}
|
||||
}
|
||||
}
|
||||
struct Pool {
|
||||
_workers: Vec<Worker>,
|
||||
pub start: u16,
|
||||
pub end: u16,
|
||||
}
|
||||
|
||||
const QUEUE_BASE_NUM: u16 = 1000;
|
||||
impl Pool {
|
||||
fn new(threads: u16, tx: Sender<WorkerMessage>, db: RefCell<&str>) -> Self {
|
||||
// Find free queues
|
||||
let mut start = QUEUE_BASE_NUM;
|
||||
let mut queues: Vec<(Queue, u16)> = vec![];
|
||||
while queues.len() != threads.into() {
|
||||
for queue_num in
|
||||
(start..start.checked_add(threads + 1).expect("No more queues left")).rev()
|
||||
{
|
||||
let mut queue = Queue::open().unwrap();
|
||||
if queue.bind(queue_num).is_err() {
|
||||
start = queue_num;
|
||||
while let Some((mut q, num)) = queues.pop() {
|
||||
let _ = q.unbind(num);
|
||||
}
|
||||
break;
|
||||
};
|
||||
queues.push((queue, queue_num));
|
||||
}
|
||||
}
|
||||
|
||||
Pool {
|
||||
_workers: queues
|
||||
.into_iter()
|
||||
.map(|(queue, queue_num)| Worker::new(queue, queue_num, tx.clone()))
|
||||
.collect(),
|
||||
start,
|
||||
end: (start + threads),
|
||||
}
|
||||
}
|
||||
|
||||
// fn join(self) {
|
||||
// for worker in self._workers {
|
||||
// let _ = worker.join();
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
struct Worker {
|
||||
_inner: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
fn new(mut queue: Queue, _queue_num: u16, tx: Sender<WorkerMessage>) -> Self {
|
||||
Worker {
|
||||
_inner: thread::spawn(move || loop {
|
||||
let mut msg = queue.recv().unwrap_or_else(|_| {
|
||||
let _ = tx.send(WorkerMessage::Error("Fuck".to_string()));
|
||||
panic!("");
|
||||
});
|
||||
|
||||
msg.set_verdict(Verdict::Accept);
|
||||
queue.verdict(msg).unwrap();
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
struct InputOuputPools {
|
||||
pub output_queue: Pool,
|
||||
pub input_queue: Pool,
|
||||
rx: Receiver<WorkerMessage>,
|
||||
}
|
||||
impl InputOuputPools {
|
||||
fn new(threads: u16) -> InputOuputPools {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
InputOuputPools {
|
||||
output_queue: Pool::new(threads / 2, tx.clone(), RefCell::new("ciao")),
|
||||
input_queue: Pool::new(threads / 2, tx, RefCell::new("miao")),
|
||||
rx,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_events(&self) {
|
||||
loop {
|
||||
let event = self.rx.recv().expect("Channel has hung up");
|
||||
println!("{}", event.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static mut DB: AtomicPtr<Arc<u32>> = AtomicPtr::new(std::ptr::null_mut() as *mut Arc<u32>);
|
||||
|
||||
fn main() -> std::io::Result<()> {
|
||||
let mut my_x: Arc<u32> = Arc::new(0);
|
||||
let my_x_ptr: *mut Arc<u32> = std::ptr::addr_of_mut!(my_x);
|
||||
|
||||
unsafe { DB.store(my_x_ptr, std::sync::atomic::Ordering::SeqCst) };
|
||||
|
||||
thread::spawn(|| loop {
|
||||
let x_ptr = unsafe { DB.load(std::sync::atomic::Ordering::SeqCst) };
|
||||
let x = unsafe { (*x_ptr).clone() };
|
||||
dbg!(x);
|
||||
//sleep_ms(1000);
|
||||
});
|
||||
|
||||
for i in 0..1000000000 {
|
||||
let mut my_x: Arc<u32> = Arc::new(i);
|
||||
let my_x_ptr: *mut Arc<u32> = std::ptr::addr_of_mut!(my_x);
|
||||
unsafe { DB.store(my_x_ptr, std::sync::atomic::Ordering::SeqCst) };
|
||||
//sleep_ms(100);
|
||||
}
|
||||
|
||||
let mut threads = env::var("NPROCS").unwrap_or_default().parse().unwrap_or(2);
|
||||
if threads % 2 != 0 {
|
||||
threads += 1;
|
||||
}
|
||||
|
||||
let in_out_pools = InputOuputPools::new(threads);
|
||||
eprintln!(
|
||||
"[info] [main] Input queues: {}:{}",
|
||||
in_out_pools.input_queue.start, in_out_pools.input_queue.end
|
||||
);
|
||||
eprintln!(
|
||||
"[info] [main] Output queues: {}:{}",
|
||||
in_out_pools.output_queue.start, in_out_pools.output_queue.end
|
||||
);
|
||||
in_out_pools.poll_events();
|
||||
Ok(())
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,493 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2007 Arash Partow (http://www.partow.net)
|
||||
URL: http://www.partow.net/programming/tcpproxy/index.html
|
||||
Modified and adapted by Pwnzer0tt1
|
||||
*/
|
||||
#include <cstdlib>
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
|
||||
#include <boost/thread.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <boost/enable_shared_from_this.hpp>
|
||||
#include <boost/bind/bind.hpp>
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/thread/mutex.hpp>
|
||||
#include <jpcre2.hpp>
|
||||
|
||||
typedef jpcre2::select<char> jp;
|
||||
using namespace std;
|
||||
|
||||
bool unhexlify(string const &hex, string &newString) {
|
||||
try{
|
||||
int len = hex.length();
|
||||
for(int i=0; i< len; i+=2)
|
||||
{
|
||||
std::string byte = hex.substr(i,2);
|
||||
char chr = (char) (int)strtol(byte.c_str(), NULL, 16);
|
||||
newString.push_back(chr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
catch (...){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
typedef pair<string,jp::Regex> regex_rule_pair;
|
||||
typedef vector<regex_rule_pair> regex_rule_vector;
|
||||
struct regex_rules{
|
||||
regex_rule_vector regex_s_c_w, regex_c_s_w, regex_s_c_b, regex_c_s_b;
|
||||
|
||||
regex_rule_vector* getByCode(char code){
|
||||
switch(code){
|
||||
case 'C': // Client to server Blacklist
|
||||
return ®ex_c_s_b; break;
|
||||
case 'c': // Client to server Whitelist
|
||||
return ®ex_c_s_w; break;
|
||||
case 'S': // Server to client Blacklist
|
||||
return ®ex_s_c_b; break;
|
||||
case 's': // Server to client Whitelist
|
||||
return ®ex_s_c_w; break;
|
||||
}
|
||||
throw invalid_argument( "Expected 'C' 'c' 'S' or 's'" );
|
||||
}
|
||||
|
||||
void add(const char* arg){
|
||||
|
||||
//Integrity checks
|
||||
size_t arg_len = strlen(arg);
|
||||
if (arg_len < 2 || arg_len%2 != 0) return;
|
||||
if (arg[0] != '0' && arg[0] != '1') return;
|
||||
if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's') return;
|
||||
string hex(arg+2), expr;
|
||||
if (!unhexlify(hex, expr)) return;
|
||||
//Push regex
|
||||
jp::Regex regex(expr,arg[0] == '1'?"gS":"giS");
|
||||
if (regex){
|
||||
#ifdef DEBUG
|
||||
cerr << "Added regex " << expr << " " << arg << endl;
|
||||
#endif
|
||||
getByCode(arg[1])->push_back(make_pair(string(arg), regex));
|
||||
} else {
|
||||
cerr << "Regex " << arg << " was not compiled successfully" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
shared_ptr<regex_rules> regex_config;
|
||||
|
||||
mutex update_mutex;
|
||||
|
||||
bool filter_data(unsigned char* data, const size_t& bytes_transferred, regex_rule_vector const &blacklist, regex_rule_vector const &whitelist){
|
||||
#ifdef DEBUG_PACKET
|
||||
cerr << "---------------- Packet ----------------" << endl;
|
||||
for(int i=0;i<bytes_transferred;i++) cerr << data[i];
|
||||
cerr << endl;
|
||||
for(int i=0;i<bytes_transferred;i++) fprintf(stderr, "%x", data[i]);
|
||||
cerr << endl;
|
||||
cerr << "---------------- End Packet ----------------" << endl;
|
||||
#endif
|
||||
string str_data((char *) data, bytes_transferred);
|
||||
for (regex_rule_pair ele:blacklist){
|
||||
try{
|
||||
if(ele.second.match(str_data)){
|
||||
stringstream msg;
|
||||
msg << "BLOCKED " << ele.first << endl;
|
||||
cout << msg.str() << std::flush;
|
||||
return false;
|
||||
}
|
||||
} catch(...){
|
||||
cerr << "Error while matching regex: " << ele.first << endl;
|
||||
}
|
||||
}
|
||||
for (regex_rule_pair ele:whitelist){
|
||||
try{
|
||||
if(!ele.second.match(str_data)){
|
||||
stringstream msg;
|
||||
msg << "BLOCKED " << ele.first << endl;
|
||||
cout << msg.str() << std::flush;
|
||||
return false;
|
||||
}
|
||||
} catch(...){
|
||||
cerr << "Error while matching regex: " << ele.first << endl;
|
||||
}
|
||||
}
|
||||
#ifdef DEBUG
|
||||
cerr << "Packet Accepted!" << endl;
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace tcp_proxy
|
||||
{
|
||||
namespace ip = boost::asio::ip;
|
||||
|
||||
class bridge : public boost::enable_shared_from_this<bridge>
|
||||
{
|
||||
public:
|
||||
|
||||
typedef ip::tcp::socket socket_type;
|
||||
typedef boost::shared_ptr<bridge> ptr_type;
|
||||
|
||||
bridge(boost::asio::io_context& ios)
|
||||
: downstream_socket_(ios),
|
||||
upstream_socket_ (ios),
|
||||
thread_safety(ios)
|
||||
{}
|
||||
|
||||
socket_type& downstream_socket()
|
||||
{
|
||||
// Client socket
|
||||
return downstream_socket_;
|
||||
}
|
||||
|
||||
socket_type& upstream_socket()
|
||||
{
|
||||
// Remote server socket
|
||||
return upstream_socket_;
|
||||
}
|
||||
|
||||
void start(const string& upstream_host, unsigned short upstream_port)
|
||||
{
|
||||
// Attempt connection to remote server (upstream side)
|
||||
upstream_socket_.async_connect(
|
||||
ip::tcp::endpoint(
|
||||
boost::asio::ip::address::from_string(upstream_host),
|
||||
upstream_port),
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(
|
||||
&bridge::handle_upstream_connect,
|
||||
shared_from_this(),
|
||||
boost::asio::placeholders::error)));
|
||||
}
|
||||
|
||||
void handle_upstream_connect(const boost::system::error_code& error)
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
// Setup async read from remote server (upstream)
|
||||
|
||||
upstream_socket_.async_read_some(
|
||||
boost::asio::buffer(upstream_data_,max_data_length),
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(&bridge::handle_upstream_read,
|
||||
shared_from_this(),
|
||||
boost::asio::placeholders::error,
|
||||
boost::asio::placeholders::bytes_transferred)));
|
||||
|
||||
// Setup async read from client (downstream)
|
||||
downstream_socket_.async_read_some(
|
||||
boost::asio::buffer(downstream_data_,max_data_length),
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(&bridge::handle_downstream_read,
|
||||
shared_from_this(),
|
||||
boost::asio::placeholders::error,
|
||||
boost::asio::placeholders::bytes_transferred)));
|
||||
}
|
||||
else
|
||||
close();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/*
|
||||
Section A: Remote Server --> Proxy --> Client
|
||||
Process data recieved from remote sever then send to client.
|
||||
*/
|
||||
|
||||
// Read from remote server complete, now send data to client
|
||||
void handle_upstream_read(const boost::system::error_code& error,
|
||||
const size_t& bytes_transferred) // Da Server a Client
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
shared_ptr<regex_rules> regex_old_config = regex_config;
|
||||
if (filter_data(upstream_data_, bytes_transferred, regex_old_config->regex_s_c_b, regex_old_config->regex_s_c_w)){
|
||||
async_write(downstream_socket_,
|
||||
boost::asio::buffer(upstream_data_,bytes_transferred),
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(&bridge::handle_downstream_write,
|
||||
shared_from_this(),
|
||||
boost::asio::placeholders::error)));
|
||||
}else{
|
||||
close();
|
||||
}
|
||||
}
|
||||
else
|
||||
close();
|
||||
}
|
||||
|
||||
// Write to client complete, Async read from remote server
|
||||
void handle_downstream_write(const boost::system::error_code& error)
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
|
||||
upstream_socket_.async_read_some(
|
||||
boost::asio::buffer(upstream_data_,max_data_length),
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(&bridge::handle_upstream_read,
|
||||
shared_from_this(),
|
||||
boost::asio::placeholders::error,
|
||||
boost::asio::placeholders::bytes_transferred)));
|
||||
}
|
||||
else
|
||||
close();
|
||||
}
|
||||
// *** End Of Section A ***
|
||||
|
||||
|
||||
/*
|
||||
Section B: Client --> Proxy --> Remove Server
|
||||
Process data recieved from client then write to remove server.
|
||||
*/
|
||||
|
||||
// Read from client complete, now send data to remote server
|
||||
void handle_downstream_read(const boost::system::error_code& error,
|
||||
const size_t& bytes_transferred) // Da Client a Server
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
shared_ptr<regex_rules> regex_old_config = regex_config;
|
||||
if (filter_data(downstream_data_, bytes_transferred, regex_old_config->regex_c_s_b, regex_old_config->regex_c_s_w)){
|
||||
async_write(upstream_socket_,
|
||||
boost::asio::buffer(downstream_data_,bytes_transferred),
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(&bridge::handle_upstream_write,
|
||||
shared_from_this(),
|
||||
boost::asio::placeholders::error)));
|
||||
}else{
|
||||
close();
|
||||
}
|
||||
}
|
||||
else
|
||||
close();
|
||||
}
|
||||
|
||||
// Write to remote server complete, Async read from client
|
||||
void handle_upstream_write(const boost::system::error_code& error)
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
downstream_socket_.async_read_some(
|
||||
boost::asio::buffer(downstream_data_,max_data_length),
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(&bridge::handle_downstream_read,
|
||||
shared_from_this(),
|
||||
boost::asio::placeholders::error,
|
||||
boost::asio::placeholders::bytes_transferred)));
|
||||
}
|
||||
else
|
||||
close();
|
||||
}
|
||||
// *** End Of Section B ***
|
||||
|
||||
void close()
|
||||
{
|
||||
boost::mutex::scoped_lock lock(mutex_);
|
||||
|
||||
if (downstream_socket_.is_open())
|
||||
{
|
||||
downstream_socket_.close();
|
||||
}
|
||||
|
||||
if (upstream_socket_.is_open())
|
||||
{
|
||||
upstream_socket_.close();
|
||||
}
|
||||
}
|
||||
|
||||
socket_type downstream_socket_;
|
||||
socket_type upstream_socket_;
|
||||
|
||||
enum { max_data_length = 8192 }; //8KB
|
||||
unsigned char downstream_data_[max_data_length];
|
||||
unsigned char upstream_data_ [max_data_length];
|
||||
boost::asio::io_context::strand thread_safety;
|
||||
boost::mutex mutex_;
|
||||
public:
|
||||
|
||||
class acceptor
|
||||
{
|
||||
public:
|
||||
|
||||
acceptor(boost::asio::io_context& io_context,
|
||||
const string& local_host, unsigned short local_port,
|
||||
const string& upstream_host, unsigned short upstream_port)
|
||||
: io_context_(io_context),
|
||||
localhost_address(boost::asio::ip::address_v4::from_string(local_host)),
|
||||
acceptor_(io_context_,ip::tcp::endpoint(localhost_address,local_port)),
|
||||
upstream_port_(upstream_port),
|
||||
upstream_host_(upstream_host)
|
||||
{}
|
||||
|
||||
bool accept_connections()
|
||||
{
|
||||
try
|
||||
{
|
||||
session_ = boost::shared_ptr<bridge>(new bridge(io_context_));
|
||||
|
||||
acceptor_.async_accept(session_->downstream_socket(),
|
||||
boost::asio::bind_executor(session_->thread_safety,
|
||||
boost::bind(&acceptor::handle_accept,
|
||||
this,
|
||||
boost::asio::placeholders::error)));
|
||||
}
|
||||
catch(exception& e)
|
||||
{
|
||||
cerr << "acceptor exception: " << e.what() << endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void handle_accept(const boost::system::error_code& error)
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
session_->start(upstream_host_,upstream_port_);
|
||||
|
||||
if (!accept_connections())
|
||||
{
|
||||
cerr << "Failure during call to accept." << endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << "Error: " << error.message() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
boost::asio::io_context& io_context_;
|
||||
ip::address_v4 localhost_address;
|
||||
ip::tcp::acceptor acceptor_;
|
||||
ptr_type session_;
|
||||
unsigned short upstream_port_;
|
||||
string upstream_host_;
|
||||
};
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
void update_config (boost::asio::streambuf &input_buffer){
|
||||
#ifdef DEBUG
|
||||
cerr << "Updating configuration" << endl;
|
||||
#endif
|
||||
std::istream config_stream(&input_buffer);
|
||||
std::unique_lock<std::mutex> lck(update_mutex);
|
||||
regex_rules *regex_new_config = new regex_rules();
|
||||
string data;
|
||||
while(true){
|
||||
config_stream >> data;
|
||||
if (config_stream.eof()) break;
|
||||
regex_new_config->add(data.c_str());
|
||||
}
|
||||
regex_config.reset(regex_new_config);
|
||||
}
|
||||
|
||||
class async_updater
|
||||
{
|
||||
public:
|
||||
async_updater(boost::asio::io_context& io_context) : input_(io_context, ::dup(STDIN_FILENO)), thread_safety(io_context)
|
||||
{
|
||||
|
||||
boost::asio::async_read_until(input_, input_buffer_, '\n',
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(&async_updater::on_update, this,
|
||||
boost::asio::placeholders::error,
|
||||
boost::asio::placeholders::bytes_transferred)));
|
||||
}
|
||||
|
||||
void on_update(const boost::system::error_code& error, std::size_t length)
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
update_config(input_buffer_);
|
||||
boost::asio::async_read_until(input_, input_buffer_, '\n',
|
||||
boost::asio::bind_executor(thread_safety,
|
||||
boost::bind(&async_updater::on_update, this,
|
||||
boost::asio::placeholders::error,
|
||||
boost::asio::placeholders::bytes_transferred)));
|
||||
}
|
||||
else
|
||||
{
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
void close()
|
||||
{
|
||||
input_.close();
|
||||
}
|
||||
|
||||
private:
|
||||
boost::asio::posix::stream_descriptor input_;
|
||||
boost::asio::io_context::strand thread_safety;
|
||||
boost::asio::streambuf input_buffer_;
|
||||
};
|
||||
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if (argc < 5)
|
||||
{
|
||||
cerr << "usage: tcpproxy_server <local host ip> <local port> <forward host ip> <forward port>" << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const unsigned short local_port = static_cast<unsigned short>(::atoi(argv[2]));
|
||||
const unsigned short forward_port = static_cast<unsigned short>(::atoi(argv[4]));
|
||||
const string local_host = argv[1];
|
||||
const string forward_host = argv[3];
|
||||
|
||||
int threads = 1;
|
||||
char * n_threads_str = getenv("NTHREADS");
|
||||
if (n_threads_str != NULL) threads = ::atoi(n_threads_str);
|
||||
|
||||
boost::asio::io_context ios;
|
||||
|
||||
boost::asio::streambuf buf;
|
||||
boost::asio::posix::stream_descriptor cin_in(ios, ::dup(STDIN_FILENO));
|
||||
boost::asio::read_until(cin_in, buf,'\n');
|
||||
update_config(buf);
|
||||
|
||||
async_updater updater(ios);
|
||||
|
||||
#ifdef DEBUG
|
||||
cerr << "Starting Proxy" << endl;
|
||||
#endif
|
||||
try
|
||||
{
|
||||
tcp_proxy::bridge::acceptor acceptor(ios,
|
||||
local_host, local_port,
|
||||
forward_host, forward_port);
|
||||
|
||||
acceptor.accept_connections();
|
||||
|
||||
if (threads > 1){
|
||||
boost::thread_group tg;
|
||||
for (unsigned i = 0; i < threads; ++i)
|
||||
tg.create_thread(boost::bind(&boost::asio::io_context::run, &ios));
|
||||
|
||||
tg.join_all();
|
||||
}else{
|
||||
ios.run();
|
||||
}
|
||||
}
|
||||
catch(exception& e)
|
||||
{
|
||||
cerr << "Error: " << e.what() << endl;
|
||||
return 1;
|
||||
}
|
||||
#ifdef DEBUG
|
||||
cerr << "Proxy stopped!" << endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -10,7 +10,7 @@ bool unhexlify(std::string const &hex, std::string &newString) {
|
||||
for(int i=0; i< len; i+=2)
|
||||
{
|
||||
std::string byte = hex.substr(i,2);
|
||||
char chr = (char) (int)strtol(byte.c_str(), NULL, 16);
|
||||
char chr = (char) (int)strtol(byte.c_str(), nullptr, 16);
|
||||
newString.push_back(chr);
|
||||
}
|
||||
return true;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from modules.firewall.nftables import FiregexTables
|
||||
from modules.firewall.models import *
|
||||
from modules.firewall.models import Rule, FirewallSettings
|
||||
from utils.sqlite import SQLite
|
||||
from modules.firewall.models import Action
|
||||
|
||||
@@ -131,5 +131,5 @@ class FirewallManager:
|
||||
return self.db.get("allow_dhcp", "1") == "1"
|
||||
|
||||
@drop_invalid.setter
|
||||
def allow_dhcp(self, value):
|
||||
def allow_dhcp_set(self, value):
|
||||
self.db.set("allow_dhcp", "1" if value else "0")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from modules.nfregex.nftables import FiregexTables
|
||||
from utils import ip_parse, run_func
|
||||
from utils import run_func
|
||||
from modules.nfregex.models import Service, Regex
|
||||
import re, os, asyncio
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
nft = FiregexTables()
|
||||
@@ -20,7 +22,8 @@ class RegexFilter:
|
||||
self.regex = regex
|
||||
self.is_case_sensitive = is_case_sensitive
|
||||
self.is_blacklist = is_blacklist
|
||||
if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True)
|
||||
if input_mode == output_mode:
|
||||
input_mode = output_mode = True # (False, False) == (True, True)
|
||||
self.input_mode = input_mode
|
||||
self.output_mode = output_mode
|
||||
self.blocked = blocked_packets
|
||||
@@ -37,8 +40,10 @@ class RegexFilter:
|
||||
update_func = update_func
|
||||
)
|
||||
def compile(self):
|
||||
if isinstance(self.regex, str): self.regex = self.regex.encode()
|
||||
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
|
||||
if isinstance(self.regex, str):
|
||||
self.regex = self.regex.encode()
|
||||
if not isinstance(self.regex, bytes):
|
||||
raise Exception("Invalid Regex Paramether")
|
||||
re.compile(self.regex) # raise re.error if it's invalid!
|
||||
case_sensitive = "1" if self.is_case_sensitive else "0"
|
||||
if self.input_mode:
|
||||
@@ -67,9 +72,9 @@ class FiregexInterceptor:
|
||||
self.srv = srv
|
||||
self.filter_map_lock = asyncio.Lock()
|
||||
self.update_config_lock = asyncio.Lock()
|
||||
input_range, output_range = await self._start_binary()
|
||||
queue_range = await self._start_binary()
|
||||
self.update_task = asyncio.create_task(self.update_blocked())
|
||||
nft.add(self.srv, input_range, output_range)
|
||||
nft.add(self.srv, queue_range)
|
||||
return self
|
||||
|
||||
async def _start_binary(self):
|
||||
@@ -87,7 +92,7 @@ class FiregexInterceptor:
|
||||
line = line_fut.decode()
|
||||
if line.startswith("QUEUES "):
|
||||
params = line.split()
|
||||
return (int(params[2]), int(params[3])), (int(params[5]), int(params[6]))
|
||||
return (int(params[1]), int(params[2]))
|
||||
else:
|
||||
self.process.kill()
|
||||
raise Exception("Invalid binary output")
|
||||
@@ -102,8 +107,10 @@ class FiregexInterceptor:
|
||||
if regex_id in self.filter_map:
|
||||
self.filter_map[regex_id].blocked+=1
|
||||
await self.filter_map[regex_id].update()
|
||||
except asyncio.CancelledError: pass
|
||||
except asyncio.IncompleteReadError: pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.IncompleteReadError:
|
||||
pass
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -135,6 +142,7 @@ class FiregexInterceptor:
|
||||
raw_filters = filter_obj.compile()
|
||||
for filter in raw_filters:
|
||||
res[filter] = filter_obj
|
||||
except Exception: pass
|
||||
except Exception:
|
||||
pass
|
||||
return res
|
||||
|
||||
|
||||
@@ -30,14 +30,15 @@ class ServiceManager:
|
||||
new_filters = set([f.id for f in regexes])
|
||||
#remove old filters
|
||||
for f in old_filters:
|
||||
if not f in new_filters:
|
||||
if f not in new_filters:
|
||||
del self.filters[f]
|
||||
#add new filters
|
||||
for f in new_filters:
|
||||
if not f in old_filters:
|
||||
if f not in old_filters:
|
||||
filter = [ele for ele in regexes if ele.id == f][0]
|
||||
self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater)
|
||||
if self.interceptor: await self.interceptor.reload(self.filters.values())
|
||||
if self.interceptor:
|
||||
await self.interceptor.reload(self.filters.values())
|
||||
|
||||
def __update_status_db(self, status):
|
||||
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id)
|
||||
@@ -114,4 +115,5 @@ class FirewallManager:
|
||||
else:
|
||||
raise ServiceNotFoundException()
|
||||
|
||||
class ServiceNotFoundException(Exception): pass
|
||||
class ServiceNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
@@ -45,36 +45,35 @@ class FiregexTables(NFTableManager):
|
||||
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.output_chain}}},
|
||||
])
|
||||
|
||||
def add(self, srv:Service, queue_range_input, queue_range_output):
|
||||
def add(self, srv:Service, queue_range):
|
||||
|
||||
for ele in self.get():
|
||||
if ele.__eq__(srv): return
|
||||
|
||||
init, end = queue_range_output
|
||||
init, end = queue_range
|
||||
if init > end: init, end = end, init
|
||||
self.cmd({ "insert":{ "rule": {
|
||||
"family": "inet",
|
||||
"table": self.table_name,
|
||||
"chain": self.output_chain,
|
||||
"expr": [
|
||||
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
|
||||
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}},
|
||||
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
|
||||
self.cmd(
|
||||
{ "insert":{ "rule": {
|
||||
"family": "inet",
|
||||
"table": self.table_name,
|
||||
"chain": self.output_chain,
|
||||
"expr": [
|
||||
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
|
||||
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "sport"}}, "op": "==", "right": int(srv.port)}},
|
||||
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
|
||||
]
|
||||
}}})
|
||||
|
||||
init, end = queue_range_input
|
||||
if init > end: init, end = end, init
|
||||
self.cmd({"insert":{"rule":{
|
||||
"family": "inet",
|
||||
"table": self.table_name,
|
||||
"chain": self.input_chain,
|
||||
"expr": [
|
||||
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
|
||||
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}},
|
||||
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
|
||||
]
|
||||
}}})
|
||||
}}},
|
||||
{"insert":{"rule":{
|
||||
"family": "inet",
|
||||
"table": self.table_name,
|
||||
"chain": self.input_chain,
|
||||
"expr": [
|
||||
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_int), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_int)}},
|
||||
{'match': {"left": { "payload": {"protocol": str(srv.proto), "field": "dport"}}, "op": "==", "right": int(srv.port)}},
|
||||
{"queue": {"num": str(init) if init == end else {"range":[init, end] }, "flags": ["bypass"]}}
|
||||
]
|
||||
}}}
|
||||
)
|
||||
|
||||
|
||||
def get(self) -> list[FiregexFilter]:
|
||||
|
||||
@@ -5,7 +5,8 @@ from utils.sqlite import SQLite
|
||||
|
||||
nft = FiregexTables()
|
||||
|
||||
class ServiceNotFoundException(Exception): pass
|
||||
class ServiceNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
class ServiceManager:
|
||||
def __init__(self, srv: Service, db):
|
||||
@@ -29,7 +30,8 @@ class ServiceManager:
|
||||
|
||||
async def refresh(self, srv:Service):
|
||||
self.srv = srv
|
||||
if self.active: await self.restart()
|
||||
if self.active:
|
||||
await self.restart()
|
||||
|
||||
def _set_status(self,active):
|
||||
self.active = active
|
||||
|
||||
@@ -50,7 +50,8 @@ class FiregexTables(NFTableManager):
|
||||
def add(self, srv:Service):
|
||||
|
||||
for ele in self.get():
|
||||
if ele.__eq__(srv): return
|
||||
if ele.__eq__(srv):
|
||||
return
|
||||
|
||||
self.cmd({ "insert":{ "rule": {
|
||||
"family": "inet",
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import re, os, asyncio
|
||||
|
||||
class Filter:
|
||||
def __init__(self, regex, is_case_sensitive=True, is_blacklist=True, c_to_s=False, s_to_c=False, blocked_packets=0, code=None):
|
||||
self.regex = regex
|
||||
self.is_case_sensitive = is_case_sensitive
|
||||
self.is_blacklist = is_blacklist
|
||||
if c_to_s == s_to_c: c_to_s = s_to_c = True # (False, False) == (True, True)
|
||||
self.c_to_s = c_to_s
|
||||
self.s_to_c = s_to_c
|
||||
self.blocked = blocked_packets
|
||||
self.code = code
|
||||
|
||||
def compile(self):
|
||||
if isinstance(self.regex, str): self.regex = self.regex.encode()
|
||||
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
|
||||
re.compile(self.regex) # raise re.error if is invalid!
|
||||
case_sensitive = "1" if self.is_case_sensitive else "0"
|
||||
if self.c_to_s:
|
||||
yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex()
|
||||
if self.s_to_c:
|
||||
yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex()
|
||||
|
||||
class Proxy:
|
||||
def __init__(self, internal_port=0, public_port=0, callback_blocked_update=None, filters=None, public_host="0.0.0.0", internal_host="127.0.0.1"):
|
||||
self.filter_map = {}
|
||||
self.filter_map_lock = asyncio.Lock()
|
||||
self.update_config_lock = asyncio.Lock()
|
||||
self.status_change = asyncio.Lock()
|
||||
self.public_host = public_host
|
||||
self.public_port = public_port
|
||||
self.internal_host = internal_host
|
||||
self.internal_port = internal_port
|
||||
self.filters = set(filters) if filters else set([])
|
||||
self.process = None
|
||||
self.callback_blocked_update = callback_blocked_update
|
||||
|
||||
async def start(self, in_pause=False):
|
||||
await self.status_change.acquire()
|
||||
if not self.isactive():
|
||||
try:
|
||||
self.filter_map = self.compile_filters()
|
||||
filters_codes = self.get_filter_codes() if not in_pause else []
|
||||
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"../proxy")
|
||||
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
proxy_binary_path, str(self.public_host), str(self.public_port), str(self.internal_host), str(self.internal_port),
|
||||
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE
|
||||
)
|
||||
await self.update_config(filters_codes)
|
||||
finally:
|
||||
self.status_change.release()
|
||||
try:
|
||||
while True:
|
||||
buff = await self.process.stdout.readuntil()
|
||||
stdout_line = buff.decode()
|
||||
if stdout_line.startswith("BLOCKED"):
|
||||
regex_id = stdout_line.split()[1]
|
||||
async with self.filter_map_lock:
|
||||
if regex_id in self.filter_map:
|
||||
self.filter_map[regex_id].blocked+=1
|
||||
if self.callback_blocked_update: self.callback_blocked_update(self.filter_map[regex_id])
|
||||
except Exception:
|
||||
return await self.process.wait()
|
||||
else:
|
||||
self.status_change.release()
|
||||
|
||||
|
||||
async def stop(self):
|
||||
async with self.status_change:
|
||||
if self.isactive():
|
||||
self.process.kill()
|
||||
return False
|
||||
return True
|
||||
|
||||
async def restart(self, in_pause=False):
|
||||
status = await self.stop()
|
||||
await self.start(in_pause=in_pause)
|
||||
return status
|
||||
|
||||
async def update_config(self, filters_codes):
|
||||
async with self.update_config_lock:
|
||||
if (self.isactive()):
|
||||
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
|
||||
await self.process.stdin.drain()
|
||||
|
||||
async def reload(self):
|
||||
if self.isactive():
|
||||
async with self.filter_map_lock:
|
||||
self.filter_map = self.compile_filters()
|
||||
filters_codes = self.get_filter_codes()
|
||||
await self.update_config(filters_codes)
|
||||
|
||||
def get_filter_codes(self):
|
||||
filters_codes = list(self.filter_map.keys())
|
||||
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
|
||||
return filters_codes
|
||||
|
||||
def isactive(self):
|
||||
return self.process and self.process.returncode is None
|
||||
|
||||
async def pause(self):
|
||||
if self.isactive():
|
||||
await self.update_config([])
|
||||
else:
|
||||
await self.start(in_pause=True)
|
||||
|
||||
def compile_filters(self):
|
||||
res = {}
|
||||
for filter_obj in self.filters:
|
||||
try:
|
||||
raw_filters = filter_obj.compile()
|
||||
for filter in raw_filters:
|
||||
res[filter] = filter_obj
|
||||
except Exception: pass
|
||||
return res
|
||||
@@ -1,199 +0,0 @@
|
||||
import secrets
|
||||
from modules.regexproxy.proxy import Filter, Proxy
|
||||
import random, socket, asyncio
|
||||
from base64 import b64decode
|
||||
from utils.sqlite import SQLite
|
||||
from utils import socketio_emit
|
||||
|
||||
class STATUS:
|
||||
WAIT = "wait"
|
||||
STOP = "stop"
|
||||
PAUSE = "pause"
|
||||
ACTIVE = "active"
|
||||
|
||||
class ServiceNotFoundException(Exception): pass
|
||||
|
||||
class ServiceManager:
|
||||
def __init__(self, id, db):
|
||||
self.id = id
|
||||
self.db = db
|
||||
self.proxy = Proxy(
|
||||
internal_host="127.0.0.1",
|
||||
callback_blocked_update=self._stats_updater
|
||||
)
|
||||
self.status = STATUS.STOP
|
||||
self.wanted_status = STATUS.STOP
|
||||
self.filters = {}
|
||||
self._update_port_from_db()
|
||||
self._update_filters_from_db()
|
||||
self.lock = asyncio.Lock()
|
||||
self.starter = None
|
||||
|
||||
def _update_port_from_db(self):
|
||||
res = self.db.query("""
|
||||
SELECT
|
||||
public_port,
|
||||
internal_port
|
||||
FROM services WHERE service_id = ?;
|
||||
""", self.id)
|
||||
if len(res) == 0: raise ServiceNotFoundException()
|
||||
self.proxy.internal_port = res[0]["internal_port"]
|
||||
self.proxy.public_port = res[0]["public_port"]
|
||||
|
||||
def _update_filters_from_db(self):
|
||||
res = self.db.query("""
|
||||
SELECT
|
||||
regex, mode, regex_id `id`, is_blacklist,
|
||||
blocked_packets n_packets, is_case_sensitive
|
||||
FROM regexes WHERE service_id = ? AND active=1;
|
||||
""", self.id)
|
||||
|
||||
#Filter check
|
||||
old_filters = set(self.filters.keys())
|
||||
new_filters = set([f["id"] for f in res])
|
||||
|
||||
#remove old filters
|
||||
for f in old_filters:
|
||||
if not f in new_filters:
|
||||
del self.filters[f]
|
||||
|
||||
for f in new_filters:
|
||||
if not f in old_filters:
|
||||
filter_info = [ele for ele in res if ele["id"] == f][0]
|
||||
self.filters[f] = Filter(
|
||||
is_case_sensitive=filter_info["is_case_sensitive"],
|
||||
c_to_s=filter_info["mode"] in ["C","B"],
|
||||
s_to_c=filter_info["mode"] in ["S","B"],
|
||||
is_blacklist=filter_info["is_blacklist"],
|
||||
regex=b64decode(filter_info["regex"]),
|
||||
blocked_packets=filter_info["n_packets"],
|
||||
code=f
|
||||
)
|
||||
self.proxy.filters = list(self.filters.values())
|
||||
|
||||
def __update_status_db(self, status):
|
||||
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.id)
|
||||
|
||||
async def next(self,to):
|
||||
async with self.lock:
|
||||
return await self._next(to)
|
||||
|
||||
async def _next(self, to):
|
||||
if self.status != to:
|
||||
# ACTIVE -> PAUSE or PAUSE -> ACTIVE
|
||||
if (self.status, to) in [(STATUS.ACTIVE, STATUS.PAUSE)]:
|
||||
await self.proxy.pause()
|
||||
self._set_status(to)
|
||||
|
||||
elif (self.status, to) in [(STATUS.PAUSE, STATUS.ACTIVE)]:
|
||||
await self.proxy.reload()
|
||||
self._set_status(to)
|
||||
|
||||
# ACTIVE -> STOP
|
||||
elif (self.status,to) in [(STATUS.ACTIVE, STATUS.STOP), (STATUS.WAIT, STATUS.STOP), (STATUS.PAUSE, STATUS.STOP)]: #Stop proxy
|
||||
if self.starter: self.starter.cancel()
|
||||
await self.proxy.stop()
|
||||
self._set_status(to)
|
||||
|
||||
# STOP -> ACTIVE or STOP -> PAUSE
|
||||
elif (self.status, to) in [(STATUS.STOP, STATUS.ACTIVE), (STATUS.STOP, STATUS.PAUSE)]:
|
||||
self.wanted_status = to
|
||||
self._set_status(STATUS.WAIT)
|
||||
self.__proxy_starter(to)
|
||||
|
||||
|
||||
def _stats_updater(self,filter:Filter):
|
||||
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.code)
|
||||
|
||||
async def update_port(self):
|
||||
async with self.lock:
|
||||
self._update_port_from_db()
|
||||
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
|
||||
next_status = self.status if self.status != STATUS.WAIT else self.wanted_status
|
||||
await self._next(STATUS.STOP)
|
||||
await self._next(next_status)
|
||||
|
||||
def _set_status(self,status):
|
||||
self.status = status
|
||||
self.__update_status_db(status)
|
||||
|
||||
|
||||
async def update_filters(self):
|
||||
async with self.lock:
|
||||
self._update_filters_from_db()
|
||||
if self.status in [STATUS.PAUSE, STATUS.ACTIVE]:
|
||||
await self.proxy.reload()
|
||||
|
||||
def __proxy_starter(self,to):
|
||||
async def func():
|
||||
try:
|
||||
while True:
|
||||
if check_port_is_open(self.proxy.public_port):
|
||||
self._set_status(to)
|
||||
await socketio_emit(["regexproxy"])
|
||||
await self.proxy.start(in_pause=(to==STATUS.PAUSE))
|
||||
self._set_status(STATUS.STOP)
|
||||
await socketio_emit(["regexproxy"])
|
||||
return
|
||||
else:
|
||||
await asyncio.sleep(.5)
|
||||
except asyncio.CancelledError:
|
||||
self._set_status(STATUS.STOP)
|
||||
await self.proxy.stop()
|
||||
self.starter = asyncio.create_task(func())
|
||||
|
||||
class ProxyManager:
|
||||
def __init__(self, db:SQLite):
|
||||
self.db = db
|
||||
self.proxy_table: dict[str, ServiceManager] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
for key in list(self.proxy_table.keys()):
|
||||
await self.remove(key)
|
||||
|
||||
async def remove(self,id):
|
||||
async with self.lock:
|
||||
if id in self.proxy_table:
|
||||
await self.proxy_table[id].next(STATUS.STOP)
|
||||
del self.proxy_table[id]
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
for srv in self.db.query('SELECT service_id, status FROM services;'):
|
||||
srv_id, req_status = srv["service_id"], srv["status"]
|
||||
if srv_id in self.proxy_table:
|
||||
continue
|
||||
|
||||
self.proxy_table[srv_id] = ServiceManager(srv_id,self.db)
|
||||
await self.proxy_table[srv_id].next(req_status)
|
||||
|
||||
def get(self,id) -> ServiceManager:
|
||||
if id in self.proxy_table:
|
||||
return self.proxy_table[id]
|
||||
else:
|
||||
raise ServiceNotFoundException()
|
||||
|
||||
def check_port_is_open(port):
|
||||
try:
|
||||
sock = socket.socket()
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(('0.0.0.0',port))
|
||||
sock.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def gen_service_id(db):
|
||||
while True:
|
||||
res = secrets.token_hex(8)
|
||||
if len(db.query('SELECT 1 FROM services WHERE service_id = ?;', res)) == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
def gen_internal_port(db):
|
||||
while True:
|
||||
res = random.randint(30000, 45000)
|
||||
if len(db.query('SELECT 1 FROM services WHERE internal_port = ?;', res)) == 0:
|
||||
break
|
||||
return res
|
||||
@@ -5,7 +5,7 @@ from utils import ip_parse, ip_family, socketio_emit
|
||||
from utils.models import ResetRequest, StatusMessageModel
|
||||
from modules.firewall.nftables import FiregexTables
|
||||
from modules.firewall.firewall import FirewallManager
|
||||
from modules.firewall.models import *
|
||||
from modules.firewall.models import FirewallSettings, RuleInfo, RuleModel, RuleFormAdd, Mode, Table
|
||||
|
||||
db = SQLite('db/firewall-rules.db', {
|
||||
'rules': {
|
||||
|
||||
@@ -147,7 +147,8 @@ async def get_service_by_id(service_id: str):
|
||||
FROM services s LEFT JOIN regexes r ON s.service_id = r.service_id
|
||||
WHERE s.service_id = ? GROUP BY s.service_id;
|
||||
""", service_id)
|
||||
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
|
||||
if len(res) == 0:
|
||||
raise HTTPException(status_code=400, detail="This service does not exists!")
|
||||
return res[0]
|
||||
|
||||
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
|
||||
@@ -177,7 +178,8 @@ async def service_delete(service_id: str):
|
||||
async def service_rename(service_id: str, form: RenameForm):
|
||||
"""Request to change the name of a specific service"""
|
||||
form.name = refactor_name(form.name)
|
||||
if not form.name: raise HTTPException(status_code=400, detail="The name cannot be empty!")
|
||||
if not form.name:
|
||||
raise HTTPException(status_code=400, detail="The name cannot be empty!")
|
||||
try:
|
||||
db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id)
|
||||
except sqlite3.IntegrityError:
|
||||
@@ -188,7 +190,8 @@ async def service_rename(service_id: str, form: RenameForm):
|
||||
@app.get('/service/{service_id}/regexes', response_model=list[RegexModel])
|
||||
async def get_service_regexe_list(service_id: str):
|
||||
"""Get the list of the regexes of a service"""
|
||||
if not db.query("SELECT 1 FROM services s WHERE s.service_id = ?;", service_id): raise HTTPException(status_code=400, detail="This service does not exists!")
|
||||
if not db.query("SELECT 1 FROM services s WHERE s.service_id = ?;", service_id):
|
||||
raise HTTPException(status_code=400, detail="This service does not exists!")
|
||||
return db.query("""
|
||||
SELECT
|
||||
regex, mode, regex_id `id`, service_id, is_blacklist,
|
||||
@@ -205,7 +208,8 @@ async def get_regex_by_id(regex_id: int):
|
||||
blocked_packets n_packets, is_case_sensitive, active
|
||||
FROM regexes WHERE `id` = ?;
|
||||
""", regex_id)
|
||||
if len(res) == 0: raise HTTPException(status_code=400, detail="This regex does not exists!")
|
||||
if len(res) == 0:
|
||||
raise HTTPException(status_code=400, detail="This regex does not exists!")
|
||||
return res[0]
|
||||
|
||||
@app.get('/regex/{regex_id}/delete', response_model=StatusMessageModel)
|
||||
|
||||
@@ -96,7 +96,8 @@ async def get_service_list():
|
||||
async def get_service_by_id(service_id: str):
|
||||
"""Get info about a specific service using his id"""
|
||||
res = db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_src, ip_dst FROM services WHERE service_id = ?;", service_id)
|
||||
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
|
||||
if len(res) == 0:
|
||||
raise HTTPException(status_code=400, detail="This service does not exists!")
|
||||
return res[0]
|
||||
|
||||
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
|
||||
@@ -125,7 +126,8 @@ async def service_delete(service_id: str):
|
||||
async def service_rename(service_id: str, form: RenameForm):
|
||||
"""Request to change the name of a specific service"""
|
||||
form.name = refactor_name(form.name)
|
||||
if not form.name: raise HTTPException(status_code=400, detail="The name cannot be empty!")
|
||||
if not form.name:
|
||||
raise HTTPException(status_code=400, detail="The name cannot be empty!")
|
||||
try:
|
||||
db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id)
|
||||
except sqlite3.IntegrityError:
|
||||
|
||||
@@ -1,311 +0,0 @@
|
||||
from base64 import b64decode
|
||||
import sqlite3, re
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from modules.regexproxy.utils import STATUS, ProxyManager, gen_internal_port, gen_service_id
|
||||
from utils.sqlite import SQLite
|
||||
from utils.models import ResetRequest, StatusMessageModel
|
||||
from utils import refactor_name, socketio_emit, PortType
|
||||
|
||||
app = APIRouter()
|
||||
db = SQLite("db/regextcpproxy.db",{
|
||||
'services': {
|
||||
'status': 'VARCHAR(100) NOT NULL',
|
||||
'service_id': 'VARCHAR(100) PRIMARY KEY',
|
||||
'internal_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536)',
|
||||
'public_port': 'INT NOT NULL CHECK(internal_port > 0 and internal_port < 65536) UNIQUE',
|
||||
'name': 'VARCHAR(100) NOT NULL UNIQUE'
|
||||
},
|
||||
'regexes': {
|
||||
'regex': 'TEXT NOT NULL',
|
||||
'mode': 'VARCHAR(1) NOT NULL',
|
||||
'service_id': 'VARCHAR(100) NOT NULL',
|
||||
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
|
||||
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
|
||||
'regex_id': 'INTEGER PRIMARY KEY',
|
||||
'is_case_sensitive' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1))',
|
||||
'active' : 'BOOLEAN NOT NULL CHECK (is_case_sensitive IN (0, 1)) DEFAULT 1',
|
||||
'FOREIGN KEY (service_id)':'REFERENCES services (service_id)',
|
||||
},
|
||||
'QUERY':[
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS unique_regex_service ON regexes (regex,service_id,is_blacklist,mode,is_case_sensitive);"
|
||||
]
|
||||
})
|
||||
|
||||
firewall = ProxyManager(db)
|
||||
|
||||
async def reset(params: ResetRequest):
|
||||
if not params.delete:
|
||||
db.backup()
|
||||
await firewall.close()
|
||||
if params.delete:
|
||||
db.delete()
|
||||
db.init()
|
||||
else:
|
||||
db.restore()
|
||||
await firewall.reload()
|
||||
|
||||
|
||||
async def startup():
|
||||
db.init()
|
||||
await firewall.reload()
|
||||
|
||||
async def shutdown():
|
||||
db.backup()
|
||||
await firewall.close()
|
||||
db.disconnect()
|
||||
db.restore()
|
||||
|
||||
async def refresh_frontend(additional:list[str]=[]):
|
||||
await socketio_emit(["regexproxy"]+additional)
|
||||
|
||||
class GeneralStatModel(BaseModel):
|
||||
closed:int
|
||||
regexes: int
|
||||
services: int
|
||||
|
||||
@app.get('/stats', response_model=GeneralStatModel)
|
||||
async def get_general_stats():
|
||||
"""Get firegex general status about services"""
|
||||
return db.query("""
|
||||
SELECT
|
||||
(SELECT COALESCE(SUM(blocked_packets),0) FROM regexes) closed,
|
||||
(SELECT COUNT(*) FROM regexes) regexes,
|
||||
(SELECT COUNT(*) FROM services) services
|
||||
""")[0]
|
||||
|
||||
class ServiceModel(BaseModel):
|
||||
id:str
|
||||
status: str
|
||||
public_port: PortType
|
||||
internal_port: PortType
|
||||
name: str
|
||||
n_regex: int
|
||||
n_packets: int
|
||||
|
||||
@app.get('/services', response_model=list[ServiceModel])
|
||||
async def get_service_list():
|
||||
"""Get the list of existent firegex services"""
|
||||
return db.query("""
|
||||
SELECT
|
||||
s.service_id `id`,
|
||||
s.status status,
|
||||
s.public_port public_port,
|
||||
s.internal_port internal_port,
|
||||
s.name name,
|
||||
COUNT(r.regex_id) n_regex,
|
||||
COALESCE(SUM(r.blocked_packets),0) n_packets
|
||||
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id
|
||||
GROUP BY s.service_id;
|
||||
""")
|
||||
|
||||
@app.get('/service/{service_id}', response_model=ServiceModel)
|
||||
async def get_service_by_id(service_id: str):
|
||||
"""Get info about a specific service using his id"""
|
||||
res = db.query("""
|
||||
SELECT
|
||||
s.service_id `id`,
|
||||
s.status status,
|
||||
s.public_port public_port,
|
||||
s.internal_port internal_port,
|
||||
s.name name,
|
||||
COUNT(r.regex_id) n_regex,
|
||||
COALESCE(SUM(r.blocked_packets),0) n_packets
|
||||
FROM services s LEFT JOIN regexes r ON r.service_id = s.service_id WHERE s.service_id = ?
|
||||
GROUP BY s.service_id;
|
||||
""", service_id)
|
||||
if len(res) == 0: raise HTTPException(status_code=400, detail="This service does not exists!")
|
||||
return res[0]
|
||||
|
||||
@app.get('/service/{service_id}/stop', response_model=StatusMessageModel)
|
||||
async def service_stop(service_id: str):
|
||||
"""Request the stop of a specific service"""
|
||||
await firewall.get(service_id).next(STATUS.STOP)
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.get('/service/{service_id}/pause', response_model=StatusMessageModel)
|
||||
async def service_pause(service_id: str):
|
||||
"""Request the pause of a specific service"""
|
||||
await firewall.get(service_id).next(STATUS.PAUSE)
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.get('/service/{service_id}/start', response_model=StatusMessageModel)
|
||||
async def service_start(service_id: str):
|
||||
"""Request the start of a specific service"""
|
||||
await firewall.get(service_id).next(STATUS.ACTIVE)
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.get('/service/{service_id}/delete', response_model=StatusMessageModel)
|
||||
async def service_delete(service_id: str):
|
||||
"""Request the deletion of a specific service"""
|
||||
db.query('DELETE FROM services WHERE service_id = ?;', service_id)
|
||||
db.query('DELETE FROM regexes WHERE service_id = ?;', service_id)
|
||||
await firewall.remove(service_id)
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
|
||||
@app.get('/service/{service_id}/regen-port', response_model=StatusMessageModel)
|
||||
async def regen_service_port(service_id: str):
|
||||
"""Request the regeneration of a the internal proxy port of a specific service"""
|
||||
db.query('UPDATE services SET internal_port = ? WHERE service_id = ?;', gen_internal_port(db), service_id)
|
||||
await firewall.get(service_id).update_port()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
class ChangePortForm(BaseModel):
|
||||
port: int|None = None
|
||||
internalPort: int|None = None
|
||||
|
||||
@app.post('/service/{service_id}/change-ports', response_model=StatusMessageModel)
|
||||
async def change_service_ports(service_id: str, change_port:ChangePortForm):
|
||||
"""Choose and change the ports of the service"""
|
||||
if change_port.port is None and change_port.internalPort is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid Request!")
|
||||
try:
|
||||
sql_inj = ""
|
||||
query:list[str|int] = []
|
||||
if not change_port.port is None:
|
||||
sql_inj+=" public_port = ? "
|
||||
query.append(change_port.port)
|
||||
if not change_port.port is None and not change_port.internalPort is None:
|
||||
sql_inj += ","
|
||||
if not change_port.internalPort is None:
|
||||
sql_inj+=" internal_port = ? "
|
||||
query.append(change_port.internalPort)
|
||||
query.append(service_id)
|
||||
db.query(f'UPDATE services SET {sql_inj} WHERE service_id = ?;', *query)
|
||||
except sqlite3.IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Port of the service has been already assigned to another service")
|
||||
await firewall.get(service_id).update_port()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
class RegexModel(BaseModel):
|
||||
regex:str
|
||||
mode:str
|
||||
id:int
|
||||
service_id:str
|
||||
is_blacklist: bool
|
||||
n_packets:int
|
||||
is_case_sensitive:bool
|
||||
active:bool
|
||||
|
||||
@app.get('/service/{service_id}/regexes', response_model=list[RegexModel])
|
||||
async def get_service_regexe_list(service_id: str):
|
||||
"""Get the list of the regexes of a service"""
|
||||
if not db.query("SELECT 1 FROM services s WHERE s.service_id = ?;", service_id): raise HTTPException(status_code=400, detail="This service does not exists!")
|
||||
return db.query("""
|
||||
SELECT
|
||||
regex, mode, regex_id `id`, service_id, is_blacklist,
|
||||
blocked_packets n_packets, is_case_sensitive, active
|
||||
FROM regexes WHERE service_id = ?;
|
||||
""", service_id)
|
||||
|
||||
@app.get('/regex/{regex_id}', response_model=RegexModel)
|
||||
async def get_regex_by_id(regex_id: int):
|
||||
"""Get regex info using his id"""
|
||||
res = db.query("""
|
||||
SELECT
|
||||
regex, mode, regex_id `id`, service_id, is_blacklist,
|
||||
blocked_packets n_packets, is_case_sensitive, active
|
||||
FROM regexes WHERE `id` = ?;
|
||||
""", regex_id)
|
||||
if len(res) == 0: raise HTTPException(status_code=400, detail="This regex does not exists!")
|
||||
return res[0]
|
||||
|
||||
@app.get('/regex/{regex_id}/delete', response_model=StatusMessageModel)
|
||||
async def regex_delete(regex_id: int):
|
||||
"""Delete a regex using his id"""
|
||||
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
|
||||
if len(res) != 0:
|
||||
db.query('DELETE FROM regexes WHERE regex_id = ?;', regex_id)
|
||||
await firewall.get(res[0]["service_id"]).update_filters()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.get('/regex/{regex_id}/enable', response_model=StatusMessageModel)
|
||||
async def regex_enable(regex_id: int):
|
||||
"""Request the enabling of a regex"""
|
||||
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
|
||||
if len(res) != 0:
|
||||
db.query('UPDATE regexes SET active=1 WHERE regex_id = ?;', regex_id)
|
||||
await firewall.get(res[0]["service_id"]).update_filters()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.get('/regex/{regex_id}/disable', response_model=StatusMessageModel)
|
||||
async def regex_disable(regex_id: int):
|
||||
"""Request the deactivation of a regex"""
|
||||
res = db.query('SELECT * FROM regexes WHERE regex_id = ?;', regex_id)
|
||||
if len(res) != 0:
|
||||
db.query('UPDATE regexes SET active=0 WHERE regex_id = ?;', regex_id)
|
||||
await firewall.get(res[0]["service_id"]).update_filters()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
class RegexAddForm(BaseModel):
|
||||
service_id: str
|
||||
regex: str
|
||||
mode: str
|
||||
active: bool|None = None
|
||||
is_blacklist: bool
|
||||
is_case_sensitive: bool
|
||||
|
||||
@app.post('/regexes/add', response_model=StatusMessageModel)
|
||||
async def add_new_regex(form: RegexAddForm):
|
||||
"""Add a new regex"""
|
||||
try:
|
||||
re.compile(b64decode(form.regex))
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid regex")
|
||||
try:
|
||||
db.query("INSERT INTO regexes (service_id, regex, is_blacklist, mode, is_case_sensitive, active ) VALUES (?, ?, ?, ?, ?, ?);",
|
||||
form.service_id, form.regex, form.is_blacklist, form.mode, form.is_case_sensitive, True if form.active is None else form.active )
|
||||
except sqlite3.IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="An identical regex already exists")
|
||||
await firewall.get(form.service_id).update_filters()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
class ServiceAddForm(BaseModel):
|
||||
name: str
|
||||
port: PortType
|
||||
internalPort: int|None = None
|
||||
|
||||
class ServiceAddStatus(BaseModel):
|
||||
status:str
|
||||
id: str|None = None
|
||||
|
||||
class RenameForm(BaseModel):
|
||||
name:str
|
||||
|
||||
@app.post('/service/{service_id}/rename', response_model=StatusMessageModel)
|
||||
async def service_rename(service_id: str, form: RenameForm):
|
||||
"""Request to change the name of a specific service"""
|
||||
form.name = refactor_name(form.name)
|
||||
if not form.name: raise HTTPException(status_code=400, detail="The name cannot be empty!")
|
||||
try:
|
||||
db.query('UPDATE services SET name=? WHERE service_id = ?;', form.name, service_id)
|
||||
except sqlite3.IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="The name is already used!")
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok'}
|
||||
|
||||
@app.post('/services/add', response_model=ServiceAddStatus)
|
||||
async def add_new_service(form: ServiceAddForm):
|
||||
"""Add a new service"""
|
||||
serv_id = gen_service_id(db)
|
||||
form.name = refactor_name(form.name)
|
||||
try:
|
||||
internal_port = form.internalPort if form.internalPort else gen_internal_port(db)
|
||||
db.query("INSERT INTO services (name, service_id, internal_port, public_port, status) VALUES (?, ?, ?, ?, ?)",
|
||||
form.name, serv_id, internal_port, form.port, 'stop')
|
||||
except sqlite3.IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Name or/and ports of the service has been already assigned to another service")
|
||||
await firewall.reload()
|
||||
await refresh_frontend()
|
||||
return {'status': 'ok', "id": serv_id }
|
||||
@@ -1,10 +1,13 @@
|
||||
import asyncio
|
||||
from ipaddress import ip_address, ip_interface
|
||||
import os, socket, psutil, sys, nftables
|
||||
import os
|
||||
import socket
|
||||
import psutil
|
||||
import sys
|
||||
import nftables
|
||||
from fastapi_socketio import SocketManager
|
||||
from fastapi import Path
|
||||
from typing import Annotated
|
||||
import json
|
||||
|
||||
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
|
||||
|
||||
@@ -31,7 +34,8 @@ async def socketio_emit(elements:list[str]):
|
||||
|
||||
def refactor_name(name:str):
|
||||
name = name.strip()
|
||||
while " " in name: name = name.replace(" "," ")
|
||||
while " " in name:
|
||||
name = name.replace(" "," ")
|
||||
return name
|
||||
|
||||
class SysctlManager:
|
||||
@@ -125,8 +129,10 @@ class NFTableManager(Singleton):
|
||||
|
||||
def cmd(self, *cmds):
|
||||
code, out, err = self.raw_cmd(*cmds)
|
||||
if code == 0: return out
|
||||
else: raise Exception(err)
|
||||
if code == 0:
|
||||
return out
|
||||
else:
|
||||
raise Exception(err)
|
||||
|
||||
def init(self):
|
||||
self.reset()
|
||||
@@ -138,8 +144,10 @@ class NFTableManager(Singleton):
|
||||
|
||||
def list_rules(self, tables = None, chains = None):
|
||||
for filter in [ele["rule"] for ele in self.raw_list() if "rule" in ele ]:
|
||||
if tables and filter["table"] not in tables: continue
|
||||
if chains and filter["chain"] not in chains: continue
|
||||
if tables and filter["table"] not in tables:
|
||||
continue
|
||||
if chains and filter["chain"] not in chains:
|
||||
continue
|
||||
yield filter
|
||||
|
||||
def raw_list(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
|
||||
import os, httpx
|
||||
import os
|
||||
import httpx
|
||||
from typing import Callable
|
||||
from fastapi import APIRouter
|
||||
from starlette.responses import StreamingResponse
|
||||
@@ -31,7 +32,8 @@ def frontend_deploy(app):
|
||||
return await frontend_debug_proxy(full_path)
|
||||
except Exception:
|
||||
return {"details":"Frontend not started at "+f"http://127.0.0.1:{os.getenv('F_PORT','5173')}"}
|
||||
else: return await react_deploy(full_path)
|
||||
else:
|
||||
return await react_deploy(full_path)
|
||||
|
||||
def list_routers():
|
||||
return [ele[:-3] for ele in list_files(ROUTERS_DIR) if ele != "__init__.py" and " " not in ele and ele.endswith(".py")]
|
||||
@@ -79,9 +81,12 @@ def load_routers(app):
|
||||
if router.shutdown:
|
||||
shutdowns.append(router.shutdown)
|
||||
async def reset(reset_option:ResetRequest):
|
||||
for func in resets: await run_func(func, reset_option)
|
||||
for func in resets:
|
||||
await run_func(func, reset_option)
|
||||
async def startup():
|
||||
for func in startups: await run_func(func)
|
||||
for func in startups:
|
||||
await run_func(func)
|
||||
async def shutdown():
|
||||
for func in shutdowns: await run_func(func)
|
||||
for func in shutdowns:
|
||||
await run_func(func)
|
||||
return reset, startup, shutdown
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json, sqlite3, os
|
||||
import json
|
||||
import sqlite3
|
||||
import os
|
||||
from hashlib import md5
|
||||
|
||||
class SQLite():
|
||||
@@ -15,8 +17,10 @@ class SQLite():
|
||||
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
|
||||
except Exception:
|
||||
path_name = os.path.dirname(self.db_name)
|
||||
if not os.path.exists(path_name): os.makedirs(path_name)
|
||||
with open(self.db_name, 'x'): pass
|
||||
if not os.path.exists(path_name):
|
||||
os.makedirs(path_name)
|
||||
with open(self.db_name, 'x'):
|
||||
pass
|
||||
self.conn = sqlite3.connect(self.db_name, check_same_thread = False)
|
||||
def dict_factory(cursor, row):
|
||||
d = {}
|
||||
@@ -36,13 +40,15 @@ class SQLite():
|
||||
with open(self.db_name, "wb") as f:
|
||||
f.write(self.__backup)
|
||||
self.__backup = None
|
||||
if were_active: self.connect()
|
||||
if were_active:
|
||||
self.connect()
|
||||
|
||||
def delete_backup(self):
|
||||
self.__backup = None
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.conn: self.conn.close()
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
def create_schema(self, tables = {}) -> None:
|
||||
@@ -50,9 +56,11 @@ class SQLite():
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("CREATE TABLE IF NOT EXISTS main.keys_values(key VARCHAR(100) PRIMARY KEY, value VARCHAR(100) NOT NULL);")
|
||||
for t in tables:
|
||||
if t == "QUERY": continue
|
||||
if t == "QUERY":
|
||||
continue
|
||||
cur.execute('CREATE TABLE IF NOT EXISTS main.{}({});'.format(t, ''.join([(c + ' ' + tables[t][c] + ', ') for c in tables[t]])[:-2]))
|
||||
if "QUERY" in tables: [cur.execute(qry) for qry in tables["QUERY"]]
|
||||
if "QUERY" in tables:
|
||||
[cur.execute(qry) for qry in tables["QUERY"]]
|
||||
cur.close()
|
||||
|
||||
def query(self, query, *values):
|
||||
@@ -82,8 +90,10 @@ class SQLite():
|
||||
raise e
|
||||
finally:
|
||||
cur.close()
|
||||
try: self.conn.commit()
|
||||
except Exception: pass
|
||||
try:
|
||||
self.conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def delete(self):
|
||||
self.disconnect()
|
||||
@@ -92,7 +102,8 @@ class SQLite():
|
||||
def init(self):
|
||||
self.connect()
|
||||
try:
|
||||
if self.get('DB_VERSION') != self.DB_VER: raise Exception("DB_VERSION is not correct")
|
||||
if self.get('DB_VERSION') != self.DB_VER:
|
||||
raise Exception("DB_VERSION is not correct")
|
||||
except Exception:
|
||||
self.delete()
|
||||
self.connect()
|
||||
|
||||
Reference in New Issue
Block a user