Python integration with c++ binary (not totally working yet)

This commit is contained in:
DomySh
2022-07-18 18:52:14 +02:00
parent 02fe8f0064
commit 2a5be65feb
12 changed files with 594 additions and 866 deletions

6
.gitignore vendored
View File

@@ -1,7 +1,9 @@
**/*.pyc **/*.pyc
**/__pycache__/ **/__pycache__/
**/.vscode/** **/.vscode/**
/.mypy_cache/** **/.vscode/
**/.mypy_cache/**
**/.mypy_cache/
**/node_modules **/node_modules
**/.pnp **/.pnp
@@ -12,7 +14,7 @@
/backend/db/firegex.db /backend/db/firegex.db
/backend/db/firegex.db-journal /backend/db/firegex.db-journal
/backend/nfqueue/nfqueue /backend/modules/cppqueue
docker-compose.yml docker-compose.yml
# misc # misc

View File

@@ -11,12 +11,12 @@ RUN git clone --branch release https://github.com/jpcre2/jpcre2
WORKDIR /tmp/jpcre2 WORKDIR /tmp/jpcre2
RUN ./configure; make; make install RUN ./configure; make; make install
RUN mkdir /execute/ RUN mkdir -p /execute/modules
WORKDIR /execute WORKDIR /execute
COPY ./backend/nfqueue /execute/nfqueue COPY ./backend/nfqueue /execute/nfqueue
RUN g++ nfqueue/nfqueue.cpp -o nfqueue/nfqueue -O3 -march=native -lnetfilter_queue -pthread -lpcre2-8 -ltins -lmnl -lnfnetlink RUN g++ nfqueue/nfqueue.cpp -o modules/cppqueue -std=c++20 -O3 -march=native -lnetfilter_queue -pthread -lpcre2-8 -ltins -lmnl -lnfnetlink
ADD ./backend/requirements.txt /execute/requirements.txt ADD ./backend/requirements.txt /execute/requirements.txt
RUN pip install --no-cache-dir -r /execute/requirements.txt RUN pip install --no-cache-dir -r /execute/requirements.txt

View File

@@ -48,7 +48,7 @@ async def updater(): pass
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
db.init() db.init()
await firewall.init(refresh_frontend) await firewall.init()
await refresh_frontend() await refresh_frontend()
if not JWT_SECRET(): db.put("secret", secrets.token_hex(32)) if not JWT_SECRET(): db.put("secret", secrets.token_hex(32))

View File

@@ -1,12 +1,11 @@
from typing import List from typing import Dict, List, Set
from pypacker import interceptor
from pypacker.layer3 import ip, ip6
from pypacker.layer4 import tcp, udp
from ipaddress import ip_interface from ipaddress import ip_interface
from modules.iptables import IPTables from modules.iptables import IPTables
import os, traceback
from modules.sqlite import Service from modules.sqlite import Service
import re, os, asyncio
import traceback
from modules.sqlite import Regex
class FilterTypes: class FilterTypes:
INPUT = "FIREGEX-INPUT" INPUT = "FIREGEX-INPUT"
@@ -15,14 +14,13 @@ class FilterTypes:
QUEUE_BASE_NUM = 1000 QUEUE_BASE_NUM = 1000
class FiregexFilter(): class FiregexFilter():
def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None, func=None): def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None):
self.target = target self.target = target
self.id = int(id) if id else None self.id = int(id) if id else None
self.queue = queue self.queue = queue
self.proto = proto self.proto = proto
self.port = int(port) self.port = int(port)
self.ip_int = str(ip_int) self.ip_int = str(ip_int)
self.func = func
def __eq__(self, o: object) -> bool: def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter): if isinstance(o, FiregexFilter):
@@ -35,16 +33,6 @@ class FiregexFilter():
def ipv4(self): def ipv4(self):
return ip_interface(self.ip_int).version == 4 return ip_interface(self.ip_int).version == 4
def input_func(self):
def none(pkt): return True
def wrap(pkt): return self.func(pkt, True)
return wrap if self.func else none
def output_func(self):
def none(pkt): return True
def wrap(pkt): return self.func(pkt, False)
return wrap if self.func else none
class FiregexTables(IPTables): class FiregexTables(IPTables):
def __init__(self, ipv6=False): def __init__(self, ipv6=False):
@@ -108,9 +96,9 @@ class FiregexTables(IPTables):
)) ))
return res return res
def add(self, filter:FiregexFilter): async def add(self, filter:FiregexFilter):
if filter in self.get(): return None if filter in self.get(): return None
return FiregexInterceptor( iptables=self, filter=filter, n_threads=int(os.getenv("N_THREADS_NFQUEUE","1"))) return await FiregexInterceptor.start( iptables=self, filter=filter, n_queues=int(os.getenv("N_THREADS_NFQUEUE","1")))
def delete_all(self): def delete_all(self):
for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]: for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]:
@@ -120,52 +108,143 @@ class FiregexTables(IPTables):
for filter in self.get(): for filter in self.get():
if filter.port == srv.port and filter.proto == srv.proto and ip_interface(filter.ip_int) == ip_interface(srv.ip_int): if filter.port == srv.port and filter.proto == srv.proto and ip_interface(filter.ip_int) == ip_interface(srv.ip_int):
self.delete_rule(filter.target, filter.id) self.delete_rule(filter.target, filter.id)
class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
is_blacklist=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
id=None,
update_func = None
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
self.id = id
self.update_func = update_func
self.compiled_regex = self.compile()
@classmethod
def from_regex(cls, regex:Regex, update_func = None):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"],
update_func = update_func
)
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
re.compile(self.regex) # raise re.error if it's invalid!
case_sensitive = "1" if self.is_case_sensitive else "0"
if self.input_mode:
yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex()
if self.output_mode:
yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex()
async def update(self):
if self.update_func:
if asyncio.iscoroutinefunction(self.update_func): await self.update_func(self)
else: self.update_func(self)
class FiregexInterceptor: class FiregexInterceptor:
def __init__(self, iptables: FiregexTables, filter: FiregexFilter, n_threads:int = 1):
def __init__(self):
self.filter:FiregexFilter
self.ipv6:bool
self.filter_map_lock:asyncio.Lock
self.filter_map: Dict[str, RegexFilter]
self.regex_filters: Set[RegexFilter]
self.update_config_lock:asyncio.Lock
self.process:asyncio.subprocess.Process
self.n_queues:int
self.update_task: asyncio.Task
self.iptables:FiregexTables
@classmethod
async def start(cls, iptables: FiregexTables, filter: FiregexFilter, n_queues:int = 1):
self = cls()
self.filter = filter self.filter = filter
self.n_queues = n_queues
self.iptables = iptables
self.ipv6 = self.filter.ipv6() self.ipv6 = self.filter.ipv6()
self.itor_input, codes = self._start_queue(filter.input_func(), n_threads) self.filter_map_lock = asyncio.Lock()
iptables.add_input(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) self.update_config_lock = asyncio.Lock()
self.itor_output, codes = self._start_queue(filter.output_func(), n_threads) input_range, output_range = await self._start_binary()
iptables.add_output(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) self.update_task = asyncio.create_task(self.update_blocked())
self.iptables.add_input(queue_range=input_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
self.iptables.add_output(queue_range=output_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int)
return self
async def _start_binary(self):
proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./cppqueue")
self.process = await asyncio.create_subprocess_exec(
proxy_binary_path, str(self.n_queues),
stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE
)
line_fut = self.process.stdout.readuntil()
try:
line_fut = await asyncio.wait_for(line_fut, timeout=1)
except asyncio.TimeoutError:
self.process.kill()
raise Exception("Invalid binary output")
line = line_fut.decode()
if line.startswith("QUEUES "):
params = line.split()
return (int(params[2]), int(params[3])), (int(params[5]), int(params[6]))
else:
self.process.kill()
raise Exception("Invalid binary output")
def _start_queue(self,func,n_threads): async def update_blocked(self):
def func_wrap(ll_data, ll_proto_id, data, ctx, *args): try:
pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data) while True:
try: line = (await self.process.stdout.readuntil()).decode()
pkt_data = None if line.startswith("BLOCKED"):
if not pkt_parsed[tcp.TCP] is None: regex_id = line.split()[1]
pkt_data = pkt_parsed[tcp.TCP].body_bytes async with self.filter_map_lock:
elif not pkt_parsed[udp.UDP] is None: if regex_id in self.filter_map:
pkt_data = pkt_parsed[udp.UDP].body_bytes self.filter_map[regex_id].blocked+=1
if pkt_data: await self.filter_map[regex_id].update()
if func(pkt_data): except asyncio.CancelledError: pass
return data, interceptor.NF_ACCEPT except asyncio.IncompleteReadError: pass
elif pkt_parsed[tcp.TCP]: except Exception:
pkt_parsed[tcp.TCP].flags &= 0x00 traceback.print_exc()
pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK
pkt_parsed[tcp.TCP].body_bytes = b""
return pkt_parsed.bin(), interceptor.NF_ACCEPT
else: return b"", interceptor.NF_DROP
else: return data, interceptor.NF_ACCEPT
except Exception:
traceback.print_exc()
return data, interceptor.NF_ACCEPT
ictor = interceptor.Interceptor()
starts = QUEUE_BASE_NUM
while True:
if starts >= 65536:
raise Exception("Netfilter queue is full!")
queue_ids = list(range(starts,starts+n_threads))
try:
ictor.start(func_wrap, queue_ids=queue_ids)
break
except interceptor.UnableToBindException as e:
starts = e.queue_id + 1
return ictor, (starts, starts+n_threads-1)
def stop(self): async def stop(self):
self.itor_input.stop() self.update_task.cancel()
self.itor_output.stop() self.process.kill()
async def _update_config(self, filters_codes):
async with self.update_config_lock:
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain()
async def reload(self, filters:List[RegexFilter]):
async with self.filter_map_lock:
self.filter_map = self.compile_filters(filters)
filters_codes = self.get_filter_codes()
await self._update_config(filters_codes)
def get_filter_codes(self):
filters_codes = list(self.filter_map.keys())
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
return filters_codes
def compile_filters(self, filters:List[RegexFilter]):
res = {}
for filter_obj in filters:
try:
raw_filters = filter_obj.compile()
for filter in raw_filters:
res[filter] = filter_obj
except Exception: pass
return res

