diff --git a/scoreboard_injector/main.py b/scoreboard_injector/main.py index 53df4f0..b61f4b5 100644 --- a/scoreboard_injector/main.py +++ b/scoreboard_injector/main.py @@ -1,6 +1,7 @@ """ -Scoreboard Injector for ForcAD -Monitors Socket.IO events for attacks and alerts on critical situations +Scoreboard Injector for ADPlatf/ForcAD +Monitors scoreboard events for attacks and alerts on critical situations +Supports selecting scoreboard platform via configuration. """ import os import asyncio @@ -11,6 +12,10 @@ import socketio from fastapi import FastAPI, HTTPException, Depends, Header import asyncpg from contextlib import asynccontextmanager +from dotenv import load_dotenv, find_dotenv + +# Load environment variables from .env if present +load_dotenv(find_dotenv(), override=False) # Configuration DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://adctrl:adctrl@postgres:5432/adctrl") @@ -20,6 +25,28 @@ OUR_TEAM_ID = int(os.getenv("OUR_TEAM_ID", "1")) ALERT_THRESHOLD_POINTS = float(os.getenv("ALERT_THRESHOLD_POINTS", "5")) TELEGRAM_API_URL = os.getenv("TELEGRAM_API_URL", "http://tg-bot:8003/send") +# Platform selection: 'adplatf' or 'forcad' +SCOREBOARD_PLATFORM = os.getenv("SCOREBOARD_PLATFORM", "adplatf").lower() + +# Platform-specific defaults (overridable via env) +# ForcAD +FORCAD_NAMESPACE = os.getenv("FORCAD_NAMESPACE", "/live_events") +FORCAD_TASKS_PATH = os.getenv("FORCAD_TASKS_PATH", "/api/client/tasks/") + +# ADPlatf +ADPLATF_NAMESPACE = os.getenv("ADPLATF_NAMESPACE", "/events") +ADPLATF_TASKS_PATH = os.getenv("ADPLATF_TASKS_PATH", "/api/client/tasks/") + +def _tasks_endpoint() -> str: + """Return full tasks endpoint URL based on platform.""" + if SCOREBOARD_PLATFORM == "forcad": + path = FORCAD_TASKS_PATH + else: + path = ADPLATF_TASKS_PATH + if path.startswith("http://") or path.startswith("https://"): + return path + return f"{SCOREBOARD_URL}{path}" + # Database pool db_pool = None ws_task = None @@ -63,20 +90,28 @@ async def send_telegram_alert(message: str, service_id: int = None, service_name print(f"Error sending telegram alert: {e}") async def fetch_task_names(): - """Fetch task names from scoreboard API""" + """Fetch task names from scoreboard API (platform-aware).""" + url = _tasks_endpoint() try: async with aiohttp.ClientSession() as session: - async with session.get(f"{SCOREBOARD_URL}/api/client/tasks/") as resp: + async with session.get(url) as resp: if resp.status == 200: tasks = await resp.json() - return {task['id']: task['name'] for task in tasks} + # Support both list and dict formats + if isinstance(tasks, list): + return {task.get('id'): task.get('name') for task in tasks} + if isinstance(tasks, dict): + # Some APIs might return {id: name} + return {int(k): v for k, v in tasks.items()} + else: + print(f"Task names fetch failed: {resp.status} at {url}") return {} except Exception as e: - print(f"Error fetching task names: {e}") + print(f"Error fetching task names from {url}: {e}") return {} -async def socketio_listener(): - """Listen to ForcAD scoreboard using Socket.IO""" +async def forcad_socketio_listener(): + """Listen to ForcAD scoreboard using Socket.IO (namespace /live_events).""" sio = socketio.AsyncClient(logger=False, engineio_logger=False) # Cache for task and team names @@ -86,7 +121,7 @@ async def socketio_listener(): # Fetch task names on startup task_names.update(await fetch_task_names()) - @sio.on('*', namespace='/live_events') + @sio.on('*', namespace=FORCAD_NAMESPACE) async def catch_all(event, data): """Catch all events from live_events namespace""" if isinstance(data, list) and len(data) >= 2: @@ -179,7 +214,7 @@ async def socketio_listener(): except Exception as e: print(f"Error processing flag_stolen event: {e}") - @sio.event(namespace='/live_events') + @sio.event(namespace=FORCAD_NAMESPACE) async def update_scoreboard(data): """Handle scoreboard update - compare with previous state to detect NEW attacks""" try: @@ -319,18 +354,134 @@ async def socketio_listener(): @sio.event async def connect(): - print(f"✅ Connected to scoreboard at {SCOREBOARD_URL}") + print(f"✅ Connected to ForcAD scoreboard at {SCOREBOARD_URL} (ns {FORCAD_NAMESPACE})") @sio.event async def disconnect(): - print(f"❌ Disconnected from scoreboard") + print(f"❌ Disconnected from ForcAD scoreboard") while True: try: - print(f"Connecting to {SCOREBOARD_URL}...") + print(f"Connecting to ForcAD at {SCOREBOARD_URL}...") await sio.connect( SCOREBOARD_URL, - namespaces=['/live_events'], + namespaces=[FORCAD_NAMESPACE], + transports=['websocket'] + ) + await sio.wait() + except Exception as e: + print(f"Connection error: {e}") + await asyncio.sleep(5) + +async def adplatf_socketio_listener(): + """Listen to ADPlatf scoreboard using Socket.IO (namespace configurable).""" + sio = socketio.AsyncClient(logger=False, engineio_logger=False) + + task_names = {} + team_names = {} + + task_names.update(await fetch_task_names()) + + @sio.on('*', namespace=ADPLATF_NAMESPACE) + async def catch_all(event, data): + """Catch all events from ADPlatf namespace and normalize.""" + # Normalize common payload shapes + payload = None + if isinstance(data, list) and len(data) >= 2: + event_type = data[0] + payload = data[1].get('data', {}) if isinstance(data[1], dict) else {} + elif isinstance(data, dict): + payload = data.get('data', data) + + if not isinstance(payload, dict): + return + + # Try multiple key patterns to detect flag events + keys = payload.keys() + attacker_id = payload.get('attacker_id') or payload.get('attacker') or payload.get('team_attacker_id') + victim_id = payload.get('victim_id') or payload.get('victim') or payload.get('team_victim_id') + task_id = payload.get('task_id') or payload.get('service_id') + attacker_delta = payload.get('attacker_delta') or payload.get('points') or payload.get('fp_delta') or 0 + + if attacker_id is not None and victim_id is not None: + await process_flag_event_normalized(attacker_id, victim_id, task_id, attacker_delta, task_names) + + async def process_flag_event_normalized(attacker_id, victim_id, task_id, attacker_delta, task_names_local): + try: + service_name = task_names_local.get(task_id, f"task_{task_id}") + timestamp = datetime.utcnow() + is_our_attack = attacker_id == OUR_TEAM_ID + is_attack_to_us = victim_id == OUR_TEAM_ID + + print(f"[ADPlatf] Flag event: attacker={attacker_id}, victim={victim_id}, service={service_name}, points={float(attacker_delta):.2f}") + if is_our_attack or is_attack_to_us: + conn = await db_pool.acquire() + try: + attack_id = f"flag_{attacker_id}_{victim_id}_{task_id}_{int(timestamp.timestamp())}" + await conn.execute( + """ + INSERT INTO attacks (attack_id, attacker_team_id, victim_team_id, service_name, timestamp, points, is_our_attack, is_attack_to_us) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (attack_id) DO NOTHING + """, + attack_id, attacker_id, victim_id, service_name, timestamp, float(attacker_delta), is_our_attack, is_attack_to_us, + ) + + if is_attack_to_us and float(attacker_delta) >= ALERT_THRESHOLD_POINTS: + alert_message = f"🚨 ATTACK DETECTED!\nTeam {attacker_id} stole flag from {service_name}\nPoints lost: {float(attacker_delta):.2f} FP" + # Lookup optional service_id + service_id = None + try: + service_row = await conn.fetchrow( + "SELECT id FROM services WHERE name = $1 LIMIT 1", + service_name, + ) + if not service_row: + service_row = await conn.fetchrow( + "SELECT id FROM services WHERE alias = $1 LIMIT 1", + service_name, + ) + if service_row: + service_id = service_row['id'] + except Exception as e: + print(f" Error looking up service_id: {e}") + + alert_id = await conn.fetchval( + """ + INSERT INTO attack_alerts (attack_id, alert_type, severity, message) + VALUES ( + (SELECT id FROM attacks WHERE attack_id = $1), + 'flag_stolen', + 'high', + $2 + ) + RETURNING id + """, + attack_id, + alert_message, + ) + + await send_telegram_alert(alert_message, service_id=service_id, service_name=service_name) + await conn.execute("UPDATE attack_alerts SET notified = true WHERE id = $1", alert_id) + finally: + await db_pool.release(conn) + except Exception as e: + print(f"Error processing ADPlatf flag event: {e}") + + @sio.event + async def connect(): + print(f"✅ Connected to ADPlatf scoreboard at {SCOREBOARD_URL} (ns {ADPLATF_NAMESPACE})") + + @sio.event + async def disconnect(): + print(f"❌ Disconnected from ADPlatf scoreboard") + + while True: + try: + print(f"Connecting to ADPlatf at {SCOREBOARD_URL}...") + await sio.connect( + SCOREBOARD_URL, + namespaces=[ADPLATF_NAMESPACE], transports=['websocket'] ) await sio.wait() @@ -343,7 +494,11 @@ async def socketio_listener(): async def lifespan(app: FastAPI): global db_pool, ws_task db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) - ws_task = asyncio.create_task(socketio_listener()) + # Start platform-specific listener + if SCOREBOARD_PLATFORM == "forcad": + ws_task = asyncio.create_task(forcad_socketio_listener()) + else: + ws_task = asyncio.create_task(adplatf_socketio_listener()) yield @@ -366,6 +521,7 @@ async def health_check(): "timestamp": datetime.utcnow().isoformat(), "team_id": OUR_TEAM_ID, "mode": "socketio", + "platform": SCOREBOARD_PLATFORM, "scoreboard_url": SCOREBOARD_URL } @@ -478,7 +634,8 @@ async def get_attacks_by_service(): COUNT(*) FILTER (WHERE is_our_attack = true) as our_attacks, COUNT(*) FILTER (WHERE is_attack_to_us = true) as attacks_to_us, COALESCE(SUM(points) FILTER (WHERE is_our_attack = true), 0) as points_gained, - COALESCE(SUM(points) FILTER (WHERE is_attack_to_us = true), 0) as points_lost + "mode": "socketio", + "platform": SCOREBOARD_PLATFORM, FROM attacks GROUP BY service_name ORDER BY total_attacks DESC @@ -489,7 +646,7 @@ async def get_attacks_by_service(): @app.post("/settings/team-id", dependencies=[Depends(verify_token)]) async def set_team_id(team_id: int): - """Update our team ID""" + socketio_url = f"{SCOREBOARD_URL}/socket.io/?EIO=4&transport=polling" global OUR_TEAM_ID OUR_TEAM_ID = team_id @@ -521,6 +678,29 @@ async def inject_test_attack(attacker_id: int, victim_id: int, service: str = "t "time": datetime.utcnow().isoformat(), "round": 1 } + tasks_url = _tasks_endpoint() + try: + async with session.get(tasks_url, timeout=aiohttp.ClientTimeout(total=5)) as resp: + result = { + "url": tasks_url, + "status": resp.status, + "reachable": resp.status == 200, + "content_type": resp.headers.get('Content-Type', ''), + } + if resp.status == 200 and 'application/json' in resp.headers.get('Content-Type', ''): + data = await resp.json() + if isinstance(data, list): + result["count"] = len(data) + elif isinstance(data, dict): + result["count"] = len(list(data.keys())) + results["endpoints_tested"].append(result) + except Exception as e: + results["endpoints_tested"].append({ + "url": tasks_url, + "reachable": False, + "error": str(e) + }) + await process_attack_event(test_event) return {"status": "injected", "event": test_event}