479 lines
17 KiB
Python
479 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import datetime as dt
|
|
import random
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
from fastapi import Depends, FastAPI, HTTPException
|
|
from sqlmodel import SQLModel
|
|
|
|
from catan.data import Resource
|
|
from catan.game import GameConfig
|
|
from catan.sdk import Action, ActionType, CatanEnv
|
|
from services.common.db import engine
|
|
from services.common.schemas import (
|
|
ActionRequest,
|
|
ActionSchema,
|
|
AddAIRequest,
|
|
CreateGameRequest,
|
|
GameStateSchema,
|
|
GameSummarySchema,
|
|
JoinGameRequest,
|
|
TradeOfferRequest,
|
|
TradeOfferSchema,
|
|
TradeRespondRequest,
|
|
)
|
|
from services.common.settings import settings
|
|
from services.game.models import Game, GameEvent, TradeOffer
|
|
from services.game.runtime import manager
|
|
|
|
app = FastAPI(title="Catan Game Service")
|
|
|
|
|
|
@app.on_event("startup")
|
|
def _startup() -> None:
|
|
SQLModel.metadata.create_all(engine)
|
|
|
|
|
|
def _serialize_resources(resources: Dict[Resource, int]) -> Dict[str, int]:
|
|
return {res.value if isinstance(res, Resource) else str(res): int(val) for res, val in resources.items()}
|
|
|
|
|
|
def _serialize_player(player: Dict[str, Any]) -> Dict[str, Any]:
|
|
data = dict(player)
|
|
resources = data.get("resources", {})
|
|
if resources:
|
|
data["resources"] = _serialize_resources(resources)
|
|
return data
|
|
|
|
|
|
def _serialize_game_observation(observation: Dict[str, Any]) -> Dict[str, Any]:
|
|
game = dict(observation["game"])
|
|
players = {name: _serialize_player(info) for name, info in game["players"].items()}
|
|
game["players"] = players
|
|
bank = game.get("bank", {})
|
|
if bank:
|
|
game["bank"] = _serialize_resources(bank)
|
|
return {"game": game, "board": observation["board"]}
|
|
|
|
|
|
def _serialize_action(action: Action) -> ActionSchema:
|
|
return ActionSchema(type=action.type.value, payload=action.payload)
|
|
|
|
|
|
def _serialize_legal_actions(actions: List[Action]) -> List[ActionSchema]:
|
|
return [_serialize_action(action) for action in actions]
|
|
|
|
|
|
def _slots_to_schema(game: Game) -> List[Dict[str, Any]]:
|
|
return game.slots.get("slots", [])
|
|
|
|
|
|
def _to_game_summary(game: Game) -> GameSummarySchema:
|
|
return GameSummarySchema(
|
|
id=game.id,
|
|
name=game.name,
|
|
status=game.status,
|
|
max_players=game.max_players,
|
|
created_by=game.created_by,
|
|
created_at=game.created_at,
|
|
players=[slot for slot in _slots_to_schema(game)],
|
|
)
|
|
|
|
|
|
def _trade_to_schema(trade: TradeOffer) -> TradeOfferSchema:
|
|
return TradeOfferSchema(
|
|
id=trade.id,
|
|
from_player=trade.from_player,
|
|
to_player=trade.to_player,
|
|
offer=trade.offer,
|
|
request=trade.request,
|
|
status=trade.status,
|
|
created_at=trade.created_at,
|
|
)
|
|
|
|
|
|
def _build_state(game: Game, runtime) -> GameStateSchema:
|
|
if game.status != "running" or runtime is None:
|
|
return GameStateSchema(
|
|
id=game.id,
|
|
name=game.name,
|
|
status=game.status,
|
|
max_players=game.max_players,
|
|
created_by=game.created_by,
|
|
created_at=game.created_at,
|
|
players=[slot for slot in _slots_to_schema(game)],
|
|
)
|
|
obs = _serialize_game_observation(runtime.env.observe())
|
|
legal_actions = _serialize_legal_actions(runtime.env.legal_actions())
|
|
trades = manager.list_trade_offers(game.id, status="open")
|
|
history = manager.list_events(game.id)
|
|
return GameStateSchema(
|
|
id=game.id,
|
|
name=game.name,
|
|
status=game.status,
|
|
max_players=game.max_players,
|
|
created_by=game.created_by,
|
|
created_at=game.created_at,
|
|
players=[slot for slot in _slots_to_schema(game)],
|
|
game=obs["game"],
|
|
board=obs["board"],
|
|
legal_actions=legal_actions,
|
|
pending_trades=[_trade_to_schema(trade) for trade in trades],
|
|
history=[
|
|
{
|
|
"idx": event.idx,
|
|
"ts": event.ts,
|
|
"actor": event.actor,
|
|
"action": {"type": event.action_type, "payload": event.payload},
|
|
"applied": event.applied,
|
|
"meta": event.debug_payload or {},
|
|
}
|
|
for event in history
|
|
],
|
|
)
|
|
|
|
|
|
def _ensure_game(game_id: str) -> Game:
|
|
try:
|
|
return manager.get(game_id).game
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Game not found")
|
|
|
|
|
|
def _get_runtime(game_id: str):
|
|
try:
|
|
return manager.get(game_id)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Game not found")
|
|
|
|
|
|
def _find_slot(game: Game, predicate) -> Optional[Dict[str, Any]]:
|
|
for slot in game.slots.get("slots", []):
|
|
if predicate(slot):
|
|
return slot
|
|
return None
|
|
|
|
|
|
def _ai_slots(game: Game) -> Dict[str, Dict[str, Any]]:
|
|
return {
|
|
slot["name"]: slot
|
|
for slot in game.slots.get("slots", [])
|
|
if slot.get("is_ai") and slot.get("name")
|
|
}
|
|
|
|
|
|
def _ai_trade_decision(resources: Dict[str, int], offer: Dict[str, int], request: Dict[str, int]) -> bool:
|
|
if any(resources.get(res, 0) < amount for res, amount in request.items()):
|
|
return False
|
|
offer_value = sum(offer.values())
|
|
request_value = sum(request.values())
|
|
if offer_value >= request_value:
|
|
return True
|
|
return random.random() < 0.35
|
|
|
|
|
|
async def _request_ai_action(agent: Dict[str, Any], observation: Dict[str, Any], legal_actions: List[ActionSchema]) -> Dict[str, Any]:
|
|
payload = {
|
|
"observation": observation,
|
|
"legal_actions": [action.model_dump() for action in legal_actions],
|
|
"agent": agent,
|
|
"debug": settings.debug,
|
|
}
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(f"{settings.ai_service_url}/act", json=payload)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def _run_ai_turns(runtime) -> None:
|
|
game = runtime.game
|
|
if game.status != "running":
|
|
return
|
|
ai_slots = _ai_slots(game)
|
|
safety = 0
|
|
while safety < 200:
|
|
safety += 1
|
|
current = runtime.env.game.current_player.name
|
|
open_trades = manager.list_trade_offers(game.id, status="open")
|
|
for trade in open_trades:
|
|
if trade.from_player != current:
|
|
continue
|
|
target = trade.to_player
|
|
if target is None:
|
|
candidates = [name for name in ai_slots if name != trade.from_player]
|
|
target = candidates[0] if candidates else None
|
|
if target and target in ai_slots:
|
|
resources = runtime.env.game.player_by_name(target).resources
|
|
resources_map = {res.value: count for res, count in resources.items()}
|
|
accept = _ai_trade_decision(resources_map, trade.offer, trade.request)
|
|
trade.status = "accepted" if accept else "declined"
|
|
manager.update_trade_offer(trade)
|
|
manager.record_event(
|
|
game.id,
|
|
target,
|
|
Action(ActionType.TRADE_PLAYER if accept else ActionType.END_TURN, {
|
|
"trade_id": trade.id,
|
|
"accept": accept,
|
|
}),
|
|
applied=False,
|
|
)
|
|
if accept:
|
|
action = Action(
|
|
ActionType.TRADE_PLAYER,
|
|
{"target": trade.from_player, "offer": trade.offer, "request": trade.request},
|
|
)
|
|
_, _, _, info = runtime.env.step(action)
|
|
manager.record_event(game.id, trade.from_player, action, applied=True)
|
|
break
|
|
if current not in ai_slots:
|
|
break
|
|
slot = ai_slots[current]
|
|
observation = _serialize_game_observation(runtime.env.observe())
|
|
legal_actions = _serialize_legal_actions(runtime.env.legal_actions())
|
|
agent_cfg = {
|
|
"kind": slot.get("ai_kind", "random"),
|
|
"model": slot.get("ai_model"),
|
|
"stochastic": True,
|
|
}
|
|
response = await _request_ai_action(agent_cfg, observation, legal_actions)
|
|
action_data = response.get("action")
|
|
debug = response.get("debug") or {}
|
|
if settings.debug:
|
|
debug = {
|
|
**debug,
|
|
"observation": observation,
|
|
"legal_actions": [action.model_dump() for action in legal_actions],
|
|
}
|
|
action = Action(ActionType(action_data["type"]), action_data.get("payload") or {})
|
|
_, _, done, info = runtime.env.step(action)
|
|
manager.record_event(game.id, current, action, applied=True, debug=debug)
|
|
if info.get("invalid"):
|
|
break
|
|
if action.type == ActionType.END_TURN:
|
|
_expire_trades(game.id)
|
|
if done:
|
|
game.status = "finished"
|
|
game.winner = runtime.env.game.winner
|
|
manager.save_game(game)
|
|
break
|
|
|
|
|
|
def _expire_trades(game_id: str) -> None:
|
|
offers = manager.list_trade_offers(game_id, status="open")
|
|
for offer in offers:
|
|
offer.status = "expired"
|
|
manager.update_trade_offer(offer)
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> Dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.get("/games")
|
|
def list_games() -> Dict[str, Any]:
|
|
games = manager.list_games()
|
|
return {"games": [_to_game_summary(game).model_dump() for game in games]}
|
|
|
|
|
|
@app.post("/games")
|
|
def create_game(payload: CreateGameRequest) -> Dict[str, Any]:
|
|
if payload.max_players < 2 or payload.max_players > 4:
|
|
raise HTTPException(status_code=400, detail="max_players must be 2-4")
|
|
game = manager.create_game(payload.name, payload.max_players, created_by=payload.created_by or "host")
|
|
return _to_game_summary(game).model_dump()
|
|
|
|
|
|
@app.post("/games/{game_id}/join")
|
|
def join_game(game_id: str, payload: JoinGameRequest) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
game = runtime.game
|
|
slot = _find_slot(game, lambda s: s.get("user_id") == payload.user_id)
|
|
if slot:
|
|
return _to_game_summary(game).model_dump()
|
|
open_slot = _find_slot(game, lambda s: s.get("name") is None)
|
|
if not open_slot:
|
|
raise HTTPException(status_code=400, detail="No available slots")
|
|
open_slot.update({
|
|
"name": payload.username,
|
|
"user_id": payload.user_id,
|
|
"ready": True,
|
|
})
|
|
manager.save_game(game)
|
|
return _to_game_summary(game).model_dump()
|
|
|
|
|
|
@app.post("/games/{game_id}/leave")
|
|
def leave_game(game_id: str, payload: JoinGameRequest) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
game = runtime.game
|
|
slot = _find_slot(game, lambda s: s.get("user_id") == payload.user_id)
|
|
if not slot:
|
|
return _to_game_summary(game).model_dump()
|
|
slot.update({
|
|
"name": None,
|
|
"user_id": None,
|
|
"ready": False,
|
|
"is_ai": False,
|
|
"ai_kind": None,
|
|
"ai_model": None,
|
|
"color": None,
|
|
})
|
|
manager.save_game(game)
|
|
return _to_game_summary(game).model_dump()
|
|
|
|
|
|
@app.post("/games/{game_id}/add_ai")
|
|
def add_ai(game_id: str, payload: AddAIRequest) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
game = runtime.game
|
|
open_slot = _find_slot(game, lambda s: s.get("name") is None)
|
|
if not open_slot:
|
|
raise HTTPException(status_code=400, detail="No available slots")
|
|
ai_type = payload.ai_type.lower()
|
|
if ai_type not in {"random", "model"}:
|
|
raise HTTPException(status_code=400, detail="Unknown AI type")
|
|
name_base = "AI" if ai_type == "random" else "Model"
|
|
existing = {slot.get("name") for slot in game.slots.get("slots", []) if slot.get("name")}
|
|
suffix = 1
|
|
name = f"{name_base}-{suffix}"
|
|
while name in existing:
|
|
suffix += 1
|
|
name = f"{name_base}-{suffix}"
|
|
open_slot.update({
|
|
"name": name,
|
|
"is_ai": True,
|
|
"ai_kind": ai_type,
|
|
"ai_model": payload.model_name,
|
|
"ready": True,
|
|
})
|
|
manager.save_game(game)
|
|
return _to_game_summary(game).model_dump()
|
|
|
|
|
|
@app.post("/games/{game_id}/start")
|
|
async def start_game(game_id: str) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
game = runtime.game
|
|
if game.status != "lobby":
|
|
raise HTTPException(status_code=400, detail="Game already started")
|
|
slots = game.slots.get("slots", [])
|
|
names = [slot.get("name") for slot in slots if slot.get("name")]
|
|
if len(names) < 2:
|
|
raise HTTPException(status_code=400, detail="Not enough players")
|
|
colors = ["red", "blue", "orange", "white"]
|
|
for slot, color in zip(slots, colors):
|
|
if slot.get("name"):
|
|
slot["color"] = color
|
|
game.slots["slots"] = slots
|
|
game.status = "running"
|
|
manager.save_game(game)
|
|
runtime.env = CatanEnv(GameConfig(player_names=names, colors=colors[: len(names)], seed=game.seed))
|
|
await _run_ai_turns(runtime)
|
|
return _build_state(game, runtime).model_dump()
|
|
|
|
|
|
@app.get("/games/{game_id}")
|
|
def game_state(game_id: str) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
game = runtime.game
|
|
return _build_state(game, runtime).model_dump()
|
|
|
|
|
|
@app.post("/games/{game_id}/action")
|
|
async def apply_action(game_id: str, payload: ActionRequest) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
game = runtime.game
|
|
if game.status != "running":
|
|
raise HTTPException(status_code=400, detail="Game not running")
|
|
action_type = ActionType(payload.action.type)
|
|
action = Action(type=action_type, payload=payload.action.payload)
|
|
actor = payload.actor
|
|
current = runtime.env.game.current_player.name
|
|
if action.type == ActionType.DISCARD:
|
|
target = action.payload.get("player")
|
|
if target != actor:
|
|
raise HTTPException(status_code=403, detail="Discard only for self")
|
|
elif actor != current:
|
|
raise HTTPException(status_code=403, detail="Not your turn")
|
|
_, _, done, info = runtime.env.step(action)
|
|
if info.get("invalid"):
|
|
raise HTTPException(status_code=400, detail=info.get("error", "Invalid action"))
|
|
manager.record_event(game.id, actor, action, applied=True)
|
|
if action.type == ActionType.END_TURN:
|
|
_expire_trades(game.id)
|
|
if done:
|
|
game.status = "finished"
|
|
game.winner = runtime.env.game.winner
|
|
manager.save_game(game)
|
|
await _run_ai_turns(runtime)
|
|
return _build_state(game, runtime).model_dump()
|
|
|
|
|
|
@app.post("/games/{game_id}/trade/offer")
|
|
async def offer_trade(game_id: str, payload: TradeOfferRequest) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
game = runtime.game
|
|
if game.status != "running":
|
|
raise HTTPException(status_code=400, detail="Game not running")
|
|
current = runtime.env.game.current_player.name
|
|
if payload.from_player != current:
|
|
raise HTTPException(status_code=403, detail="Only current player can offer trades")
|
|
if not runtime.env.game.has_rolled:
|
|
raise HTTPException(status_code=400, detail="Roll dice before trading")
|
|
trade = manager.create_trade_offer(
|
|
game.id,
|
|
payload.from_player,
|
|
payload.to_player,
|
|
payload.offer,
|
|
payload.request,
|
|
)
|
|
manager.record_event(
|
|
game.id,
|
|
payload.from_player,
|
|
Action(ActionType.TRADE_PLAYER, {"trade_id": trade.id, "offer": payload.offer, "request": payload.request}),
|
|
applied=False,
|
|
)
|
|
await _run_ai_turns(runtime)
|
|
return _trade_to_schema(trade).model_dump()
|
|
|
|
|
|
@app.post("/games/{game_id}/trade/{trade_id}/respond")
|
|
def respond_trade(game_id: str, trade_id: str, payload: TradeRespondRequest) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
trade = next((t for t in manager.list_trade_offers(game_id, status="open") if t.id == trade_id), None)
|
|
if not trade:
|
|
raise HTTPException(status_code=404, detail="Trade not found")
|
|
if trade.to_player and trade.to_player != payload.player:
|
|
raise HTTPException(status_code=403, detail="Not target player")
|
|
if payload.player == trade.from_player:
|
|
raise HTTPException(status_code=400, detail="Cannot accept own trade")
|
|
trade.status = "accepted" if payload.accept else "declined"
|
|
manager.update_trade_offer(trade)
|
|
manager.record_event(
|
|
game_id,
|
|
payload.player,
|
|
Action(ActionType.TRADE_PLAYER, {"trade_id": trade.id, "accept": payload.accept}),
|
|
applied=False,
|
|
)
|
|
if payload.accept:
|
|
action = Action(ActionType.TRADE_PLAYER, {
|
|
"target": trade.from_player,
|
|
"offer": trade.offer,
|
|
"request": trade.request,
|
|
})
|
|
_, _, _, info = runtime.env.step(action)
|
|
if info.get("invalid"):
|
|
raise HTTPException(status_code=400, detail=info.get("error", "Invalid trade"))
|
|
manager.record_event(game_id, trade.from_player, action, applied=True)
|
|
return {"status": trade.status}
|
|
|
|
|
|
@app.post("/games/{game_id}/advance")
|
|
async def advance_ai(game_id: str) -> Dict[str, Any]:
|
|
runtime = _get_runtime(game_id)
|
|
await _run_ai_turns(runtime)
|
|
return _build_state(runtime.game, runtime).model_dump()
|