View File

@@ -1,6 +1,6 @@
import traceback, asyncio, pcre import traceback, asyncio
from typing import Dict from typing import Dict
from modules.firegex import FiregexFilter, FiregexTables from modules.firegex import FiregexFilter, FiregexTables, RegexFilter
from modules.sqlite import Regex, SQLite, Service from modules.sqlite import Regex, SQLite, Service
class STATUS: class STATUS:
@@ -12,17 +12,8 @@ class FirewallManager:
self.db = db self.db = db
self.proxy_table: Dict[str, ServiceManager] = {} self.proxy_table: Dict[str, ServiceManager] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.updater_task = None
def init_updater(self, callback = None):
if not self.updater_task:
self.updater_task = asyncio.create_task(self._stats_updater(callback))
def close_updater(self):
if self.updater_task: self.updater_task.cancel()
async def close(self): async def close(self):
self.close_updater()
if self.updater_task: self.updater_task.cancel() if self.updater_task: self.updater_task.cancel()
for key in list(self.proxy_table.keys()): for key in list(self.proxy_table.keys()):
await self.remove(key) await self.remove(key)
@@ -33,8 +24,7 @@ class FirewallManager:
await self.proxy_table[srv_id].next(STATUS.STOP) await self.proxy_table[srv_id].next(STATUS.STOP)
del self.proxy_table[srv_id] del self.proxy_table[srv_id]
async def init(self, callback = None): async def init(self):
self.init_updater(callback)
await self.reload() await self.reload()
async def reload(self): async def reload(self):
@@ -43,7 +33,6 @@ class FirewallManager:
srv = Service.from_dict(srv) srv = Service.from_dict(srv)
if srv.id in self.proxy_table: if srv.id in self.proxy_table:
continue continue
self.proxy_table[srv.id] = ServiceManager(srv, self.db) self.proxy_table[srv.id] = ServiceManager(srv, self.db)
await self.proxy_table[srv.id].next(srv.status) await self.proxy_table[srv.id].next(srv.status)
@@ -71,42 +60,6 @@ class FirewallManager:
class ServiceNotFoundException(Exception): pass class ServiceNotFoundException(Exception): pass
class RegexFilter:
def __init__(
self, regex,
is_case_sensitive=True,
is_blacklist=True,
input_mode=False,
output_mode=False,
blocked_packets=0,
id=None
):
self.regex = regex
self.is_case_sensitive = is_case_sensitive
self.is_blacklist = is_blacklist
if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True)
self.input_mode = input_mode
self.output_mode = output_mode
self.blocked = blocked_packets
self.id = id
self.compiled_regex = self.compile()
@classmethod
def from_regex(cls, regex:Regex):
return cls(
id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive,
is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets,
input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"]
)
def compile(self):
if isinstance(self.regex, str): self.regex = self.regex.encode()
if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether")
return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex)
def check(self, data):
return True if self.compiled_regex.search(data) else False
class ServiceManager: class ServiceManager:
def __init__(self, srv: Service, db): def __init__(self, srv: Service, db):
self.srv = srv self.srv = srv
@@ -114,12 +67,10 @@ class ServiceManager:
self.firegextable = FiregexTables(self.srv.ipv6) self.firegextable = FiregexTables(self.srv.ipv6)
self.status = STATUS.STOP self.status = STATUS.STOP
self.filters: Dict[int, FiregexFilter] = {} self.filters: Dict[int, FiregexFilter] = {}
self._update_filters_from_db()
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.interceptor = None self.interceptor = None
# TODO I don't like so much this method async def _update_filters_from_db(self):
def _update_filters_from_db(self):
regexes = [ regexes = [
Regex.from_dict(ele) for ele in Regex.from_dict(ele) for ele in
self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id) self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id)
@@ -127,17 +78,16 @@ class ServiceManager:
#Filter check #Filter check
old_filters = set(self.filters.keys()) old_filters = set(self.filters.keys())
new_filters = set([f.id for f in regexes]) new_filters = set([f.id for f in regexes])
#remove old filters #remove old filters
for f in old_filters: for f in old_filters:
if not f in new_filters: if not f in new_filters:
del self.filters[f] del self.filters[f]
#add new filters #add new filters
for f in new_filters: for f in new_filters:
if not f in old_filters: if not f in old_filters:
filter = [ele for ele in regexes if ele.id == f][0] filter = [ele for ele in regexes if ele.id == f][0]
self.filters[f] = RegexFilter.from_regex(filter) self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater)
if self.interceptor: await self.interceptor.reload(self.filters.values())
def __update_status_db(self, status): def __update_status_db(self, status):
self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id) self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id)
@@ -145,49 +95,36 @@ class ServiceManager:
async def next(self,to): async def next(self,to):
async with self.lock: async with self.lock:
if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP):
self.stop() await self.stop()
self._set_status(to) self._set_status(to)
# PAUSE -> ACTIVE # PAUSE -> ACTIVE
elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE): elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE):
self.restart() await self.restart()
def _stats_updater(self,filter:RegexFilter): def _stats_updater(self,filter:RegexFilter):
self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id) self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id)
def update_stats(self):
for ele in self.filters.values():
self._stats_updater(ele)
def _set_status(self,status): def _set_status(self,status):
self.status = status self.status = status
self.__update_status_db(status) self.__update_status_db(status)
def start(self): async def start(self):
if not self.interceptor: if not self.interceptor:
self.firegextable.delete_by_srv(self.srv) self.firegextable.delete_by_srv(self.srv)
def regex_filter(pkt, by_client): self.interceptor = await self.firegextable.add(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int))
try: await self._update_filters_from_db()
for filter in self.filters.values():
if (by_client and filter.input_mode) or (not by_client and filter.output_mode):
match = filter.check(pkt)
if (filter.is_blacklist and match) or (not filter.is_blacklist and not match):
filter.blocked+=1
return False
except IndexError: pass
return True
self.interceptor = self.firegextable.add(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int, func=regex_filter))
self._set_status(STATUS.ACTIVE) self._set_status(STATUS.ACTIVE)
def stop(self): async def stop(self):
self.firegextable.delete_by_srv(self.srv) self.firegextable.delete_by_srv(self.srv)
if self.interceptor: if self.interceptor:
self.interceptor.stop() await self.interceptor.stop()
self.interceptor = None self.interceptor = None
def restart(self): async def restart(self):
self.stop() await self.stop()
self.start() await self.start()
async def update_filters(self): async def update_filters(self):
async with self.lock: async with self.lock:
self._update_filters_from_db() await self._update_filters_from_db()

