drop stream on udp (due to missing method to keep stream) + ack on reload config
This commit is contained in:
@@ -33,9 +33,10 @@ void config_updater (){
|
|||||||
try{
|
try{
|
||||||
regex_config.reset(new RegexRules(raw_rules, regex_config->stream_mode()));
|
regex_config.reset(new RegexRules(raw_rules, regex_config->stream_mode()));
|
||||||
cerr << "[info] [updater] Config update done to ver "<< regex_config->ver() << endl;
|
cerr << "[info] [updater] Config update done to ver "<< regex_config->ver() << endl;
|
||||||
}catch(...){
|
cout << "ACK OK" << endl;
|
||||||
|
}catch(const std::exception& e){
|
||||||
cerr << "[error] [updater] Failed to build new configuration!" << endl;
|
cerr << "[error] [updater] Failed to build new configuration!" << endl;
|
||||||
// TODO send a row on stdout for this error
|
cout << "ACK FAIL " << e.what() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +111,9 @@ bool filter_callback(packet_info& info){
|
|||||||
cerr << "[error] [filter_callback] Error opening the stream matcher (hs)" << endl;
|
cerr << "[error] [filter_callback] Error opening the stream matcher (hs)" << endl;
|
||||||
throw invalid_argument("Cannot open stream match on hyperscan");
|
throw invalid_argument("Cannot open stream match on hyperscan");
|
||||||
}
|
}
|
||||||
|
if (info.is_tcp){
|
||||||
match_map->insert_or_assign(info.sid, stream_match);
|
match_map->insert_or_assign(info.sid, stream_match);
|
||||||
|
}
|
||||||
}else{
|
}else{
|
||||||
stream_match = stream_search->second;
|
stream_match = stream_search->second;
|
||||||
}
|
}
|
||||||
@@ -130,6 +133,13 @@ bool filter_callback(packet_info& info){
|
|||||||
0, scratch_space, match_func, &match_res
|
0, scratch_space, match_func, &match_res
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
if (
|
||||||
|
!info.is_tcp && conf->stream_mode() &&
|
||||||
|
hs_close_stream(stream_match, scratch_space, nullptr, nullptr) != HS_SUCCESS
|
||||||
|
){
|
||||||
|
cerr << "[error] [filter_callback] Error closing the stream matcher (hs)" << endl;
|
||||||
|
throw invalid_argument("Cannot close stream match on hyperscan");
|
||||||
|
}
|
||||||
if (err != HS_SUCCESS && err != HS_SCAN_TERMINATED) {
|
if (err != HS_SUCCESS && err != HS_SCAN_TERMINATED) {
|
||||||
cerr << "[error] [filter_callback] Error while matching the stream (hs)" << endl;
|
cerr << "[error] [filter_callback] Error while matching the stream (hs)" << endl;
|
||||||
throw invalid_argument("Error while matching the stream with hyperscan");
|
throw invalid_argument("Error while matching the stream with hyperscan");
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from utils import DEBUG
|
from utils import DEBUG
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
nft = FiregexTables()
|
nft = FiregexTables()
|
||||||
|
|
||||||
@@ -64,6 +65,10 @@ class FiregexInterceptor:
|
|||||||
self.update_config_lock:asyncio.Lock
|
self.update_config_lock:asyncio.Lock
|
||||||
self.process:asyncio.subprocess.Process
|
self.process:asyncio.subprocess.Process
|
||||||
self.update_task: asyncio.Task
|
self.update_task: asyncio.Task
|
||||||
|
self.ack_arrived = False
|
||||||
|
self.ack_status = None
|
||||||
|
self.ack_fail_what = ""
|
||||||
|
self.ack_lock = asyncio.Lock()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def start(cls, srv: Service):
|
async def start(cls, srv: Service):
|
||||||
@@ -74,6 +79,8 @@ class FiregexInterceptor:
|
|||||||
queue_range = await self._start_binary()
|
queue_range = await self._start_binary()
|
||||||
self.update_task = asyncio.create_task(self.update_blocked())
|
self.update_task = asyncio.create_task(self.update_blocked())
|
||||||
nft.add(self.srv, queue_range)
|
nft.add(self.srv, queue_range)
|
||||||
|
if not self.ack_lock.locked():
|
||||||
|
await self.ack_lock.acquire()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _start_binary(self):
|
async def _start_binary(self):
|
||||||
@@ -109,6 +116,12 @@ class FiregexInterceptor:
|
|||||||
if regex_id in self.filter_map:
|
if regex_id in self.filter_map:
|
||||||
self.filter_map[regex_id].blocked+=1
|
self.filter_map[regex_id].blocked+=1
|
||||||
await self.filter_map[regex_id].update()
|
await self.filter_map[regex_id].update()
|
||||||
|
if line.startswith("ACK "):
|
||||||
|
self.ack_arrived = True
|
||||||
|
self.ack_status = line.split()[1].upper() == "OK"
|
||||||
|
if not self.ack_status:
|
||||||
|
self.ack_fail_what = " ".join(line.split()[2:])
|
||||||
|
self.ack_lock.release()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except asyncio.IncompleteReadError:
|
except asyncio.IncompleteReadError:
|
||||||
@@ -125,6 +138,14 @@ class FiregexInterceptor:
|
|||||||
async with self.update_config_lock:
|
async with self.update_config_lock:
|
||||||
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
|
self.process.stdin.write((" ".join(filters_codes)+"\n").encode())
|
||||||
await self.process.stdin.drain()
|
await self.process.stdin.drain()
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(3):
|
||||||
|
await self.ack_lock.acquire()
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
if not self.ack_arrived or not self.ack_status:
|
||||||
|
raise HTTPException(status_code=500, detail=f"NFQ error: {self.ack_fail_what}")
|
||||||
|
|
||||||
|
|
||||||
async def reload(self, filters:list[RegexFilter]):
|
async def reload(self, filters:list[RegexFilter]):
|
||||||
async with self.filter_map_lock:
|
async with self.filter_map_lock:
|
||||||
|
|||||||
@@ -89,12 +89,18 @@ async def reset(params: ResetRequest):
|
|||||||
db.init()
|
db.init()
|
||||||
else:
|
else:
|
||||||
db.restore()
|
db.restore()
|
||||||
|
try:
|
||||||
await firewall.init()
|
await firewall.init()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
async def startup():
|
async def startup():
|
||||||
db.init()
|
db.init()
|
||||||
|
try:
|
||||||
await firewall.init()
|
await firewall.init()
|
||||||
|
except Exception as e:
|
||||||
|
print("WARNING cannot start firewall:", e)
|
||||||
|
|
||||||
async def shutdown():
|
async def shutdown():
|
||||||
db.backup()
|
db.backup()
|
||||||
|
|||||||
Reference in New Issue
Block a user