Update main.py
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Scoreboard Injector for ForcAD
|
Scoreboard Injector for ADPlatf/ForcAD
|
||||||
Monitors Socket.IO events for attacks and alerts on critical situations
|
Monitors scoreboard events for attacks and alerts on critical situations
|
||||||
|
Supports selecting scoreboard platform via configuration.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -11,6 +12,10 @@ import socketio
|
|||||||
from fastapi import FastAPI, HTTPException, Depends, Header
|
from fastapi import FastAPI, HTTPException, Depends, Header
|
||||||
import asyncpg
|
import asyncpg
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from dotenv import load_dotenv, find_dotenv
|
||||||
|
|
||||||
|
# Load environment variables from .env if present
|
||||||
|
load_dotenv(find_dotenv(), override=False)
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://adctrl:adctrl@postgres:5432/adctrl")
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://adctrl:adctrl@postgres:5432/adctrl")
|
||||||
@@ -20,6 +25,28 @@ OUR_TEAM_ID = int(os.getenv("OUR_TEAM_ID", "1"))
|
|||||||
ALERT_THRESHOLD_POINTS = float(os.getenv("ALERT_THRESHOLD_POINTS", "5"))
|
ALERT_THRESHOLD_POINTS = float(os.getenv("ALERT_THRESHOLD_POINTS", "5"))
|
||||||
TELEGRAM_API_URL = os.getenv("TELEGRAM_API_URL", "http://tg-bot:8003/send")
|
TELEGRAM_API_URL = os.getenv("TELEGRAM_API_URL", "http://tg-bot:8003/send")
|
||||||
|
|
||||||
|
# Platform selection: 'adplatf' or 'forcad'
|
||||||
|
SCOREBOARD_PLATFORM = os.getenv("SCOREBOARD_PLATFORM", "adplatf").lower()
|
||||||
|
|
||||||
|
# Platform-specific defaults (overridable via env)
|
||||||
|
# ForcAD
|
||||||
|
FORCAD_NAMESPACE = os.getenv("FORCAD_NAMESPACE", "/live_events")
|
||||||
|
FORCAD_TASKS_PATH = os.getenv("FORCAD_TASKS_PATH", "/api/client/tasks/")
|
||||||
|
|
||||||
|
# ADPlatf
|
||||||
|
ADPLATF_NAMESPACE = os.getenv("ADPLATF_NAMESPACE", "/events")
|
||||||
|
ADPLATF_TASKS_PATH = os.getenv("ADPLATF_TASKS_PATH", "/api/client/tasks/")
|
||||||
|
|
||||||
|
def _tasks_endpoint() -> str:
|
||||||
|
"""Return full tasks endpoint URL based on platform."""
|
||||||
|
if SCOREBOARD_PLATFORM == "forcad":
|
||||||
|
path = FORCAD_TASKS_PATH
|
||||||
|
else:
|
||||||
|
path = ADPLATF_TASKS_PATH
|
||||||
|
if path.startswith("http://") or path.startswith("https://"):
|
||||||
|
return path
|
||||||
|
return f"{SCOREBOARD_URL}{path}"
|
||||||
|
|
||||||
# Database pool
|
# Database pool
|
||||||
db_pool = None
|
db_pool = None
|
||||||
ws_task = None
|
ws_task = None
|
||||||
@@ -63,20 +90,28 @@ async def send_telegram_alert(message: str, service_id: int = None, service_name
|
|||||||
print(f"Error sending telegram alert: {e}")
|
print(f"Error sending telegram alert: {e}")
|
||||||
|
|
||||||
async def fetch_task_names():
|
async def fetch_task_names():
|
||||||
"""Fetch task names from scoreboard API"""
|
"""Fetch task names from scoreboard API (platform-aware)."""
|
||||||
|
url = _tasks_endpoint()
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(f"{SCOREBOARD_URL}/api/client/tasks/") as resp:
|
async with session.get(url) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
tasks = await resp.json()
|
tasks = await resp.json()
|
||||||
return {task['id']: task['name'] for task in tasks}
|
# Support both list and dict formats
|
||||||
|
if isinstance(tasks, list):
|
||||||
|
return {task.get('id'): task.get('name') for task in tasks}
|
||||||
|
if isinstance(tasks, dict):
|
||||||
|
# Some APIs might return {id: name}
|
||||||
|
return {int(k): v for k, v in tasks.items()}
|
||||||
|
else:
|
||||||
|
print(f"Task names fetch failed: {resp.status} at {url}")
|
||||||
return {}
|
return {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error fetching task names: {e}")
|
print(f"Error fetching task names from {url}: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def socketio_listener():
|
async def forcad_socketio_listener():
|
||||||
"""Listen to ForcAD scoreboard using Socket.IO"""
|
"""Listen to ForcAD scoreboard using Socket.IO (namespace /live_events)."""
|
||||||
sio = socketio.AsyncClient(logger=False, engineio_logger=False)
|
sio = socketio.AsyncClient(logger=False, engineio_logger=False)
|
||||||
|
|
||||||
# Cache for task and team names
|
# Cache for task and team names
|
||||||
@@ -86,7 +121,7 @@ async def socketio_listener():
|
|||||||
# Fetch task names on startup
|
# Fetch task names on startup
|
||||||
task_names.update(await fetch_task_names())
|
task_names.update(await fetch_task_names())
|
||||||
|
|
||||||
@sio.on('*', namespace='/live_events')
|
@sio.on('*', namespace=FORCAD_NAMESPACE)
|
||||||
async def catch_all(event, data):
|
async def catch_all(event, data):
|
||||||
"""Catch all events from live_events namespace"""
|
"""Catch all events from live_events namespace"""
|
||||||
if isinstance(data, list) and len(data) >= 2:
|
if isinstance(data, list) and len(data) >= 2:
|
||||||
@@ -179,7 +214,7 @@ async def socketio_listener():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing flag_stolen event: {e}")
|
print(f"Error processing flag_stolen event: {e}")
|
||||||
|
|
||||||
@sio.event(namespace='/live_events')
|
@sio.event(namespace=FORCAD_NAMESPACE)
|
||||||
async def update_scoreboard(data):
|
async def update_scoreboard(data):
|
||||||
"""Handle scoreboard update - compare with previous state to detect NEW attacks"""
|
"""Handle scoreboard update - compare with previous state to detect NEW attacks"""
|
||||||
try:
|
try:
|
||||||
@@ -319,18 +354,134 @@ async def socketio_listener():
|
|||||||
|
|
||||||
@sio.event
|
@sio.event
|
||||||
async def connect():
|
async def connect():
|
||||||
print(f"✅ Connected to scoreboard at {SCOREBOARD_URL}")
|
print(f"✅ Connected to ForcAD scoreboard at {SCOREBOARD_URL} (ns {FORCAD_NAMESPACE})")
|
||||||
|
|
||||||
@sio.event
|
@sio.event
|
||||||
async def disconnect():
|
async def disconnect():
|
||||||
print(f"❌ Disconnected from scoreboard")
|
print(f"❌ Disconnected from ForcAD scoreboard")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
print(f"Connecting to {SCOREBOARD_URL}...")
|
print(f"Connecting to ForcAD at {SCOREBOARD_URL}...")
|
||||||
await sio.connect(
|
await sio.connect(
|
||||||
SCOREBOARD_URL,
|
SCOREBOARD_URL,
|
||||||
namespaces=['/live_events'],
|
namespaces=[FORCAD_NAMESPACE],
|
||||||
|
transports=['websocket']
|
||||||
|
)
|
||||||
|
await sio.wait()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Connection error: {e}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
async def adplatf_socketio_listener():
|
||||||
|
"""Listen to ADPlatf scoreboard using Socket.IO (namespace configurable)."""
|
||||||
|
sio = socketio.AsyncClient(logger=False, engineio_logger=False)
|
||||||
|
|
||||||
|
task_names = {}
|
||||||
|
team_names = {}
|
||||||
|
|
||||||
|
task_names.update(await fetch_task_names())
|
||||||
|
|
||||||
|
@sio.on('*', namespace=ADPLATF_NAMESPACE)
|
||||||
|
async def catch_all(event, data):
|
||||||
|
"""Catch all events from ADPlatf namespace and normalize."""
|
||||||
|
# Normalize common payload shapes
|
||||||
|
payload = None
|
||||||
|
if isinstance(data, list) and len(data) >= 2:
|
||||||
|
event_type = data[0]
|
||||||
|
payload = data[1].get('data', {}) if isinstance(data[1], dict) else {}
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
payload = data.get('data', data)
|
||||||
|
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try multiple key patterns to detect flag events
|
||||||
|
keys = payload.keys()
|
||||||
|
attacker_id = payload.get('attacker_id') or payload.get('attacker') or payload.get('team_attacker_id')
|
||||||
|
victim_id = payload.get('victim_id') or payload.get('victim') or payload.get('team_victim_id')
|
||||||
|
task_id = payload.get('task_id') or payload.get('service_id')
|
||||||
|
attacker_delta = payload.get('attacker_delta') or payload.get('points') or payload.get('fp_delta') or 0
|
||||||
|
|
||||||
|
if attacker_id is not None and victim_id is not None:
|
||||||
|
await process_flag_event_normalized(attacker_id, victim_id, task_id, attacker_delta, task_names)
|
||||||
|
|
||||||
|
async def process_flag_event_normalized(attacker_id, victim_id, task_id, attacker_delta, task_names_local):
|
||||||
|
try:
|
||||||
|
service_name = task_names_local.get(task_id, f"task_{task_id}")
|
||||||
|
timestamp = datetime.utcnow()
|
||||||
|
is_our_attack = attacker_id == OUR_TEAM_ID
|
||||||
|
is_attack_to_us = victim_id == OUR_TEAM_ID
|
||||||
|
|
||||||
|
print(f"[ADPlatf] Flag event: attacker={attacker_id}, victim={victim_id}, service={service_name}, points={float(attacker_delta):.2f}")
|
||||||
|
if is_our_attack or is_attack_to_us:
|
||||||
|
conn = await db_pool.acquire()
|
||||||
|
try:
|
||||||
|
attack_id = f"flag_{attacker_id}_{victim_id}_{task_id}_{int(timestamp.timestamp())}"
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO attacks (attack_id, attacker_team_id, victim_team_id, service_name, timestamp, points, is_our_attack, is_attack_to_us)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
ON CONFLICT (attack_id) DO NOTHING
|
||||||
|
""",
|
||||||
|
attack_id, attacker_id, victim_id, service_name, timestamp, float(attacker_delta), is_our_attack, is_attack_to_us,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_attack_to_us and float(attacker_delta) >= ALERT_THRESHOLD_POINTS:
|
||||||
|
alert_message = f"🚨 ATTACK DETECTED!\nTeam {attacker_id} stole flag from {service_name}\nPoints lost: {float(attacker_delta):.2f} FP"
|
||||||
|
# Lookup optional service_id
|
||||||
|
service_id = None
|
||||||
|
try:
|
||||||
|
service_row = await conn.fetchrow(
|
||||||
|
"SELECT id FROM services WHERE name = $1 LIMIT 1",
|
||||||
|
service_name,
|
||||||
|
)
|
||||||
|
if not service_row:
|
||||||
|
service_row = await conn.fetchrow(
|
||||||
|
"SELECT id FROM services WHERE alias = $1 LIMIT 1",
|
||||||
|
service_name,
|
||||||
|
)
|
||||||
|
if service_row:
|
||||||
|
service_id = service_row['id']
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error looking up service_id: {e}")
|
||||||
|
|
||||||
|
alert_id = await conn.fetchval(
|
||||||
|
"""
|
||||||
|
INSERT INTO attack_alerts (attack_id, alert_type, severity, message)
|
||||||
|
VALUES (
|
||||||
|
(SELECT id FROM attacks WHERE attack_id = $1),
|
||||||
|
'flag_stolen',
|
||||||
|
'high',
|
||||||
|
$2
|
||||||
|
)
|
||||||
|
RETURNING id
|
||||||
|
""",
|
||||||
|
attack_id,
|
||||||
|
alert_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
await send_telegram_alert(alert_message, service_id=service_id, service_name=service_name)
|
||||||
|
await conn.execute("UPDATE attack_alerts SET notified = true WHERE id = $1", alert_id)
|
||||||
|
finally:
|
||||||
|
await db_pool.release(conn)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing ADPlatf flag event: {e}")
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def connect():
|
||||||
|
print(f"✅ Connected to ADPlatf scoreboard at {SCOREBOARD_URL} (ns {ADPLATF_NAMESPACE})")
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def disconnect():
|
||||||
|
print(f"❌ Disconnected from ADPlatf scoreboard")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
print(f"Connecting to ADPlatf at {SCOREBOARD_URL}...")
|
||||||
|
await sio.connect(
|
||||||
|
SCOREBOARD_URL,
|
||||||
|
namespaces=[ADPLATF_NAMESPACE],
|
||||||
transports=['websocket']
|
transports=['websocket']
|
||||||
)
|
)
|
||||||
await sio.wait()
|
await sio.wait()
|
||||||
@@ -343,7 +494,11 @@ async def socketio_listener():
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global db_pool, ws_task
|
global db_pool, ws_task
|
||||||
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||||
ws_task = asyncio.create_task(socketio_listener())
|
# Start platform-specific listener
|
||||||
|
if SCOREBOARD_PLATFORM == "forcad":
|
||||||
|
ws_task = asyncio.create_task(forcad_socketio_listener())
|
||||||
|
else:
|
||||||
|
ws_task = asyncio.create_task(adplatf_socketio_listener())
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -366,6 +521,7 @@ async def health_check():
|
|||||||
"timestamp": datetime.utcnow().isoformat(),
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
"team_id": OUR_TEAM_ID,
|
"team_id": OUR_TEAM_ID,
|
||||||
"mode": "socketio",
|
"mode": "socketio",
|
||||||
|
"platform": SCOREBOARD_PLATFORM,
|
||||||
"scoreboard_url": SCOREBOARD_URL
|
"scoreboard_url": SCOREBOARD_URL
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -478,7 +634,8 @@ async def get_attacks_by_service():
|
|||||||
COUNT(*) FILTER (WHERE is_our_attack = true) as our_attacks,
|
COUNT(*) FILTER (WHERE is_our_attack = true) as our_attacks,
|
||||||
COUNT(*) FILTER (WHERE is_attack_to_us = true) as attacks_to_us,
|
COUNT(*) FILTER (WHERE is_attack_to_us = true) as attacks_to_us,
|
||||||
COALESCE(SUM(points) FILTER (WHERE is_our_attack = true), 0) as points_gained,
|
COALESCE(SUM(points) FILTER (WHERE is_our_attack = true), 0) as points_gained,
|
||||||
COALESCE(SUM(points) FILTER (WHERE is_attack_to_us = true), 0) as points_lost
|
"mode": "socketio",
|
||||||
|
"platform": SCOREBOARD_PLATFORM,
|
||||||
FROM attacks
|
FROM attacks
|
||||||
GROUP BY service_name
|
GROUP BY service_name
|
||||||
ORDER BY total_attacks DESC
|
ORDER BY total_attacks DESC
|
||||||
@@ -489,7 +646,7 @@ async def get_attacks_by_service():
|
|||||||
|
|
||||||
@app.post("/settings/team-id", dependencies=[Depends(verify_token)])
|
@app.post("/settings/team-id", dependencies=[Depends(verify_token)])
|
||||||
async def set_team_id(team_id: int):
|
async def set_team_id(team_id: int):
|
||||||
"""Update our team ID"""
|
socketio_url = f"{SCOREBOARD_URL}/socket.io/?EIO=4&transport=polling"
|
||||||
global OUR_TEAM_ID
|
global OUR_TEAM_ID
|
||||||
OUR_TEAM_ID = team_id
|
OUR_TEAM_ID = team_id
|
||||||
|
|
||||||
@@ -521,6 +678,29 @@ async def inject_test_attack(attacker_id: int, victim_id: int, service: str = "t
|
|||||||
"time": datetime.utcnow().isoformat(),
|
"time": datetime.utcnow().isoformat(),
|
||||||
"round": 1
|
"round": 1
|
||||||
}
|
}
|
||||||
|
tasks_url = _tasks_endpoint()
|
||||||
|
try:
|
||||||
|
async with session.get(tasks_url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||||
|
result = {
|
||||||
|
"url": tasks_url,
|
||||||
|
"status": resp.status,
|
||||||
|
"reachable": resp.status == 200,
|
||||||
|
"content_type": resp.headers.get('Content-Type', ''),
|
||||||
|
}
|
||||||
|
if resp.status == 200 and 'application/json' in resp.headers.get('Content-Type', ''):
|
||||||
|
data = await resp.json()
|
||||||
|
if isinstance(data, list):
|
||||||
|
result["count"] = len(data)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
result["count"] = len(list(data.keys()))
|
||||||
|
results["endpoints_tested"].append(result)
|
||||||
|
except Exception as e:
|
||||||
|
results["endpoints_tested"].append({
|
||||||
|
"url": tasks_url,
|
||||||
|
"reachable": False,
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
await process_attack_event(test_event)
|
await process_attack_event(test_event)
|
||||||
return {"status": "injected", "event": test_event}
|
return {"status": "injected", "event": test_event}
|
||||||
|
|||||||
Reference in New Issue
Block a user