View File

@@ -0,0 +1,294 @@
#include <linux/netfilter/nfnetlink_queue.h>
#include <libnetfilter_queue/libnetfilter_queue.h>
#include <linux/netfilter/nfnetlink_conntrack.h>
#include <tins/tins.h>
#include <libmnl/libmnl.h>
#include <linux/netfilter.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/types.h>
#include <stdexcept>
#include <thread>
#ifndef NETFILTER_CLASSES_HPP
#define NETFILTER_CLASSES_HPP
typedef bool NetFilterQueueCallback(const uint8_t*,uint32_t);
Tins::PDU * find_transport_layer(Tins::PDU* pkt){
while(pkt != NULL){
if (pkt->pdu_type() == Tins::PDU::TCP || pkt->pdu_type() == Tins::PDU::UDP) {
return pkt;
}
pkt = pkt->inner_pdu();
}
return pkt;
}
template <NetFilterQueueCallback callback_func>
class NetfilterQueue {
public:
size_t BUF_SIZE = 0xffff + (MNL_SOCKET_BUFFER_SIZE/2);
char *buf = NULL;
unsigned int portid;
u_int16_t queue_num;
struct mnl_socket* nl = NULL;
NetfilterQueue(u_int16_t queue_num): queue_num(queue_num) {
nl = mnl_socket_open(NETLINK_NETFILTER);
if (nl == NULL) { throw std::runtime_error( "mnl_socket_open" );}
if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
mnl_socket_close(nl);
throw std::runtime_error( "mnl_socket_bind" );
}
portid = mnl_socket_get_portid(nl);
buf = (char*) malloc(BUF_SIZE);
if (!buf) {
mnl_socket_close(nl);
throw std::runtime_error( "allocate receive buffer" );
}
if (send_config_cmd(NFQNL_CFG_CMD_BIND) < 0) {
_clear();
throw std::runtime_error( "mnl_socket_send" );
}
//TEST if BIND was successful
if (send_config_cmd(NFQNL_CFG_CMD_NONE) < 0) { // SEND A NONE cmmand to generate an error meessage
_clear();
throw std::runtime_error( "mnl_socket_send" );
}
if (recv_packet() == -1) { //RECV the error message
_clear();
throw std::runtime_error( "mnl_socket_recvfrom" );
}
struct nlmsghdr *nlh = (struct nlmsghdr *) buf;
if (nlh->nlmsg_type != NLMSG_ERROR) {
_clear();
throw std::runtime_error( "unexpected packet from kernel (expected NLMSG_ERROR packet)" );
}
//nfqnl_msg_config_cmd
nlmsgerr* error_msg = (nlmsgerr *)mnl_nlmsg_get_payload(nlh);
// error code taken from the linux kernel:
// https://elixir.bootlin.com/linux/v5.18.12/source/include/linux/errno.h#L27
#define ENOTSUPP 524 /* Operation is not supported */
if (error_msg->error != -ENOTSUPP) {
_clear();
throw std::invalid_argument( "queueid is already busy" );
}
//END TESTING BIND
nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
nfq_nlmsg_cfg_put_params(nlh, NFQNL_COPY_PACKET, 0xffff);
mnl_attr_put_u32(nlh, NFQA_CFG_FLAGS, htonl(NFQA_CFG_F_GSO));
mnl_attr_put_u32(nlh, NFQA_CFG_MASK, htonl(NFQA_CFG_F_GSO));
if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
_clear();
throw std::runtime_error( "mnl_socket_send" );
}
}
void run(){
/*
* ENOBUFS is signalled to userspace when packets were lost
* on kernel side. In most cases, userspace isn't interested
* in this information, so turn it off.
*/
int ret = 1;
mnl_socket_setsockopt(nl, NETLINK_NO_ENOBUFS, &ret, sizeof(int));
for (;;) {
ret = recv_packet();
if (ret == -1) {
throw std::runtime_error( "mnl_socket_recvfrom" );
}
ret = mnl_cb_run(buf, ret, 0, portid, queue_cb, nl);
if (ret < 0){
throw std::runtime_error( "mnl_cb_run" );
}
}
}
~NetfilterQueue() {
send_config_cmd(NFQNL_CFG_CMD_UNBIND);
_clear();
}
private:
ssize_t send_config_cmd(nfqnl_msg_config_cmds cmd){
struct nlmsghdr *nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
nfq_nlmsg_cfg_put_cmd(nlh, AF_INET, cmd);
return mnl_socket_sendto(nl, nlh, nlh->nlmsg_len);
}
ssize_t recv_packet(){
return mnl_socket_recvfrom(nl, buf, BUF_SIZE);
}
void _clear(){
if (buf != NULL) {
free(buf);
buf = NULL;
}
mnl_socket_close(nl);
}
static int queue_cb(const struct nlmsghdr *nlh, void *data)
{
struct mnl_socket* nl = (struct mnl_socket*)data;
//Extract attributes from the nlmsghdr
struct nlattr *attr[NFQA_MAX+1] = {};
if (nfq_nlmsg_parse(nlh, attr) < 0) {
perror("problems parsing");
return MNL_CB_ERROR;
}
if (attr[NFQA_PACKET_HDR] == NULL) {
fputs("metaheader not set\n", stderr);
return MNL_CB_ERROR;
}
//Get Payload
uint16_t plen = mnl_attr_get_payload_len(attr[NFQA_PAYLOAD]);
void *payload = mnl_attr_get_payload(attr[NFQA_PAYLOAD]);
//Return result to the kernel
struct nfqnl_msg_packet_hdr *ph = (nfqnl_msg_packet_hdr*) mnl_attr_get_payload(attr[NFQA_PACKET_HDR]);
struct nfgenmsg *nfg = (nfgenmsg *)mnl_nlmsg_get_payload(nlh);
char buf[MNL_SOCKET_BUFFER_SIZE];
struct nlmsghdr *nlh_verdict;
struct nlattr *nest;
nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id));
/*
This define allow to avoid to allocate new heap memory for each packet.
The code under this comment is replicated for ipv6 and ip
Better solutions are welcome. :)
*/
#define PKT_HANDLE \
Tins::PDU *transport_layer = find_transport_layer(&packet); \
if(transport_layer->inner_pdu() == nullptr || transport_layer == nullptr){ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
}else{ \
int size = transport_layer->inner_pdu()->size(); \
if(callback_func((const uint8_t*)payload+plen - size, size)){ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
} else{ \
if (transport_layer->pdu_type() == Tins::PDU::TCP){ \
((Tins::TCP *)transport_layer)->release_inner_pdu(); \
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::FIN,1); \
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::ACK,1); \
((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::SYN,0); \
nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
}else{ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_DROP ); \
} \
} \
}
// Check IP protocol version
if ( (((uint8_t*)payload)[0] & 0xf0) == 0x40 ){
Tins::IP packet = Tins::IP((uint8_t*)payload,plen);
PKT_HANDLE
}else{
Tins::IPv6 packet = Tins::IPv6((uint8_t*)payload,plen);
PKT_HANDLE
}
/* example to set the connmark. First, start NFQA_CT section: */
nest = mnl_attr_nest_start(nlh_verdict, NFQA_CT);
/* then, add the connmark attribute: */
mnl_attr_put_u32(nlh_verdict, CTA_MARK, htonl(42));
/* more conntrack attributes, e.g. CTA_LABELS could be set here */
/* end conntrack section */
mnl_attr_nest_end(nlh_verdict, nest);
if (mnl_socket_sendto(nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) {
throw std::runtime_error( "mnl_socket_send" );
}
return MNL_CB_OK;
}
};
template <NetFilterQueueCallback func>
class NFQueueSequence{
private:
std::vector<NetfilterQueue<func> *> nfq;
uint16_t _init;
uint16_t _end;
std::vector<std::thread> threads;
public:
static const int QUEUE_BASE_NUM = 1000;
NFQueueSequence(uint16_t seq_len){
if (seq_len <= 0) throw std::invalid_argument("seq_len <= 0");
nfq = std::vector<NetfilterQueue<func>*>(seq_len);
_init = QUEUE_BASE_NUM;
while(nfq[0] == NULL){
if (_init+seq_len-1 >= 65536){
throw std::runtime_error("NFQueueSequence: too many queues!");
}
for (int i=0;i<seq_len;i++){
try{
nfq[i] = new NetfilterQueue<func>(_init+i);
}catch(const std::invalid_argument e){
for(int j = 0; j < i; j++) {
delete nfq[j];
nfq[j] = nullptr;
}
_init += seq_len - i;
break;
}
}
}
_end = _init + seq_len - 1;
}
void start(){
if (threads.size() != 0) throw std::runtime_error("NFQueueSequence: already started!");
for (int i=0;i<nfq.size();i++){
threads.push_back(std::thread(&NetfilterQueue<func>::run, nfq[i]));
}
}
void join(){
for (int i=0;i<nfq.size();i++){
threads[i].join();
}
threads.clear();
}
uint16_t init(){
return _init;
}
uint16_t end(){
return _end;
}
~NFQueueSequence(){
for (int i=0;i<nfq.size();i++){
delete nfq[i];
}
}
};
#endif // NETFILTER_CLASSES_HPP

