Files
catan/catan/ml/trainers.py
dan 46a07f548b
Some checks failed
ci / tests (push) Has been cancelled
Add microservices, web UI, and replay tooling
2025-12-25 03:28:40 +03:00

159 lines
6.1 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import random
from typing import List, Sequence
import numpy as np
import torch
from torch import nn
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from ..sdk import Action, ActionType, CatanEnv
from .agents import PolicyAgent
from .encoding import encode_action, encode_observation
from .policies import PolicyNetwork
def _ensure_tensor(data, device: str = "cpu") -> torch.Tensor:
return torch.tensor(data, dtype=torch.float32, device=device)
class ReinforcementLearningTrainer:
def __init__(
self,
env: CatanEnv,
hidden_layers: Sequence[int] = (256, 256),
lr: float = 3e-4,
gamma: float = 0.99,
device: str = "cpu",
) -> None:
self.env = env
self.gamma = gamma
self.device = device
obs_dim = encode_observation(env.observe()).shape[0]
dummy_action = Action(ActionType.END_TURN, {})
action_dim = encode_action(dummy_action).shape[0]
self.policy = PolicyNetwork(obs_dim, action_dim, hidden_layers).to(device)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
def select_action(self, observation, legal_actions: List[Action]) -> tuple[Action, torch.Tensor]:
obs_tensor = _ensure_tensor(observation, self.device)
action_vectors = torch.stack(
[_ensure_tensor(encode_action(action), self.device) for action in legal_actions]
)
logits = self.policy(obs_tensor, action_vectors)
probs = torch.softmax(logits, dim=0)
dist = torch.distributions.Categorical(probs=probs)
idx = dist.sample()
action = legal_actions[idx.item()]
log_prob = dist.log_prob(idx)
from .agents import finalize_action
return finalize_action(self.env, action, None), log_prob
def run_episode(self) -> tuple[List[torch.Tensor], List[float]]:
observation = self.env.reset()
log_probs: List[torch.Tensor] = []
rewards: List[float] = []
done = False
while not done:
legal_actions = self.env.legal_actions()
action, log_prob = self.select_action(observation, legal_actions)
observation, reward, done, _ = self.env.step(action)
log_probs.append(log_prob)
rewards.append(float(reward))
return log_probs, rewards
def train(self, episodes: int = 50) -> List[float]:
history: List[float] = []
for _ in range(episodes):
log_probs, rewards = self.run_episode()
returns = self._discounted_returns(rewards)
loss = 0.0
for log_prob, ret in zip(log_probs, returns):
loss = loss - log_prob * ret
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
history.append(sum(rewards))
return history
def _discounted_returns(self, rewards: List[float]) -> List[float]:
accumulator = 0.0
returns: List[float] = []
for reward in reversed(rewards):
accumulator = reward + self.gamma * accumulator
returns.append(accumulator)
returns.reverse()
mean = np.mean(returns) if returns else 0.0
std = np.std(returns) if returns else 1.0
std = std if std > 1e-6 else 1.0
return [(ret - mean) / std for ret in returns]
@dataclass
class EvolutionConfig:
population_size: int = 20
elite_fraction: float = 0.2
mutation_scale: float = 0.1
episodes_per_candidate: int = 1
class EvolutionStrategyTrainer:
def __init__(
self,
env: CatanEnv,
hidden_layers: Sequence[int] = (256, 256),
device: str = "cpu",
config: EvolutionConfig | None = None,
) -> None:
self.env = env
self.device = device
obs_dim = encode_observation(env.observe()).shape[0]
action_dim = encode_action(Action(ActionType.END_TURN, {})).shape[0]
self.policy = PolicyNetwork(obs_dim, action_dim, hidden_layers).to(device)
self.config = config or EvolutionConfig()
self.vector_length = len(parameters_to_vector(self.policy.parameters()))
def evaluate(self, weights: torch.Tensor) -> float:
vector_to_parameters(weights, self.policy.parameters())
agent = PolicyAgent(self.policy, device=self.device, stochastic=False)
total_reward = 0.0
for episode in range(self.config.episodes_per_candidate):
observation = self.env.reset(seed=episode)
done = False
while not done:
legal_actions = self.env.legal_actions()
action = agent.choose_action(self.env, legal_actions)
observation, reward, done, _ = self.env.step(action)
total_reward += reward
return total_reward
def evolve(self, generations: int = 20) -> torch.Tensor:
population = [
torch.randn(self.vector_length, device=self.device) * 0.1
for _ in range(self.config.population_size)
]
best_vector = population[0]
best_score = float("-inf")
elite_size = max(1, int(self.config.population_size * self.config.elite_fraction))
for _ in range(generations):
scores = []
for candidate in population:
score = self.evaluate(candidate.clone())
scores.append((score, candidate))
if score > best_score:
best_score = score
best_vector = candidate.clone()
scores.sort(key=lambda item: item[0], reverse=True)
elites = [candidate for _, candidate in scores[:elite_size]]
new_population = elites.copy()
while len(new_population) < self.config.population_size:
parent = random.choice(elites)
child = parent + torch.randn_like(parent) * self.config.mutation_scale
new_population.append(child)
population = new_population
vector_to_parameters(best_vector, self.policy.parameters())
return best_vector