Update main.py

This commit is contained in:
Ilya Starchak
2025-12-12 18:48:26 +03:00
parent 412fd99f05
commit c827c7d35c

View File

@@ -1,6 +1,7 @@
""" """
Scoreboard Injector for ForcAD Scoreboard Injector for ADPlatf/ForcAD
Monitors Socket.IO events for attacks and alerts on critical situations Monitors scoreboard events for attacks and alerts on critical situations
Supports selecting scoreboard platform via configuration.
""" """
import os import os
import asyncio import asyncio
@@ -11,6 +12,10 @@ import socketio
from fastapi import FastAPI, HTTPException, Depends, Header from fastapi import FastAPI, HTTPException, Depends, Header
import asyncpg import asyncpg
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dotenv import load_dotenv, find_dotenv
# Load environment variables from .env if present
load_dotenv(find_dotenv(), override=False)
# Configuration # Configuration
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://adctrl:adctrl@postgres:5432/adctrl") DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://adctrl:adctrl@postgres:5432/adctrl")
@@ -20,6 +25,28 @@ OUR_TEAM_ID = int(os.getenv("OUR_TEAM_ID", "1"))
ALERT_THRESHOLD_POINTS = float(os.getenv("ALERT_THRESHOLD_POINTS", "5")) ALERT_THRESHOLD_POINTS = float(os.getenv("ALERT_THRESHOLD_POINTS", "5"))
TELEGRAM_API_URL = os.getenv("TELEGRAM_API_URL", "http://tg-bot:8003/send") TELEGRAM_API_URL = os.getenv("TELEGRAM_API_URL", "http://tg-bot:8003/send")
# Platform selection: 'adplatf' or 'forcad'
SCOREBOARD_PLATFORM = os.getenv("SCOREBOARD_PLATFORM", "adplatf").lower()
# Platform-specific defaults (overridable via env)
# ForcAD
FORCAD_NAMESPACE = os.getenv("FORCAD_NAMESPACE", "/live_events")
FORCAD_TASKS_PATH = os.getenv("FORCAD_TASKS_PATH", "/api/client/tasks/")
# ADPlatf
ADPLATF_NAMESPACE = os.getenv("ADPLATF_NAMESPACE", "/events")
ADPLATF_TASKS_PATH = os.getenv("ADPLATF_TASKS_PATH", "/api/client/tasks/")
def _tasks_endpoint() -> str:
"""Return full tasks endpoint URL based on platform."""
if SCOREBOARD_PLATFORM == "forcad":
path = FORCAD_TASKS_PATH
else:
path = ADPLATF_TASKS_PATH
if path.startswith("http://") or path.startswith("https://"):
return path
return f"{SCOREBOARD_URL}{path}"
# Database pool # Database pool
db_pool = None db_pool = None
ws_task = None ws_task = None
@@ -63,20 +90,28 @@ async def send_telegram_alert(message: str, service_id: int = None, service_name
print(f"Error sending telegram alert: {e}") print(f"Error sending telegram alert: {e}")
async def fetch_task_names(): async def fetch_task_names():
"""Fetch task names from scoreboard API""" """Fetch task names from scoreboard API (platform-aware)."""
url = _tasks_endpoint()
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(f"{SCOREBOARD_URL}/api/client/tasks/") as resp: async with session.get(url) as resp:
if resp.status == 200: if resp.status == 200:
tasks = await resp.json() tasks = await resp.json()
return {task['id']: task['name'] for task in tasks} # Support both list and dict formats
if isinstance(tasks, list):
return {task.get('id'): task.get('name') for task in tasks}
if isinstance(tasks, dict):
# Some APIs might return {id: name}
return {int(k): v for k, v in tasks.items()}
else:
print(f"Task names fetch failed: {resp.status} at {url}")
return {} return {}
except Exception as e: except Exception as e:
print(f"Error fetching task names: {e}") print(f"Error fetching task names from {url}: {e}")
return {} return {}
async def socketio_listener(): async def forcad_socketio_listener():
"""Listen to ForcAD scoreboard using Socket.IO""" """Listen to ForcAD scoreboard using Socket.IO (namespace /live_events)."""
sio = socketio.AsyncClient(logger=False, engineio_logger=False) sio = socketio.AsyncClient(logger=False, engineio_logger=False)
# Cache for task and team names # Cache for task and team names
@@ -86,7 +121,7 @@ async def socketio_listener():
# Fetch task names on startup # Fetch task names on startup
task_names.update(await fetch_task_names()) task_names.update(await fetch_task_names())
@sio.on('*', namespace='/live_events') @sio.on('*', namespace=FORCAD_NAMESPACE)
async def catch_all(event, data): async def catch_all(event, data):
"""Catch all events from live_events namespace""" """Catch all events from live_events namespace"""
if isinstance(data, list) and len(data) >= 2: if isinstance(data, list) and len(data) >= 2:
@@ -179,7 +214,7 @@ async def socketio_listener():
except Exception as e: except Exception as e:
print(f"Error processing flag_stolen event: {e}") print(f"Error processing flag_stolen event: {e}")
@sio.event(namespace='/live_events') @sio.event(namespace=FORCAD_NAMESPACE)
async def update_scoreboard(data): async def update_scoreboard(data):
"""Handle scoreboard update - compare with previous state to detect NEW attacks""" """Handle scoreboard update - compare with previous state to detect NEW attacks"""
try: try:
@@ -319,18 +354,134 @@ async def socketio_listener():
@sio.event @sio.event
async def connect(): async def connect():
print(f"✅ Connected to scoreboard at {SCOREBOARD_URL}") print(f"✅ Connected to ForcAD scoreboard at {SCOREBOARD_URL} (ns {FORCAD_NAMESPACE})")
@sio.event @sio.event
async def disconnect(): async def disconnect():
print(f"❌ Disconnected from scoreboard") print(f"❌ Disconnected from ForcAD scoreboard")
while True: while True:
try: try:
print(f"Connecting to {SCOREBOARD_URL}...") print(f"Connecting to ForcAD at {SCOREBOARD_URL}...")
await sio.connect( await sio.connect(
SCOREBOARD_URL, SCOREBOARD_URL,
namespaces=['/live_events'], namespaces=[FORCAD_NAMESPACE],
transports=['websocket']
)
await sio.wait()
except Exception as e:
print(f"Connection error: {e}")
await asyncio.sleep(5)
async def adplatf_socketio_listener():
"""Listen to ADPlatf scoreboard using Socket.IO (namespace configurable)."""
sio = socketio.AsyncClient(logger=False, engineio_logger=False)
task_names = {}
team_names = {}
task_names.update(await fetch_task_names())
@sio.on('*', namespace=ADPLATF_NAMESPACE)
async def catch_all(event, data):
"""Catch all events from ADPlatf namespace and normalize."""
# Normalize common payload shapes
payload = None
if isinstance(data, list) and len(data) >= 2:
event_type = data[0]
payload = data[1].get('data', {}) if isinstance(data[1], dict) else {}
elif isinstance(data, dict):
payload = data.get('data', data)
if not isinstance(payload, dict):
return
# Try multiple key patterns to detect flag events
keys = payload.keys()
attacker_id = payload.get('attacker_id') or payload.get('attacker') or payload.get('team_attacker_id')
victim_id = payload.get('victim_id') or payload.get('victim') or payload.get('team_victim_id')
task_id = payload.get('task_id') or payload.get('service_id')
attacker_delta = payload.get('attacker_delta') or payload.get('points') or payload.get('fp_delta') or 0
if attacker_id is not None and victim_id is not None:
await process_flag_event_normalized(attacker_id, victim_id, task_id, attacker_delta, task_names)
async def process_flag_event_normalized(attacker_id, victim_id, task_id, attacker_delta, task_names_local):
try:
service_name = task_names_local.get(task_id, f"task_{task_id}")
timestamp = datetime.utcnow()
is_our_attack = attacker_id == OUR_TEAM_ID
is_attack_to_us = victim_id == OUR_TEAM_ID
print(f"[ADPlatf] Flag event: attacker={attacker_id}, victim={victim_id}, service={service_name}, points={float(attacker_delta):.2f}")
if is_our_attack or is_attack_to_us:
conn = await db_pool.acquire()
try:
attack_id = f"flag_{attacker_id}_{victim_id}_{task_id}_{int(timestamp.timestamp())}"
await conn.execute(
"""
INSERT INTO attacks (attack_id, attacker_team_id, victim_team_id, service_name, timestamp, points, is_our_attack, is_attack_to_us)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (attack_id) DO NOTHING
""",
attack_id, attacker_id, victim_id, service_name, timestamp, float(attacker_delta), is_our_attack, is_attack_to_us,
)
if is_attack_to_us and float(attacker_delta) >= ALERT_THRESHOLD_POINTS:
alert_message = f"🚨 ATTACK DETECTED!\nTeam {attacker_id} stole flag from {service_name}\nPoints lost: {float(attacker_delta):.2f} FP"
# Lookup optional service_id
service_id = None
try:
service_row = await conn.fetchrow(
"SELECT id FROM services WHERE name = $1 LIMIT 1",
service_name,
)
if not service_row:
service_row = await conn.fetchrow(
"SELECT id FROM services WHERE alias = $1 LIMIT 1",
service_name,
)
if service_row:
service_id = service_row['id']
except Exception as e:
print(f" Error looking up service_id: {e}")
alert_id = await conn.fetchval(
"""
INSERT INTO attack_alerts (attack_id, alert_type, severity, message)
VALUES (
(SELECT id FROM attacks WHERE attack_id = $1),
'flag_stolen',
'high',
$2
)
RETURNING id
""",
attack_id,
alert_message,
)
await send_telegram_alert(alert_message, service_id=service_id, service_name=service_name)
await conn.execute("UPDATE attack_alerts SET notified = true WHERE id = $1", alert_id)
finally:
await db_pool.release(conn)
except Exception as e:
print(f"Error processing ADPlatf flag event: {e}")
@sio.event
async def connect():
print(f"✅ Connected to ADPlatf scoreboard at {SCOREBOARD_URL} (ns {ADPLATF_NAMESPACE})")
@sio.event
async def disconnect():
print(f"❌ Disconnected from ADPlatf scoreboard")
while True:
try:
print(f"Connecting to ADPlatf at {SCOREBOARD_URL}...")
await sio.connect(
SCOREBOARD_URL,
namespaces=[ADPLATF_NAMESPACE],
transports=['websocket'] transports=['websocket']
) )
await sio.wait() await sio.wait()
@@ -343,7 +494,11 @@ async def socketio_listener():
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global db_pool, ws_task global db_pool, ws_task
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
ws_task = asyncio.create_task(socketio_listener()) # Start platform-specific listener
if SCOREBOARD_PLATFORM == "forcad":
ws_task = asyncio.create_task(forcad_socketio_listener())
else:
ws_task = asyncio.create_task(adplatf_socketio_listener())
yield yield
@@ -366,6 +521,7 @@ async def health_check():
"timestamp": datetime.utcnow().isoformat(), "timestamp": datetime.utcnow().isoformat(),
"team_id": OUR_TEAM_ID, "team_id": OUR_TEAM_ID,
"mode": "socketio", "mode": "socketio",
"platform": SCOREBOARD_PLATFORM,
"scoreboard_url": SCOREBOARD_URL "scoreboard_url": SCOREBOARD_URL
} }
@@ -478,7 +634,8 @@ async def get_attacks_by_service():
COUNT(*) FILTER (WHERE is_our_attack = true) as our_attacks, COUNT(*) FILTER (WHERE is_our_attack = true) as our_attacks,
COUNT(*) FILTER (WHERE is_attack_to_us = true) as attacks_to_us, COUNT(*) FILTER (WHERE is_attack_to_us = true) as attacks_to_us,
COALESCE(SUM(points) FILTER (WHERE is_our_attack = true), 0) as points_gained, COALESCE(SUM(points) FILTER (WHERE is_our_attack = true), 0) as points_gained,
COALESCE(SUM(points) FILTER (WHERE is_attack_to_us = true), 0) as points_lost "mode": "socketio",
"platform": SCOREBOARD_PLATFORM,
FROM attacks FROM attacks
GROUP BY service_name GROUP BY service_name
ORDER BY total_attacks DESC ORDER BY total_attacks DESC
@@ -489,7 +646,7 @@ async def get_attacks_by_service():
@app.post("/settings/team-id", dependencies=[Depends(verify_token)]) @app.post("/settings/team-id", dependencies=[Depends(verify_token)])
async def set_team_id(team_id: int): async def set_team_id(team_id: int):
"""Update our team ID""" socketio_url = f"{SCOREBOARD_URL}/socket.io/?EIO=4&transport=polling"
global OUR_TEAM_ID global OUR_TEAM_ID
OUR_TEAM_ID = team_id OUR_TEAM_ID = team_id
@@ -521,6 +678,29 @@ async def inject_test_attack(attacker_id: int, victim_id: int, service: str = "t
"time": datetime.utcnow().isoformat(), "time": datetime.utcnow().isoformat(),
"round": 1 "round": 1
} }
tasks_url = _tasks_endpoint()
try:
async with session.get(tasks_url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
result = {
"url": tasks_url,
"status": resp.status,
"reachable": resp.status == 200,
"content_type": resp.headers.get('Content-Type', ''),
}
if resp.status == 200 and 'application/json' in resp.headers.get('Content-Type', ''):
data = await resp.json()
if isinstance(data, list):
result["count"] = len(data)
elif isinstance(data, dict):
result["count"] = len(list(data.keys()))
results["endpoints_tested"].append(result)
except Exception as e:
results["endpoints_tested"].append({
"url": tasks_url,
"reachable": False,
"error": str(e)
})
await process_attack_event(test_event) await process_attack_event(test_event)
return {"status": "injected", "event": test_event} return {"status": "injected", "event": test_event}