From 2a5be65feb9c6aad8a680bd064dcd140b9aa6e69 Mon Sep 17 00:00:00 2001 From: DomySh Date: Mon, 18 Jul 2022 18:52:14 +0200 Subject: [PATCH] Python integration with c++ binary (not totally working yet) --- .gitignore | 6 +- Dockerfile | 4 +- backend/app.py | 2 +- backend/modules/firegex.py | 205 +++++++---- backend/modules/firewall.py | 97 +---- backend/nfqueue/classes/netfilter.hpp | 294 +++++++++++++++ backend/nfqueue/classes/regex_filter.hpp | 95 +++++ backend/nfqueue/go.mod | 10 - backend/nfqueue/go.sum | 18 - backend/nfqueue/main.go | 264 -------------- backend/nfqueue/nfqueue.cpp | 437 +---------------------- backend/nfqueue/utils.hpp | 28 ++ 12 files changed, 594 insertions(+), 866 deletions(-) create mode 100644 backend/nfqueue/classes/netfilter.hpp create mode 100644 backend/nfqueue/classes/regex_filter.hpp delete mode 100644 backend/nfqueue/go.mod delete mode 100644 backend/nfqueue/go.sum delete mode 100644 backend/nfqueue/main.go create mode 100644 backend/nfqueue/utils.hpp diff --git a/.gitignore b/.gitignore index dbef8d4..0eb1867 100755 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,9 @@ **/*.pyc **/__pycache__/ **/.vscode/** -/.mypy_cache/** +**/.vscode/ +**/.mypy_cache/** +**/.mypy_cache/ **/node_modules **/.pnp @@ -12,7 +14,7 @@ /backend/db/firegex.db /backend/db/firegex.db-journal -/backend/nfqueue/nfqueue +/backend/modules/cppqueue docker-compose.yml # misc diff --git a/Dockerfile b/Dockerfile index 8f0f113..fba2f7c 100755 --- a/Dockerfile +++ b/Dockerfile @@ -11,12 +11,12 @@ RUN git clone --branch release https://github.com/jpcre2/jpcre2 WORKDIR /tmp/jpcre2 RUN ./configure; make; make install -RUN mkdir /execute/ +RUN mkdir -p /execute/modules WORKDIR /execute COPY ./backend/nfqueue /execute/nfqueue -RUN g++ nfqueue/nfqueue.cpp -o nfqueue/nfqueue -O3 -march=native -lnetfilter_queue -pthread -lpcre2-8 -ltins -lmnl -lnfnetlink +RUN g++ nfqueue/nfqueue.cpp -o modules/cppqueue -std=c++20 -O3 -march=native -lnetfilter_queue -pthread -lpcre2-8 -ltins -lmnl -lnfnetlink ADD ./backend/requirements.txt /execute/requirements.txt RUN pip install --no-cache-dir -r /execute/requirements.txt diff --git a/backend/app.py b/backend/app.py index eaea55e..1a4f2f0 100644 --- a/backend/app.py +++ b/backend/app.py @@ -48,7 +48,7 @@ async def updater(): pass @app.on_event("startup") async def startup_event(): db.init() - await firewall.init(refresh_frontend) + await firewall.init() await refresh_frontend() if not JWT_SECRET(): db.put("secret", secrets.token_hex(32)) diff --git a/backend/modules/firegex.py b/backend/modules/firegex.py index 64453df..ebff9f1 100644 --- a/backend/modules/firegex.py +++ b/backend/modules/firegex.py @@ -1,12 +1,11 @@ -from typing import List -from pypacker import interceptor -from pypacker.layer3 import ip, ip6 -from pypacker.layer4 import tcp, udp +from typing import Dict, List, Set from ipaddress import ip_interface from modules.iptables import IPTables -import os, traceback - from modules.sqlite import Service +import re, os, asyncio +import traceback + +from modules.sqlite import Regex class FilterTypes: INPUT = "FIREGEX-INPUT" @@ -15,14 +14,13 @@ class FilterTypes: QUEUE_BASE_NUM = 1000 class FiregexFilter(): - def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None, func=None): + def __init__(self, proto:str, port:int, ip_int:str, queue=None, target=None, id=None): self.target = target self.id = int(id) if id else None self.queue = queue self.proto = proto self.port = int(port) self.ip_int = str(ip_int) - self.func = func def __eq__(self, o: object) -> bool: if isinstance(o, FiregexFilter): @@ -35,16 +33,6 @@ class FiregexFilter(): def ipv4(self): return ip_interface(self.ip_int).version == 4 - def input_func(self): - def none(pkt): return True - def wrap(pkt): return self.func(pkt, True) - return wrap if self.func else none - - def output_func(self): - def none(pkt): return True - def wrap(pkt): return self.func(pkt, False) - return wrap if self.func else none - class FiregexTables(IPTables): def __init__(self, ipv6=False): @@ -108,9 +96,9 @@ class FiregexTables(IPTables): )) return res - def add(self, filter:FiregexFilter): + async def add(self, filter:FiregexFilter): if filter in self.get(): return None - return FiregexInterceptor( iptables=self, filter=filter, n_threads=int(os.getenv("N_THREADS_NFQUEUE","1"))) + return await FiregexInterceptor.start( iptables=self, filter=filter, n_queues=int(os.getenv("N_THREADS_NFQUEUE","1"))) def delete_all(self): for filter_type in [FilterTypes.INPUT, FilterTypes.OUTPUT]: @@ -120,52 +108,143 @@ class FiregexTables(IPTables): for filter in self.get(): if filter.port == srv.port and filter.proto == srv.proto and ip_interface(filter.ip_int) == ip_interface(srv.ip_int): self.delete_rule(filter.target, filter.id) + + +class RegexFilter: + def __init__( + self, regex, + is_case_sensitive=True, + is_blacklist=True, + input_mode=False, + output_mode=False, + blocked_packets=0, + id=None, + update_func = None + ): + self.regex = regex + self.is_case_sensitive = is_case_sensitive + self.is_blacklist = is_blacklist + if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True) + self.input_mode = input_mode + self.output_mode = output_mode + self.blocked = blocked_packets + self.id = id + self.update_func = update_func + self.compiled_regex = self.compile() + + @classmethod + def from_regex(cls, regex:Regex, update_func = None): + return cls( + id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive, + is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets, + input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"], + update_func = update_func + ) + def compile(self): + if isinstance(self.regex, str): self.regex = self.regex.encode() + if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether") + re.compile(self.regex) # raise re.error if it's invalid! + case_sensitive = "1" if self.is_case_sensitive else "0" + if self.input_mode: + yield case_sensitive + "C" + self.regex.hex() if self.is_blacklist else case_sensitive + "c"+ self.regex.hex() + if self.output_mode: + yield case_sensitive + "S" + self.regex.hex() if self.is_blacklist else case_sensitive + "s"+ self.regex.hex() + + async def update(self): + if self.update_func: + if asyncio.iscoroutinefunction(self.update_func): await self.update_func(self) + else: self.update_func(self) class FiregexInterceptor: - def __init__(self, iptables: FiregexTables, filter: FiregexFilter, n_threads:int = 1): + + def __init__(self): + self.filter:FiregexFilter + self.ipv6:bool + self.filter_map_lock:asyncio.Lock + self.filter_map: Dict[str, RegexFilter] + self.regex_filters: Set[RegexFilter] + self.update_config_lock:asyncio.Lock + self.process:asyncio.subprocess.Process + self.n_queues:int + self.update_task: asyncio.Task + self.iptables:FiregexTables + + @classmethod + async def start(cls, iptables: FiregexTables, filter: FiregexFilter, n_queues:int = 1): + self = cls() self.filter = filter + self.n_queues = n_queues + self.iptables = iptables self.ipv6 = self.filter.ipv6() - self.itor_input, codes = self._start_queue(filter.input_func(), n_threads) - iptables.add_input(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) - self.itor_output, codes = self._start_queue(filter.output_func(), n_threads) - iptables.add_output(queue_range=codes, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) + self.filter_map_lock = asyncio.Lock() + self.update_config_lock = asyncio.Lock() + input_range, output_range = await self._start_binary() + self.update_task = asyncio.create_task(self.update_blocked()) + self.iptables.add_input(queue_range=input_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) + self.iptables.add_output(queue_range=output_range, proto=self.filter.proto, port=self.filter.port, ip_int=self.filter.ip_int) + return self + + async def _start_binary(self): + proxy_binary_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),"./cppqueue") + self.process = await asyncio.create_subprocess_exec( + proxy_binary_path, str(self.n_queues), + stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE + ) + line_fut = self.process.stdout.readuntil() + try: + line_fut = await asyncio.wait_for(line_fut, timeout=1) + except asyncio.TimeoutError: + self.process.kill() + raise Exception("Invalid binary output") + line = line_fut.decode() + if line.startswith("QUEUES "): + params = line.split() + return (int(params[2]), int(params[3])), (int(params[5]), int(params[6])) + else: + self.process.kill() + raise Exception("Invalid binary output") - def _start_queue(self,func,n_threads): - def func_wrap(ll_data, ll_proto_id, data, ctx, *args): - pkt_parsed = ip6.IP6(data) if self.ipv6 else ip.IP(data) - try: - pkt_data = None - if not pkt_parsed[tcp.TCP] is None: - pkt_data = pkt_parsed[tcp.TCP].body_bytes - elif not pkt_parsed[udp.UDP] is None: - pkt_data = pkt_parsed[udp.UDP].body_bytes - if pkt_data: - if func(pkt_data): - return data, interceptor.NF_ACCEPT - elif pkt_parsed[tcp.TCP]: - pkt_parsed[tcp.TCP].flags &= 0x00 - pkt_parsed[tcp.TCP].flags |= tcp.TH_FIN | tcp.TH_ACK - pkt_parsed[tcp.TCP].body_bytes = b"" - return pkt_parsed.bin(), interceptor.NF_ACCEPT - else: return b"", interceptor.NF_DROP - else: return data, interceptor.NF_ACCEPT - except Exception: - traceback.print_exc() - return data, interceptor.NF_ACCEPT - - ictor = interceptor.Interceptor() - starts = QUEUE_BASE_NUM - while True: - if starts >= 65536: - raise Exception("Netfilter queue is full!") - queue_ids = list(range(starts,starts+n_threads)) - try: - ictor.start(func_wrap, queue_ids=queue_ids) - break - except interceptor.UnableToBindException as e: - starts = e.queue_id + 1 - return ictor, (starts, starts+n_threads-1) + async def update_blocked(self): + try: + while True: + line = (await self.process.stdout.readuntil()).decode() + if line.startswith("BLOCKED"): + regex_id = line.split()[1] + async with self.filter_map_lock: + if regex_id in self.filter_map: + self.filter_map[regex_id].blocked+=1 + await self.filter_map[regex_id].update() + except asyncio.CancelledError: pass + except asyncio.IncompleteReadError: pass + except Exception: + traceback.print_exc() - def stop(self): - self.itor_input.stop() - self.itor_output.stop() \ No newline at end of file + async def stop(self): + self.update_task.cancel() + self.process.kill() + + async def _update_config(self, filters_codes): + async with self.update_config_lock: + self.process.stdin.write((" ".join(filters_codes)+"\n").encode()) + await self.process.stdin.drain() + + async def reload(self, filters:List[RegexFilter]): + async with self.filter_map_lock: + self.filter_map = self.compile_filters(filters) + filters_codes = self.get_filter_codes() + await self._update_config(filters_codes) + + def get_filter_codes(self): + filters_codes = list(self.filter_map.keys()) + filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True) + return filters_codes + + def compile_filters(self, filters:List[RegexFilter]): + res = {} + for filter_obj in filters: + try: + raw_filters = filter_obj.compile() + for filter in raw_filters: + res[filter] = filter_obj + except Exception: pass + return res \ No newline at end of file diff --git a/backend/modules/firewall.py b/backend/modules/firewall.py index 3a76673..582b159 100644 --- a/backend/modules/firewall.py +++ b/backend/modules/firewall.py @@ -1,6 +1,6 @@ -import traceback, asyncio, pcre +import traceback, asyncio from typing import Dict -from modules.firegex import FiregexFilter, FiregexTables +from modules.firegex import FiregexFilter, FiregexTables, RegexFilter from modules.sqlite import Regex, SQLite, Service class STATUS: @@ -12,17 +12,8 @@ class FirewallManager: self.db = db self.proxy_table: Dict[str, ServiceManager] = {} self.lock = asyncio.Lock() - self.updater_task = None - - def init_updater(self, callback = None): - if not self.updater_task: - self.updater_task = asyncio.create_task(self._stats_updater(callback)) - - def close_updater(self): - if self.updater_task: self.updater_task.cancel() async def close(self): - self.close_updater() if self.updater_task: self.updater_task.cancel() for key in list(self.proxy_table.keys()): await self.remove(key) @@ -33,8 +24,7 @@ class FirewallManager: await self.proxy_table[srv_id].next(STATUS.STOP) del self.proxy_table[srv_id] - async def init(self, callback = None): - self.init_updater(callback) + async def init(self): await self.reload() async def reload(self): @@ -43,7 +33,6 @@ class FirewallManager: srv = Service.from_dict(srv) if srv.id in self.proxy_table: continue - self.proxy_table[srv.id] = ServiceManager(srv, self.db) await self.proxy_table[srv.id].next(srv.status) @@ -71,42 +60,6 @@ class FirewallManager: class ServiceNotFoundException(Exception): pass -class RegexFilter: - def __init__( - self, regex, - is_case_sensitive=True, - is_blacklist=True, - input_mode=False, - output_mode=False, - blocked_packets=0, - id=None - ): - self.regex = regex - self.is_case_sensitive = is_case_sensitive - self.is_blacklist = is_blacklist - if input_mode == output_mode: input_mode = output_mode = True # (False, False) == (True, True) - self.input_mode = input_mode - self.output_mode = output_mode - self.blocked = blocked_packets - self.id = id - self.compiled_regex = self.compile() - - @classmethod - def from_regex(cls, regex:Regex): - return cls( - id=regex.id, regex=regex.regex, is_case_sensitive=regex.is_case_sensitive, - is_blacklist=regex.is_blacklist, blocked_packets=regex.blocked_packets, - input_mode = regex.mode in ["C","B"], output_mode=regex.mode in ["S","B"] - ) - - def compile(self): - if isinstance(self.regex, str): self.regex = self.regex.encode() - if not isinstance(self.regex, bytes): raise Exception("Invalid Regex Paramether") - return pcre.compile(self.regex if self.is_case_sensitive else b"(?i)"+self.regex) - - def check(self, data): - return True if self.compiled_regex.search(data) else False - class ServiceManager: def __init__(self, srv: Service, db): self.srv = srv @@ -114,12 +67,10 @@ class ServiceManager: self.firegextable = FiregexTables(self.srv.ipv6) self.status = STATUS.STOP self.filters: Dict[int, FiregexFilter] = {} - self._update_filters_from_db() self.lock = asyncio.Lock() self.interceptor = None - # TODO I don't like so much this method - def _update_filters_from_db(self): + async def _update_filters_from_db(self): regexes = [ Regex.from_dict(ele) for ele in self.db.query("SELECT * FROM regexes WHERE service_id = ? AND active=1;", self.srv.id) @@ -127,17 +78,16 @@ class ServiceManager: #Filter check old_filters = set(self.filters.keys()) new_filters = set([f.id for f in regexes]) - #remove old filters for f in old_filters: if not f in new_filters: del self.filters[f] - #add new filters for f in new_filters: if not f in old_filters: filter = [ele for ele in regexes if ele.id == f][0] - self.filters[f] = RegexFilter.from_regex(filter) + self.filters[f] = RegexFilter.from_regex(filter, self._stats_updater) + if self.interceptor: await self.interceptor.reload(self.filters.values()) def __update_status_db(self, status): self.db.query("UPDATE services SET status = ? WHERE service_id = ?;", status, self.srv.id) @@ -145,49 +95,36 @@ class ServiceManager: async def next(self,to): async with self.lock: if (self.status, to) == (STATUS.ACTIVE, STATUS.STOP): - self.stop() + await self.stop() self._set_status(to) # PAUSE -> ACTIVE elif (self.status, to) == (STATUS.STOP, STATUS.ACTIVE): - self.restart() + await self.restart() def _stats_updater(self,filter:RegexFilter): self.db.query("UPDATE regexes SET blocked_packets = ? WHERE regex_id = ?;", filter.blocked, filter.id) - - def update_stats(self): - for ele in self.filters.values(): - self._stats_updater(ele) def _set_status(self,status): self.status = status self.__update_status_db(status) - def start(self): + async def start(self): if not self.interceptor: self.firegextable.delete_by_srv(self.srv) - def regex_filter(pkt, by_client): - try: - for filter in self.filters.values(): - if (by_client and filter.input_mode) or (not by_client and filter.output_mode): - match = filter.check(pkt) - if (filter.is_blacklist and match) or (not filter.is_blacklist and not match): - filter.blocked+=1 - return False - except IndexError: pass - return True - self.interceptor = self.firegextable.add(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int, func=regex_filter)) + self.interceptor = await self.firegextable.add(FiregexFilter(self.srv.proto,self.srv.port, self.srv.ip_int)) + await self._update_filters_from_db() self._set_status(STATUS.ACTIVE) - def stop(self): + async def stop(self): self.firegextable.delete_by_srv(self.srv) if self.interceptor: - self.interceptor.stop() + await self.interceptor.stop() self.interceptor = None - def restart(self): - self.stop() - self.start() + async def restart(self): + await self.stop() + await self.start() async def update_filters(self): async with self.lock: - self._update_filters_from_db() \ No newline at end of file + await self._update_filters_from_db() \ No newline at end of file diff --git a/backend/nfqueue/classes/netfilter.hpp b/backend/nfqueue/classes/netfilter.hpp new file mode 100644 index 0000000..f485e55 --- /dev/null +++ b/backend/nfqueue/classes/netfilter.hpp @@ -0,0 +1,294 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef NETFILTER_CLASSES_HPP +#define NETFILTER_CLASSES_HPP + +typedef bool NetFilterQueueCallback(const uint8_t*,uint32_t); + +Tins::PDU * find_transport_layer(Tins::PDU* pkt){ + while(pkt != NULL){ + if (pkt->pdu_type() == Tins::PDU::TCP || pkt->pdu_type() == Tins::PDU::UDP) { + return pkt; + } + pkt = pkt->inner_pdu(); + } + return pkt; +} + +template +class NetfilterQueue { + public: + size_t BUF_SIZE = 0xffff + (MNL_SOCKET_BUFFER_SIZE/2); + char *buf = NULL; + unsigned int portid; + u_int16_t queue_num; + struct mnl_socket* nl = NULL; + + NetfilterQueue(u_int16_t queue_num): queue_num(queue_num) { + + nl = mnl_socket_open(NETLINK_NETFILTER); + + if (nl == NULL) { throw std::runtime_error( "mnl_socket_open" );} + + if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) { + mnl_socket_close(nl); + throw std::runtime_error( "mnl_socket_bind" ); + } + portid = mnl_socket_get_portid(nl); + + buf = (char*) malloc(BUF_SIZE); + + if (!buf) { + mnl_socket_close(nl); + throw std::runtime_error( "allocate receive buffer" ); + } + + if (send_config_cmd(NFQNL_CFG_CMD_BIND) < 0) { + _clear(); + throw std::runtime_error( "mnl_socket_send" ); + } + //TEST if BIND was successful + if (send_config_cmd(NFQNL_CFG_CMD_NONE) < 0) { // SEND A NONE cmmand to generate an error meessage + _clear(); + throw std::runtime_error( "mnl_socket_send" ); + } + if (recv_packet() == -1) { //RECV the error message + _clear(); + throw std::runtime_error( "mnl_socket_recvfrom" ); + } + + struct nlmsghdr *nlh = (struct nlmsghdr *) buf; + + if (nlh->nlmsg_type != NLMSG_ERROR) { + _clear(); + throw std::runtime_error( "unexpected packet from kernel (expected NLMSG_ERROR packet)" ); + } + //nfqnl_msg_config_cmd + nlmsgerr* error_msg = (nlmsgerr *)mnl_nlmsg_get_payload(nlh); + + // error code taken from the linux kernel: + // https://elixir.bootlin.com/linux/v5.18.12/source/include/linux/errno.h#L27 + #define ENOTSUPP 524 /* Operation is not supported */ + + if (error_msg->error != -ENOTSUPP) { + _clear(); + throw std::invalid_argument( "queueid is already busy" ); + } + + //END TESTING BIND + nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num); + nfq_nlmsg_cfg_put_params(nlh, NFQNL_COPY_PACKET, 0xffff); + + + mnl_attr_put_u32(nlh, NFQA_CFG_FLAGS, htonl(NFQA_CFG_F_GSO)); + mnl_attr_put_u32(nlh, NFQA_CFG_MASK, htonl(NFQA_CFG_F_GSO)); + + if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) { + _clear(); + throw std::runtime_error( "mnl_socket_send" ); + } + + } + + + + void run(){ + /* + * ENOBUFS is signalled to userspace when packets were lost + * on kernel side. In most cases, userspace isn't interested + * in this information, so turn it off. + */ + int ret = 1; + mnl_socket_setsockopt(nl, NETLINK_NO_ENOBUFS, &ret, sizeof(int)); + + for (;;) { + ret = recv_packet(); + if (ret == -1) { + throw std::runtime_error( "mnl_socket_recvfrom" ); + } + + ret = mnl_cb_run(buf, ret, 0, portid, queue_cb, nl); + if (ret < 0){ + throw std::runtime_error( "mnl_cb_run" ); + } + } + } + + ~NetfilterQueue() { + send_config_cmd(NFQNL_CFG_CMD_UNBIND); + _clear(); + } + private: + + ssize_t send_config_cmd(nfqnl_msg_config_cmds cmd){ + struct nlmsghdr *nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num); + nfq_nlmsg_cfg_put_cmd(nlh, AF_INET, cmd); + return mnl_socket_sendto(nl, nlh, nlh->nlmsg_len); + } + + ssize_t recv_packet(){ + return mnl_socket_recvfrom(nl, buf, BUF_SIZE); + } + + void _clear(){ + if (buf != NULL) { + free(buf); + buf = NULL; + } + mnl_socket_close(nl); + } + + static int queue_cb(const struct nlmsghdr *nlh, void *data) + { + struct mnl_socket* nl = (struct mnl_socket*)data; + //Extract attributes from the nlmsghdr + struct nlattr *attr[NFQA_MAX+1] = {}; + + if (nfq_nlmsg_parse(nlh, attr) < 0) { + perror("problems parsing"); + return MNL_CB_ERROR; + } + if (attr[NFQA_PACKET_HDR] == NULL) { + fputs("metaheader not set\n", stderr); + return MNL_CB_ERROR; + } + //Get Payload + uint16_t plen = mnl_attr_get_payload_len(attr[NFQA_PAYLOAD]); + void *payload = mnl_attr_get_payload(attr[NFQA_PAYLOAD]); + + //Return result to the kernel + struct nfqnl_msg_packet_hdr *ph = (nfqnl_msg_packet_hdr*) mnl_attr_get_payload(attr[NFQA_PACKET_HDR]); + struct nfgenmsg *nfg = (nfgenmsg *)mnl_nlmsg_get_payload(nlh); + char buf[MNL_SOCKET_BUFFER_SIZE]; + struct nlmsghdr *nlh_verdict; + struct nlattr *nest; + + nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id)); + + /* + This define allow to avoid to allocate new heap memory for each packet. + The code under this comment is replicated for ipv6 and ip + Better solutions are welcome. :) + */ + #define PKT_HANDLE \ + Tins::PDU *transport_layer = find_transport_layer(&packet); \ + if(transport_layer->inner_pdu() == nullptr || transport_layer == nullptr){ \ + nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \ + }else{ \ + int size = transport_layer->inner_pdu()->size(); \ + if(callback_func((const uint8_t*)payload+plen - size, size)){ \ + nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \ + } else{ \ + if (transport_layer->pdu_type() == Tins::PDU::TCP){ \ + ((Tins::TCP *)transport_layer)->release_inner_pdu(); \ + ((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::FIN,1); \ + ((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::ACK,1); \ + ((Tins::TCP *)transport_layer)->set_flag(Tins::TCP::SYN,0); \ + nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); \ + nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \ + }else{ \ + nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_DROP ); \ + } \ + } \ + } + + // Check IP protocol version + if ( (((uint8_t*)payload)[0] & 0xf0) == 0x40 ){ + Tins::IP packet = Tins::IP((uint8_t*)payload,plen); + PKT_HANDLE + }else{ + Tins::IPv6 packet = Tins::IPv6((uint8_t*)payload,plen); + PKT_HANDLE + } + + /* example to set the connmark. First, start NFQA_CT section: */ + nest = mnl_attr_nest_start(nlh_verdict, NFQA_CT); + + /* then, add the connmark attribute: */ + mnl_attr_put_u32(nlh_verdict, CTA_MARK, htonl(42)); + /* more conntrack attributes, e.g. CTA_LABELS could be set here */ + + /* end conntrack section */ + mnl_attr_nest_end(nlh_verdict, nest); + + if (mnl_socket_sendto(nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) { + throw std::runtime_error( "mnl_socket_send" ); + } + + return MNL_CB_OK; + } + +}; + +template +class NFQueueSequence{ + private: + std::vector *> nfq; + uint16_t _init; + uint16_t _end; + std::vector threads; + public: + static const int QUEUE_BASE_NUM = 1000; + + NFQueueSequence(uint16_t seq_len){ + if (seq_len <= 0) throw std::invalid_argument("seq_len <= 0"); + nfq = std::vector*>(seq_len); + _init = QUEUE_BASE_NUM; + while(nfq[0] == NULL){ + if (_init+seq_len-1 >= 65536){ + throw std::runtime_error("NFQueueSequence: too many queues!"); + } + for (int i=0;i(_init+i); + }catch(const std::invalid_argument e){ + for(int j = 0; j < i; j++) { + delete nfq[j]; + nfq[j] = nullptr; + } + _init += seq_len - i; + break; + } + } + } + _end = _init + seq_len - 1; + } + + void start(){ + if (threads.size() != 0) throw std::runtime_error("NFQueueSequence: already started!"); + for (int i=0;i::run, nfq[i])); + } + } + + void join(){ + for (int i=0;i +#include +#include +#include +#include "../utils.hpp" + + +#ifndef REGEX_FILTER_HPP +#define REGEX_FILTER_HPP + +typedef jpcre2::select jp; +typedef std::pair regex_rule_pair; +typedef std::vector regex_rule_vector; +struct regex_rules{ + regex_rule_vector output_whitelist, input_whitelist, output_blacklist, input_blacklist; + + regex_rule_vector* getByCode(char code){ + switch(code){ + case 'C': // Client to server Blacklist + return &input_blacklist; break; + case 'c': // Client to server Whitelist + return &input_whitelist; break; + case 'S': // Server to client Blacklist + return &output_blacklist; break; + case 's': // Server to client Whitelist + return &output_whitelist; break; + } + throw std::invalid_argument( "Expected 'C' 'c' 'S' or 's'" ); + } + + int add(const char* arg){ + //Integrity checks + size_t arg_len = strlen(arg); + if (arg_len < 2 || arg_len%2 != 0){ + std::cerr << "[warning] [regex_rules.add] invalid arg passed (" << arg << "), skipping..." << std::endl; + return -1; + } + if (arg[0] != '0' && arg[0] != '1'){ + std::cerr << "[warning] [regex_rules.add] invalid is_case_sensitive (" << arg[0] << ") in '" << arg << "', must be '1' or '0', skipping..." << std::endl; + return -1; + } + if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's'){ + std::cerr << "[warning] [regex_rules.add] invalid filter_type (" << arg[1] << ") in '" << arg << "', must be 'C', 'c', 'S' or 's', skipping..." << std::endl; + return -1; + } + std::string hex(arg+2), expr; + if (!unhexlify(hex, expr)){ + std::cerr << "[warning] [regex_rules.add] invalid hex regex value (" << hex << "), skipping..." << std::endl; + return -1; + } + //Push regex + jp::Regex regex(expr,arg[0] == '1'?"gS":"giS"); + if (regex){ + std::cerr << "[info] [regex_rules.add] adding new regex filter: '" << expr << "'" << std::endl; + getByCode(arg[1])->push_back(std::make_pair(std::string(arg), regex)); + } else { + std::cerr << "[warning] [regex_rules.add] compiling of '" << expr << "' regex failed, skipping..." << std::endl; + return -1; + } + return 0; + } + + bool check(unsigned char* data, const size_t& bytes_transferred, const bool in_input){ + std::string str_data((char *) data, bytes_transferred); + for (regex_rule_pair ele:(in_input?input_blacklist:output_blacklist)){ + try{ + if(ele.second.match(str_data)){ + std::stringstream msg; + msg << "BLOCKED " << ele.first << "\n"; + std::cout << msg.str() << std::flush; + return false; + } + } catch(...){ + std::cerr << "[info] [regex_rules.check] Error while matching blacklist regex: " << ele.first << std::endl; + } + } + for (regex_rule_pair ele:(in_input?input_whitelist:output_whitelist)){ + try{ + std::cerr << "[debug] [regex_rules.check] regex whitelist match " << ele.second.getPattern() << std::endl; + if(!ele.second.match(str_data)){ + std::stringstream msg; + msg << "BLOCKED " << ele.first << "\n"; + std::cout << msg.str() << std::flush; + return false; + } + } catch(...){ + std::cerr << "[info] [regex_rules.check] Error while matching whitelist regex: " << ele.first << std::endl; + } + } + return true; + } + +}; + +#endif // REGEX_FILTER_HPP \ No newline at end of file diff --git a/backend/nfqueue/go.mod b/backend/nfqueue/go.mod deleted file mode 100644 index 57db091..0000000 --- a/backend/nfqueue/go.mod +++ /dev/null @@ -1,10 +0,0 @@ -module main - -go 1.18 - -require github.com/DomySh/go-netfilter-queue v0.0.0-20220713124014-7261f0df2c15 - -require ( - github.com/Jemmic/go-pcre2 v0.0.0-20190111114109-bd52ad5f7098 // indirect - github.com/google/gopacket v1.1.19 // indirect -) diff --git a/backend/nfqueue/go.sum b/backend/nfqueue/go.sum deleted file mode 100644 index fa84d31..0000000 --- a/backend/nfqueue/go.sum +++ /dev/null @@ -1,18 +0,0 @@ -github.com/DomySh/go-netfilter-queue v0.0.0-20220713124014-7261f0df2c15 h1:6v9D8bG3oR0dJFMuEeEAg8Xwn436Ziv+P7QWS04wAG8= -github.com/DomySh/go-netfilter-queue v0.0.0-20220713124014-7261f0df2c15/go.mod h1:VdJ6kqHln0XlrhuxQM6eBjRIHCzvAMgcZDAtyD/GU5s= -github.com/Jemmic/go-pcre2 v0.0.0-20190111114109-bd52ad5f7098 h1:ZwFIi+5jGJWVrB2V4NvrEhIUy6uDkfnTtBsgj3HAImI= -github.com/Jemmic/go-pcre2 v0.0.0-20190111114109-bd52ad5f7098/go.mod h1:c+8WT1L7lfohb4xMaa3yAV7nlYNepqc2ZV09/CU8R/U= -github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= -github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/backend/nfqueue/main.go b/backend/nfqueue/main.go deleted file mode 100644 index 66356da..0000000 --- a/backend/nfqueue/main.go +++ /dev/null @@ -1,264 +0,0 @@ -package main - -import ( - "bufio" - "encoding/hex" - "fmt" - "log" - "os" - "os/user" - "strconv" - "strings" - - "github.com/DomySh/go-netfilter-queue" - "github.com/Jemmic/go-pcre2" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" -) - -const QUEUE_BASE_NUM = 1000 -const MAX_PACKET_IN_QUEUE = 100 - -type regex_pair struct { - regex string - matcher *pcre2.Matcher -} - -type regex_filters struct { - input_whitelist []regex_pair - input_blacklist []regex_pair - output_whitelist []regex_pair - output_blacklist []regex_pair - regexes []*pcre2.Regexp -} - -func NewRegexFilter() *regex_filters { - res := new(regex_filters) - res.input_blacklist = make([]regex_pair, 0) - res.input_whitelist = make([]regex_pair, 0) - res.output_blacklist = make([]regex_pair, 0) - res.output_whitelist = make([]regex_pair, 0) - res.regexes = make([]*pcre2.Regexp, 0) - return res -} - -func (self *regex_filters) add(raw_regex string) { - filter_type := strings.ToLower(raw_regex[0:2]) - - decoded_regex, err := hex.DecodeString(raw_regex[2:]) - if err != nil { - log.Printf("[add] Unable to decode regex '%s': %s", raw_regex, err) - return - } - - regex, err := pcre2.Compile(string(decoded_regex), 0) - if err != nil { - log.Printf("[add] Unable to compile regex '%s': %s", string(decoded_regex), err) - return - } - self.regexes = append(self.regexes, regex) - if filter_type[0] == 'i' { - if filter_type[1] == '1' { - self.input_whitelist = append(self.input_whitelist, regex_pair{raw_regex, regex.NewMatcher()}) - } else { - self.input_blacklist = append(self.input_blacklist, regex_pair{raw_regex, regex.NewMatcher()}) - } - } else { - if filter_type[1] == '1' { - self.output_whitelist = append(self.output_whitelist, regex_pair{raw_regex, regex.NewMatcher()}) - } else { - self.output_blacklist = append(self.output_blacklist, regex_pair{raw_regex, regex.NewMatcher()}) - } - } -} - -func (self *regex_filters) check(data []byte, is_input bool) bool { - if is_input { - for _, rgx := range self.input_blacklist { - if rgx.matcher.Match(data, 0) { - fmt.Printf("BLOCKED %s\n", rgx.regex) - return false - } - } - for _, rgx := range self.input_whitelist { - if !rgx.matcher.Match(data, 0) { - fmt.Printf("BLOCKED %s\n", rgx.regex) - return false - } - } - } else { - for _, rgx := range self.output_blacklist { - if rgx.matcher.Match(data, 0) { - fmt.Printf("BLOCKED %s\n", rgx.regex) - return false - } - } - for _, rgx := range self.output_whitelist { - if !rgx.matcher.Match(data, 0) { - fmt.Printf("BLOCKED %s\n", rgx.regex) - return false - } - } - } - return true -} - -func (self *regex_filters) clear() { - for _, rgx := range self.input_whitelist { - rgx.matcher.Free() - } - for _, rgx := range self.input_blacklist { - rgx.matcher.Free() - } - for _, rgx := range self.output_whitelist { - rgx.matcher.Free() - } - for _, rgx := range self.output_blacklist { - rgx.matcher.Free() - } - for _, regex := range self.regexes { - regex.Free() - } -} - -func handle_packets(packets <-chan netfilter.NFPacket, filter_table_channel chan regex_filters, is_input bool) { - filter_table := regex_filters{} - for true { - filter := filter_table - select { - case ft := <-filter_table_channel: - { - filter_table = ft - } - case p := <-packets: - { - p.SetVerdict(netfilter.NF_ACCEPT) - break - transport_layer := p.Packet.TransportLayer() - data := transport_layer.LayerPayload() - if len(data) > 0 { - if filter.check(data, is_input) { - p.SetVerdict(netfilter.NF_ACCEPT) - } else { - if transport_layer.LayerType() == layers.LayerTypeTCP { - *p.Packet.ApplicationLayer().(*gopacket.Payload) = []byte{} - transport_layer.(*layers.TCP).Payload = []byte{} - transport_layer.(*layers.TCP).FIN = true - transport_layer.(*layers.TCP).SYN = false - transport_layer.(*layers.TCP).RST = false - transport_layer.(*layers.TCP).ACK = true - transport_layer.(*layers.TCP).SetNetworkLayerForChecksum(p.Packet.NetworkLayer()) - buffer := gopacket.NewSerializeBuffer() - options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} - if err := gopacket.SerializePacket(buffer, options, p.Packet); err != nil { - p.SetVerdict(netfilter.NF_DROP) - } - p.SetVerdictWithPacket(netfilter.NF_ACCEPT, buffer.Bytes()) - } else { - p.SetVerdict(netfilter.NF_DROP) - } - } - } else { - p.SetVerdict(netfilter.NF_ACCEPT) - } - } - } - } -} - -func isRoot() bool { - currentUser, err := user.Current() - if err != nil { - log.Fatalf("[isRoot] Unable to get current user: %s", err) - } - return currentUser.Username == "root" -} - -func create_queue_seq(num int) ([]*netfilter.NFQueue, int, int) { - var queue_list = make([]*netfilter.NFQueue, num) - var err error - starts := QUEUE_BASE_NUM - for queue_list[0] == nil { - if starts+num-1 >= 65536 { - log.Fatalf("Netfilter queue is full!") - } - for i := 0; i < len(queue_list); i++ { - queue_list[i], err = netfilter.NewNFQueue(uint16(starts+num-1-i), MAX_PACKET_IN_QUEUE, netfilter.NF_DEFAULT_PACKET_SIZE) - if err != nil { - for j := 0; j < i; j++ { - queue_list[j].Close() - queue_list[j] = nil - } - starts = starts + num - i - break - } - } - - } - return queue_list, starts, starts + num - 1 -} - -func main() { - log.SetOutput(os.Stderr) - if !isRoot() { - log.Fatalf("[main] You must be root to run this program") - } - - number_of_queues := 1 - - if len(os.Args) >= 2 { - var err error - number_of_queues, err = strconv.Atoi(os.Args[1]) - if err != nil { - log.Fatalf("[main] Invalid number of queues: %s", err) - } - } - var filter_channels []chan regex_filters - // Start the queue list - queue_list, starts_input, end_input := create_queue_seq(number_of_queues) - for _, queue := range queue_list { - defer queue.Close() - ch := make(chan regex_filters) - filter_channels = append(filter_channels, ch) - go handle_packets(queue.GetPackets(), ch, true) - } - - queue_list, starts_output, end_output := create_queue_seq(number_of_queues) - for _, queue := range queue_list { - defer queue.Close() - ch := make(chan regex_filters) - filter_channels = append(filter_channels, ch) - go handle_packets(queue.GetPackets(), ch, false) - } - - fmt.Println("QUEUE INPUT", starts_input, end_input, "OUTPUT", starts_output, end_output) - - //Reading for new configuration - reader := bufio.NewReader(os.Stdin) - old_filter_table := NewRegexFilter() - for true { - text, err := reader.ReadString('\n') - log.Printf("[main] Regex rule updating...") - if err != nil { - log.Fatalf("[main] Unable to read from stdin: %s", err) - } - text = strings.Trim(text, "\n") - regexes := strings.Split(text, " ") - - new_filters := NewRegexFilter() - for _, regex := range regexes { - regex = strings.Trim(regex, " ") - if len(regex) < 2 { - continue - } - new_filters.add(regex) - } - for _, ch := range filter_channels { - ch <- *new_filters - } - old_filter_table.clear() - old_filter_table = new_filters - log.Printf("[main] Regex filter rules updated!") - } - -} diff --git a/backend/nfqueue/nfqueue.cpp b/backend/nfqueue/nfqueue.cpp index 16d6e6a..373ed4f 100644 --- a/backend/nfqueue/nfqueue.cpp +++ b/backend/nfqueue/nfqueue.cpp @@ -1,421 +1,22 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "classes/regex_filter.hpp" +#include "classes/netfilter.hpp" +#include "utils.hpp" #include -#include -#include -#include -#include -#include -#include -#include - using namespace std; -using namespace Tins; -typedef jpcre2::select jp; -mutex stdout_mutex; - - -bool unhexlify(string const &hex, string &newString) { - try{ - int len = hex.length(); - for(int i=0; i< len; i+=2) - { - std::string byte = hex.substr(i,2); - char chr = (char) (int)strtol(byte.c_str(), NULL, 16); - newString.push_back(chr); - } - return true; - } - catch (...){ - return false; - } -} - - -typedef pair regex_rule_pair; -typedef vector regex_rule_vector; -struct regex_rules{ - regex_rule_vector output_whitelist, input_whitelist, output_blacklist, input_blacklist; - - regex_rule_vector* getByCode(char code){ - switch(code){ - case 'C': // Client to server Blacklist - return &input_blacklist; break; - case 'c': // Client to server Whitelist - return &input_whitelist; break; - case 'S': // Server to client Blacklist - return &output_blacklist; break; - case 's': // Server to client Whitelist - return &output_whitelist; break; - } - throw invalid_argument( "Expected 'C' 'c' 'S' or 's'" ); - } - - int add(const char* arg){ - //Integrity checks - size_t arg_len = strlen(arg); - if (arg_len < 2 || arg_len%2 != 0){ - cerr << "[warning] [regex_rules.add] invalid arg passed (" << arg << "), skipping..." << endl; - return -1; - } - if (arg[0] != '0' && arg[0] != '1'){ - cerr << "[warning] [regex_rules.add] invalid is_case_sensitive (" << arg[0] << ") in '" << arg << "', must be '1' or '0', skipping..." << endl; - return -1; - } - if (arg[1] != 'C' && arg[1] != 'c' && arg[1] != 'S' && arg[1] != 's'){ - cerr << "[warning] [regex_rules.add] invalid filter_type (" << arg[1] << ") in '" << arg << "', must be 'C', 'c', 'S' or 's', skipping..." << endl; - return -1; - } - string hex(arg+2), expr; - if (!unhexlify(hex, expr)){ - cerr << "[warning] [regex_rules.add] invalid hex regex value (" << hex << "), skipping..." << endl; - return -1; - } - //Push regex - jp::Regex regex(expr,arg[0] == '1'?"gS":"giS"); - if (regex){ - cerr << "[info] [regex_rules.add] adding new regex filter: '" << expr << "'" << endl; - getByCode(arg[1])->push_back(make_pair(string(arg), regex)); - } else { - cerr << "[warning] [regex_rules.add] compiling of '" << expr << "' regex failed, skipping..." << endl; - return -1; - } - return 0; - } - - bool check(unsigned char* data, const size_t& bytes_transferred, const bool in_input){ - string str_data((char *) data, bytes_transferred); - for (regex_rule_pair ele:(in_input?input_blacklist:output_blacklist)){ - try{ - if(ele.second.match(str_data)){ - unique_lock lck(stdout_mutex); - cout << "BLOCKED " << ele.first << endl; - return false; - } - } catch(...){ - cerr << "[info] [regex_rules.check] Error while matching blacklist regex: " << ele.first << endl; - } - } - for (regex_rule_pair ele:(in_input?input_whitelist:output_whitelist)){ - try{ - cerr << "[debug] [regex_rules.check] regex whitelist match " << ele.second.getPattern() << endl; - if(!ele.second.match(str_data)){ - unique_lock lck(stdout_mutex); - cout << "BLOCKED " << ele.first << endl; - return false; - } - } catch(...){ - cerr << "[info] [regex_rules.check] Error while matching whitelist regex: " << ele.first << endl; - } - } - return true; - } - -}; - - shared_ptr regex_config; -typedef bool NetFilterQueueCallback(const uint8_t*,uint32_t); - -PDU * find_transport_layer(PDU* pkt){ - while(pkt != NULL){ - if (pkt->pdu_type() == PDU::TCP || pkt->pdu_type() == PDU::UDP) { - return pkt; - } - pkt = pkt->inner_pdu(); - } - return pkt; -} - -template -class NetfilterQueue { - public: - size_t BUF_SIZE = 0xffff + (MNL_SOCKET_BUFFER_SIZE/2); - char *buf = NULL; - unsigned int portid; - u_int16_t queue_num; - struct mnl_socket* nl = NULL; - - NetfilterQueue(u_int16_t queue_num): queue_num(queue_num) { - - nl = mnl_socket_open(NETLINK_NETFILTER); - - if (nl == NULL) { throw runtime_error( "mnl_socket_open" );} - - if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) { - mnl_socket_close(nl); - throw runtime_error( "mnl_socket_bind" ); - } - portid = mnl_socket_get_portid(nl); - - buf = (char*) malloc(BUF_SIZE); - - if (!buf) { - mnl_socket_close(nl); - throw runtime_error( "allocate receive buffer" ); - } - - if (send_config_cmd(NFQNL_CFG_CMD_BIND) < 0) { - _clear(); - throw runtime_error( "mnl_socket_send" ); - } - //TEST if BIND was successful - if (send_config_cmd(NFQNL_CFG_CMD_NONE) < 0) { // SEND A NONE cmmand to generate an error meessage - _clear(); - throw runtime_error( "mnl_socket_send" ); - } - if (recv_packet() == -1) { //RECV the error message - _clear(); - throw std::runtime_error( "mnl_socket_recvfrom" ); - } - - struct nlmsghdr *nlh = (struct nlmsghdr *) buf; - - if (nlh->nlmsg_type != NLMSG_ERROR) { - _clear(); - throw runtime_error( "unexpected packet from kernel (expected NLMSG_ERROR packet)" ); - } - //nfqnl_msg_config_cmd - nlmsgerr* error_msg = (nlmsgerr *)mnl_nlmsg_get_payload(nlh); - - // error code taken from the linux kernel: - // https://elixir.bootlin.com/linux/v5.18.12/source/include/linux/errno.h#L27 - #define ENOTSUPP 524 /* Operation is not supported */ - - if (error_msg->error != -ENOTSUPP) { - _clear(); - throw std::invalid_argument( "queueid is already busy" ); - } - - //END TESTING BIND - nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num); - nfq_nlmsg_cfg_put_params(nlh, NFQNL_COPY_PACKET, 0xffff); - - - mnl_attr_put_u32(nlh, NFQA_CFG_FLAGS, htonl(NFQA_CFG_F_GSO)); - mnl_attr_put_u32(nlh, NFQA_CFG_MASK, htonl(NFQA_CFG_F_GSO)); - - if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) { - _clear(); - throw runtime_error( "mnl_socket_send" ); - } - - } - - - - void run(){ - /* - * ENOBUFS is signalled to userspace when packets were lost - * on kernel side. In most cases, userspace isn't interested - * in this information, so turn it off. - */ - int ret = 1; - mnl_socket_setsockopt(nl, NETLINK_NO_ENOBUFS, &ret, sizeof(int)); - - for (;;) { - ret = recv_packet(); - if (ret == -1) { - throw std::runtime_error( "mnl_socket_recvfrom" ); - } - - ret = mnl_cb_run(buf, ret, 0, portid, queue_cb, nl); - if (ret < 0){ - throw std::runtime_error( "mnl_cb_run" ); - } - } - } - - ~NetfilterQueue() { - send_config_cmd(NFQNL_CFG_CMD_UNBIND); - _clear(); - } - private: - - ssize_t send_config_cmd(nfqnl_msg_config_cmds cmd){ - struct nlmsghdr *nlh = nfq_nlmsg_put(buf, NFQNL_MSG_CONFIG, queue_num); - nfq_nlmsg_cfg_put_cmd(nlh, AF_INET, cmd); - return mnl_socket_sendto(nl, nlh, nlh->nlmsg_len); - } - - ssize_t recv_packet(){ - return mnl_socket_recvfrom(nl, buf, BUF_SIZE); - } - - void _clear(){ - if (buf != NULL) { - free(buf); - buf = NULL; - } - mnl_socket_close(nl); - } - - static int queue_cb(const struct nlmsghdr *nlh, void *data) - { - struct mnl_socket* nl = (struct mnl_socket*)data; - //Extract attributes from the nlmsghdr - struct nlattr *attr[NFQA_MAX+1] = {}; - - if (nfq_nlmsg_parse(nlh, attr) < 0) { - perror("problems parsing"); - return MNL_CB_ERROR; - } - if (attr[NFQA_PACKET_HDR] == NULL) { - fputs("metaheader not set\n", stderr); - return MNL_CB_ERROR; - } - //Get Payload - uint16_t plen = mnl_attr_get_payload_len(attr[NFQA_PAYLOAD]); - void *payload = mnl_attr_get_payload(attr[NFQA_PAYLOAD]); - - //Return result to the kernel - struct nfqnl_msg_packet_hdr *ph = (nfqnl_msg_packet_hdr*) mnl_attr_get_payload(attr[NFQA_PACKET_HDR]); - struct nfgenmsg *nfg = (nfgenmsg *)mnl_nlmsg_get_payload(nlh); - char buf[MNL_SOCKET_BUFFER_SIZE]; - struct nlmsghdr *nlh_verdict; - struct nlattr *nest; - - nlh_verdict = nfq_nlmsg_put(buf, NFQNL_MSG_VERDICT, ntohs(nfg->res_id)); - - /* - This define allow to avoid to allocate new heap memory for each packet. - The code under this comment is replicated for ipv6 and ip - Better solutions are welcome. :) - */ - #define PKT_HANDLE \ - PDU *transport_layer = find_transport_layer(&packet); \ - if(transport_layer->inner_pdu() == nullptr || transport_layer == nullptr){ \ - nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \ - }else{ \ - int size = transport_layer->inner_pdu()->size(); \ - if(callback_func((const uint8_t*)payload+plen - size, size)){ \ - nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \ - } else{ \ - if (transport_layer->pdu_type() == PDU::TCP){ \ - ((TCP *)transport_layer)->release_inner_pdu(); \ - ((TCP *)transport_layer)->set_flag(TCP::FIN,1); \ - ((TCP *)transport_layer)->set_flag(TCP::ACK,1); \ - ((TCP *)transport_layer)->set_flag(TCP::SYN,0); \ - nfq_nlmsg_verdict_put_pkt(nlh_verdict, packet.serialize().data(), packet.size()); \ - nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_ACCEPT ); \ - }else{ \ - nfq_nlmsg_verdict_put(nlh_verdict, ntohl(ph->packet_id), NF_DROP ); \ - } \ - } \ - } - - // Check IP protocol version - if ( (((uint8_t*)payload)[0] & 0xf0) == 0x40 ){ - IP packet = IP((uint8_t*)payload,plen); - PKT_HANDLE - }else{ - IPv6 packet = IPv6((uint8_t*)payload,plen); - PKT_HANDLE - } - - /* example to set the connmark. First, start NFQA_CT section: */ - nest = mnl_attr_nest_start(nlh_verdict, NFQA_CT); - - /* then, add the connmark attribute: */ - mnl_attr_put_u32(nlh_verdict, CTA_MARK, htonl(42)); - /* more conntrack attributes, e.g. CTA_LABELS could be set here */ - - /* end conntrack section */ - mnl_attr_nest_end(nlh_verdict, nest); - - if (mnl_socket_sendto(nl, nlh_verdict, nlh_verdict->nlmsg_len) < 0) { - throw std::runtime_error( "mnl_socket_send" ); - } - - return MNL_CB_OK; - } - -}; - -template -class NFQueueSequence{ - private: - vector *> nfq; - uint16_t _init; - uint16_t _end; - vector threads; - public: - static const int QUEUE_BASE_NUM = 1000; - - NFQueueSequence(uint16_t seq_len){ - if (seq_len <= 0) throw invalid_argument("seq_len <= 0"); - nfq = vector*>(seq_len); - _init = QUEUE_BASE_NUM; - while(nfq[0] == NULL){ - if (_init+seq_len-1 >= 65536){ - throw runtime_error("NFQueueSequence: too many queues!"); - } - for (int i=0;i(_init+i); - }catch(const invalid_argument e){ - for(int j = 0; j < i; j++) { - delete nfq[j]; - nfq[j] = nullptr; - } - _init += seq_len - i; - break; - } - } - } - _end = _init + seq_len - 1; - } - - void start(){ - if (threads.size() != 0) throw runtime_error("NFQueueSequence: already started!"); - for (int i=0;i::run, nfq[i])); - } - } - - void join(){ - for (int i=0;i> data; - regex_new_config->add(data.c_str()); + if (data != "" && data != "\n"){ + regex_new_config->add(data.c_str()); + } } regex_config.reset(regex_new_config); cerr << "[info] [updater] Config update done" << endl; @@ -456,21 +59,3 @@ int main(int argc, char *argv[]) config_updater(); } - - -/* - -libpcre2-dev -libnetfilter-queue-dev -libtins-dev -libmnl-dev - -c++ nfqueue.cpp -o nfqueue -pthread -lpcre2-8 -ltins -lnetfilter_queue -lmnl - -WORKDIR /tmp/ -RUN git clone --branch release https://github.com/jpcre2/jpcre2 -WORKDIR /tmp/jpcre2 -RUN ./configure; make; make install -WORKDIR / - -*/ \ No newline at end of file diff --git a/backend/nfqueue/utils.hpp b/backend/nfqueue/utils.hpp new file mode 100644 index 0000000..d7a092a --- /dev/null +++ b/backend/nfqueue/utils.hpp @@ -0,0 +1,28 @@ +#include +#include + +#ifndef UTILS_HPP +#define UTILS_HPP + +bool unhexlify(std::string const &hex, std::string &newString) { + try{ + int len = hex.length(); + for(int i=0; i< len; i+=2) + { + std::string byte = hex.substr(i,2); + char chr = (char) (int)strtol(byte.c_str(), NULL, 16); + newString.push_back(chr); + } + return true; + } + catch (...){ + return false; + } +} + + +bool is_sudo(){ + return getuid() == 0; +} + +#endif \ No newline at end of file