chess-ruben-coord-v0 / src /evaluate.py
rubenifrah's picture
Chess Challenge submission by rubenifrah
481ffb7 verified
"""
Evaluation script for the Chess Challenge.
This script evaluates a trained chess model by playing games against
Stockfish and computing ELO ratings.
"""
from __future__ import annotations
import argparse
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
@dataclass
class GameResult:
"""Result of a single game."""
moves: List[str]
result: str # "1-0", "0-1", or "1/2-1/2"
model_color: str # "white" or "black"
termination: str # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
illegal_move_count: int
class ChessEvaluator:
"""
Evaluator for chess models.
This class handles playing games between a trained model and Stockfish,
tracking results, and computing ELO ratings.
"""
def __init__(
self,
model,
tokenizer,
stockfish_path: Optional[str] = None,
stockfish_level: int = 1,
max_retries: int = 3,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initialize the evaluator.
Args:
model: The trained chess model.
tokenizer: The chess tokenizer.
stockfish_path: Path to Stockfish executable.
stockfish_level: Stockfish skill level (0-20).
max_retries: Maximum retries for illegal moves.
device: Device to run the model on.
"""
self.model = model.to(device)
self.tokenizer = tokenizer
self.max_retries = max_retries
self.device = device
# Initialize Stockfish
try:
import chess
import chess.engine
self.chess = chess
if stockfish_path is None:
# Try common paths
import shutil
stockfish_path = shutil.which("stockfish")
if stockfish_path:
self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
self.engine.configure({"Skill Level": stockfish_level})
else:
print("WARNING: Stockfish not found. Install it for full evaluation.")
self.engine = None
except ImportError:
raise ImportError(
"python-chess is required for evaluation. "
"Install it with: pip install python-chess"
)
def __del__(self):
"""Clean up Stockfish engine."""
if hasattr(self, 'engine') and self.engine:
self.engine.quit()
def _convert_board_to_moves(self, board) -> str:
"""Convert board move history to model input format."""
moves = []
temp_board = self.chess.Board()
for move in board.move_stack:
# Get piece and color
color = "W" if temp_board.turn == self.chess.WHITE else "B"
piece = temp_board.piece_at(move.from_square)
piece_letter = piece.symbol().upper() if piece else "P"
# Get squares
from_sq = self.chess.square_name(move.from_square)
to_sq = self.chess.square_name(move.to_square)
move_str = f"{color}{piece_letter}{from_sq}{to_sq}"
# Add promotion
if move.promotion:
move_str += f"={self.chess.piece_symbol(move.promotion).upper()}"
# Add capture suffix
if temp_board.is_capture(move):
move_str += "(x)"
# Add check/checkmate suffix
temp_board.push(move)
if temp_board.is_checkmate():
move_str = move_str.replace("(x)", "(x+*)") if "(x)" in move_str else move_str + "(+*)"
elif temp_board.is_check():
move_str = move_str.replace("(x)", "(x+)") if "(x)" in move_str else move_str + "(+)"
# Handle castling
if piece_letter == "K" and abs(ord(from_sq[0]) - ord(to_sq[0])) > 1:
if to_sq[0] == 'g': # Kingside
move_str = move_str.split("(")[0] + "(o)"
else: # Queenside
move_str = move_str.split("(")[0] + "(O)"
moves.append(move_str)
return " ".join(moves)
def _is_separator_token(self, token_str: str) -> bool:
"""
Check if a token represents a separator (whitespace, EOS, etc.).
This allows the evaluator to work with different tokenization strategies:
- Move-level tokenizers: each move is one token, no separators generated
- Character-level tokenizers: space character marks end of move
- BPE/subword tokenizers: may generate partial moves
Args:
token_str: The decoded token string.
Returns:
True if this token indicates end of a move.
"""
# Check for EOS token
if hasattr(self.tokenizer, 'eos_token') and token_str == self.tokenizer.eos_token:
return True
# Check for whitespace (space, newline, etc.)
if token_str.strip() == "" and len(token_str) > 0:
return True
# Check if the token ends with whitespace (some tokenizers include trailing space)
if token_str != token_str.rstrip():
return True
return False
def _generate_move_tokens(
self,
input_ids: torch.Tensor,
temperature: float = 0.7,
top_k: int = 10,
max_tokens: int = 20,
) -> str:
"""
Generate tokens until a separator (whitespace/EOS) is encountered.
This method supports different tokenization strategies:
- For move-level tokenizers: generates one token (the full move)
- For character/subword tokenizers: generates until whitespace
Args:
input_ids: The input token IDs.
temperature: Sampling temperature.
top_k: Top-k filtering parameter.
max_tokens: Maximum tokens to generate for a single move.
Returns:
The generated move string (without trailing separator).
"""
generated_tokens = []
current_ids = input_ids.clone()
for _ in range(max_tokens):
with torch.no_grad():
outputs = self.model(input_ids=current_ids)
logits = outputs.logits[:, -1, :] / temperature
# Apply top-k filtering
if top_k > 0:
top_k_values = torch.topk(logits, min(top_k, logits.size(-1)))[0]
indices_to_remove = logits < top_k_values[..., -1, None]
logits[indices_to_remove] = float("-inf")
# Sample
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # Shape: [1, 1]
# Decode the token
token_str = self.tokenizer.decode(next_token[0])
# Check if this is a separator token
if self._is_separator_token(token_str):
break
generated_tokens.append(next_token[0]) # Store [1] tensor
# Append to input for next iteration (next_token is already [1, 1])
current_ids = torch.cat([current_ids, next_token], dim=-1)
# For move-level tokenizers, a single non-separator token is the full move
# We can detect this by checking if the token looks like a complete move
# (starts with W or B, has enough characters for a move)
if len(token_str) >= 6 and token_str[0] in "WB":
break
# Decode all generated tokens together
if generated_tokens:
all_tokens = torch.cat(generated_tokens, dim=0)
move_str = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
return move_str.strip()
return ""
def _get_model_move(
self,
board,
temperature: float = 0.7,
top_k: int = 10,
) -> Tuple[Optional[str], int]:
"""
Get the model's next move prediction.
This method generates tokens until a separator (whitespace/EOS) is produced,
allowing it to work with different tokenization strategies:
- Move-level tokenizers: each move is a single token
- Character-level tokenizers: moves are generated character by character
- BPE/subword tokenizers: moves may be split into subwords
Returns:
Tuple of (UCI move string, number of retries used).
"""
self.model.eval()
# Convert board to input format
moves_str = self._convert_board_to_moves(board)
# Add BOS token if no moves yet
if not moves_str:
input_text = self.tokenizer.bos_token
else:
input_text = self.tokenizer.bos_token + " " + moves_str
# Tokenize
inputs = self.tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=self.model.config.n_ctx - 10, # Leave room for generated tokens
).to(self.device)
# Try to generate a legal move
for retry in range(self.max_retries):
# Generate tokens until separator
move_token = self._generate_move_tokens(
inputs["input_ids"],
temperature=temperature,
top_k=top_k,
)
# Convert to UCI
if len(move_token) >= 6:
uci_move = move_token[2:4] + move_token[4:6]
# Handle promotion
if "=" in move_token:
promo_idx = move_token.index("=")
uci_move += move_token[promo_idx + 1].lower()
try:
move = self.chess.Move.from_uci(uci_move)
if move in board.legal_moves:
return uci_move, retry
except (ValueError, self.chess.InvalidMoveError):
pass
return None, self.max_retries
def _get_stockfish_move(self, board, time_limit: float = 0.1) -> str:
"""Get Stockfish's move."""
if self.engine is None:
raise RuntimeError("Stockfish engine not initialized")
result = self.engine.play(board, self.chess.engine.Limit(time=time_limit))
return result.move.uci()
def play_game(
self,
model_color: str = "white",
max_moves: int = 200,
temperature: float = 0.7,
) -> GameResult:
"""
Play a single game between the model and Stockfish.
Args:
model_color: "white" or "black".
max_moves: Maximum number of moves before draw.
temperature: Sampling temperature for model.
Returns:
GameResult with the game details.
"""
board = self.chess.Board()
moves = []
illegal_move_count = 0
model_is_white = model_color == "white"
while not board.is_game_over() and len(moves) < max_moves:
is_model_turn = (board.turn == self.chess.WHITE) == model_is_white
if is_model_turn:
# Model's turn
uci_move, retries = self._get_model_move(board, temperature)
illegal_move_count += retries
if uci_move is None:
# Model couldn't find a legal move
return GameResult(
moves=moves,
result="0-1" if model_is_white else "1-0",
model_color=model_color,
termination="illegal_move",
illegal_move_count=illegal_move_count + 1,
)
move = self.chess.Move.from_uci(uci_move)
else:
# Stockfish's turn
if self.engine:
uci_move = self._get_stockfish_move(board)
move = self.chess.Move.from_uci(uci_move)
else:
# Random move if no engine
move = random.choice(list(board.legal_moves))
board.push(move)
moves.append(move.uci())
# Determine result
if board.is_checkmate():
if board.turn == self.chess.WHITE:
result = "0-1" # Black wins
else:
result = "1-0" # White wins
termination = "checkmate"
elif board.is_stalemate():
result = "1/2-1/2"
termination = "stalemate"
elif board.is_insufficient_material():
result = "1/2-1/2"
termination = "insufficient_material"
elif board.can_claim_draw():
result = "1/2-1/2"
termination = "draw_claim"
elif len(moves) >= max_moves:
result = "1/2-1/2"
termination = "max_moves"
else:
result = "1/2-1/2"
termination = "unknown"
return GameResult(
moves=moves,
result=result,
model_color=model_color,
termination=termination,
illegal_move_count=illegal_move_count,
)
def evaluate_legal_moves(
self,
n_positions: int = 1000,
temperature: float = 0.7,
verbose: bool = True,
) -> dict:
"""
Evaluate the model's ability to generate legal moves.
This evaluation only checks if the model generates legal moves,
without playing full games. Useful as a first-pass evaluation.
Args:
n_positions: Number of positions to test.
temperature: Sampling temperature.
verbose: Whether to print progress.
Returns:
Dictionary with legal move statistics.
"""
results = {
"total_positions": 0,
"legal_first_try": 0,
"legal_with_retry": 0,
"illegal_all_retries": 0,
"positions": [],
}
# Generate random positions by playing random moves
for i in range(n_positions):
board = self.chess.Board()
# Play random number of moves (5-40) to get varied positions
n_random_moves = random.randint(5, 40)
for _ in range(n_random_moves):
if board.is_game_over():
break
move = random.choice(list(board.legal_moves))
board.push(move)
if board.is_game_over():
continue # Skip terminal positions
results["total_positions"] += 1
# Test model's move generation
uci_move, retries = self._get_model_move(board, temperature)
position_result = {
"fen": board.fen(),
"move_number": len(board.move_stack),
"legal": uci_move is not None,
"retries": retries,
}
results["positions"].append(position_result)
if uci_move is not None:
if retries == 0:
results["legal_first_try"] += 1
else:
results["legal_with_retry"] += 1
else:
results["illegal_all_retries"] += 1
if verbose and (i + 1) % 100 == 0:
legal_rate = (results["legal_first_try"] + results["legal_with_retry"]) / results["total_positions"]
print(f" Positions: {i + 1}/{n_positions} | Legal rate: {legal_rate:.1%}")
# Calculate statistics
total = results["total_positions"]
if total > 0:
results["legal_rate_first_try"] = results["legal_first_try"] / total
results["legal_rate_with_retry"] = (results["legal_first_try"] + results["legal_with_retry"]) / total
results["illegal_rate"] = results["illegal_all_retries"] / total
else:
results["legal_rate_first_try"] = 0
results["legal_rate_with_retry"] = 0
results["illegal_rate"] = 1
return results
def evaluate(
self,
n_games: int = 100,
temperature: float = 0.7,
verbose: bool = True,
) -> dict:
"""
Run a full win-rate evaluation of the model against Stockfish.
Args:
n_games: Number of games to play.
temperature: Sampling temperature.
verbose: Whether to print progress.
Returns:
Dictionary with evaluation metrics.
"""
results = {
"wins": 0,
"losses": 0,
"draws": 0,
"illegal_moves": 0,
"total_moves": 0,
"games": [],
}
for i in range(n_games):
# Alternate colors
model_color = "white" if i % 2 == 0 else "black"
game = self.play_game(
model_color=model_color,
temperature=temperature,
)
results["games"].append(game)
results["total_moves"] += len(game.moves)
results["illegal_moves"] += game.illegal_move_count
# Count result
if game.result == "1/2-1/2":
results["draws"] += 1
elif (game.result == "1-0" and model_color == "white") or \
(game.result == "0-1" and model_color == "black"):
results["wins"] += 1
else:
results["losses"] += 1
if verbose and (i + 1) % 10 == 0:
print(f" Games: {i + 1}/{n_games} | "
f"W: {results['wins']} L: {results['losses']} D: {results['draws']}")
# Calculate statistics
total = results["wins"] + results["losses"] + results["draws"]
results["win_rate"] = results["wins"] / total if total > 0 else 0
results["draw_rate"] = results["draws"] / total if total > 0 else 0
results["loss_rate"] = results["losses"] / total if total > 0 else 0
total_attempts = results["total_moves"] + results["illegal_moves"]
# Average length counts both legal moves and illegal attempts so early illegal terminations
# don't show as near-zero length games.
results["avg_game_length"] = total_attempts / total if total > 0 else 0
# Illegal move rate: illegal attempts over total attempts
results["illegal_move_rate"] = results["illegal_moves"] / total_attempts if total_attempts > 0 else 0
# Estimate ELO (simplified)
# Stockfish Level 1 is approximately 1350 ELO
stockfish_elo = 1350
if results["win_rate"] > 0 or results["loss_rate"] > 0:
score = results["wins"] + 0.5 * results["draws"]
expected = total * 0.5 # Expected score against equal opponent
# Simple ELO estimation
if score > 0:
win_ratio = score / total
if win_ratio > 0 and win_ratio < 1:
elo_diff = -400 * (1 - 2 * win_ratio) / (1 if win_ratio > 0.5 else -1)
results["estimated_elo"] = stockfish_elo + elo_diff
else:
results["estimated_elo"] = stockfish_elo + (400 if win_ratio >= 1 else -400)
else:
results["estimated_elo"] = stockfish_elo - 400
else:
results["estimated_elo"] = None
return results
def load_model_from_hub(model_id: str, device: str = "auto"):
"""
Load a model from the Hugging Face Hub.
Args:
model_id: Model ID on Hugging Face Hub.
device: Device to load the model on.
Returns:
Tuple of (model, tokenizer).
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
# Import to register custom classes
from src.model import ChessConfig, ChessForCausalLM
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map=device,
)
return model, tokenizer
def main():
"""Main evaluation function."""
parser = argparse.ArgumentParser(description="Evaluate a chess model")
parser.add_argument(
"--model_path", type=str, required=True,
help="Path to the model or Hugging Face model ID"
)
parser.add_argument(
"--mode", type=str, default="legal", choices=["legal", "winrate", "both"],
help="Evaluation mode: 'legal' for legal move rate, 'winrate' for games, 'both' for both"
)
parser.add_argument(
"--stockfish_path", type=str, default=None,
help="Path to Stockfish executable"
)
parser.add_argument(
"--stockfish_level", type=int, default=1,
help="Stockfish skill level (0-20)"
)
parser.add_argument(
"--n_positions", type=int, default=500,
help="Number of positions for legal move evaluation"
)
parser.add_argument(
"--n_games", type=int, default=100,
help="Number of games to play for win rate evaluation"
)
parser.add_argument(
"--temperature", type=float, default=0.7,
help="Sampling temperature"
)
args = parser.parse_args()
print("=" * 60)
print("CHESS CHALLENGE - EVALUATION")
print("=" * 60)
# Load model
print(f"\nLoading model from: {args.model_path}")
if "/" in args.model_path and not args.model_path.startswith("."):
# Assume Hugging Face model ID
model, tokenizer = load_model_from_hub(args.model_path)
else:
# Local path
from transformers import AutoModelForCausalLM
from src.tokenizer import ChessTokenizer
from src.model import ChessConfig, ChessForCausalLM
tokenizer = ChessTokenizer.from_pretrained(args.model_path)
model = AutoModelForCausalLM.from_pretrained(args.model_path)
# Create evaluator
print(f"\nSetting up evaluator...")
evaluator = ChessEvaluator(
model=model,
tokenizer=tokenizer,
stockfish_path=args.stockfish_path,
stockfish_level=args.stockfish_level,
)
# Run legal move evaluation
if args.mode in ["legal", "both"]:
print(f"\n" + "=" * 60)
print("PHASE 1: LEGAL MOVE EVALUATION")
print("=" * 60)
print(f"Testing {args.n_positions} random positions...")
legal_results = evaluator.evaluate_legal_moves(
n_positions=args.n_positions,
temperature=args.temperature,
verbose=True,
)
print("\n" + "-" * 40)
print("LEGAL MOVE RESULTS")
print("-" * 40)
print(f" Positions tested: {legal_results['total_positions']}")
print(f" Legal (1st try): {legal_results['legal_first_try']} ({legal_results['legal_rate_first_try']:.1%})")
print(f" Legal (with retry): {legal_results['legal_first_try'] + legal_results['legal_with_retry']} ({legal_results['legal_rate_with_retry']:.1%})")
print(f" Always illegal: {legal_results['illegal_all_retries']} ({legal_results['illegal_rate']:.1%})")
# Run win rate evaluation
if args.mode in ["winrate", "both"]:
print(f"\n" + "=" * 60)
print("PHASE 2: WIN RATE EVALUATION")
print("=" * 60)
print(f"Playing {args.n_games} games against Stockfish (Level {args.stockfish_level})...")
winrate_results = evaluator.evaluate(
n_games=args.n_games,
temperature=args.temperature,
verbose=True,
)
print("\n" + "-" * 40)
print("WIN RATE RESULTS")
print("-" * 40)
print(f" Wins: {winrate_results['wins']}")
print(f" Losses: {winrate_results['losses']}")
print(f" Draws: {winrate_results['draws']}")
print(f"\n Win Rate: {winrate_results['win_rate']:.1%}")
print(f" Draw Rate: {winrate_results['draw_rate']:.1%}")
print(f" Loss Rate: {winrate_results['loss_rate']:.1%}")
print(f"\n Avg Game Length: {winrate_results['avg_game_length']:.1f} moves")
print(f" Illegal Move Rate: {winrate_results['illegal_move_rate']:.2%}")
if winrate_results["estimated_elo"]:
print(f"\n Estimated ELO: {winrate_results['estimated_elo']:.0f}")
print("\n" + "=" * 60)
print("EVALUATION COMPLETE")
print("=" * 60)
if __name__ == "__main__":
main()