from __future__ import annotations import random from typing import Any, Dict, List import numpy as np import torch from fastapi import FastAPI, HTTPException from torch.serialization import add_safe_globals from catan.data import Resource from catan.ml.encoding import encode_action, encode_observation from catan.game import GameConfig from catan.ml.selfplay import ActionScoringNetwork, PPOConfig from catan.sdk import Action, ActionType, parse_resource from services.common.schemas import AIRequest, AIResponse, ActionSchema from services.common.settings import settings app = FastAPI(title="Catan AI Service") add_safe_globals([PPOConfig, GameConfig]) class ModelRegistry: def __init__(self) -> None: self._cache: Dict[str, ActionScoringNetwork] = {} def list_models(self) -> List[str]: path = settings.models_dir import os if not os.path.isdir(path): return [] return sorted([name for name in os.listdir(path) if name.endswith(".pt")]) def load(self, name: str) -> ActionScoringNetwork: if name in self._cache: return self._cache[name] path = f"{settings.models_dir}/{name}" state = torch.load(path, map_location="cpu", weights_only=False) cfg = state.get("config") if cfg is None: raise ValueError("Invalid model config") input_dim = state["actor"]["network.0.weight"].shape[1] action_dim = encode_action(Action(ActionType.END_TURN, {})).shape[0] obs_dim = input_dim - action_dim actor = ActionScoringNetwork(obs_dim, action_dim, cfg.hidden_sizes) actor.load_state_dict(state["actor"]) actor.eval() self._cache[name] = actor return actor registry = ModelRegistry() def _deserialize_resources(resources: Dict[str, int]) -> Dict[Resource, int]: return {parse_resource(k): int(v) for k, v in resources.items()} def _deserialize_observation(obs: Dict[str, Any]) -> Dict[str, Any]: game = dict(obs["game"]) players = {} for name, info in game["players"].items(): data = dict(info) if isinstance(data.get("resources"), dict) and "hidden" not in data["resources"]: data["resources"] = _deserialize_resources(data["resources"]) players[name] = data game["players"] = players if "bank" in game and isinstance(game["bank"], dict): game["bank"] = {parse_resource(k): int(v) for k, v in game["bank"].items()} return {"game": game, "board": obs["board"]} def _finalize_action(template: ActionSchema, rng: random.Random) -> ActionSchema: payload = dict(template.payload or {}) action_type = ActionType(template.type) if action_type in {ActionType.MOVE_ROBBER, ActionType.PLAY_KNIGHT}: options = payload.get("options", []) if not options: return ActionSchema(type=ActionType.END_TURN.value, payload={}) choice = rng.choice(options) new_payload: Dict[str, Any] = {"hex": choice["hex"]} victims = choice.get("victims") or [] if victims: new_payload["victim"] = rng.choice(victims) return ActionSchema(type=action_type.value, payload=new_payload) if action_type == ActionType.PLAY_ROAD_BUILDING: edges = payload.get("edges", []) if not edges: return ActionSchema(type=ActionType.END_TURN.value, payload={}) picks = rng.sample(edges, k=min(2, len(edges))) while len(picks) < 2: picks.append(rng.choice(edges)) return ActionSchema(type=action_type.value, payload={"edges": picks[:2]}) if action_type == ActionType.PLAY_YEAR_OF_PLENTY: bank = payload.get("bank", {}) available = [res for res, amount in bank.items() if amount > 0] if not available: available = list(bank.keys()) if not available: return ActionSchema(type=ActionType.END_TURN.value, payload={}) pick = rng.choice(available) return ActionSchema(type=action_type.value, payload={"resources": [pick, pick]}) if action_type == ActionType.PLAY_MONOPOLY: choices = payload.get("resources") or [res.value for res in Resource if res != Resource.DESERT] if not choices: return ActionSchema(type=ActionType.END_TURN.value, payload={}) return ActionSchema(type=action_type.value, payload={"resource": rng.choice(choices)}) if action_type == ActionType.DISCARD: required = payload.get("required") resources = payload.get("resources") or {} if not isinstance(required, int): return ActionSchema(type=ActionType.END_TURN.value, payload={}) pool = [] for res, count in resources.items(): if res == "desert" or count <= 0: continue pool.extend([res] * int(count)) rng.shuffle(pool) cards: Dict[str, int] = {} for res in pool[:required]: cards[res] = cards.get(res, 0) + 1 return ActionSchema(type=action_type.value, payload={"player": payload.get("player"), "cards": cards}) return ActionSchema(type=action_type.value, payload=payload) def _choose_action(obs: Dict[str, Any], legal_actions: List[ActionSchema], agent: Dict[str, Any], debug: bool) -> AIResponse: rng = random.Random() kind = agent.get("kind", "random") if not legal_actions: return AIResponse(action=ActionSchema(type=ActionType.END_TURN.value, payload={})) if kind == "random": template = rng.choice(legal_actions) return AIResponse(action=_finalize_action(template, rng)) if kind == "model": model_name = agent.get("model") if not model_name: raise HTTPException(status_code=400, detail="Model name required") actor = registry.load(model_name) obs_vec = encode_observation(_deserialize_observation(obs)) actions_vec = np.stack([encode_action(Action(ActionType(a.type), a.payload)) for a in legal_actions]) obs_tensor = torch.tensor(obs_vec, dtype=torch.float32) action_tensor = torch.tensor(actions_vec, dtype=torch.float32) logits = actor(obs_tensor, action_tensor) probs = torch.softmax(logits, dim=0) if agent.get("stochastic", True): dist = torch.distributions.Categorical(probs=probs) idx = dist.sample().item() else: idx = torch.argmax(probs).item() selected = legal_actions[idx] finalized = _finalize_action(selected, rng) debug_payload = {} if debug: debug_payload = { "logits": logits.detach().cpu().tolist(), "probs": probs.detach().cpu().tolist(), "index": idx, "model": model_name, } return AIResponse(action=finalized, debug=debug_payload) raise HTTPException(status_code=400, detail="Unknown agent kind") @app.get("/health") def health() -> Dict[str, str]: return {"status": "ok"} @app.get("/models") def list_models() -> Dict[str, List[str]]: return {"models": registry.list_models()} @app.post("/act") def act(payload: AIRequest) -> Dict[str, Any]: response = _choose_action(payload.observation, payload.legal_actions, payload.agent, payload.debug) return response.model_dump()