Add microservices, web UI, and replay tooling
Some checks failed
ci / tests (push) Has been cancelled

This commit is contained in:
dan
2025-12-25 03:28:40 +03:00
commit 46a07f548b
72 changed files with 9142 additions and 0 deletions

185
services/ai/app.py Normal file
View File

@@ -0,0 +1,185 @@
from __future__ import annotations
import random
from typing import Any, Dict, List
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from torch.serialization import add_safe_globals
from catan.data import Resource
from catan.ml.encoding import encode_action, encode_observation
from catan.game import GameConfig
from catan.ml.selfplay import ActionScoringNetwork, PPOConfig
from catan.sdk import Action, ActionType, parse_resource
from services.common.schemas import AIRequest, AIResponse, ActionSchema
from services.common.settings import settings
app = FastAPI(title="Catan AI Service")
add_safe_globals([PPOConfig, GameConfig])
class ModelRegistry:
def __init__(self) -> None:
self._cache: Dict[str, ActionScoringNetwork] = {}
def list_models(self) -> List[str]:
path = settings.models_dir
import os
if not os.path.isdir(path):
return []
return sorted([name for name in os.listdir(path) if name.endswith(".pt")])
def load(self, name: str) -> ActionScoringNetwork:
if name in self._cache:
return self._cache[name]
path = f"{settings.models_dir}/{name}"
state = torch.load(path, map_location="cpu", weights_only=False)
cfg = state.get("config")
if cfg is None:
raise ValueError("Invalid model config")
input_dim = state["actor"]["network.0.weight"].shape[1]
action_dim = encode_action(Action(ActionType.END_TURN, {})).shape[0]
obs_dim = input_dim - action_dim
actor = ActionScoringNetwork(obs_dim, action_dim, cfg.hidden_sizes)
actor.load_state_dict(state["actor"])
actor.eval()
self._cache[name] = actor
return actor
registry = ModelRegistry()
def _deserialize_resources(resources: Dict[str, int]) -> Dict[Resource, int]:
return {parse_resource(k): int(v) for k, v in resources.items()}
def _deserialize_observation(obs: Dict[str, Any]) -> Dict[str, Any]:
game = dict(obs["game"])
players = {}
for name, info in game["players"].items():
data = dict(info)
if isinstance(data.get("resources"), dict) and "hidden" not in data["resources"]:
data["resources"] = _deserialize_resources(data["resources"])
players[name] = data
game["players"] = players
if "bank" in game and isinstance(game["bank"], dict):
game["bank"] = {parse_resource(k): int(v) for k, v in game["bank"].items()}
return {"game": game, "board": obs["board"]}
def _finalize_action(template: ActionSchema, rng: random.Random) -> ActionSchema:
payload = dict(template.payload or {})
action_type = ActionType(template.type)
if action_type in {ActionType.MOVE_ROBBER, ActionType.PLAY_KNIGHT}:
options = payload.get("options", [])
if not options:
return ActionSchema(type=ActionType.END_TURN.value, payload={})
choice = rng.choice(options)
new_payload: Dict[str, Any] = {"hex": choice["hex"]}
victims = choice.get("victims") or []
if victims:
new_payload["victim"] = rng.choice(victims)
return ActionSchema(type=action_type.value, payload=new_payload)
if action_type == ActionType.PLAY_ROAD_BUILDING:
edges = payload.get("edges", [])
if not edges:
return ActionSchema(type=ActionType.END_TURN.value, payload={})
picks = rng.sample(edges, k=min(2, len(edges)))
while len(picks) < 2:
picks.append(rng.choice(edges))
return ActionSchema(type=action_type.value, payload={"edges": picks[:2]})
if action_type == ActionType.PLAY_YEAR_OF_PLENTY:
bank = payload.get("bank", {})
available = [res for res, amount in bank.items() if amount > 0]
if not available:
available = list(bank.keys())
if not available:
return ActionSchema(type=ActionType.END_TURN.value, payload={})
pick = rng.choice(available)
return ActionSchema(type=action_type.value, payload={"resources": [pick, pick]})
if action_type == ActionType.PLAY_MONOPOLY:
choices = payload.get("resources") or [res.value for res in Resource if res != Resource.DESERT]
if not choices:
return ActionSchema(type=ActionType.END_TURN.value, payload={})
return ActionSchema(type=action_type.value, payload={"resource": rng.choice(choices)})
if action_type == ActionType.DISCARD:
required = payload.get("required")
resources = payload.get("resources") or {}
if not isinstance(required, int):
return ActionSchema(type=ActionType.END_TURN.value, payload={})
pool = []
for res, count in resources.items():
if res == "desert" or count <= 0:
continue
pool.extend([res] * int(count))
rng.shuffle(pool)
cards: Dict[str, int] = {}
for res in pool[:required]:
cards[res] = cards.get(res, 0) + 1
return ActionSchema(type=action_type.value, payload={"player": payload.get("player"), "cards": cards})
return ActionSchema(type=action_type.value, payload=payload)
def _choose_action(obs: Dict[str, Any], legal_actions: List[ActionSchema], agent: Dict[str, Any], debug: bool) -> AIResponse:
rng = random.Random()
kind = agent.get("kind", "random")
if not legal_actions:
return AIResponse(action=ActionSchema(type=ActionType.END_TURN.value, payload={}))
if kind == "random":
template = rng.choice(legal_actions)
return AIResponse(action=_finalize_action(template, rng))
if kind == "model":
model_name = agent.get("model")
if not model_name:
raise HTTPException(status_code=400, detail="Model name required")
actor = registry.load(model_name)
obs_vec = encode_observation(_deserialize_observation(obs))
actions_vec = np.stack([encode_action(Action(ActionType(a.type), a.payload)) for a in legal_actions])
obs_tensor = torch.tensor(obs_vec, dtype=torch.float32)
action_tensor = torch.tensor(actions_vec, dtype=torch.float32)
logits = actor(obs_tensor, action_tensor)
probs = torch.softmax(logits, dim=0)
if agent.get("stochastic", True):
dist = torch.distributions.Categorical(probs=probs)
idx = dist.sample().item()
else:
idx = torch.argmax(probs).item()
selected = legal_actions[idx]
finalized = _finalize_action(selected, rng)
debug_payload = {}
if debug:
debug_payload = {
"logits": logits.detach().cpu().tolist(),
"probs": probs.detach().cpu().tolist(),
"index": idx,
"model": model_name,
}
return AIResponse(action=finalized, debug=debug_payload)
raise HTTPException(status_code=400, detail="Unknown agent kind")
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.get("/models")
def list_models() -> Dict[str, List[str]]:
return {"models": registry.list_models()}
@app.post("/act")
def act(payload: AIRequest) -> Dict[str, Any]:
response = _choose_action(payload.observation, payload.legal_actions, payload.agent, payload.debug)
return response.model_dump()