Update main.py
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Scoreboard Injector for ForcAD
|
||||
Monitors Socket.IO events for attacks and alerts on critical situations
|
||||
Scoreboard Injector for ADPlatf/ForcAD
|
||||
Monitors scoreboard events for attacks and alerts on critical situations
|
||||
Supports selecting scoreboard platform via configuration.
|
||||
"""
|
||||
import os
|
||||
import asyncio
|
||||
@@ -11,6 +12,10 @@ import socketio
|
||||
from fastapi import FastAPI, HTTPException, Depends, Header
|
||||
import asyncpg
|
||||
from contextlib import asynccontextmanager
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
# Load environment variables from .env if present
|
||||
load_dotenv(find_dotenv(), override=False)
|
||||
|
||||
# Configuration
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://adctrl:adctrl@postgres:5432/adctrl")
|
||||
@@ -20,6 +25,28 @@ OUR_TEAM_ID = int(os.getenv("OUR_TEAM_ID", "1"))
|
||||
ALERT_THRESHOLD_POINTS = float(os.getenv("ALERT_THRESHOLD_POINTS", "5"))
|
||||
TELEGRAM_API_URL = os.getenv("TELEGRAM_API_URL", "http://tg-bot:8003/send")
|
||||
|
||||
# Platform selection: 'adplatf' or 'forcad'
|
||||
SCOREBOARD_PLATFORM = os.getenv("SCOREBOARD_PLATFORM", "adplatf").lower()
|
||||
|
||||
# Platform-specific defaults (overridable via env)
|
||||
# ForcAD
|
||||
FORCAD_NAMESPACE = os.getenv("FORCAD_NAMESPACE", "/live_events")
|
||||
FORCAD_TASKS_PATH = os.getenv("FORCAD_TASKS_PATH", "/api/client/tasks/")
|
||||
|
||||
# ADPlatf
|
||||
ADPLATF_NAMESPACE = os.getenv("ADPLATF_NAMESPACE", "/events")
|
||||
ADPLATF_TASKS_PATH = os.getenv("ADPLATF_TASKS_PATH", "/api/client/tasks/")
|
||||
|
||||
def _tasks_endpoint() -> str:
|
||||
"""Return full tasks endpoint URL based on platform."""
|
||||
if SCOREBOARD_PLATFORM == "forcad":
|
||||
path = FORCAD_TASKS_PATH
|
||||
else:
|
||||
path = ADPLATF_TASKS_PATH
|
||||
if path.startswith("http://") or path.startswith("https://"):
|
||||
return path
|
||||
return f"{SCOREBOARD_URL}{path}"
|
||||
|
||||
# Database pool
|
||||
db_pool = None
|
||||
ws_task = None
|
||||
@@ -63,20 +90,28 @@ async def send_telegram_alert(message: str, service_id: int = None, service_name
|
||||
print(f"Error sending telegram alert: {e}")
|
||||
|
||||
async def fetch_task_names():
|
||||
"""Fetch task names from scoreboard API"""
|
||||
"""Fetch task names from scoreboard API (platform-aware)."""
|
||||
url = _tasks_endpoint()
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{SCOREBOARD_URL}/api/client/tasks/") as resp:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
tasks = await resp.json()
|
||||
return {task['id']: task['name'] for task in tasks}
|
||||
# Support both list and dict formats
|
||||
if isinstance(tasks, list):
|
||||
return {task.get('id'): task.get('name') for task in tasks}
|
||||
if isinstance(tasks, dict):
|
||||
# Some APIs might return {id: name}
|
||||
return {int(k): v for k, v in tasks.items()}
|
||||
else:
|
||||
print(f"Task names fetch failed: {resp.status} at {url}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
print(f"Error fetching task names: {e}")
|
||||
print(f"Error fetching task names from {url}: {e}")
|
||||
return {}
|
||||
|
||||
async def socketio_listener():
|
||||
"""Listen to ForcAD scoreboard using Socket.IO"""
|
||||
async def forcad_socketio_listener():
|
||||
"""Listen to ForcAD scoreboard using Socket.IO (namespace /live_events)."""
|
||||
sio = socketio.AsyncClient(logger=False, engineio_logger=False)
|
||||
|
||||
# Cache for task and team names
|
||||
@@ -86,7 +121,7 @@ async def socketio_listener():
|
||||
# Fetch task names on startup
|
||||
task_names.update(await fetch_task_names())
|
||||
|
||||
@sio.on('*', namespace='/live_events')
|
||||
@sio.on('*', namespace=FORCAD_NAMESPACE)
|
||||
async def catch_all(event, data):
|
||||
"""Catch all events from live_events namespace"""
|
||||
if isinstance(data, list) and len(data) >= 2:
|
||||
@@ -179,7 +214,7 @@ async def socketio_listener():
|
||||
except Exception as e:
|
||||
print(f"Error processing flag_stolen event: {e}")
|
||||
|
||||
@sio.event(namespace='/live_events')
|
||||
@sio.event(namespace=FORCAD_NAMESPACE)
|
||||
async def update_scoreboard(data):
|
||||
"""Handle scoreboard update - compare with previous state to detect NEW attacks"""
|
||||
try:
|
||||
@@ -319,18 +354,134 @@ async def socketio_listener():
|
||||
|
||||
@sio.event
|
||||
async def connect():
|
||||
print(f"✅ Connected to scoreboard at {SCOREBOARD_URL}")
|
||||
print(f"✅ Connected to ForcAD scoreboard at {SCOREBOARD_URL} (ns {FORCAD_NAMESPACE})")
|
||||
|
||||
@sio.event
|
||||
async def disconnect():
|
||||
print(f"❌ Disconnected from scoreboard")
|
||||
print(f"❌ Disconnected from ForcAD scoreboard")
|
||||
|
||||
while True:
|
||||
try:
|
||||
print(f"Connecting to {SCOREBOARD_URL}...")
|
||||
print(f"Connecting to ForcAD at {SCOREBOARD_URL}...")
|
||||
await sio.connect(
|
||||
SCOREBOARD_URL,
|
||||
namespaces=['/live_events'],
|
||||
namespaces=[FORCAD_NAMESPACE],
|
||||
transports=['websocket']
|
||||
)
|
||||
await sio.wait()
|
||||
except Exception as e:
|
||||
print(f"Connection error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def adplatf_socketio_listener():
|
||||
"""Listen to ADPlatf scoreboard using Socket.IO (namespace configurable)."""
|
||||
sio = socketio.AsyncClient(logger=False, engineio_logger=False)
|
||||
|
||||
task_names = {}
|
||||
team_names = {}
|
||||
|
||||
task_names.update(await fetch_task_names())
|
||||
|
||||
@sio.on('*', namespace=ADPLATF_NAMESPACE)
|
||||
async def catch_all(event, data):
|
||||
"""Catch all events from ADPlatf namespace and normalize."""
|
||||
# Normalize common payload shapes
|
||||
payload = None
|
||||
if isinstance(data, list) and len(data) >= 2:
|
||||
event_type = data[0]
|
||||
payload = data[1].get('data', {}) if isinstance(data[1], dict) else {}
|
||||
elif isinstance(data, dict):
|
||||
payload = data.get('data', data)
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
|
||||
# Try multiple key patterns to detect flag events
|
||||
keys = payload.keys()
|
||||
attacker_id = payload.get('attacker_id') or payload.get('attacker') or payload.get('team_attacker_id')
|
||||
victim_id = payload.get('victim_id') or payload.get('victim') or payload.get('team_victim_id')
|
||||
task_id = payload.get('task_id') or payload.get('service_id')
|
||||
attacker_delta = payload.get('attacker_delta') or payload.get('points') or payload.get('fp_delta') or 0
|
||||
|
||||
if attacker_id is not None and victim_id is not None:
|
||||
await process_flag_event_normalized(attacker_id, victim_id, task_id, attacker_delta, task_names)
|
||||
|
||||
async def process_flag_event_normalized(attacker_id, victim_id, task_id, attacker_delta, task_names_local):
|
||||
try:
|
||||
service_name = task_names_local.get(task_id, f"task_{task_id}")
|
||||
timestamp = datetime.utcnow()
|
||||
is_our_attack = attacker_id == OUR_TEAM_ID
|
||||
is_attack_to_us = victim_id == OUR_TEAM_ID
|
||||
|
||||
print(f"[ADPlatf] Flag event: attacker={attacker_id}, victim={victim_id}, service={service_name}, points={float(attacker_delta):.2f}")
|
||||
if is_our_attack or is_attack_to_us:
|
||||
conn = await db_pool.acquire()
|
||||
try:
|
||||
attack_id = f"flag_{attacker_id}_{victim_id}_{task_id}_{int(timestamp.timestamp())}"
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO attacks (attack_id, attacker_team_id, victim_team_id, service_name, timestamp, points, is_our_attack, is_attack_to_us)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (attack_id) DO NOTHING
|
||||
""",
|
||||
attack_id, attacker_id, victim_id, service_name, timestamp, float(attacker_delta), is_our_attack, is_attack_to_us,
|
||||
)
|
||||
|
||||
if is_attack_to_us and float(attacker_delta) >= ALERT_THRESHOLD_POINTS:
|
||||
alert_message = f"🚨 ATTACK DETECTED!\nTeam {attacker_id} stole flag from {service_name}\nPoints lost: {float(attacker_delta):.2f} FP"
|
||||
# Lookup optional service_id
|
||||
service_id = None
|
||||
try:
|
||||
service_row = await conn.fetchrow(
|
||||
"SELECT id FROM services WHERE name = $1 LIMIT 1",
|
||||
service_name,
|
||||
)
|
||||
if not service_row:
|
||||
service_row = await conn.fetchrow(
|
||||
"SELECT id FROM services WHERE alias = $1 LIMIT 1",
|
||||
service_name,
|
||||
)
|
||||
if service_row:
|
||||
service_id = service_row['id']
|
||||
except Exception as e:
|
||||
print(f" Error looking up service_id: {e}")
|
||||
|
||||
alert_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO attack_alerts (attack_id, alert_type, severity, message)
|
||||
VALUES (
|
||||
(SELECT id FROM attacks WHERE attack_id = $1),
|
||||
'flag_stolen',
|
||||
'high',
|
||||
$2
|
||||
)
|
||||
RETURNING id
|
||||
""",
|
||||
attack_id,
|
||||
alert_message,
|
||||
)
|
||||
|
||||
await send_telegram_alert(alert_message, service_id=service_id, service_name=service_name)
|
||||
await conn.execute("UPDATE attack_alerts SET notified = true WHERE id = $1", alert_id)
|
||||
finally:
|
||||
await db_pool.release(conn)
|
||||
except Exception as e:
|
||||
print(f"Error processing ADPlatf flag event: {e}")
|
||||
|
||||
@sio.event
|
||||
async def connect():
|
||||
print(f"✅ Connected to ADPlatf scoreboard at {SCOREBOARD_URL} (ns {ADPLATF_NAMESPACE})")
|
||||
|
||||
@sio.event
|
||||
async def disconnect():
|
||||
print(f"❌ Disconnected from ADPlatf scoreboard")
|
||||
|
||||
while True:
|
||||
try:
|
||||
print(f"Connecting to ADPlatf at {SCOREBOARD_URL}...")
|
||||
await sio.connect(
|
||||
SCOREBOARD_URL,
|
||||
namespaces=[ADPLATF_NAMESPACE],
|
||||
transports=['websocket']
|
||||
)
|
||||
await sio.wait()
|
||||
@@ -343,7 +494,11 @@ async def socketio_listener():
|
||||
async def lifespan(app: FastAPI):
|
||||
global db_pool, ws_task
|
||||
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
ws_task = asyncio.create_task(socketio_listener())
|
||||
# Start platform-specific listener
|
||||
if SCOREBOARD_PLATFORM == "forcad":
|
||||
ws_task = asyncio.create_task(forcad_socketio_listener())
|
||||
else:
|
||||
ws_task = asyncio.create_task(adplatf_socketio_listener())
|
||||
|
||||
yield
|
||||
|
||||
@@ -366,6 +521,7 @@ async def health_check():
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"team_id": OUR_TEAM_ID,
|
||||
"mode": "socketio",
|
||||
"platform": SCOREBOARD_PLATFORM,
|
||||
"scoreboard_url": SCOREBOARD_URL
|
||||
}
|
||||
|
||||
@@ -478,7 +634,8 @@ async def get_attacks_by_service():
|
||||
COUNT(*) FILTER (WHERE is_our_attack = true) as our_attacks,
|
||||
COUNT(*) FILTER (WHERE is_attack_to_us = true) as attacks_to_us,
|
||||
COALESCE(SUM(points) FILTER (WHERE is_our_attack = true), 0) as points_gained,
|
||||
COALESCE(SUM(points) FILTER (WHERE is_attack_to_us = true), 0) as points_lost
|
||||
"mode": "socketio",
|
||||
"platform": SCOREBOARD_PLATFORM,
|
||||
FROM attacks
|
||||
GROUP BY service_name
|
||||
ORDER BY total_attacks DESC
|
||||
@@ -489,7 +646,7 @@ async def get_attacks_by_service():
|
||||
|
||||
@app.post("/settings/team-id", dependencies=[Depends(verify_token)])
|
||||
async def set_team_id(team_id: int):
|
||||
"""Update our team ID"""
|
||||
socketio_url = f"{SCOREBOARD_URL}/socket.io/?EIO=4&transport=polling"
|
||||
global OUR_TEAM_ID
|
||||
OUR_TEAM_ID = team_id
|
||||
|
||||
@@ -521,6 +678,29 @@ async def inject_test_attack(attacker_id: int, victim_id: int, service: str = "t
|
||||
"time": datetime.utcnow().isoformat(),
|
||||
"round": 1
|
||||
}
|
||||
tasks_url = _tasks_endpoint()
|
||||
try:
|
||||
async with session.get(tasks_url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||
result = {
|
||||
"url": tasks_url,
|
||||
"status": resp.status,
|
||||
"reachable": resp.status == 200,
|
||||
"content_type": resp.headers.get('Content-Type', ''),
|
||||
}
|
||||
if resp.status == 200 and 'application/json' in resp.headers.get('Content-Type', ''):
|
||||
data = await resp.json()
|
||||
if isinstance(data, list):
|
||||
result["count"] = len(data)
|
||||
elif isinstance(data, dict):
|
||||
result["count"] = len(list(data.keys()))
|
||||
results["endpoints_tested"].append(result)
|
||||
except Exception as e:
|
||||
results["endpoints_tested"].append({
|
||||
"url": tasks_url,
|
||||
"reachable": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
|
||||
await process_attack_event(test_event)
|
||||
return {"status": "injected", "event": test_event}
|
||||
|
||||
Reference in New Issue
Block a user