View File

@@ -0,0 +1,95 @@
#include <iostream>
#include <cstring>
#include <jpcre2.hpp>
#include <sstream>
#include "../utils.hpp"
#ifndef REGEX_FILTER_HPP
#define REGEX_FILTER_HPP
typedef jpcre2::select<char> jp;
typedef std::pair<std::string,jp::Regex> regex_rule_pair;
typedef std::vector<regex_rule_pair> regex_rule_vector;
struct regex_rules{
regex_rule_vector output_whitelist, input_whitelist, output_blacklist, input_blacklist;
regex_rule_vector* getByCode(char code){
switch(code){
case 'C': // Client to server Blacklist
return &input_blacklist; break;
case 'c': // Client to server Whitelist
return &input_whitelist; break;
case 'S': // Server to client Blacklist
return &output_blacklist; break;
case 's': // Server to client Whitelist
return &output_whitelist; break;
}
throw std::invalid_argument( "Expected 'C' 'c' 'S' or 's'" );
}
int add(const char* arg){
//Integrity checks
size_t arg_len = strlen(arg);
if (arg_len < 2 || arg_len%2 != 0){
std::cerr << "[warning] [regex_rules.add] invalid arg passed (" << arg << "), skipping..." << std::endl;
return -1;
}
if (arg[0] != '0' && arg[0] != '1'){
std::cerr << "[warning] [regex_rules.add] invalid is_case_sensitive (" << arg[0] << ") in '" << arg << "', must be '1' or '0', skipping..." << std::endl;
return -1;
}
if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's'){
std::cerr << "[warning] [regex_rules.add] invalid filter_type (" << arg[1] << ") in '" << arg << "', must be 'C', 'c', 'S' or 's', skipping..." << std::endl;
return -1;
}
std::string hex(arg+2), expr;
if (!unhexlify(hex, expr)){
std::cerr << "[warning] [regex_rules.add] invalid hex regex value (" << hex << "), skipping..." << std::endl;
return -1;
}
//Push regex
jp::Regex regex(expr,arg[0] == '1'?"gS":"giS");
if (regex){
std::cerr << "[info] [regex_rules.add] adding new regex filter: '" << expr << "'" << std::endl;
getByCode(arg[1])->push_back(std::make_pair(std::string(arg), regex));
} else {
std::cerr << "[warning] [regex_rules.add] compiling of '" << expr << "' regex failed, skipping..." << std::endl;
return -1;
}
return 0;
}
bool check(unsigned char* data, const size_t& bytes_transferred, const bool in_input){
std::string str_data((char *) data, bytes_transferred);
for (regex_rule_pair ele:(in_input?input_blacklist:output_blacklist)){
try{
if(ele.second.match(str_data)){
std::stringstream msg;
msg << "BLOCKED " << ele.first << "\n";
std::cout << msg.str() << std::flush;
return false;
}
} catch(...){
std::cerr << "[info] [regex_rules.check] Error while matching blacklist regex: " << ele.first << std::endl;
}
}
for (regex_rule_pair ele:(in_input?input_whitelist:output_whitelist)){
try{
std::cerr << "[debug] [regex_rules.check] regex whitelist match " << ele.second.getPattern() << std::endl;
if(!ele.second.match(str_data)){
std::stringstream msg;
msg << "BLOCKED " << ele.first << "\n";
std::cout << msg.str() << std::flush;
return false;
}
} catch(...){
std::cerr << "[info] [regex_rules.check] Error while matching whitelist regex: " << ele.first << std::endl;
}
}
return true;
}
};
#endif // REGEX_FILTER_HPP

