diff --git a/backend/app.py b/backend/app.py index 7bb478e..2ea341b 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,6 +1,5 @@ import uvicorn, secrets, utils import os, asyncio -from typing import List from fastapi import FastAPI, HTTPException, Depends, APIRouter from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import jwt @@ -108,7 +107,7 @@ async def change_password(form: PasswordChangeForm): return {"status":"ok", "access_token": create_access_token({"logged_in": True})} -@api.get('/interfaces', response_model=List[IpInterface]) +@api.get('/interfaces', response_model=list[IpInterface]) async def get_ip_interfaces(): """Get a list of ip and ip6 interfaces""" return get_interfaces() diff --git a/backend/binsrc/nfqueue_regex/src/main.rs b/backend/binsrc/nfqueue_regex/src/main.rs index 6e3b673..d12f973 100644 --- a/backend/binsrc/nfqueue_regex/src/main.rs +++ b/backend/binsrc/nfqueue_regex/src/main.rs @@ -2,6 +2,11 @@ use std::env; use std::collections::HashMap; +#[macro_use] +extern crate hyperscan; + +use hyperscan::prelude::*; + #[derive(Hash, Eq, PartialEq, Debug)] struct ConnectionFlux { src_ip: String, @@ -24,13 +29,434 @@ fn main() { n_of_threads = 1; } - let _connections = HashMap::from([ + let _connections = HashMap::from([ (ConnectionFlux::new("127.0.0.1", 1337, "127.0.0.1", 1337), 25), ]); eprintln!("[info][main] Using {} threads", n_of_threads) } + +// Hyperscan example program 2: pcapscan + +use std::collections::HashMap; +use std::fs; +use std::io; +use std::iter; +use std::net::SocketAddrV4; +use std::path::{Path, PathBuf}; +use std::process::exit; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::{Duration, Instant}; + +use anyhow::{Context, Result}; +use byteorder::{BigEndian, ReadBytesExt}; +use pnet::packet::{ + ethernet::{EtherTypes, EthernetPacket}, + ip::IpNextHeaderProtocols, + ipv4::Ipv4Packet, + udp::UdpPacket, + Packet, PrimitiveValues, +}; +use structopt::StructOpt; + +use hyperscan::prelude::*; + +/** + * This function will read in the file with the specified name, with an + * expression per line, ignoring lines starting with '#' and build a Hyperscan + * database for it. + */ +fn init_db>(path: P) -> Result<(StreamingDatabase)> { + // do the actual file reading and string handling + let patterns: Patterns = fs::read_to_string(path)?.parse()?; + + println!("Compiling Hyperscan databases with {} patterns.", patterns.len()); + + Ok((build_database(&patterns)?)) +} + +fn build_database, T: Mode>(builder: &B) -> Result> { + let now = Instant::now(); + + let db = builder.build::()?; + + println!( + "compile `{}` mode database in {} ms", + T::NAME, + now.elapsed().as_millis() + ); + + Ok(db) +} + +// Key for identifying a stream in our pcap input data, using data from its IP +// headers. +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +struct Session { + src: SocketAddrV4, + dst: SocketAddrV4, +} + +impl Session { + fn new(ipv4: &Ipv4Packet) -> Session { + let mut c = io::Cursor::new(ipv4.payload()); + let src_port = c.read_u16::().unwrap(); + let dst_port = c.read_u16::().unwrap(); + + Session { + src: SocketAddrV4::new(ipv4.get_source(), src_port), + dst: SocketAddrV4::new(ipv4.get_destination(), dst_port), + } + } +} + +const IP_FLAG_MF: u8 = 1; + +struct Benchmark { + /// Map used to construct stream_ids + sessions: HashMap>, + + /// Hyperscan compiled database (streaming mode) + streaming_db: StreamingDatabase, + + /// Hyperscan temporary scratch space (used in both modes) + scratch: Scratch, + + // Count of matches found during scanning + match_count: AtomicUsize, +} + +impl Benchmark { + fn new(streaming_db: StreamingDatabase) -> Result { + let mut s = streaming_db.alloc_scratch()?; + + block_db.realloc_scratch(&mut s)?; + + Ok(Benchmark { + sessions: HashMap::new(), + streaming_db: streaming_db, + scratch: s, + match_count: AtomicUsize::new(0), + }) + } + + fn decode_packet(packet: &pcap::Packet) -> Option<(Session, Vec)> { + let ether = EthernetPacket::new(&packet.data).unwrap(); + + if ether.get_ethertype() != EtherTypes::Ipv4 { + return None; + } + + let ipv4 = Ipv4Packet::new(ðer.payload()).unwrap(); + + if ipv4.get_version() != 4 { + return None; + } + + if (ipv4.get_flags() & IP_FLAG_MF) == IP_FLAG_MF || ipv4.get_fragment_offset() != 0 { + return None; + } + + match ipv4.get_next_level_protocol() { + IpNextHeaderProtocols::Tcp => { + let payload = ipv4.payload(); + let data_off = ((payload[12] >> 4) * 4) as usize; + + Some((Session::new(&ipv4), Vec::from(&payload[data_off..]))) + } + + IpNextHeaderProtocols::Udp => { + let udp = UdpPacket::new(&ipv4.payload()).unwrap(); + + Some((Session::new(&ipv4), Vec::from(udp.payload()))) + } + _ => None, + } + } + + fn read_streams>(&mut self, path: P) -> Result<(), pcap::Error> { + let mut capture = pcap::Capture::from_file(path)?; + + while let Ok(ref packet) = capture.next_packet() { + if let Some((key, payload)) = Self::decode_packet(&packet) { + if payload.len() > 0 { + let stream_id = match self.sessions.get(&key) { + Some(&id) => id, + None => { + let id = self.sessions.len(); + + assert!(self.sessions.insert(key, id).is_none()); + + id + } + }; + + self.stream_ids.push(stream_id); + self.packets.push(Box::new(payload)); + } + } + } + + println!( + "read {} packets in {} sessions", + self.packets.len(), + self.stream_ids.len(), + ); + + Ok(()) + } + + // Return the number of bytes scanned + fn bytes(&self) -> usize { + self.packets.iter().fold(0, |bytes, p| bytes + p.len()) + } + + // Return the number of matches found. + fn matches(&self) -> usize { + self.match_count.load(Ordering::Relaxed) + } + + // Clear the number of matches found. + fn clear_matches(&mut self) { + self.match_count.store(0, Ordering::Relaxed); + } + + // Open a Hyperscan stream for each stream in stream_ids + fn open_streams(&mut self) -> Result<()> { + self.streams = iter::repeat_with(|| self.streaming_db.open_stream()) + .take(self.sessions.len()) + .collect::>>()?; + + Ok(()) + } + + // Close all open Hyperscan streams (potentially generating any end-anchored matches) + fn close_streams(&mut self) -> Result<()> { + for stream in self.streams.drain(..) { + let match_count = &self.match_count; + stream + .close(&self.scratch, |_, _, _, _| { + match_count.fetch_add(1, Ordering::Relaxed); + + Matching::Continue + }) + .with_context(|| "close stream")?; + } + + Ok(()) + } + + fn reset_streams(&mut self) -> Result<()> { + for ref stream in &self.streams { + stream + .reset(&self.scratch, |_, _, _, _| { + self.match_count.fetch_add(1, Ordering::Relaxed); + + Matching::Continue + }) + .with_context(|| "reset stream")?; + } + + Ok(()) + } + + // Scan each packet (in the ordering given in the PCAP file) + // through Hyperscan using the streaming interface. + fn scan_streams(&mut self) -> Result<()> { + for (i, ref packet) in self.packets.iter().enumerate() { + let ref stream = self.streams[self.stream_ids[i]]; + + stream + .scan(packet.as_ref().as_slice(), &self.scratch, |_, _, _, _| { + self.match_count.fetch_add(1, Ordering::Relaxed); + + Matching::Continue + }) + .with_context(|| "scan packet")?; + } + + Ok(()) + } + + // Scan each packet (in the ordering given in the PCAP file) + // through Hyperscan using the block-mode interface. + fn scan_block(&mut self) -> Result<()> { + for ref packet in &self.packets { + self.block_db + .scan(packet.as_ref().as_slice(), &self.scratch, |_, _, _, _| { + self.match_count.fetch_add(1, Ordering::Relaxed); + + Matching::Continue + }) + .with_context(|| "scan packet")?; + } + + Ok(()) + } + + // Display some information about the compiled database and scanned data. + fn display_stats(&self) -> Result<()> { + let num_packets = self.packets.len(); + let num_streams = self.sessions.len(); + let num_bytes = self.bytes(); + + println!( + "{} packets in {} streams, totalling {} bytes.", + num_packets, num_streams, num_bytes + ); + println!( + "Average packet length: {} bytes.", + num_bytes / if num_packets > 0 { num_packets } else { 1 } + ); + println!( + "Average stream length: {} bytes.", + num_bytes / if num_streams > 0 { num_streams } else { 1 } + ); + println!(""); + println!( + "Streaming mode Hyperscan database size : {} bytes.", + self.streaming_db.size()? + ); + println!( + "Block mode Hyperscan database size : {} bytes.", + self.block_db.size()? + ); + println!( + "Streaming mode Hyperscan stream state size: {} bytes (per stream).", + self.streaming_db.stream_size()? + ); + + Ok(()) + } +} + +#[derive(Debug, StructOpt)] +#[structopt(name = "simplegrep", about = "An example search a given input file for a pattern.")] +struct Opt { + /// repeat times + #[structopt(short = "n", default_value = "1")] + repeats: usize, + + /// pattern file + #[structopt(parse(from_os_str))] + pattern_file: PathBuf, + + /// pcap file + #[structopt(parse(from_os_str))] + pcap_file: PathBuf, +} + +// Main entry point. +fn main() -> Result<()> { + let Opt { + repeats, + pattern_file, + pcap_file, + } = Opt::from_args(); + + // Read our pattern set in and build Hyperscan databases from it. + println!("Pattern file: {:?}", pattern_file); + + let (streaming_db, block_db) = match read_databases(pattern_file) { + Ok((streaming_db, block_db)) => (streaming_db, block_db), + Err(err) => { + eprintln!("ERROR: Unable to parse and compile patterns: {}\n", err); + exit(-1); + } + }; + + // Read our input PCAP file in + let mut bench = Benchmark::new(streaming_db, block_db)?; + + println!("PCAP input file: {:?}", pcap_file); + + if let Err(err) = bench.read_streams(pcap_file) { + eprintln!("Unable to read packets from PCAP file. Exiting. {}\n", err); + exit(-1); + } + + if repeats != 1 { + println!("Repeating PCAP scan {} times.", repeats); + } + + bench.display_stats()?; + + // Streaming mode scans. + let mut streaming_scan = Duration::from_secs(0); + let mut streaming_open_close = Duration::from_secs(0); + + for i in 0..repeats { + if i == 0 { + // Open streams. + let now = Instant::now(); + bench.open_streams()?; + streaming_open_close = streaming_open_close + now.elapsed(); + } else { + // Reset streams. + let now = Instant::now(); + bench.reset_streams()?; + streaming_open_close = streaming_open_close + now.elapsed(); + } + + // Scan all our packets in streaming mode. + let now = Instant::now(); + bench.scan_streams()?; + streaming_scan = streaming_scan + now.elapsed(); + } + + // Close streams. + let now = Instant::now(); + bench.close_streams()?; + streaming_open_close = streaming_open_close + now.elapsed(); + + // Collect data from streaming mode scans. + let bytes = bench.bytes(); + let total_bytes = (bytes * 8 * repeats) as f64; + let tput_stream_scanning = total_bytes * 1000.0 / streaming_scan.as_millis() as f64; + let tput_stream_overhead = total_bytes * 1000.0 / (streaming_scan + streaming_open_close).as_millis() as f64; + let matches_stream = bench.matches(); + let match_rate_stream = (matches_stream as f64) / ((bytes * repeats) as f64 / 1024.0); + + // Scan all our packets in block mode. + bench.clear_matches(); + let now = Instant::now(); + for _ in 0..repeats { + bench.scan_block()?; + } + let scan_block = now.elapsed(); + + // Collect data from block mode scans. + let tput_block_scanning = total_bytes * 1000.0 / scan_block.as_millis() as f64; + let matches_block = bench.matches(); + let match_rate_block = (matches_block as f64) / ((bytes * repeats) as f64 / 1024.0); + + println!("\nStreaming mode:\n"); + println!(" Total matches: {}", matches_stream); + println!(" Match rate: {:.4} matches/kilobyte", match_rate_stream); + println!( + " Throughput (with stream overhead): {:.2} megabits/sec", + tput_stream_overhead / 1000000.0 + ); + println!( + " Throughput (no stream overhead): {:.2} megabits/sec", + tput_stream_scanning / 1000000.0 + ); + + println!("\nBlock mode:\n"); + println!(" Total matches: {}", matches_block); + println!(" Match rate: {:.4} matches/kilobyte", match_rate_block); + println!(" Throughput: {:.2} megabits/sec", tput_block_scanning / 1000000.0); + + if bytes < (2 * 1024 * 1024) { + println!( + "\nWARNING: Input PCAP file is less than 2MB in size.\n + This test may have been too short to calculate accurate results." + ); + } + + Ok(()) +} /* shared_ptr regex_config; diff --git a/backend/modules/firewall/__init__.py b/backend/modules/firewall/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/modules/firewall/firewall.py b/backend/modules/firewall/firewall.py new file mode 100644 index 0000000..2ea594c --- /dev/null +++ b/backend/modules/firewall/firewall.py @@ -0,0 +1,24 @@ +import asyncio +from modules.firewall.nftables import FiregexTables +from modules.firewall.models import Rule +from utils.sqlite import SQLite + +nft = FiregexTables() + +class FirewallManager: + def __init__(self, db:SQLite): + self.db = db + self.lock = asyncio.Lock() + + async def close(self): + async with self.lock: + nft.reset() + + async def init(self): + FiregexTables().init() + await self.reload() + + async def reload(self): + async with self.lock: + nft.set(map(Rule.from_dict, self.db.query('SELECT * FROM rules WHERE active = 1 ORDER BY rule_id;'))) + diff --git a/backend/modules/firewall/models.py b/backend/modules/firewall/models.py new file mode 100644 index 0000000..b7a2aa0 --- /dev/null +++ b/backend/modules/firewall/models.py @@ -0,0 +1,33 @@ +class Rule: + def __init__(self, rule_id: int, name: str, active: bool, proto: str, ip_src:str, ip_dst:str, port_src_from:str, port_dst_from:str, port_src_to:str, port_dst_to:str, action:str, mode:str): + self.rule_id = rule_id + self.active = active + self.name = name + self.proto = proto + self.ip_src = ip_src + self.ip_dst = ip_dst + self.port_src_from = port_src_from + self.port_dst_from = port_dst_from + self.port_src_to = port_src_to + self.port_dst_to = port_dst_to + self.action = action + self.input_mode = mode in ["I"] + self.output_mode = mode in ["O"] + + + @classmethod + def from_dict(cls, var: dict): + return cls( + rule_id=var["rule_id"], + active=var["active"], + name=var["name"], + proto=var["proto"], + ip_src=var["ip_src"], + ip_dst=var["ip_dst"], + port_dst_from=var["port_dst_from"], + port_dst_to=var["port_dst_to"], + port_src_from=var["port_src_from"], + port_src_to=var["port_src_to"], + action=var["action"], + mode=var["mode"] + ) \ No newline at end of file diff --git a/backend/modules/firewall/nftables.py b/backend/modules/firewall/nftables.py new file mode 100644 index 0000000..6acf10a --- /dev/null +++ b/backend/modules/firewall/nftables.py @@ -0,0 +1,88 @@ +from modules.firewall.models import Rule +from utils import nftables_int_to_json, ip_parse, ip_family, NFTableManager, nftables_json_to_int + + +class FiregexHijackRule(): + def __init__(self, proto:str, ip_src:str, ip_dst:str, port_src_from:int, port_dst_from:int, port_src_to:int, port_dst_to:int, action:str, target:str, id:int): + self.id = id + self.target = target + self.proto = proto + self.ip_src = ip_src + self.ip_dst = ip_dst + self.port_src_from = min(port_src_from, port_src_to) + self.port_dst_from = min(port_dst_from, port_dst_to) + self.port_src_to = max(port_src_from, port_src_to) + self.port_dst_to = max(port_dst_from, port_dst_to) + self.action = action + + def __eq__(self, o: object) -> bool: + if isinstance(o, FiregexHijackRule) or isinstance(o, Rule): + return self.action == o.action and self.proto == o.proto and\ + ip_parse(self.ip_src) == ip_parse(o.ip_src) and ip_parse(self.ip_dst) == ip_parse(o.ip_dst) and\ + int(self.port_src_from) == int(o.port_src_from) and int(self.port_dst_from) == int(o.port_dst_from) and\ + int(self.port_src_to) == int(o.port_src_to) and int(self.port_dst_to) == int(o.port_dst_to) + return False + + +class FiregexTables(NFTableManager): + rules_chain_in = "firewall_rules_in" + rules_chain_out = "firewall_rules_out" + + def __init__(self): + super().__init__([ + {"add":{"chain":{ + "family":"inet", + "table":self.table_name, + "name":self.rules_chain_in, + "type":"filter", + "hook":"prerouting", + "prio":-300, + "policy":"accept" + }}}, + {"add":{"chain":{ + "family":"inet", + "table":self.table_name, + "name":self.rules_chain_out, + "type":"filter", + "hook":"postrouting", + "prio":-300, + "policy":"accept" + }}}, + ],[ + {"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}}, + {"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}}, + {"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}}, + {"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}}, + ]) + + def delete_all(self): + self.cmd( + {"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}}, + {"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}}, + ) + + def set(self, srv:list[Rule]): + self.delete_all() + for ele in srv: self.add(ele) + + def add(self, srv:Rule): + port_filters = [] + if srv.proto != "any": + if srv.port_src_from != 1 or srv.port_src_to != 65535: #Any Port + port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '>=', 'right': int(srv.port_src_from)}}) + port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '<=', 'right': int(srv.port_src_to)}}) + if srv.port_dst_from != 1 or srv.port_dst_to != 65535: #Any Port + port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'op': '>=', 'right': int(srv.port_dst_from)}}) + port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'op': '<=', 'right': int(srv.port_dst_to)}}) + if len(port_filters) == 0: + port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '!=', 'right': 0}}) #filter the protocol if no port is specified + + self.cmd({ "insert":{ "rule": { + "family": "inet", + "table": self.table_name, + "chain": self.rules_chain_out if srv.output_mode else self.rules_chain_in, + "expr": [ + {'match': {'left': {'payload': {'protocol': ip_family(srv.ip_src), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_src)}}, + {'match': {'left': {'payload': {'protocol': ip_family(srv.ip_dst), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_dst)}}, + ] + port_filters + [{'accept': None} if srv.action == "accept" else {'reject': {}} if srv.action == "reject" else {'drop': None}] + }}}) \ No newline at end of file diff --git a/backend/modules/nfregex/firegex.py b/backend/modules/nfregex/firegex.py index 0753e0f..8d57d5d 100644 --- a/backend/modules/nfregex/firegex.py +++ b/backend/modules/nfregex/firegex.py @@ -1,4 +1,3 @@ -from typing import Dict, List, Set from modules.nfregex.nftables import FiregexTables from utils import ip_parse, run_func from modules.nfregex.models import Service, Regex @@ -56,8 +55,8 @@ class FiregexInterceptor: def __init__(self): self.srv:Service self.filter_map_lock:asyncio.Lock - self.filter_map: Dict[str, RegexFilter] - self.regex_filters: Set[RegexFilter] + self.filter_map: dict[str, RegexFilter] + self.regex_filters: set[RegexFilter] self.update_config_lock:asyncio.Lock self.process:asyncio.subprocess.Process self.update_task: asyncio.Task @@ -118,7 +117,7 @@ class FiregexInterceptor: self.process.stdin.write((" ".join(filters_codes)+"\n").encode()) await self.process.stdin.drain() - async def reload(self, filters:List[RegexFilter]): + async def reload(self, filters:list[RegexFilter]): async with self.filter_map_lock: self.filter_map = self.compile_filters(filters) filters_codes = self.get_filter_codes() @@ -129,7 +128,7 @@ class FiregexInterceptor: filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True) return filters_codes - def compile_filters(self, filters:List[RegexFilter]): + def compile_filters(self, filters:list[RegexFilter]): res = {} for filter_obj in filters: try: diff --git a/backend/modules/nfregex/firewall.py b/backend/modules/nfregex/firewall.py index 18544f6..9516f63 100644 --- a/backend/modules/nfregex/firewall.py +++ b/backend/modules/nfregex/firewall.py @@ -1,5 +1,4 @@ import asyncio -from typing import Dict from modules.nfregex.firegex import FiregexInterceptor, RegexFilter from modules.nfregex.nftables import FiregexTables, FiregexFilter from modules.nfregex.models import Regex, Service @@ -11,49 +10,13 @@ class STATUS: nft = FiregexTables() -class FirewallManager: - def __init__(self, db:SQLite): - self.db = db - self.service_table: Dict[str, ServiceManager] = {} - self.lock = asyncio.Lock() - - async def close(self): - for key in list(self.service_table.keys()): - await self.remove(key) - - async def remove(self,srv_id): - async with self.lock: - if srv_id in self.service_table: - await self.service_table[srv_id].next(STATUS.STOP) - del self.service_table[srv_id] - - async def init(self): - nft.init() - await self.reload() - - async def reload(self): - async with self.lock: - for srv in self.db.query('SELECT * FROM services;'): - srv = Service.from_dict(srv) - if srv.id in self.service_table: - continue - self.service_table[srv.id] = ServiceManager(srv, self.db) - await self.service_table[srv.id].next(srv.status) - - def get(self,srv_id): - if srv_id in self.service_table: - return self.service_table[srv_id] - else: - raise ServiceNotFoundException() - -class ServiceNotFoundException(Exception): pass class ServiceManager: def __init__(self, srv: Service, db): self.srv = srv self.db = db self.status = STATUS.STOP - self.filters: Dict[int, FiregexFilter] = {} + self.filters: dict[int, FiregexFilter] = {} self.lock = asyncio.Lock() self.interceptor = None @@ -114,4 +77,41 @@ class ServiceManager: async def update_filters(self): async with self.lock: - await self._update_filters_from_db() \ No newline at end of file + await self._update_filters_from_db() + +class FirewallManager: + def __init__(self, db:SQLite): + self.db = db + self.service_table: dict[str, ServiceManager] = {} + self.lock = asyncio.Lock() + + async def close(self): + for key in list(self.service_table.keys()): + await self.remove(key) + + async def remove(self,srv_id): + async with self.lock: + if srv_id in self.service_table: + await self.service_table[srv_id].next(STATUS.STOP) + del self.service_table[srv_id] + + async def init(self): + nft.init() + await self.reload() + + async def reload(self): + async with self.lock: + for srv in self.db.query('SELECT * FROM services;'): + srv = Service.from_dict(srv) + if srv.id in self.service_table: + continue + self.service_table[srv.id] = ServiceManager(srv, self.db) + await self.service_table[srv.id].next(srv.status) + + def get(self,srv_id) -> ServiceManager: + if srv_id in self.service_table: + return self.service_table[srv_id] + else: + raise ServiceNotFoundException() + +class ServiceNotFoundException(Exception): pass diff --git a/backend/modules/nfregex/nftables.py b/backend/modules/nfregex/nftables.py index 47b77c2..a0bc917 100644 --- a/backend/modules/nfregex/nftables.py +++ b/backend/modules/nfregex/nftables.py @@ -1,4 +1,3 @@ -from typing import List from modules.nfregex.models import Service from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json @@ -11,9 +10,7 @@ class FiregexFilter: self.ip_int = str(ip_int) def __eq__(self, o: object) -> bool: - if isinstance(o, FiregexFilter): - return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) - elif isinstance(o, Service): + if isinstance(o, FiregexFilter) or isinstance(o, Service): return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) return False @@ -80,7 +77,7 @@ class FiregexTables(NFTableManager): }}}) - def get(self) -> List[FiregexFilter]: + def get(self) -> list[FiregexFilter]: res = [] for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]): ip_int = None diff --git a/backend/modules/porthijack/firewall.py b/backend/modules/porthijack/firewall.py index b1d21a4..ebb4cfc 100644 --- a/backend/modules/porthijack/firewall.py +++ b/backend/modules/porthijack/firewall.py @@ -1,47 +1,10 @@ import asyncio -from typing import Dict from modules.porthijack.nftables import FiregexTables from modules.porthijack.models import Service from utils.sqlite import SQLite nft = FiregexTables() -class FirewallManager: - def __init__(self, db:SQLite): - self.db = db - self.service_table: Dict[str, ServiceManager] = {} - self.lock = asyncio.Lock() - - async def close(self): - for key in list(self.service_table.keys()): - await self.remove(key) - - async def remove(self,srv_id): - async with self.lock: - if srv_id in self.service_table: - await self.service_table[srv_id].disable() - del self.service_table[srv_id] - - async def init(self): - FiregexTables().init() - await self.reload() - - async def reload(self): - async with self.lock: - for srv in self.db.query('SELECT * FROM services;'): - srv = Service.from_dict(srv) - if srv.service_id in self.service_table: - continue - self.service_table[srv.service_id] = ServiceManager(srv, self.db) - if srv.active: - await self.service_table[srv.service_id].enable() - - def get(self,srv_id): - if srv_id in self.service_table: - return self.service_table[srv_id] - else: - raise ServiceNotFoundException() - class ServiceNotFoundException(Exception): pass class ServiceManager: @@ -74,4 +37,41 @@ class ServiceManager: async def restart(self): await self.disable() - await self.enable() \ No newline at end of file + await self.enable() + +class FirewallManager: + def __init__(self, db:SQLite): + self.db = db + self.service_table: dict[str, ServiceManager] = {} + self.lock = asyncio.Lock() + + async def close(self): + for key in list(self.service_table.keys()): + await self.remove(key) + + async def remove(self,srv_id): + async with self.lock: + if srv_id in self.service_table: + await self.service_table[srv_id].disable() + del self.service_table[srv_id] + + async def init(self): + FiregexTables().init() + await self.reload() + + async def reload(self): + async with self.lock: + for srv in self.db.query('SELECT * FROM services;'): + srv = Service.from_dict(srv) + if srv.service_id in self.service_table: + continue + self.service_table[srv.service_id] = ServiceManager(srv, self.db) + if srv.active: + await self.service_table[srv.service_id].enable() + + def get(self,srv_id) -> ServiceManager: + if srv_id in self.service_table: + return self.service_table[srv_id] + else: + raise ServiceNotFoundException() + diff --git a/backend/modules/porthijack/nftables.py b/backend/modules/porthijack/nftables.py index 92255f3..34e0669 100644 --- a/backend/modules/porthijack/nftables.py +++ b/backend/modules/porthijack/nftables.py @@ -1,4 +1,3 @@ -from typing import List from modules.porthijack.models import Service from utils import addr_parse, ip_parse, ip_family, NFTableManager, nftables_json_to_int @@ -13,9 +12,7 @@ class FiregexHijackRule(): self.ip_dst = str(ip_dst) def __eq__(self, o: object) -> bool: - if isinstance(o, FiregexHijackRule): - return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_src) == ip_parse(o.ip_src) - elif isinstance(o, Service): + if isinstance(o, FiregexHijackRule) or isinstance(o, Service): return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_src) == ip_parse(o.ip_src) return False @@ -79,10 +76,9 @@ class FiregexTables(NFTableManager): }}}) - def get(self) -> List[FiregexHijackRule]: + def get(self) -> list[FiregexHijackRule]: res = [] for filter in self.list_rules(tables=[self.table_name], chains=[self.prerouting_porthijack,self.postrouting_porthijack]): - filter["expr"][0]["match"]["right"] res.append(FiregexHijackRule( target=filter["chain"], id=int(filter["handle"]), diff --git a/backend/modules/regexproxy/utils.py b/backend/modules/regexproxy/utils.py index 5ad5597..eae7895 100644 --- a/backend/modules/regexproxy/utils.py +++ b/backend/modules/regexproxy/utils.py @@ -145,7 +145,7 @@ class ServiceManager: class ProxyManager: def __init__(self, db:SQLite): self.db = db - self.proxy_table:dict = {} + self.proxy_table: dict[str, ServiceManager] = {} self.lock = asyncio.Lock() async def close(self): @@ -168,7 +168,7 @@ class ProxyManager: self.proxy_table[srv_id] = ServiceManager(srv_id,self.db) await self.proxy_table[srv_id].next(req_status) - def get(self,id): + def get(self,id) -> ServiceManager: if id in self.proxy_table: return self.proxy_table[id] else: diff --git a/backend/routers/firewall.py b/backend/routers/firewall.py new file mode 100644 index 0000000..2efae98 --- /dev/null +++ b/backend/routers/firewall.py @@ -0,0 +1,171 @@ +import sqlite3 +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from utils.sqlite import SQLite +from utils import ip_parse, ip_family, refactor_name, refresh_frontend, PortType +from utils.models import ResetRequest, StatusMessageModel +from modules.firewall.nftables import FiregexTables +from modules.firewall.firewall import FirewallManager + +class RuleModel(BaseModel): + active: bool + name: str + proto: str + ip_src: str + ip_dst: str + port_src_from: PortType + port_dst_from: PortType + port_src_to: PortType + port_dst_to: PortType + action: str + mode:str + +class RuleAddResponse(BaseModel): + status:str|list[dict] + +class RenameForm(BaseModel): + name:str + +class GeneralStatModel(BaseModel): + rules: int + +app = APIRouter() + +db = SQLite('db/firewall-rules.db', { + 'rules': { + 'rule_id': 'INT PRIMARY KEY CHECK (rule_id >= 0)', + 'mode': 'VARCHAR(1) NOT NULL CHECK (mode IN ("O", "I"))', # O = out, I = in, B = both + 'name': 'VARCHAR(100) NOT NULL', + 'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1))', + 'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp", "any"))', + 'ip_src': 'VARCHAR(100) NOT NULL', + 'port_src_from': 'INT CHECK(port_src_from > 0 and port_src_from < 65536)', + 'port_src_to': 'INT CHECK(port_src_to > 0 and port_src_to < 65536 and port_src_from <= port_src_to)', + 'ip_dst': 'VARCHAR(100) NOT NULL', + 'port_dst_from': 'INT CHECK(port_dst_from > 0 and port_dst_from < 65536)', + 'port_dst_to': 'INT CHECK(port_dst_to > 0 and port_dst_to < 65536 and port_dst_from <= port_dst_to)', + 'action': 'VARCHAR(10) NOT NULL CHECK (action IN ("accept", "drop", "reject"))', + }, + 'QUERY':[ + "CREATE UNIQUE INDEX IF NOT EXISTS unique_rules ON rules (proto, ip_src, ip_dst, port_src_from, port_src_to, port_dst_from, port_dst_to, action);" + ] +}) + +async def reset(params: ResetRequest): + if not params.delete: + db.backup() + await firewall.close() + FiregexTables().reset() + if params.delete: + db.delete() + db.init() + else: + db.restore() + await firewall.init() + + +async def startup(): + db.init() + await firewall.init() + +async def shutdown(): + db.backup() + await firewall.close() + db.disconnect() + db.restore() + +async def apply_changes(): + await firewall.reload() + await refresh_frontend() + return {'status': 'ok'} + +firewall = FirewallManager(db) + +@app.get('/stats', response_model=GeneralStatModel) +async def get_general_stats(): + """Get firegex general status about rules""" + return db.query("SELECT (SELECT COUNT(*) FROM rules) rules")[0] + +@app.get('/rules', response_model=list[RuleModel]) +async def get_rule_list(): + """Get the list of existent firegex rules""" + return db.query("SELECT active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id;") + +@app.get('/rule/{rule_id}/disable', response_model=StatusMessageModel) +async def service_disable(rule_id: str): + """Request disabling a specific rule""" + if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0: + return {'status': 'Rule not found'} + db.query('UPDATE rules SET active = 0 WHERE rule_id = ?;', rule_id) + return await apply_changes() + +@app.get('/rule/{rule_id}/enable', response_model=StatusMessageModel) +async def service_start(rule_id: str): + """Request the enabling a specific rule""" + if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0: + return {'status': 'Rule not found'} + db.query('UPDATE rules SET active = 1 WHERE rule_id = ?;', rule_id) + return await apply_changes() + +@app.post('/service/{rule_id}/rename', response_model=StatusMessageModel) +async def service_rename(rule_id: str, form: RenameForm): + """Request to change the name of a specific service""" + if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0: + return {'status': 'Rule not found'} + form.name = refactor_name(form.name) + if not form.name: return {'status': 'The name cannot be empty!'} + try: + db.query('UPDATE rules SET name=? WHERE rule_id = ?;', form.name, rule_id) + except sqlite3.IntegrityError: + return {'status': 'This name is already used'} + await refresh_frontend() + return {'status': 'ok'} + +def parse_and_check_rule(rule:RuleModel): + try: + rule.ip_src = ip_parse(rule.ip_src) + rule.ip_dst = ip_parse(rule.ip_dst) + except ValueError: + return {"status":"Invalid address"} + + rule.port_dst_from, rule.port_dst_to = min(rule.port_dst_from, rule.port_dst_to), max(rule.port_dst_from, rule.port_dst_to) + rule.port_src_from, rule.port_src_to = min(rule.port_src_from, rule.port_src_to), max(rule.port_src_from, rule.port_src_to) + + if ip_family(rule.ip_dst) != ip_family(rule.ip_src): + return {"status":"Destination and source addresses must be of the same family"} + if rule.proto not in ["tcp", "udp", "any"]: + return {"status":"Invalid protocol"} + if rule.action not in ["accept", "drop", "reject"]: + return {"status":"Invalid action"} + return rule + + +@app.post('/rules/set', response_model=RuleAddResponse) +async def add_new_service(form: list[RuleModel]): + """Add a new service""" + form = [parse_and_check_rule(ele) for ele in form] + errors = [({"rule":i} | ele) for i, ele in enumerate(form) if isinstance(ele, dict)] + if len(errors) > 0: + return {'status': errors} + try: + db.queries(["DELETE FROM rules"]+ + [(""" + INSERT INTO rules ( + rule_id, active, name, + proto, + ip_src, ip_dst, + port_src_from, port_dst_from, + port_src_to, port_dst_to, + action, mode + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ? ,?, ?)""", + rid, ele.active, ele.name, + ele.proto, + ele.ip_src, ele.ip_dst, + ele.port_src_from, ele.port_dst_from, + ele.port_src_to, ele.port_dst_to, + ele.action, ele.mode + ) for rid, ele in enumerate(form)] + ) + except sqlite3.IntegrityError: + return {'status': 'Error saving the rules: maybe there are duplicated rules'} + return await apply_changes() diff --git a/backend/routers/nfregex.py b/backend/routers/nfregex.py index 91131b1..ef86980 100644 --- a/backend/routers/nfregex.py +++ b/backend/routers/nfregex.py @@ -2,13 +2,12 @@ from base64 import b64decode import re import secrets import sqlite3 -from typing import List, Union from fastapi import APIRouter, HTTPException from pydantic import BaseModel from modules.nfregex.nftables import FiregexTables from modules.nfregex.firewall import STATUS, FirewallManager from utils.sqlite import SQLite -from utils import ip_parse, refactor_name, refresh_frontend +from utils import ip_parse, refactor_name, refresh_frontend, PortType from utils.models import ResetRequest, StatusMessageModel class GeneralStatModel(BaseModel): @@ -19,7 +18,7 @@ class GeneralStatModel(BaseModel): class ServiceModel(BaseModel): status: str service_id: str - port: int + port: PortType name: str proto: str ip_int: str @@ -43,19 +42,19 @@ class RegexAddForm(BaseModel): service_id: str regex: str mode: str - active: Union[bool,None] + active: bool|None is_blacklist: bool is_case_sensitive: bool class ServiceAddForm(BaseModel): name: str - port: int + port: PortType proto: str ip_int: str class ServiceAddResponse(BaseModel): status:str - service_id: Union[None,str] + service_id: str|None app = APIRouter() @@ -70,7 +69,7 @@ db = SQLite('db/nft-regex.db', { }, 'regexes': { 'regex': 'TEXT NOT NULL', - 'mode': 'VARCHAR(1) NOT NULL', + 'mode': 'VARCHAR(1) NOT NULL CHECK (mode IN ("C", "S", "B"))', # C = to the client, S = to the server, B = both 'service_id': 'VARCHAR(100) NOT NULL', 'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))', 'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', @@ -127,7 +126,7 @@ async def get_general_stats(): (SELECT COUNT(*) FROM services) services """)[0] -@app.get('/services', response_model=List[ServiceModel]) +@app.get('/services', response_model=list[ServiceModel]) async def get_service_list(): """Get the list of existent firegex services""" return db.query(""" @@ -198,7 +197,7 @@ async def service_rename(service_id: str, form: RenameForm): await refresh_frontend() return {'status': 'ok'} -@app.get('/service/{service_id}/regexes', response_model=List[RegexModel]) +@app.get('/service/{service_id}/regexes', response_model=list[RegexModel]) async def get_service_regexe_list(service_id: str): """Get the list of the regexes of a service""" return db.query(""" diff --git a/backend/routers/porthijack.py b/backend/routers/porthijack.py index 16b92ba..54c1314 100644 --- a/backend/routers/porthijack.py +++ b/backend/routers/porthijack.py @@ -1,11 +1,10 @@ import secrets import sqlite3 -from typing import List, Union from fastapi import APIRouter, HTTPException from pydantic import BaseModel from modules.porthijack.models import Service from utils.sqlite import SQLite -from utils import addr_parse, ip_family, refactor_name, refresh_frontend +from utils import addr_parse, ip_family, refactor_name, refresh_frontend, PortType from utils.models import ResetRequest, StatusMessageModel from modules.porthijack.nftables import FiregexTables from modules.porthijack.firewall import FirewallManager @@ -13,8 +12,8 @@ from modules.porthijack.firewall import FirewallManager class ServiceModel(BaseModel): service_id: str active: bool - public_port: int - proxy_port: int + public_port: PortType + proxy_port: PortType name: str proto: str ip_src: str @@ -25,15 +24,15 @@ class RenameForm(BaseModel): class ServiceAddForm(BaseModel): name: str - public_port: int - proxy_port: int + public_port: PortType + proxy_port: PortType proto: str ip_src: str ip_dst: str class ServiceAddResponse(BaseModel): status:str - service_id: Union[None,str] + service_id: str|None class GeneralStatModel(BaseModel): services: int @@ -96,7 +95,7 @@ async def get_general_stats(): (SELECT COUNT(*) FROM services) services """)[0] -@app.get('/services', response_model=List[ServiceModel]) +@app.get('/services', response_model=list[ServiceModel]) async def get_service_list(): """Get the list of existent firegex services""" return db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_src, ip_dst FROM services;") @@ -144,7 +143,7 @@ async def service_rename(service_id: str, form: RenameForm): class ChangeDestination(BaseModel): ip_dst: str - proxy_port: int + proxy_port: PortType @app.post('/service/{service_id}/change-destination', response_model=StatusMessageModel) async def service_change_destination(service_id: str, form: ChangeDestination): diff --git a/backend/routers/regexproxy.py b/backend/routers/regexproxy.py index c7d565e..eec4834 100644 --- a/backend/routers/regexproxy.py +++ b/backend/routers/regexproxy.py @@ -1,12 +1,11 @@ from base64 import b64decode import sqlite3, re -from typing import List, Union from fastapi import APIRouter, HTTPException from pydantic import BaseModel from modules.regexproxy.utils import STATUS, ProxyManager, gen_internal_port, gen_service_id from utils.sqlite import SQLite from utils.models import ResetRequest, StatusMessageModel -from utils import refactor_name, refresh_frontend +from utils import refactor_name, refresh_frontend, PortType app = APIRouter() db = SQLite("db/regextcpproxy.db",{ @@ -77,13 +76,13 @@ async def get_general_stats(): class ServiceModel(BaseModel): id:str status: str - public_port: int - internal_port: int + public_port: PortType + internal_port: PortType name: str n_regex: int n_packets: int -@app.get('/services', response_model=List[ServiceModel]) +@app.get('/services', response_model=list[ServiceModel]) async def get_service_list(): """Get the list of existent firegex services""" return db.query(""" @@ -157,8 +156,8 @@ async def regen_service_port(service_id: str): return {'status': 'ok'} class ChangePortForm(BaseModel): - port: Union[int, None] - internalPort: Union[int, None] + port: int|None + internalPort: int|None @app.post('/service/{service_id}/change-ports', response_model=StatusMessageModel) async def change_service_ports(service_id: str, change_port:ChangePortForm): @@ -167,7 +166,7 @@ async def change_service_ports(service_id: str, change_port:ChangePortForm): return {'status': 'Invalid Request!'} try: sql_inj = "" - query:List[Union[str,int]] = [] + query:list[str|int] = [] if not change_port.port is None: sql_inj+=" public_port = ? " query.append(change_port.port) @@ -194,7 +193,7 @@ class RegexModel(BaseModel): is_case_sensitive:bool active:bool -@app.get('/service/{service_id}/regexes', response_model=List[RegexModel]) +@app.get('/service/{service_id}/regexes', response_model=list[RegexModel]) async def get_service_regexe_list(service_id: str): """Get the list of the regexes of a service""" return db.query(""" @@ -250,7 +249,7 @@ class RegexAddForm(BaseModel): service_id: str regex: str mode: str - active: Union[bool,None] + active: bool|None is_blacklist: bool is_case_sensitive: bool @@ -272,12 +271,12 @@ async def add_new_regex(form: RegexAddForm): class ServiceAddForm(BaseModel): name: str - port: int - internalPort: Union[int, None] + port: PortType + internalPort: int|None class ServiceAddStatus(BaseModel): status:str - id: Union[str,None] + id: str|None class RenameForm(BaseModel): name:str diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py index dfd8ca5..788ed11 100644 --- a/backend/utils/__init__.py +++ b/backend/utils/__init__.py @@ -2,6 +2,8 @@ import asyncio from ipaddress import ip_address, ip_interface import os, socket, psutil, sys, nftables from fastapi_socketio import SocketManager +from fastapi import Path +from typing import Annotated LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1")) @@ -15,6 +17,8 @@ FIREGEX_PORT = int(os.getenv("PORT","4444")) JWT_ALGORITHM: str = "HS256" API_VERSION = "2.0.0" +PortType = Annotated[int, Path(gt=0, lt=65536)] + async def run_func(func, *args, **kwargs): if asyncio.iscoroutinefunction(func): return await func(*args, **kwargs) @@ -133,4 +137,4 @@ class NFTableManager(Singleton): def raw_list(self): return self.cmd({"list": {"ruleset": None}})["nftables"] - + diff --git a/backend/utils/loader.py b/backend/utils/loader.py index b9b2a80..c37327a 100644 --- a/backend/utils/loader.py +++ b/backend/utils/loader.py @@ -1,7 +1,6 @@ -import os, httpx, websockets -from sys import prefix -from typing import Callable, List, Union +import os, httpx +from typing import Callable from fastapi import APIRouter, WebSocket import asyncio from starlette.responses import StreamingResponse @@ -49,10 +48,10 @@ def list_routers(): return [ele[:-3] for ele in list_files(ROUTERS_DIR) if ele != "__init__.py" and " " not in ele and ele.endswith(".py")] class RouterModule(): - router: Union[None, APIRouter] - reset: Union[None, Callable] - startup: Union[None, Callable] - shutdown: Union[None, Callable] + router: None|APIRouter + reset: None|Callable + startup: None|Callable + shutdown: None|Callable name: str def __init__(self, router: APIRouter, reset: Callable, startup: Callable, shutdown: Callable, name:str): @@ -66,7 +65,7 @@ class RouterModule(): return f"RouterModule(router={self.router}, reset={self.reset}, startup={self.startup}, shutdown={self.shutdown})" def get_router_modules(): - res: List[RouterModule] = [] + res: list[RouterModule] = [] for route in list_routers(): module = getattr(__import__(f"routers.{route}"), route, None) if module: diff --git a/backend/utils/models.py b/backend/utils/models.py index e589685..46128e7 100644 --- a/backend/utils/models.py +++ b/backend/utils/models.py @@ -1,4 +1,3 @@ -from typing import Union from pydantic import BaseModel class StatusMessageModel(BaseModel): @@ -18,7 +17,7 @@ class PasswordChangeForm(BaseModel): class ChangePasswordModel(BaseModel): status: str - access_token: Union[str,None] + access_token: str|None class IpInterface(BaseModel): addr: str diff --git a/backend/utils/sqlite.py b/backend/utils/sqlite.py index 12ced61..c22c2d0 100644 --- a/backend/utils/sqlite.py +++ b/backend/utils/sqlite.py @@ -1,11 +1,9 @@ -from typing import Union import json, sqlite3, os from hashlib import md5 -import base64 class SQLite(): def __init__(self, db_name: str, schema:dict = None) -> None: - self.conn: Union[None, sqlite3.Connection] = None + self.conn: sqlite3.Connection|None = None self.cur = None self.db_name = db_name self.__backup = None @@ -58,10 +56,25 @@ class SQLite(): cur.close() def query(self, query, *values): + return self.queries([(query, *values)])[0] + + def queries(self, queries: list[tuple[str, ...]]): + return list(self.queries_iter(queries)) + + def queries_iter(self, queries: list[tuple[str, ...]]): cur = self.conn.cursor() try: - cur.execute(query, values) - return cur.fetchall() + for query_data in queries: + values = [] + str_query = None + if isinstance(query_data, str): + str_query = query_data + elif (isinstance(query_data, tuple) or isinstance(query_data, list)) and len(query_data) > 0 and isinstance(query_data[0], str): + str_query = query_data[0] + values = query_data[1:] + if str_query: + cur.execute(str_query, values) + yield cur.fetchall() finally: cur.close() try: self.conn.commit()