adding firewall function to firegex!

This commit is contained in:
Domingo Dirutigliano
2023-09-22 20:46:50 +02:00
parent 4b8b145b68
commit 7fda371dcb
20 changed files with 890 additions and 145 deletions

View File

@@ -1,6 +1,5 @@
import uvicorn, secrets, utils import uvicorn, secrets, utils
import os, asyncio import os, asyncio
from typing import List
from fastapi import FastAPI, HTTPException, Depends, APIRouter from fastapi import FastAPI, HTTPException, Depends, APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt from jose import jwt
@@ -108,7 +107,7 @@ async def change_password(form: PasswordChangeForm):
return {"status":"ok", "access_token": create_access_token({"logged_in": True})} return {"status":"ok", "access_token": create_access_token({"logged_in": True})}
@api.get('/interfaces', response_model=List[IpInterface]) @api.get('/interfaces', response_model=list[IpInterface])
async def get_ip_interfaces(): async def get_ip_interfaces():
"""Get a list of ip and ip6 interfaces""" """Get a list of ip and ip6 interfaces"""
return get_interfaces() return get_interfaces()

View File

@@ -2,6 +2,11 @@ use std::env;
use std::collections::HashMap; use std::collections::HashMap;
#[macro_use]
extern crate hyperscan;
use hyperscan::prelude::*;
#[derive(Hash, Eq, PartialEq, Debug)] #[derive(Hash, Eq, PartialEq, Debug)]
struct ConnectionFlux { struct ConnectionFlux {
src_ip: String, src_ip: String,
@@ -24,13 +29,434 @@ fn main() {
n_of_threads = 1; n_of_threads = 1;
} }
let _connections = HashMap::from([ let _connections = HashMap<ConnectionFlux, >::from([
(ConnectionFlux::new("127.0.0.1", 1337, "127.0.0.1", 1337), 25), (ConnectionFlux::new("127.0.0.1", 1337, "127.0.0.1", 1337), 25),
]); ]);
eprintln!("[info][main] Using {} threads", n_of_threads) eprintln!("[info][main] Using {} threads", n_of_threads)
} }
// Hyperscan example program 2: pcapscan
use std::collections::HashMap;
use std::fs;
use std::io;
use std::iter;
use std::net::SocketAddrV4;
use std::path::{Path, PathBuf};
use std::process::exit;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use byteorder::{BigEndian, ReadBytesExt};
use pnet::packet::{
ethernet::{EtherTypes, EthernetPacket},
ip::IpNextHeaderProtocols,
ipv4::Ipv4Packet,
udp::UdpPacket,
Packet, PrimitiveValues,
};
use structopt::StructOpt;
use hyperscan::prelude::*;
/**
* This function will read in the file with the specified name, with an
* expression per line, ignoring lines starting with '#' and build a Hyperscan
* database for it.
*/
fn init_db<P: AsRef<Path>>(path: P) -> Result<(StreamingDatabase)> {
// do the actual file reading and string handling
let patterns: Patterns = fs::read_to_string(path)?.parse()?;
println!("Compiling Hyperscan databases with {} patterns.", patterns.len());
Ok((build_database(&patterns)?))
}
fn build_database<B: Builder<Err = hyperscan::Error>, T: Mode>(builder: &B) -> Result<Database<T>> {
let now = Instant::now();
let db = builder.build::<T>()?;
println!(
"compile `{}` mode database in {} ms",
T::NAME,
now.elapsed().as_millis()
);
Ok(db)
}
// Key for identifying a stream in our pcap input data, using data from its IP
// headers.
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct Session {
src: SocketAddrV4,
dst: SocketAddrV4,
}
impl Session {
fn new(ipv4: &Ipv4Packet) -> Session {
let mut c = io::Cursor::new(ipv4.payload());
let src_port = c.read_u16::<BigEndian>().unwrap();
let dst_port = c.read_u16::<BigEndian>().unwrap();
Session {
src: SocketAddrV4::new(ipv4.get_source(), src_port),
dst: SocketAddrV4::new(ipv4.get_destination(), dst_port),
}
}
}
const IP_FLAG_MF: u8 = 1;
struct Benchmark {
/// Map used to construct stream_ids
sessions: HashMap<Session, Vec<Stream>>,
/// Hyperscan compiled database (streaming mode)
streaming_db: StreamingDatabase,
/// Hyperscan temporary scratch space (used in both modes)
scratch: Scratch,
// Count of matches found during scanning
match_count: AtomicUsize,
}
impl Benchmark {
fn new(streaming_db: StreamingDatabase) -> Result<Benchmark> {
let mut s = streaming_db.alloc_scratch()?;
block_db.realloc_scratch(&mut s)?;
Ok(Benchmark {
sessions: HashMap::new(),
streaming_db: streaming_db,
scratch: s,
match_count: AtomicUsize::new(0),
})
}
fn decode_packet(packet: &pcap::Packet) -> Option<(Session, Vec<u8>)> {
let ether = EthernetPacket::new(&packet.data).unwrap();
if ether.get_ethertype() != EtherTypes::Ipv4 {
return None;
}
let ipv4 = Ipv4Packet::new(&ether.payload()).unwrap();
if ipv4.get_version() != 4 {
return None;
}
if (ipv4.get_flags() & IP_FLAG_MF) == IP_FLAG_MF || ipv4.get_fragment_offset() != 0 {
return None;
}
match ipv4.get_next_level_protocol() {
IpNextHeaderProtocols::Tcp => {
let payload = ipv4.payload();
let data_off = ((payload[12] >> 4) * 4) as usize;
Some((Session::new(&ipv4), Vec::from(&payload[data_off..])))
}
IpNextHeaderProtocols::Udp => {
let udp = UdpPacket::new(&ipv4.payload()).unwrap();
Some((Session::new(&ipv4), Vec::from(udp.payload())))
}
_ => None,
}
}
fn read_streams<P: AsRef<Path>>(&mut self, path: P) -> Result<(), pcap::Error> {
let mut capture = pcap::Capture::from_file(path)?;
while let Ok(ref packet) = capture.next_packet() {
if let Some((key, payload)) = Self::decode_packet(&packet) {
if payload.len() > 0 {
let stream_id = match self.sessions.get(&key) {
Some(&id) => id,
None => {
let id = self.sessions.len();
assert!(self.sessions.insert(key, id).is_none());
id
}
};
self.stream_ids.push(stream_id);
self.packets.push(Box::new(payload));
}
}
}
println!(
"read {} packets in {} sessions",
self.packets.len(),
self.stream_ids.len(),
);
Ok(())
}
// Return the number of bytes scanned
fn bytes(&self) -> usize {
self.packets.iter().fold(0, |bytes, p| bytes + p.len())
}
// Return the number of matches found.
fn matches(&self) -> usize {
self.match_count.load(Ordering::Relaxed)
}
// Clear the number of matches found.
fn clear_matches(&mut self) {
self.match_count.store(0, Ordering::Relaxed);
}
// Open a Hyperscan stream for each stream in stream_ids
fn open_streams(&mut self) -> Result<()> {
self.streams = iter::repeat_with(|| self.streaming_db.open_stream())
.take(self.sessions.len())
.collect::<hyperscan::Result<Vec<_>>>()?;
Ok(())
}
// Close all open Hyperscan streams (potentially generating any end-anchored matches)
fn close_streams(&mut self) -> Result<()> {
for stream in self.streams.drain(..) {
let match_count = &self.match_count;
stream
.close(&self.scratch, |_, _, _, _| {
match_count.fetch_add(1, Ordering::Relaxed);
Matching::Continue
})
.with_context(|| "close stream")?;
}
Ok(())
}
fn reset_streams(&mut self) -> Result<()> {
for ref stream in &self.streams {
stream
.reset(&self.scratch, |_, _, _, _| {
self.match_count.fetch_add(1, Ordering::Relaxed);
Matching::Continue
})
.with_context(|| "reset stream")?;
}
Ok(())
}
// Scan each packet (in the ordering given in the PCAP file)
// through Hyperscan using the streaming interface.
fn scan_streams(&mut self) -> Result<()> {
for (i, ref packet) in self.packets.iter().enumerate() {
let ref stream = self.streams[self.stream_ids[i]];
stream
.scan(packet.as_ref().as_slice(), &self.scratch, |_, _, _, _| {
self.match_count.fetch_add(1, Ordering::Relaxed);
Matching::Continue
})
.with_context(|| "scan packet")?;
}
Ok(())
}
// Scan each packet (in the ordering given in the PCAP file)
// through Hyperscan using the block-mode interface.
fn scan_block(&mut self) -> Result<()> {
for ref packet in &self.packets {
self.block_db
.scan(packet.as_ref().as_slice(), &self.scratch, |_, _, _, _| {
self.match_count.fetch_add(1, Ordering::Relaxed);
Matching::Continue
})
.with_context(|| "scan packet")?;
}
Ok(())
}
// Display some information about the compiled database and scanned data.
fn display_stats(&self) -> Result<()> {
let num_packets = self.packets.len();
let num_streams = self.sessions.len();
let num_bytes = self.bytes();
println!(
"{} packets in {} streams, totalling {} bytes.",
num_packets, num_streams, num_bytes
);
println!(
"Average packet length: {} bytes.",
num_bytes / if num_packets > 0 { num_packets } else { 1 }
);
println!(
"Average stream length: {} bytes.",
num_bytes / if num_streams > 0 { num_streams } else { 1 }
);
println!("");
println!(
"Streaming mode Hyperscan database size : {} bytes.",
self.streaming_db.size()?
);
println!(
"Block mode Hyperscan database size : {} bytes.",
self.block_db.size()?
);
println!(
"Streaming mode Hyperscan stream state size: {} bytes (per stream).",
self.streaming_db.stream_size()?
);
Ok(())
}
}
#[derive(Debug, StructOpt)]
#[structopt(name = "simplegrep", about = "An example search a given input file for a pattern.")]
struct Opt {
/// repeat times
#[structopt(short = "n", default_value = "1")]
repeats: usize,
/// pattern file
#[structopt(parse(from_os_str))]
pattern_file: PathBuf,
/// pcap file
#[structopt(parse(from_os_str))]
pcap_file: PathBuf,
}
// Main entry point.
fn main() -> Result<()> {
let Opt {
repeats,
pattern_file,
pcap_file,
} = Opt::from_args();
// Read our pattern set in and build Hyperscan databases from it.
println!("Pattern file: {:?}", pattern_file);
let (streaming_db, block_db) = match read_databases(pattern_file) {
Ok((streaming_db, block_db)) => (streaming_db, block_db),
Err(err) => {
eprintln!("ERROR: Unable to parse and compile patterns: {}\n", err);
exit(-1);
}
};
// Read our input PCAP file in
let mut bench = Benchmark::new(streaming_db, block_db)?;
println!("PCAP input file: {:?}", pcap_file);
if let Err(err) = bench.read_streams(pcap_file) {
eprintln!("Unable to read packets from PCAP file. Exiting. {}\n", err);
exit(-1);
}
if repeats != 1 {
println!("Repeating PCAP scan {} times.", repeats);
}
bench.display_stats()?;
// Streaming mode scans.
let mut streaming_scan = Duration::from_secs(0);
let mut streaming_open_close = Duration::from_secs(0);
for i in 0..repeats {
if i == 0 {
// Open streams.
let now = Instant::now();
bench.open_streams()?;
streaming_open_close = streaming_open_close + now.elapsed();
} else {
// Reset streams.
let now = Instant::now();
bench.reset_streams()?;
streaming_open_close = streaming_open_close + now.elapsed();
}
// Scan all our packets in streaming mode.
let now = Instant::now();
bench.scan_streams()?;
streaming_scan = streaming_scan + now.elapsed();
}
// Close streams.
let now = Instant::now();
bench.close_streams()?;
streaming_open_close = streaming_open_close + now.elapsed();
// Collect data from streaming mode scans.
let bytes = bench.bytes();
let total_bytes = (bytes * 8 * repeats) as f64;
let tput_stream_scanning = total_bytes * 1000.0 / streaming_scan.as_millis() as f64;
let tput_stream_overhead = total_bytes * 1000.0 / (streaming_scan + streaming_open_close).as_millis() as f64;
let matches_stream = bench.matches();
let match_rate_stream = (matches_stream as f64) / ((bytes * repeats) as f64 / 1024.0);
// Scan all our packets in block mode.
bench.clear_matches();
let now = Instant::now();
for _ in 0..repeats {
bench.scan_block()?;
}
let scan_block = now.elapsed();
// Collect data from block mode scans.
let tput_block_scanning = total_bytes * 1000.0 / scan_block.as_millis() as f64;
let matches_block = bench.matches();
let match_rate_block = (matches_block as f64) / ((bytes * repeats) as f64 / 1024.0);
println!("\nStreaming mode:\n");
println!(" Total matches: {}", matches_stream);
println!(" Match rate: {:.4} matches/kilobyte", match_rate_stream);
println!(
" Throughput (with stream overhead): {:.2} megabits/sec",
tput_stream_overhead / 1000000.0
);
println!(
" Throughput (no stream overhead): {:.2} megabits/sec",
tput_stream_scanning / 1000000.0
);
println!("\nBlock mode:\n");
println!(" Total matches: {}", matches_block);
println!(" Match rate: {:.4} matches/kilobyte", match_rate_block);
println!(" Throughput: {:.2} megabits/sec", tput_block_scanning / 1000000.0);
if bytes < (2 * 1024 * 1024) {
println!(
"\nWARNING: Input PCAP file is less than 2MB in size.\n
This test may have been too short to calculate accurate results."
);
}
Ok(())
}
/* /*
shared_ptr<regex_rules> regex_config; shared_ptr<regex_rules> regex_config;

View File

View File

@@ -0,0 +1,24 @@
import asyncio
from modules.firewall.nftables import FiregexTables
from modules.firewall.models import Rule
from utils.sqlite import SQLite
nft = FiregexTables()
class FirewallManager:
def __init__(self, db:SQLite):
self.db = db
self.lock = asyncio.Lock()
async def close(self):
async with self.lock:
nft.reset()
async def init(self):
FiregexTables().init()
await self.reload()
async def reload(self):
async with self.lock:
nft.set(map(Rule.from_dict, self.db.query('SELECT * FROM rules WHERE active = 1 ORDER BY rule_id;')))

View File

@@ -0,0 +1,33 @@
class Rule:
def __init__(self, rule_id: int, name: str, active: bool, proto: str, ip_src:str, ip_dst:str, port_src_from:str, port_dst_from:str, port_src_to:str, port_dst_to:str, action:str, mode:str):
self.rule_id = rule_id
self.active = active
self.name = name
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
self.port_src_from = port_src_from
self.port_dst_from = port_dst_from
self.port_src_to = port_src_to
self.port_dst_to = port_dst_to
self.action = action
self.input_mode = mode in ["I"]
self.output_mode = mode in ["O"]
@classmethod
def from_dict(cls, var: dict):
return cls(
rule_id=var["rule_id"],
active=var["active"],
name=var["name"],
proto=var["proto"],
ip_src=var["ip_src"],
ip_dst=var["ip_dst"],
port_dst_from=var["port_dst_from"],
port_dst_to=var["port_dst_to"],
port_src_from=var["port_src_from"],
port_src_to=var["port_src_to"],
action=var["action"],
mode=var["mode"]
)

View File

@@ -0,0 +1,88 @@
from modules.firewall.models import Rule
from utils import nftables_int_to_json, ip_parse, ip_family, NFTableManager, nftables_json_to_int
class FiregexHijackRule():
def __init__(self, proto:str, ip_src:str, ip_dst:str, port_src_from:int, port_dst_from:int, port_src_to:int, port_dst_to:int, action:str, target:str, id:int):
self.id = id
self.target = target
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
self.port_src_from = min(port_src_from, port_src_to)
self.port_dst_from = min(port_dst_from, port_dst_to)
self.port_src_to = max(port_src_from, port_src_to)
self.port_dst_to = max(port_dst_from, port_dst_to)
self.action = action
def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexHijackRule) or isinstance(o, Rule):
return self.action == o.action and self.proto == o.proto and\
ip_parse(self.ip_src) == ip_parse(o.ip_src) and ip_parse(self.ip_dst) == ip_parse(o.ip_dst) and\
int(self.port_src_from) == int(o.port_src_from) and int(self.port_dst_from) == int(o.port_dst_from) and\
int(self.port_src_to) == int(o.port_src_to) and int(self.port_dst_to) == int(o.port_dst_to)
return False
class FiregexTables(NFTableManager):
rules_chain_in = "firewall_rules_in"
rules_chain_out = "firewall_rules_out"
def __init__(self):
super().__init__([
{"add":{"chain":{
"family":"inet",
"table":self.table_name,
"name":self.rules_chain_in,
"type":"filter",
"hook":"prerouting",
"prio":-300,
"policy":"accept"
}}},
{"add":{"chain":{
"family":"inet",
"table":self.table_name,
"name":self.rules_chain_out,
"type":"filter",
"hook":"postrouting",
"prio":-300,
"policy":"accept"
}}},
],[
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}},
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}},
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}},
{"delete":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}},
])
def delete_all(self):
self.cmd(
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_in}}},
{"flush":{"chain":{"table":self.table_name,"family":"inet", "name":self.rules_chain_out}}},
)
def set(self, srv:list[Rule]):
self.delete_all()
for ele in srv: self.add(ele)
def add(self, srv:Rule):
port_filters = []
if srv.proto != "any":
if srv.port_src_from != 1 or srv.port_src_to != 65535: #Any Port
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '>=', 'right': int(srv.port_src_from)}})
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '<=', 'right': int(srv.port_src_to)}})
if srv.port_dst_from != 1 or srv.port_dst_to != 65535: #Any Port
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'op': '>=', 'right': int(srv.port_dst_from)}})
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'dport'}}, 'op': '<=', 'right': int(srv.port_dst_to)}})
if len(port_filters) == 0:
port_filters.append({'match': {'left': {'payload': {'protocol': str(srv.proto), 'field': 'sport'}}, 'op': '!=', 'right': 0}}) #filter the protocol if no port is specified
self.cmd({ "insert":{ "rule": {
"family": "inet",
"table": self.table_name,
"chain": self.rules_chain_out if srv.output_mode else self.rules_chain_in,
"expr": [
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_src), 'field': 'saddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_src)}},
{'match': {'left': {'payload': {'protocol': ip_family(srv.ip_dst), 'field': 'daddr'}}, 'op': '==', 'right': nftables_int_to_json(srv.ip_dst)}},
] + port_filters + [{'accept': None} if srv.action == "accept" else {'reject': {}} if srv.action == "reject" else {'drop': None}]
}}})

View File

@@ -1,4 +1,3 @@
from typing import Dict, List, Set
from modules.nfregex.nftables import FiregexTables from modules.nfregex.nftables import FiregexTables
from utils import ip_parse, run_func from utils import ip_parse, run_func
from modules.nfregex.models import Service, Regex from modules.nfregex.models import Service, Regex
@@ -56,8 +55,8 @@ class FiregexInterceptor:
def __init__(self): def __init__(self):
self.srv:Service self.srv:Service
self.filter_map_lock:asyncio.Lock self.filter_map_lock:asyncio.Lock
self.filter_map: Dict[str, RegexFilter] self.filter_map: dict[str, RegexFilter]
self.regex_filters: Set[RegexFilter] self.regex_filters: set[RegexFilter]
self.update_config_lock:asyncio.Lock self.update_config_lock:asyncio.Lock
self.process:asyncio.subprocess.Process self.process:asyncio.subprocess.Process
self.update_task: asyncio.Task self.update_task: asyncio.Task
@@ -118,7 +117,7 @@ class FiregexInterceptor:
self.process.stdin.write((" ".join(filters_codes)+"\n").encode()) self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
await self.process.stdin.drain() await self.process.stdin.drain()
async def reload(self, filters:List[RegexFilter]): async def reload(self, filters:list[RegexFilter]):
async with self.filter_map_lock: async with self.filter_map_lock:
self.filter_map = self.compile_filters(filters) self.filter_map = self.compile_filters(filters)
filters_codes = self.get_filter_codes() filters_codes = self.get_filter_codes()
@@ -129,7 +128,7 @@ class FiregexInterceptor:
filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True) filters_codes.sort(key=lambda a: self.filter_map[a].blocked, reverse=True)
return filters_codes return filters_codes
def compile_filters(self, filters:List[RegexFilter]): def compile_filters(self, filters:list[RegexFilter]):
res = {} res = {}
for filter_obj in filters: for filter_obj in filters:
try: try:

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
from typing import Dict
from modules.nfregex.firegex import FiregexInterceptor, RegexFilter from modules.nfregex.firegex import FiregexInterceptor, RegexFilter
from modules.nfregex.nftables import FiregexTables, FiregexFilter from modules.nfregex.nftables import FiregexTables, FiregexFilter
from modules.nfregex.models import Regex, Service from modules.nfregex.models import Regex, Service
@@ -11,49 +10,13 @@ class STATUS:
nft = FiregexTables() nft = FiregexTables()
class FirewallManager:
def __init__(self, db:SQLite):
self.db = db
self.service_table: Dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
async def close(self):
for key in list(self.service_table.keys()):
await self.remove(key)
async def remove(self,srv_id):
async with self.lock:
if srv_id in self.service_table:
await self.service_table[srv_id].next(STATUS.STOP)
del self.service_table[srv_id]
async def init(self):
nft.init()
await self.reload()
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv)
if srv.id in self.service_table:
continue
self.service_table[srv.id] = ServiceManager(srv, self.db)
await self.service_table[srv.id].next(srv.status)
def get(self,srv_id):
if srv_id in self.service_table:
return self.service_table[srv_id]
else:
raise ServiceNotFoundException()
class ServiceNotFoundException(Exception): pass
class ServiceManager: class ServiceManager:
def __init__(self, srv: Service, db): def __init__(self, srv: Service, db):
self.srv = srv self.srv = srv
self.db = db self.db = db
self.status = STATUS.STOP self.status = STATUS.STOP
self.filters: Dict[int, FiregexFilter] = {} self.filters: dict[int, FiregexFilter] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.interceptor = None self.interceptor = None
@@ -114,4 +77,41 @@ class ServiceManager:
async def update_filters(self): async def update_filters(self):
async with self.lock: async with self.lock:
await self._update_filters_from_db() await self._update_filters_from_db()
class FirewallManager:
def __init__(self, db:SQLite):
self.db = db
self.service_table: dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
async def close(self):
for key in list(self.service_table.keys()):
await self.remove(key)
async def remove(self,srv_id):
async with self.lock:
if srv_id in self.service_table:
await self.service_table[srv_id].next(STATUS.STOP)
del self.service_table[srv_id]
async def init(self):
nft.init()
await self.reload()
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv)
if srv.id in self.service_table:
continue
self.service_table[srv.id] = ServiceManager(srv, self.db)
await self.service_table[srv.id].next(srv.status)
def get(self,srv_id) -> ServiceManager:
if srv_id in self.service_table:
return self.service_table[srv_id]
else:
raise ServiceNotFoundException()
class ServiceNotFoundException(Exception): pass

View File

@@ -1,4 +1,3 @@
from typing import List
from modules.nfregex.models import Service from modules.nfregex.models import Service
from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json from utils import ip_parse, ip_family, NFTableManager, nftables_int_to_json
@@ -11,9 +10,7 @@ class FiregexFilter:
self.ip_int = str(ip_int) self.ip_int = str(ip_int)
def __eq__(self, o: object) -> bool: def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexFilter): if isinstance(o, FiregexFilter) or isinstance(o, Service):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
elif isinstance(o, Service):
return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int) return self.port == o.port and self.proto == o.proto and ip_parse(self.ip_int) == ip_parse(o.ip_int)
return False return False
@@ -80,7 +77,7 @@ class FiregexTables(NFTableManager):
}}}) }}})
def get(self) -> List[FiregexFilter]: def get(self) -> list[FiregexFilter]:
res = [] res = []
for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]): for filter in self.list_rules(tables=[self.table_name], chains=[self.input_chain,self.output_chain]):
ip_int = None ip_int = None

View File

@@ -1,47 +1,10 @@
import asyncio import asyncio
from typing import Dict
from modules.porthijack.nftables import FiregexTables from modules.porthijack.nftables import FiregexTables
from modules.porthijack.models import Service from modules.porthijack.models import Service
from utils.sqlite import SQLite from utils.sqlite import SQLite
nft = FiregexTables() nft = FiregexTables()
class FirewallManager:
def __init__(self, db:SQLite):
self.db = db
self.service_table: Dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
async def close(self):
for key in list(self.service_table.keys()):
await self.remove(key)
async def remove(self,srv_id):
async with self.lock:
if srv_id in self.service_table:
await self.service_table[srv_id].disable()
del self.service_table[srv_id]
async def init(self):
FiregexTables().init()
await self.reload()
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv)
if srv.service_id in self.service_table:
continue
self.service_table[srv.service_id] = ServiceManager(srv, self.db)
if srv.active:
await self.service_table[srv.service_id].enable()
def get(self,srv_id):
if srv_id in self.service_table:
return self.service_table[srv_id]
else:
raise ServiceNotFoundException()
class ServiceNotFoundException(Exception): pass class ServiceNotFoundException(Exception): pass
class ServiceManager: class ServiceManager:
@@ -74,4 +37,41 @@ class ServiceManager:
async def restart(self): async def restart(self):
await self.disable() await self.disable()
await self.enable() await self.enable()
class FirewallManager:
def __init__(self, db:SQLite):
self.db = db
self.service_table: dict[str, ServiceManager] = {}
self.lock = asyncio.Lock()
async def close(self):
for key in list(self.service_table.keys()):
await self.remove(key)
async def remove(self,srv_id):
async with self.lock:
if srv_id in self.service_table:
await self.service_table[srv_id].disable()
del self.service_table[srv_id]
async def init(self):
FiregexTables().init()
await self.reload()
async def reload(self):
async with self.lock:
for srv in self.db.query('SELECT * FROM services;'):
srv = Service.from_dict(srv)
if srv.service_id in self.service_table:
continue
self.service_table[srv.service_id] = ServiceManager(srv, self.db)
if srv.active:
await self.service_table[srv.service_id].enable()
def get(self,srv_id) -> ServiceManager:
if srv_id in self.service_table:
return self.service_table[srv_id]
else:
raise ServiceNotFoundException()

View File

@@ -1,4 +1,3 @@
from typing import List
from modules.porthijack.models import Service from modules.porthijack.models import Service
from utils import addr_parse, ip_parse, ip_family, NFTableManager, nftables_json_to_int from utils import addr_parse, ip_parse, ip_family, NFTableManager, nftables_json_to_int
@@ -13,9 +12,7 @@ class FiregexHijackRule():
self.ip_dst = str(ip_dst) self.ip_dst = str(ip_dst)
def __eq__(self, o: object) -> bool: def __eq__(self, o: object) -> bool:
if isinstance(o, FiregexHijackRule): if isinstance(o, FiregexHijackRule) or isinstance(o, Service):
return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_src) == ip_parse(o.ip_src)
elif isinstance(o, Service):
return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_src) == ip_parse(o.ip_src) return self.public_port == o.public_port and self.proto == o.proto and ip_parse(self.ip_src) == ip_parse(o.ip_src)
return False return False
@@ -79,10 +76,9 @@ class FiregexTables(NFTableManager):
}}}) }}})
def get(self) -> List[FiregexHijackRule]: def get(self) -> list[FiregexHijackRule]:
res = [] res = []
for filter in self.list_rules(tables=[self.table_name], chains=[self.prerouting_porthijack,self.postrouting_porthijack]): for filter in self.list_rules(tables=[self.table_name], chains=[self.prerouting_porthijack,self.postrouting_porthijack]):
filter["expr"][0]["match"]["right"]
res.append(FiregexHijackRule( res.append(FiregexHijackRule(
target=filter["chain"], target=filter["chain"],
id=int(filter["handle"]), id=int(filter["handle"]),

View File

@@ -145,7 +145,7 @@ class ServiceManager:
class ProxyManager: class ProxyManager:
def __init__(self, db:SQLite): def __init__(self, db:SQLite):
self.db = db self.db = db
self.proxy_table:dict = {} self.proxy_table: dict[str, ServiceManager] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
async def close(self): async def close(self):
@@ -168,7 +168,7 @@ class ProxyManager:
self.proxy_table[srv_id] = ServiceManager(srv_id,self.db) self.proxy_table[srv_id] = ServiceManager(srv_id,self.db)
await self.proxy_table[srv_id].next(req_status) await self.proxy_table[srv_id].next(req_status)
def get(self,id): def get(self,id) -> ServiceManager:
if id in self.proxy_table: if id in self.proxy_table:
return self.proxy_table[id] return self.proxy_table[id]
else: else:

171
backend/routers/firewall.py Normal file
View File

@@ -0,0 +1,171 @@
import sqlite3
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from utils.sqlite import SQLite
from utils import ip_parse, ip_family, refactor_name, refresh_frontend, PortType
from utils.models import ResetRequest, StatusMessageModel
from modules.firewall.nftables import FiregexTables
from modules.firewall.firewall import FirewallManager
class RuleModel(BaseModel):
active: bool
name: str
proto: str
ip_src: str
ip_dst: str
port_src_from: PortType
port_dst_from: PortType
port_src_to: PortType
port_dst_to: PortType
action: str
mode:str
class RuleAddResponse(BaseModel):
status:str|list[dict]
class RenameForm(BaseModel):
name:str
class GeneralStatModel(BaseModel):
rules: int
app = APIRouter()
db = SQLite('db/firewall-rules.db', {
'rules': {
'rule_id': 'INT PRIMARY KEY CHECK (rule_id >= 0)',
'mode': 'VARCHAR(1) NOT NULL CHECK (mode IN ("O", "I"))', # O = out, I = in, B = both
'name': 'VARCHAR(100) NOT NULL',
'active' : 'BOOLEAN NOT NULL CHECK (active IN (0, 1))',
'proto': 'VARCHAR(3) NOT NULL CHECK (proto IN ("tcp", "udp", "any"))',
'ip_src': 'VARCHAR(100) NOT NULL',
'port_src_from': 'INT CHECK(port_src_from > 0 and port_src_from < 65536)',
'port_src_to': 'INT CHECK(port_src_to > 0 and port_src_to < 65536 and port_src_from <= port_src_to)',
'ip_dst': 'VARCHAR(100) NOT NULL',
'port_dst_from': 'INT CHECK(port_dst_from > 0 and port_dst_from < 65536)',
'port_dst_to': 'INT CHECK(port_dst_to > 0 and port_dst_to < 65536 and port_dst_from <= port_dst_to)',
'action': 'VARCHAR(10) NOT NULL CHECK (action IN ("accept", "drop", "reject"))',
},
'QUERY':[
"CREATE UNIQUE INDEX IF NOT EXISTS unique_rules ON rules (proto, ip_src, ip_dst, port_src_from, port_src_to, port_dst_from, port_dst_to, action);"
]
})
async def reset(params: ResetRequest):
if not params.delete:
db.backup()
await firewall.close()
FiregexTables().reset()
if params.delete:
db.delete()
db.init()
else:
db.restore()
await firewall.init()
async def startup():
db.init()
await firewall.init()
async def shutdown():
db.backup()
await firewall.close()
db.disconnect()
db.restore()
async def apply_changes():
await firewall.reload()
await refresh_frontend()
return {'status': 'ok'}
firewall = FirewallManager(db)
@app.get('/stats', response_model=GeneralStatModel)
async def get_general_stats():
"""Get firegex general status about rules"""
return db.query("SELECT (SELECT COUNT(*) FROM rules) rules")[0]
@app.get('/rules', response_model=list[RuleModel])
async def get_rule_list():
"""Get the list of existent firegex rules"""
return db.query("SELECT active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id;")
@app.get('/rule/{rule_id}/disable', response_model=StatusMessageModel)
async def service_disable(rule_id: str):
"""Request disabling a specific rule"""
if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0:
return {'status': 'Rule not found'}
db.query('UPDATE rules SET active = 0 WHERE rule_id = ?;', rule_id)
return await apply_changes()
@app.get('/rule/{rule_id}/enable', response_model=StatusMessageModel)
async def service_start(rule_id: str):
"""Request the enabling a specific rule"""
if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0:
return {'status': 'Rule not found'}
db.query('UPDATE rules SET active = 1 WHERE rule_id = ?;', rule_id)
return await apply_changes()
@app.post('/service/{rule_id}/rename', response_model=StatusMessageModel)
async def service_rename(rule_id: str, form: RenameForm):
"""Request to change the name of a specific service"""
if len(db.query('SELECT 1 FROM rules WHERE rule_id = ?;', rule_id)) == 0:
return {'status': 'Rule not found'}
form.name = refactor_name(form.name)
if not form.name: return {'status': 'The name cannot be empty!'}
try:
db.query('UPDATE rules SET name=? WHERE rule_id = ?;', form.name, rule_id)
except sqlite3.IntegrityError:
return {'status': 'This name is already used'}
await refresh_frontend()
return {'status': 'ok'}
def parse_and_check_rule(rule:RuleModel):
try:
rule.ip_src = ip_parse(rule.ip_src)
rule.ip_dst = ip_parse(rule.ip_dst)
except ValueError:
return {"status":"Invalid address"}
rule.port_dst_from, rule.port_dst_to = min(rule.port_dst_from, rule.port_dst_to), max(rule.port_dst_from, rule.port_dst_to)
rule.port_src_from, rule.port_src_to = min(rule.port_src_from, rule.port_src_to), max(rule.port_src_from, rule.port_src_to)
if ip_family(rule.ip_dst) != ip_family(rule.ip_src):
return {"status":"Destination and source addresses must be of the same family"}
if rule.proto not in ["tcp", "udp", "any"]:
return {"status":"Invalid protocol"}
if rule.action not in ["accept", "drop", "reject"]:
return {"status":"Invalid action"}
return rule
@app.post('/rules/set', response_model=RuleAddResponse)
async def add_new_service(form: list[RuleModel]):
"""Add a new service"""
form = [parse_and_check_rule(ele) for ele in form]
errors = [({"rule":i} | ele) for i, ele in enumerate(form) if isinstance(ele, dict)]
if len(errors) > 0:
return {'status': errors}
try:
db.queries(["DELETE FROM rules"]+
[("""
INSERT INTO rules (
rule_id, active, name,
proto,
ip_src, ip_dst,
port_src_from, port_dst_from,
port_src_to, port_dst_to,
action, mode
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ? ,?, ?)""",
rid, ele.active, ele.name,
ele.proto,
ele.ip_src, ele.ip_dst,
ele.port_src_from, ele.port_dst_from,
ele.port_src_to, ele.port_dst_to,
ele.action, ele.mode
) for rid, ele in enumerate(form)]
)
except sqlite3.IntegrityError:
return {'status': 'Error saving the rules: maybe there are duplicated rules'}
return await apply_changes()

View File

@@ -2,13 +2,12 @@ from base64 import b64decode
import re import re
import secrets import secrets
import sqlite3 import sqlite3
from typing import List, Union
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from modules.nfregex.nftables import FiregexTables from modules.nfregex.nftables import FiregexTables
from modules.nfregex.firewall import STATUS, FirewallManager from modules.nfregex.firewall import STATUS, FirewallManager
from utils.sqlite import SQLite from utils.sqlite import SQLite
from utils import ip_parse, refactor_name, refresh_frontend from utils import ip_parse, refactor_name, refresh_frontend, PortType
from utils.models import ResetRequest, StatusMessageModel from utils.models import ResetRequest, StatusMessageModel
class GeneralStatModel(BaseModel): class GeneralStatModel(BaseModel):
@@ -19,7 +18,7 @@ class GeneralStatModel(BaseModel):
class ServiceModel(BaseModel): class ServiceModel(BaseModel):
status: str status: str
service_id: str service_id: str
port: int port: PortType
name: str name: str
proto: str proto: str
ip_int: str ip_int: str
@@ -43,19 +42,19 @@ class RegexAddForm(BaseModel):
service_id: str service_id: str
regex: str regex: str
mode: str mode: str
active: Union[bool,None] active: bool|None
is_blacklist: bool is_blacklist: bool
is_case_sensitive: bool is_case_sensitive: bool
class ServiceAddForm(BaseModel): class ServiceAddForm(BaseModel):
name: str name: str
port: int port: PortType
proto: str proto: str
ip_int: str ip_int: str
class ServiceAddResponse(BaseModel): class ServiceAddResponse(BaseModel):
status:str status:str
service_id: Union[None,str] service_id: str|None
app = APIRouter() app = APIRouter()
@@ -70,7 +69,7 @@ db = SQLite('db/nft-regex.db', {
}, },
'regexes': { 'regexes': {
'regex': 'TEXT NOT NULL', 'regex': 'TEXT NOT NULL',
'mode': 'VARCHAR(1) NOT NULL', 'mode': 'VARCHAR(1) NOT NULL CHECK (mode IN ("C", "S", "B"))', # C = to the client, S = to the server, B = both
'service_id': 'VARCHAR(100) NOT NULL', 'service_id': 'VARCHAR(100) NOT NULL',
'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))', 'is_blacklist': 'BOOLEAN NOT NULL CHECK (is_blacklist IN (0, 1))',
'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0', 'blocked_packets': 'INTEGER UNSIGNED NOT NULL DEFAULT 0',
@@ -127,7 +126,7 @@ async def get_general_stats():
(SELECT COUNT(*) FROM services) services (SELECT COUNT(*) FROM services) services
""")[0] """)[0]
@app.get('/services', response_model=List[ServiceModel]) @app.get('/services', response_model=list[ServiceModel])
async def get_service_list(): async def get_service_list():
"""Get the list of existent firegex services""" """Get the list of existent firegex services"""
return db.query(""" return db.query("""
@@ -198,7 +197,7 @@ async def service_rename(service_id: str, form: RenameForm):
await refresh_frontend() await refresh_frontend()
return {'status': 'ok'} return {'status': 'ok'}
@app.get('/service/{service_id}/regexes', response_model=List[RegexModel]) @app.get('/service/{service_id}/regexes', response_model=list[RegexModel])
async def get_service_regexe_list(service_id: str): async def get_service_regexe_list(service_id: str):
"""Get the list of the regexes of a service""" """Get the list of the regexes of a service"""
return db.query(""" return db.query("""

View File

@@ -1,11 +1,10 @@
import secrets import secrets
import sqlite3 import sqlite3
from typing import List, Union
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from modules.porthijack.models import Service from modules.porthijack.models import Service
from utils.sqlite import SQLite from utils.sqlite import SQLite
from utils import addr_parse, ip_family, refactor_name, refresh_frontend from utils import addr_parse, ip_family, refactor_name, refresh_frontend, PortType
from utils.models import ResetRequest, StatusMessageModel from utils.models import ResetRequest, StatusMessageModel
from modules.porthijack.nftables import FiregexTables from modules.porthijack.nftables import FiregexTables
from modules.porthijack.firewall import FirewallManager from modules.porthijack.firewall import FirewallManager
@@ -13,8 +12,8 @@ from modules.porthijack.firewall import FirewallManager
class ServiceModel(BaseModel): class ServiceModel(BaseModel):
service_id: str service_id: str
active: bool active: bool
public_port: int public_port: PortType
proxy_port: int proxy_port: PortType
name: str name: str
proto: str proto: str
ip_src: str ip_src: str
@@ -25,15 +24,15 @@ class RenameForm(BaseModel):
class ServiceAddForm(BaseModel): class ServiceAddForm(BaseModel):
name: str name: str
public_port: int public_port: PortType
proxy_port: int proxy_port: PortType
proto: str proto: str
ip_src: str ip_src: str
ip_dst: str ip_dst: str
class ServiceAddResponse(BaseModel): class ServiceAddResponse(BaseModel):
status:str status:str
service_id: Union[None,str] service_id: str|None
class GeneralStatModel(BaseModel): class GeneralStatModel(BaseModel):
services: int services: int
@@ -96,7 +95,7 @@ async def get_general_stats():
(SELECT COUNT(*) FROM services) services (SELECT COUNT(*) FROM services) services
""")[0] """)[0]
@app.get('/services', response_model=List[ServiceModel]) @app.get('/services', response_model=list[ServiceModel])
async def get_service_list(): async def get_service_list():
"""Get the list of existent firegex services""" """Get the list of existent firegex services"""
return db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_src, ip_dst FROM services;") return db.query("SELECT service_id, active, public_port, proxy_port, name, proto, ip_src, ip_dst FROM services;")
@@ -144,7 +143,7 @@ async def service_rename(service_id: str, form: RenameForm):
class ChangeDestination(BaseModel): class ChangeDestination(BaseModel):
ip_dst: str ip_dst: str
proxy_port: int proxy_port: PortType
@app.post('/service/{service_id}/change-destination', response_model=StatusMessageModel) @app.post('/service/{service_id}/change-destination', response_model=StatusMessageModel)
async def service_change_destination(service_id: str, form: ChangeDestination): async def service_change_destination(service_id: str, form: ChangeDestination):

View File

@@ -1,12 +1,11 @@
from base64 import b64decode from base64 import b64decode
import sqlite3, re import sqlite3, re
from typing import List, Union
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from modules.regexproxy.utils import STATUS, ProxyManager, gen_internal_port, gen_service_id from modules.regexproxy.utils import STATUS, ProxyManager, gen_internal_port, gen_service_id
from utils.sqlite import SQLite from utils.sqlite import SQLite
from utils.models import ResetRequest, StatusMessageModel from utils.models import ResetRequest, StatusMessageModel
from utils import refactor_name, refresh_frontend from utils import refactor_name, refresh_frontend, PortType
app = APIRouter() app = APIRouter()
db = SQLite("db/regextcpproxy.db",{ db = SQLite("db/regextcpproxy.db",{
@@ -77,13 +76,13 @@ async def get_general_stats():
class ServiceModel(BaseModel): class ServiceModel(BaseModel):
id:str id:str
status: str status: str
public_port: int public_port: PortType
internal_port: int internal_port: PortType
name: str name: str
n_regex: int n_regex: int
n_packets: int n_packets: int
@app.get('/services', response_model=List[ServiceModel]) @app.get('/services', response_model=list[ServiceModel])
async def get_service_list(): async def get_service_list():
"""Get the list of existent firegex services""" """Get the list of existent firegex services"""
return db.query(""" return db.query("""
@@ -157,8 +156,8 @@ async def regen_service_port(service_id: str):
return {'status': 'ok'} return {'status': 'ok'}
class ChangePortForm(BaseModel): class ChangePortForm(BaseModel):
port: Union[int, None] port: int|None
internalPort: Union[int, None] internalPort: int|None
@app.post('/service/{service_id}/change-ports', response_model=StatusMessageModel) @app.post('/service/{service_id}/change-ports', response_model=StatusMessageModel)
async def change_service_ports(service_id: str, change_port:ChangePortForm): async def change_service_ports(service_id: str, change_port:ChangePortForm):
@@ -167,7 +166,7 @@ async def change_service_ports(service_id: str, change_port:ChangePortForm):
return {'status': 'Invalid Request!'} return {'status': 'Invalid Request!'}
try: try:
sql_inj = "" sql_inj = ""
query:List[Union[str,int]] = [] query:list[str|int] = []
if not change_port.port is None: if not change_port.port is None:
sql_inj+=" public_port = ? " sql_inj+=" public_port = ? "
query.append(change_port.port) query.append(change_port.port)
@@ -194,7 +193,7 @@ class RegexModel(BaseModel):
is_case_sensitive:bool is_case_sensitive:bool
active:bool active:bool
@app.get('/service/{service_id}/regexes', response_model=List[RegexModel]) @app.get('/service/{service_id}/regexes', response_model=list[RegexModel])
async def get_service_regexe_list(service_id: str): async def get_service_regexe_list(service_id: str):
"""Get the list of the regexes of a service""" """Get the list of the regexes of a service"""
return db.query(""" return db.query("""
@@ -250,7 +249,7 @@ class RegexAddForm(BaseModel):
service_id: str service_id: str
regex: str regex: str
mode: str mode: str
active: Union[bool,None] active: bool|None
is_blacklist: bool is_blacklist: bool
is_case_sensitive: bool is_case_sensitive: bool
@@ -272,12 +271,12 @@ async def add_new_regex(form: RegexAddForm):
class ServiceAddForm(BaseModel): class ServiceAddForm(BaseModel):
name: str name: str
port: int port: PortType
internalPort: Union[int, None] internalPort: int|None
class ServiceAddStatus(BaseModel): class ServiceAddStatus(BaseModel):
status:str status:str
id: Union[str,None] id: str|None
class RenameForm(BaseModel): class RenameForm(BaseModel):
name:str name:str

View File

@@ -2,6 +2,8 @@ import asyncio
from ipaddress import ip_address, ip_interface from ipaddress import ip_address, ip_interface
import os, socket, psutil, sys, nftables import os, socket, psutil, sys, nftables
from fastapi_socketio import SocketManager from fastapi_socketio import SocketManager
from fastapi import Path
from typing import Annotated
LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1")) LOCALHOST_IP = socket.gethostbyname(os.getenv("LOCALHOST_IP","127.0.0.1"))
@@ -15,6 +17,8 @@ FIREGEX_PORT = int(os.getenv("PORT","4444"))
JWT_ALGORITHM: str = "HS256" JWT_ALGORITHM: str = "HS256"
API_VERSION = "2.0.0" API_VERSION = "2.0.0"
PortType = Annotated[int, Path(gt=0, lt=65536)]
async def run_func(func, *args, **kwargs): async def run_func(func, *args, **kwargs):
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs) return await func(*args, **kwargs)
@@ -133,4 +137,4 @@ class NFTableManager(Singleton):
def raw_list(self): def raw_list(self):
return self.cmd({"list": {"ruleset": None}})["nftables"] return self.cmd({"list": {"ruleset": None}})["nftables"]

View File

@@ -1,7 +1,6 @@
import os, httpx, websockets import os, httpx
from sys import prefix from typing import Callable
from typing import Callable, List, Union
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import asyncio import asyncio
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
@@ -49,10 +48,10 @@ def list_routers():
return [ele[:-3] for ele in list_files(ROUTERS_DIR) if ele != "__init__.py" and " " not in ele and ele.endswith(".py")] return [ele[:-3] for ele in list_files(ROUTERS_DIR) if ele != "__init__.py" and " " not in ele and ele.endswith(".py")]
class RouterModule(): class RouterModule():
router: Union[None, APIRouter] router: None|APIRouter
reset: Union[None, Callable] reset: None|Callable
startup: Union[None, Callable] startup: None|Callable
shutdown: Union[None, Callable] shutdown: None|Callable
name: str name: str
def __init__(self, router: APIRouter, reset: Callable, startup: Callable, shutdown: Callable, name:str): def __init__(self, router: APIRouter, reset: Callable, startup: Callable, shutdown: Callable, name:str):
@@ -66,7 +65,7 @@ class RouterModule():
return f"RouterModule(router={self.router}, reset={self.reset}, startup={self.startup}, shutdown={self.shutdown})" return f"RouterModule(router={self.router}, reset={self.reset}, startup={self.startup}, shutdown={self.shutdown})"
def get_router_modules(): def get_router_modules():
res: List[RouterModule] = [] res: list[RouterModule] = []
for route in list_routers(): for route in list_routers():
module = getattr(__import__(f"routers.{route}"), route, None) module = getattr(__import__(f"routers.{route}"), route, None)
if module: if module:

View File

@@ -1,4 +1,3 @@
from typing import Union
from pydantic import BaseModel from pydantic import BaseModel
class StatusMessageModel(BaseModel): class StatusMessageModel(BaseModel):
@@ -18,7 +17,7 @@ class PasswordChangeForm(BaseModel):
class ChangePasswordModel(BaseModel): class ChangePasswordModel(BaseModel):
status: str status: str
access_token: Union[str,None] access_token: str|None
class IpInterface(BaseModel): class IpInterface(BaseModel):
addr: str addr: str

View File

@@ -1,11 +1,9 @@
from typing import Union
import json, sqlite3, os import json, sqlite3, os
from hashlib import md5 from hashlib import md5
import base64
class SQLite(): class SQLite():
def __init__(self, db_name: str, schema:dict = None) -> None: def __init__(self, db_name: str, schema:dict = None) -> None:
self.conn: Union[None, sqlite3.Connection] = None self.conn: sqlite3.Connection|None = None
self.cur = None self.cur = None
self.db_name = db_name self.db_name = db_name
self.__backup = None self.__backup = None
@@ -58,10 +56,25 @@ class SQLite():
cur.close() cur.close()
def query(self, query, *values): def query(self, query, *values):
return self.queries([(query, *values)])[0]
def queries(self, queries: list[tuple[str, ...]]):
return list(self.queries_iter(queries))
def queries_iter(self, queries: list[tuple[str, ...]]):
cur = self.conn.cursor() cur = self.conn.cursor()
try: try:
cur.execute(query, values) for query_data in queries:
return cur.fetchall() values = []
str_query = None
if isinstance(query_data, str):
str_query = query_data
elif (isinstance(query_data, tuple) or isinstance(query_data, list)) and len(query_data) > 0 and isinstance(query_data[0], str):
str_query = query_data[0]
values = query_data[1:]
if str_query:
cur.execute(str_query, values)
yield cur.fetchall()
finally: finally:
cur.close() cur.close()
try: self.conn.commit() try: self.conn.commit()