View File

@@ -1,10 +0,0 @@
module main
go 1.18
require github.com/DomySh/go-netfilter-queue v0.0.0-20220713124014-7261f0df2c15
require (
github.com/Jemmic/go-pcre2 v0.0.0-20190111114109-bd52ad5f7098 // indirect
github.com/google/gopacket v1.1.19 // indirect
)

View File

@@ -1,18 +0,0 @@
github.com/DomySh/go-netfilter-queue v0.0.0-20220713124014-7261f0df2c15 h1:6v9D8bG3oR0dJFMuEeEAg8Xwn436Ziv+P7QWS04wAG8=
github.com/DomySh/go-netfilter-queue v0.0.0-20220713124014-7261f0df2c15/go.mod h1:VdJ6kqHln0XlrhuxQM6eBjRIHCzvAMgcZDAtyD/GU5s=
github.com/Jemmic/go-pcre2 v0.0.0-20190111114109-bd52ad5f7098 h1:ZwFIi+5jGJWVrB2V4NvrEhIUy6uDkfnTtBsgj3HAImI=
github.com/Jemmic/go-pcre2 v0.0.0-20190111114109-bd52ad5f7098/go.mod h1:c+8WT1L7lfohb4xMaa3yAV7nlYNepqc2ZV09/CU8R/U=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -1,264 +0,0 @@
package main
import (
"bufio"
"encoding/hex"
"fmt"
"log"
"os"
"os/user"
"strconv"
"strings"
"github.com/DomySh/go-netfilter-queue"
"github.com/Jemmic/go-pcre2"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
const QUEUE_BASE_NUM = 1000
const MAX_PACKET_IN_QUEUE = 100
type regex_pair struct {
regex string
matcher *pcre2.Matcher
}
type regex_filters struct {
input_whitelist []regex_pair
input_blacklist []regex_pair
output_whitelist []regex_pair
output_blacklist []regex_pair
regexes []*pcre2.Regexp
}
func NewRegexFilter() *regex_filters {
res := new(regex_filters)
res.input_blacklist = make([]regex_pair, 0)
res.input_whitelist = make([]regex_pair, 0)
res.output_blacklist = make([]regex_pair, 0)
res.output_whitelist = make([]regex_pair, 0)
res.regexes = make([]*pcre2.Regexp, 0)
return res
}
func (self *regex_filters) add(raw_regex string) {
filter_type := strings.ToLower(raw_regex[0:2])
decoded_regex, err := hex.DecodeString(raw_regex[2:])
if err != nil {
log.Printf("[add] Unable to decode regex '%s': %s", raw_regex, err)
return
}
regex, err := pcre2.Compile(string(decoded_regex), 0)
if err != nil {
log.Printf("[add] Unable to compile regex '%s': %s", string(decoded_regex), err)
return
}
self.regexes = append(self.regexes, regex)
if filter_type[0] == 'i' {
if filter_type[1] == '1' {
self.input_whitelist = append(self.input_whitelist, regex_pair{raw_regex, regex.NewMatcher()})
} else {
self.input_blacklist = append(self.input_blacklist, regex_pair{raw_regex, regex.NewMatcher()})
}
} else {
if filter_type[1] == '1' {
self.output_whitelist = append(self.output_whitelist, regex_pair{raw_regex, regex.NewMatcher()})
} else {
self.output_blacklist = append(self.output_blacklist, regex_pair{raw_regex, regex.NewMatcher()})
}
}
}
func (self *regex_filters) check(data []byte, is_input bool) bool {
if is_input {
for _, rgx := range self.input_blacklist {
if rgx.matcher.Match(data, 0) {
fmt.Printf("BLOCKED %s\n", rgx.regex)
return false
}
}
for _, rgx := range self.input_whitelist {
if !rgx.matcher.Match(data, 0) {
fmt.Printf("BLOCKED %s\n", rgx.regex)
return false
}
}
} else {
for _, rgx := range self.output_blacklist {
if rgx.matcher.Match(data, 0) {
fmt.Printf("BLOCKED %s\n", rgx.regex)
return false
}
}
for _, rgx := range self.output_whitelist {
if !rgx.matcher.Match(data, 0) {
fmt.Printf("BLOCKED %s\n", rgx.regex)
return false
}
}
}
return true
}
func (self *regex_filters) clear() {
for _, rgx := range self.input_whitelist {
rgx.matcher.Free()
}
for _, rgx := range self.input_blacklist {
rgx.matcher.Free()
}
for _, rgx := range self.output_whitelist {
rgx.matcher.Free()
}
for _, rgx := range self.output_blacklist {
rgx.matcher.Free()
}
for _, regex := range self.regexes {
regex.Free()
}
}
func handle_packets(packets <-chan netfilter.NFPacket, filter_table_channel chan regex_filters, is_input bool) {
filter_table := regex_filters{}
for true {
filter := filter_table
select {
case ft := <-filter_table_channel:
{
filter_table = ft
}
case p := <-packets:
{
p.SetVerdict(netfilter.NF_ACCEPT)
break
transport_layer := p.Packet.TransportLayer()
data := transport_layer.LayerPayload()
if len(data) > 0 {
if filter.check(data, is_input) {
p.SetVerdict(netfilter.NF_ACCEPT)
} else {
if transport_layer.LayerType() == layers.LayerTypeTCP {
*p.Packet.ApplicationLayer().(*gopacket.Payload) = []byte{}
transport_layer.(*layers.TCP).Payload = []byte{}
transport_layer.(*layers.TCP).FIN = true
transport_layer.(*layers.TCP).SYN = false
transport_layer.(*layers.TCP).RST = false
transport_layer.(*layers.TCP).ACK = true
transport_layer.(*layers.TCP).SetNetworkLayerForChecksum(p.Packet.NetworkLayer())
buffer := gopacket.NewSerializeBuffer()
options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}
if err := gopacket.SerializePacket(buffer, options, p.Packet); err != nil {
p.SetVerdict(netfilter.NF_DROP)
}
p.SetVerdictWithPacket(netfilter.NF_ACCEPT, buffer.Bytes())
} else {
p.SetVerdict(netfilter.NF_DROP)
}
}
} else {
p.SetVerdict(netfilter.NF_ACCEPT)
}
}
}
}
}
func isRoot() bool {
currentUser, err := user.Current()
if err != nil {
log.Fatalf("[isRoot] Unable to get current user: %s", err)
}
return currentUser.Username == "root"
}
func create_queue_seq(num int) ([]*netfilter.NFQueue, int, int) {
var queue_list = make([]*netfilter.NFQueue, num)
var err error
starts := QUEUE_BASE_NUM
for queue_list[0] == nil {
if starts+num-1 >= 65536 {
log.Fatalf("Netfilter queue is full!")
}
for i := 0; i < len(queue_list); i++ {
queue_list[i], err = netfilter.NewNFQueue(uint16(starts+num-1-i), MAX_PACKET_IN_QUEUE, netfilter.NF_DEFAULT_PACKET_SIZE)
if err != nil {
for j := 0; j < i; j++ {
queue_list[j].Close()
queue_list[j] = nil
}
starts = starts + num - i
break
}
}
}
return queue_list, starts, starts + num - 1
}
func main() {
log.SetOutput(os.Stderr)
if !isRoot() {
log.Fatalf("[main] You must be root to run this program")
}
number_of_queues := 1
if len(os.Args) >= 2 {
var err error
number_of_queues, err = strconv.Atoi(os.Args[1])
if err != nil {
log.Fatalf("[main] Invalid number of queues: %s", err)
}
}
var filter_channels []chan regex_filters
// Start the queue list
queue_list, starts_input, end_input := create_queue_seq(number_of_queues)
for _, queue := range queue_list {
defer queue.Close()
ch := make(chan regex_filters)
filter_channels = append(filter_channels, ch)
go handle_packets(queue.GetPackets(), ch, true)
}
queue_list, starts_output, end_output := create_queue_seq(number_of_queues)
for _, queue := range queue_list {
defer queue.Close()
ch := make(chan regex_filters)
filter_channels = append(filter_channels, ch)
go handle_packets(queue.GetPackets(), ch, false)
}
fmt.Println("QUEUE INPUT", starts_input, end_input, "OUTPUT", starts_output, end_output)
//Reading for new configuration
reader := bufio.NewReader(os.Stdin)
old_filter_table := NewRegexFilter()
for true {
text, err := reader.ReadString('\n')
log.Printf("[main] Regex rule updating...")
if err != nil {
log.Fatalf("[main] Unable to read from stdin: %s", err)
}
text = strings.Trim(text, "\n")
regexes := strings.Split(text, " ")
new_filters := NewRegexFilter()
for _, regex := range regexes {
regex = strings.Trim(regex, " ")
if len(regex) < 2 {
continue
}
new_filters.add(regex)
}
for _, ch := range filter_channels {
ch <- *new_filters
}
old_filter_table.clear()
old_filter_table = new_filters
log.Printf("[main] Regex filter rules updated!")
}
}

