Add microservices, web UI, and replay tooling
Some checks failed
ci / tests (push) Has been cancelled
Some checks failed
ci / tests (push) Has been cancelled
This commit is contained in:
185
services/ai/app.py
Normal file
185
services/ai/app.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from torch.serialization import add_safe_globals
|
||||
|
||||
from catan.data import Resource
|
||||
from catan.ml.encoding import encode_action, encode_observation
|
||||
from catan.game import GameConfig
|
||||
from catan.ml.selfplay import ActionScoringNetwork, PPOConfig
|
||||
from catan.sdk import Action, ActionType, parse_resource
|
||||
from services.common.schemas import AIRequest, AIResponse, ActionSchema
|
||||
from services.common.settings import settings
|
||||
|
||||
app = FastAPI(title="Catan AI Service")
|
||||
|
||||
add_safe_globals([PPOConfig, GameConfig])
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._cache: Dict[str, ActionScoringNetwork] = {}
|
||||
|
||||
def list_models(self) -> List[str]:
|
||||
path = settings.models_dir
|
||||
import os
|
||||
|
||||
if not os.path.isdir(path):
|
||||
return []
|
||||
return sorted([name for name in os.listdir(path) if name.endswith(".pt")])
|
||||
|
||||
def load(self, name: str) -> ActionScoringNetwork:
|
||||
if name in self._cache:
|
||||
return self._cache[name]
|
||||
path = f"{settings.models_dir}/{name}"
|
||||
state = torch.load(path, map_location="cpu", weights_only=False)
|
||||
cfg = state.get("config")
|
||||
if cfg is None:
|
||||
raise ValueError("Invalid model config")
|
||||
input_dim = state["actor"]["network.0.weight"].shape[1]
|
||||
action_dim = encode_action(Action(ActionType.END_TURN, {})).shape[0]
|
||||
obs_dim = input_dim - action_dim
|
||||
actor = ActionScoringNetwork(obs_dim, action_dim, cfg.hidden_sizes)
|
||||
actor.load_state_dict(state["actor"])
|
||||
actor.eval()
|
||||
self._cache[name] = actor
|
||||
return actor
|
||||
|
||||
|
||||
registry = ModelRegistry()
|
||||
|
||||
|
||||
def _deserialize_resources(resources: Dict[str, int]) -> Dict[Resource, int]:
|
||||
return {parse_resource(k): int(v) for k, v in resources.items()}
|
||||
|
||||
|
||||
def _deserialize_observation(obs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
game = dict(obs["game"])
|
||||
players = {}
|
||||
for name, info in game["players"].items():
|
||||
data = dict(info)
|
||||
if isinstance(data.get("resources"), dict) and "hidden" not in data["resources"]:
|
||||
data["resources"] = _deserialize_resources(data["resources"])
|
||||
players[name] = data
|
||||
game["players"] = players
|
||||
if "bank" in game and isinstance(game["bank"], dict):
|
||||
game["bank"] = {parse_resource(k): int(v) for k, v in game["bank"].items()}
|
||||
return {"game": game, "board": obs["board"]}
|
||||
|
||||
|
||||
def _finalize_action(template: ActionSchema, rng: random.Random) -> ActionSchema:
|
||||
payload = dict(template.payload or {})
|
||||
action_type = ActionType(template.type)
|
||||
|
||||
if action_type in {ActionType.MOVE_ROBBER, ActionType.PLAY_KNIGHT}:
|
||||
options = payload.get("options", [])
|
||||
if not options:
|
||||
return ActionSchema(type=ActionType.END_TURN.value, payload={})
|
||||
choice = rng.choice(options)
|
||||
new_payload: Dict[str, Any] = {"hex": choice["hex"]}
|
||||
victims = choice.get("victims") or []
|
||||
if victims:
|
||||
new_payload["victim"] = rng.choice(victims)
|
||||
return ActionSchema(type=action_type.value, payload=new_payload)
|
||||
|
||||
if action_type == ActionType.PLAY_ROAD_BUILDING:
|
||||
edges = payload.get("edges", [])
|
||||
if not edges:
|
||||
return ActionSchema(type=ActionType.END_TURN.value, payload={})
|
||||
picks = rng.sample(edges, k=min(2, len(edges)))
|
||||
while len(picks) < 2:
|
||||
picks.append(rng.choice(edges))
|
||||
return ActionSchema(type=action_type.value, payload={"edges": picks[:2]})
|
||||
|
||||
if action_type == ActionType.PLAY_YEAR_OF_PLENTY:
|
||||
bank = payload.get("bank", {})
|
||||
available = [res for res, amount in bank.items() if amount > 0]
|
||||
if not available:
|
||||
available = list(bank.keys())
|
||||
if not available:
|
||||
return ActionSchema(type=ActionType.END_TURN.value, payload={})
|
||||
pick = rng.choice(available)
|
||||
return ActionSchema(type=action_type.value, payload={"resources": [pick, pick]})
|
||||
|
||||
if action_type == ActionType.PLAY_MONOPOLY:
|
||||
choices = payload.get("resources") or [res.value for res in Resource if res != Resource.DESERT]
|
||||
if not choices:
|
||||
return ActionSchema(type=ActionType.END_TURN.value, payload={})
|
||||
return ActionSchema(type=action_type.value, payload={"resource": rng.choice(choices)})
|
||||
|
||||
if action_type == ActionType.DISCARD:
|
||||
required = payload.get("required")
|
||||
resources = payload.get("resources") or {}
|
||||
if not isinstance(required, int):
|
||||
return ActionSchema(type=ActionType.END_TURN.value, payload={})
|
||||
pool = []
|
||||
for res, count in resources.items():
|
||||
if res == "desert" or count <= 0:
|
||||
continue
|
||||
pool.extend([res] * int(count))
|
||||
rng.shuffle(pool)
|
||||
cards: Dict[str, int] = {}
|
||||
for res in pool[:required]:
|
||||
cards[res] = cards.get(res, 0) + 1
|
||||
return ActionSchema(type=action_type.value, payload={"player": payload.get("player"), "cards": cards})
|
||||
|
||||
return ActionSchema(type=action_type.value, payload=payload)
|
||||
|
||||
|
||||
def _choose_action(obs: Dict[str, Any], legal_actions: List[ActionSchema], agent: Dict[str, Any], debug: bool) -> AIResponse:
|
||||
rng = random.Random()
|
||||
kind = agent.get("kind", "random")
|
||||
if not legal_actions:
|
||||
return AIResponse(action=ActionSchema(type=ActionType.END_TURN.value, payload={}))
|
||||
if kind == "random":
|
||||
template = rng.choice(legal_actions)
|
||||
return AIResponse(action=_finalize_action(template, rng))
|
||||
if kind == "model":
|
||||
model_name = agent.get("model")
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="Model name required")
|
||||
actor = registry.load(model_name)
|
||||
obs_vec = encode_observation(_deserialize_observation(obs))
|
||||
actions_vec = np.stack([encode_action(Action(ActionType(a.type), a.payload)) for a in legal_actions])
|
||||
obs_tensor = torch.tensor(obs_vec, dtype=torch.float32)
|
||||
action_tensor = torch.tensor(actions_vec, dtype=torch.float32)
|
||||
logits = actor(obs_tensor, action_tensor)
|
||||
probs = torch.softmax(logits, dim=0)
|
||||
if agent.get("stochastic", True):
|
||||
dist = torch.distributions.Categorical(probs=probs)
|
||||
idx = dist.sample().item()
|
||||
else:
|
||||
idx = torch.argmax(probs).item()
|
||||
selected = legal_actions[idx]
|
||||
finalized = _finalize_action(selected, rng)
|
||||
debug_payload = {}
|
||||
if debug:
|
||||
debug_payload = {
|
||||
"logits": logits.detach().cpu().tolist(),
|
||||
"probs": probs.detach().cpu().tolist(),
|
||||
"index": idx,
|
||||
"model": model_name,
|
||||
}
|
||||
return AIResponse(action=finalized, debug=debug_payload)
|
||||
raise HTTPException(status_code=400, detail="Unknown agent kind")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> Dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
def list_models() -> Dict[str, List[str]]:
|
||||
return {"models": registry.list_models()}
|
||||
|
||||
|
||||
@app.post("/act")
|
||||
def act(payload: AIRequest) -> Dict[str, Any]:
|
||||
response = _choose_action(payload.observation, payload.legal_actions, payload.agent, payload.debug)
|
||||
return response.model_dump()
|
||||
Reference in New Issue
Block a user