View File

@@ -1,421 +1,22 @@
#include <arpa/inet.h> #include "classes/regex_filter.hpp"
#include <type_traits> #include "classes/netfilter.hpp"
#include <tins/tins.h> #include "utils.hpp"
#include <libmnl/libmnl.h>
#include <linux/netfilter.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/types.h>
#include <linux/netfilter/nfnetlink_queue.h>
#include <libnetfilter_queue/libnetfilter_queue.h>
#include <linux/netfilter/nfnetlink_conntrack.h>
#include <stdexcept>
#include <iostream> #include <iostream>
#include <cstring>
#include <cstdlib>
#include <cerrno>
#include <sstream>
#include <thread>
#include <mutex>
#include <jpcre2.hpp>
using namespace std; using namespace std;
using namespace Tins;
typedef jpcre2::select<char> jp;
mutex stdout_mutex;
bool unhexlify(string const &hex, string &newString) {
try{
int len = hex.length();
for(int i=0; i< len; i+=2)
{
std::string byte = hex.substr(i,2);
char chr = (char) (int)strtol(byte.c_str(), NULL, 16);
newString.push_back(chr);
}
return true;
}
catch (...){
return false;
}
}
typedef pair<string,jp::Regex> regex_rule_pair;
typedef vector<regex_rule_pair> regex_rule_vector;
struct regex_rules{
regex_rule_vector output_whitelist, input_whitelist, output_blacklist, input_blacklist;
regex_rule_vector* getByCode(char code){
switch(code){
case 'C': // Client to server Blacklist
return &input_blacklist; break;
case 'c': // Client to server Whitelist
return &input_whitelist; break;
case 'S': // Server to client Blacklist
return &output_blacklist; break;
case 's': // Server to client Whitelist
return &output_whitelist; break;
}
throw invalid_argument( "Expected 'C' 'c' 'S' or 's'" );
}
int add(const char* arg){
//Integrity checks
size_t arg_len = strlen(arg);
if (arg_len < 2 || arg_len%2 != 0){
cerr << "[warning] [regex_rules.add] invalid arg passed (" << arg << "), skipping..." << endl;
return -1;
}
if (arg[0] != '0' && arg[0] != '1'){
cerr << "[warning] [regex_rules.add] invalid is_case_sensitive (" << arg[0] << ") in '" << arg << "', must be '1' or '0', skipping..." << endl;
return -1;
}
if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's'){
cerr << "[warning] [regex_rules.add] invalid filter_type (" << arg[1] << ") in '" << arg << "', must be 'C', 'c', 'S' or 's', skipping..." << endl;
return -1;
}
string hex(arg+2), expr;
if (!unhexlify(hex, expr)){
cerr << "[warning] [regex_rules.add] invalid hex regex value (" << hex << "), skipping..." << endl;
return -1;
}
//Push regex
jp::Regex regex(expr,arg[0] == '1'?"gS":"giS");
if (regex){
cerr << "[info] [regex_rules.add] adding new regex filter: '" << expr << "'" << endl;
getByCode(arg[1])->push_back(make_pair(string(arg), regex));
} else {
cerr << "[warning] [regex_rules.add] compiling of '" << expr << "' regex failed, skipping..." << endl;
return -1;
}
return 0;
}
bool check(unsigned char* data, const size_t& bytes_transferred, const bool in_input){
string str_data((char *) data, bytes_transferred);
for (regex_rule_pair ele:(in_input?input_blacklist:output_blacklist)){
try{
if(ele.second.match(str_data)){
unique_lock<mutex> lck(stdout_mutex);
cout << "BLOCKED " << ele.first << endl;
return false;
}
} catch(...){
cerr << "[info] [regex_rules.check] Error while matching blacklist regex: " << ele.first << endl;
}
}
for (regex_rule_pair ele:(in_input?input_whitelist:output_whitelist)){
try{
cerr << "[debug] [regex_rules.check] regex whitelist match " << ele.second.getPattern() << endl;
if(!ele.second.match(str_data)){
unique_lock<mutex> lck(stdout_mutex);
cout << "BLOCKED " << ele.first << endl;
return false;
}
} catch(...){
cerr << "[info] [regex_rules.check] Error while matching whitelist regex: " << ele.first << endl;
}
}
return true;
}
};
shared_ptr<regex_rules> regex_config; shared_ptr<regex_rules> regex_config;
typedef bool NetFilterQueueCallback(const uint8_t*,uint32_t);
PDU * find_transport_layer(PDU* pkt){
while(pkt != NULL){
if (pkt->pdu_type() == PDU::TCP || pkt->pdu_type() == PDU::UDP) {
return pkt;
}
pkt = pkt->inner_pdu();
}
return pkt;
}
template <NetFilterQueueCallback callback_func>
class NetfilterQueue {
public:
size_t BUF_SIZE = 0xffff + (MNL_SOCKET_BUFFER_SIZE/2);
char *buf = NULL;
unsigned int portid;
u_int16_t queue_num;
struct mnl_socket* nl = NULL;
NetfilterQueue(u_int16_t queue_num): queue_num(queue_num) {
nl = mnl_socket_open(NETLINK_NETFILTER);
if (nl == NULL) { throw runtime_error( "mnl_socket_open" );}
if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
mnl_socket_close(nl);
throw runtime_error( "mnl_socket_bind" );
}
portid = mnl_socket_get_portid(nl);
buf = (char*) malloc(BUF_SIZE);
if (!buf) {
mnl_socket_close(nl);
throw runtime_error( "allocate receive buffer" );
}
if (send_config_cmd(NFQNL_CFG_CMD_BIND) < 0) {
_clear();
throw runtime_error( "mnl_socket_send" );
}
//TEST if BIND was successful
if (send_config_cmd(NFQNL_CFG_CMD_NONE) < 0) { // SEND A NONE cmmand to generate an error meessage
_clear();
throw runtime_error( "mnl_socket_send" );
}
if (recv_packet() == -1) { //RECV the error message
_clear();
throw std::runtime_error( "mnl_socket_recvfrom" );
}
struct nlmsghdr *nlh = (struct nlmsghdr *) buf;
if (nlh->nlmsg_type != NLMSG_ERROR) {
_clear();
throw runtime_error( "unexpected packet from kernel (expected NLMSG_ERROR packet)" );
}
//nfqnl_msg_config_cmd
nlmsgerr* error_msg = (nlmsgerr *)mnl_nlmsg_get_payload(nlh);
// error code taken from the linux kernel:
// https://elixir.bootlin.com/linux/v5.18.12/source/include/linux/errno.h#L27
#define ENOTSUPP 524 /* Operation is not supported */
if (error_msg->error != -ENOTSUPP) {
_clear();
throw std::invalid_argument( "queueid is already busy" );
}
//END TESTING BIND
nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
nfq_nlmsg_cfg_put_params(nlh, NFQNL_COPY_PACKET, 0xffff);
mnl_attr_put_u32(nlh, NFQA_CFG_FLAGS, htonl(NFQA_CFG_F_GSO));
mnl_attr_put_u32(nlh, NFQA_CFG_MASK, htonl(NFQA_CFG_F_GSO));
if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
_clear();
throw runtime_error( "mnl_socket_send" );
}
}
void run(){
/*
* ENOBUFS is signalled to userspace when packets were lost
* on kernel side. In most cases, userspace isn't interested
* in this information, so turn it off.
*/
int ret = 1;
mnl_socket_setsockopt(nl, NETLINK_NO_ENOBUFS, &ret, sizeof(int));
for (;;) {
ret = recv_packet();
if (ret == -1) {
throw std::runtime_error( "mnl_socket_recvfrom" );
}
ret = mnl_cb_run(buf, ret, 0, portid, queue_cb, nl);
if (ret < 0){
throw std::runtime_error( "mnl_cb_run" );
}
}
}
~NetfilterQueue() {
send_config_cmd(NFQNL_CFG_CMD_UNBIND);
_clear();
}
private:
ssize_t send_config_cmd(nfqnl_msg_config_cmds cmd){
struct nlmsghdr *nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num);
nfq_nlmsg_cfg_put_cmd(nlh, AF_INET, cmd);
return mnl_socket_sendto(nl, nlh, nlh->nlmsg_len);
}
ssize_t recv_packet(){
return mnl_socket_recvfrom(nl, buf, BUF_SIZE);
}
void _clear(){
if (buf != NULL) {
free(buf);
buf = NULL;
}
mnl_socket_close(nl);
}
static int queue_cb(const struct nlmsghdr *nlh, void *data)
{
struct mnl_socket* nl = (struct mnl_socket*)data;
//Extract attributes from the nlmsghdr
struct nlattr *attr[NFQA_MAX+1] = {};
if (nfq_nlmsg_parse(nlh, attr) < 0) {
perror("problems parsing");
return MNL_CB_ERROR;
}
if (attr[NFQA_PACKET_HDR] == NULL) {
fputs("metaheader not set\n", stderr);
return MNL_CB_ERROR;
}
//Get Payload
uint16_t plen = mnl_attr_get_payload_len(attr[NFQA_PAYLOAD]);
void *payload = mnl_attr_get_payload(attr[NFQA_PAYLOAD]);
//Return result to the kernel
struct nfqnl_msg_packet_hdr *ph = (nfqnl_msg_packet_hdr*) mnl_attr_get_payload(attr[NFQA_PACKET_HDR]);
struct nfgenmsg *nfg = (nfgenmsg *)mnl_nlmsg_get_payload(nlh);
char buf[MNL_SOCKET_BUFFER_SIZE];
struct nlmsghdr *nlh_verdict;
struct nlattr *nest;
nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id));
/*
This define allow to avoid to allocate new heap memory for each packet.
The code under this comment is replicated for ipv6 and ip
Better solutions are welcome. :)
*/
#define PKT_HANDLE \
PDU *transport_layer = find_transport_layer(&packet); \
if(transport_layer->inner_pdu() == nullptr || transport_layer == nullptr){ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
}else{ \
int size = transport_layer->inner_pdu()->size(); \
if(callback_func((const uint8_t*)payload+plen - size, size)){ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
} else{ \
if (transport_layer->pdu_type() == PDU::TCP){ \
((TCP *)transport_layer)->release_inner_pdu(); \
((TCP *)transport_layer)->set_flag(TCP::FIN,1); \
((TCP *)transport_layer)->set_flag(TCP::ACK,1); \
((TCP *)transport_layer)->set_flag(TCP::SYN,0); \
nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \
}else{ \
nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_DROP ); \
} \
} \
}
// Check IP protocol version
if ( (((uint8_t*)payload)[0] & 0xf0) == 0x40 ){
IP packet = IP((uint8_t*)payload,plen);
PKT_HANDLE
}else{
IPv6 packet = IPv6((uint8_t*)payload,plen);
PKT_HANDLE
}
/* example to set the connmark. First, start NFQA_CT section: */
nest = mnl_attr_nest_start(nlh_verdict, NFQA_CT);
/* then, add the connmark attribute: */
mnl_attr_put_u32(nlh_verdict, CTA_MARK, htonl(42));
/* more conntrack attributes, e.g. CTA_LABELS could be set here */
/* end conntrack section */
mnl_attr_nest_end(nlh_verdict, nest);
if (mnl_socket_sendto(nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) {
throw std::runtime_error( "mnl_socket_send" );
}
return MNL_CB_OK;
}
};
template <NetFilterQueueCallback func>
class NFQueueSequence{
private:
vector<NetfilterQueue<func> *> nfq;
uint16_t _init;
uint16_t _end;
vector<thread> threads;
public:
static const int QUEUE_BASE_NUM = 1000;
NFQueueSequence(uint16_t seq_len){
if (seq_len <= 0) throw invalid_argument("seq_len <= 0");
nfq = vector<NetfilterQueue<func>*>(seq_len);
_init = QUEUE_BASE_NUM;
while(nfq[0] == NULL){
if (_init+seq_len-1 >= 65536){
throw runtime_error("NFQueueSequence: too many queues!");
}
for (int i=0;i<seq_len;i++){
try{
nfq[i] = new NetfilterQueue<func>(_init+i);
}catch(const invalid_argument e){
for(int j = 0; j < i; j++) {
delete nfq[j];
nfq[j] = nullptr;
}
_init += seq_len - i;
break;
}
}
}
_end = _init + seq_len - 1;
}
void start(){
if (threads.size() != 0) throw runtime_error("NFQueueSequence: already started!");
for (int i=0;i<nfq.size();i++){
threads.push_back(thread(&NetfilterQueue<func>::run, nfq[i]));
}
}
void join(){
for (int i=0;i<nfq.size();i++){
threads[i].join();
}
threads.clear();
}
uint16_t init(){
return _init;
}
uint16_t end(){
return _end;
}
~NFQueueSequence(){
for (int i=0;i<nfq.size();i++){
delete nfq[i];
}
}
};
bool is_sudo(){
return getuid() == 0;
}
void config_updater (){ void config_updater (){
string line, data; string line, data;
while (true){ while (true){
getline(cin, line); getline(cin, line);
if (cin.eof()){
cerr << "[fatal] [upfdater] cin.eof()" << endl;
exit(EXIT_FAILURE);
}
if (cin.bad()){ if (cin.bad()){
cerr << "[fatal] [upfdater] cin.bad() != 0" << endl; cerr << "[fatal] [upfdater] cin.bad()" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
cerr << "[info] [updater] Updating configuration with line " << line << endl; cerr << "[info] [updater] Updating configuration with line " << line << endl;
@@ -423,7 +24,9 @@ void config_updater (){
regex_rules *regex_new_config = new regex_rules(); regex_rules *regex_new_config = new regex_rules();
while(!config_stream.eof()){ while(!config_stream.eof()){
config_stream >> data; config_stream >> data;
regex_new_config->add(data.c_str()); if (data != "" && data != "\n"){
regex_new_config->add(data.c_str());
}
} }
regex_config.reset(regex_new_config); regex_config.reset(regex_new_config);
cerr << "[info] [updater] Config update done" << endl; cerr << "[info] [updater] Config update done" << endl;
@@ -456,21 +59,3 @@ int main(int argc, char *argv[])
config_updater(); config_updater();
} }
/*
libpcre2-dev
libnetfilter-queue-dev
libtins-dev
libmnl-dev
c++ nfqueue.cpp -o nfqueue -pthread -lpcre2-8 -ltins -lnetfilter_queue -lmnl
WORKDIR /tmp/
RUN git clone --branch release https://github.com/jpcre2/jpcre2
WORKDIR /tmp/jpcre2
RUN ./configure; make; make install
WORKDIR /
*/

28
backend/nfqueue/utils.hpp Normal file
View File

@@ -0,0 +1,28 @@
#include <string>
#include <unistd.h>
#ifndef UTILS_HPP
#define UTILS_HPP
bool unhexlify(std::string const &hex, std::string &newString) {
try{
int len = hex.length();
for(int i=0; i< len; i+=2)
{
std::string byte = hex.substr(i,2);
char chr = (char) (int)strtol(byte.c_str(), NULL, 16);
newString.push_back(chr);
}
return true;
}
catch (...){
return false;
}
}
bool is_sudo(){
return getuid() == 0;
}
#endif