Snake World Model - Pixel-Space EDM v2

A DIAMOND-style pixel-space EDM (Elucidating the Design Space of Diffusion-Based Generative Models) implementation for world modeling in the Snake game environment. This model predicts future game frames conditioned on previous frames and player actions.

See GitHub repo for train and play code: https://github.com/roastedpotato66/snake-world-modeling

Model Details

Architecture

The model uses a DIAMOND-style UNet architecture working directly in pixel space:

  • Input: (B, 15, 64, 64) - noisy target frame (3 channels) + 4 context frames (12 channels)
  • Output: (B, 3, 64, 64) - denoised next frame prediction
  • Base dimensions: 128 channels, 512 condition dimension
  • Resolution: 64Γ—64 RGB images

Key Components

  1. UNet Encoder-Decoder

    • Encoder: 64Γ—64 β†’ 32Γ—32 β†’ 16Γ—16 β†’ 8Γ—8 (3 downsampling blocks)
    • Bottleneck: Self-attention at 8Γ—8 resolution for global reasoning
    • Decoder: 8Γ—8 β†’ 16Γ—16 β†’ 32Γ—32 β†’ 64Γ—64 (3 upsampling blocks)
    • Skip connections between encoder and decoder
  2. Adaptive Group Normalization

    • Conditions normalization on combined action + noise level embeddings
    • Enables strong action conditioning throughout the network
  3. EDM Preconditioning

    • Preconditioned network output: c_skip * x_noisy + c_out * network(x)
    • Stable training with very few denoising steps (only 3 steps needed)
  4. Frame Stacking

    • 4 previous frames concatenated channel-wise (12 channels total)
    • Provides temporal context for prediction

Training Metrics

  • Best Epoch: 34
  • Best Validation MSE: 0.000137
  • Training Loss (final): 0.000798
  • CFG Difference: 0.003674

Usage

The simplest way is to directly download the model.pt and create and move it to the output/ folder. Before playing, you will need to generate some data (1k is enough) for initialization. Follow the instructions in the GitHub repository's README.

Interactive Play

Use the provided play script in the GitHub:

python scripts/play_pixel_edm.py \
    --model_path model.pt \
    --data_dir data/images \
    --cfg_scale 2.0 \
    --steps 3

Controls:

  • WASD or Arrow Keys - Move snake
  • R - Reset with new random seed from data
  • ESC - Quit

Training Details

Dataset

  • Format: 64Γ—64 RGB images from Snake game
  • Context: 4 consecutive frames
  • Actions: One-hot encoded [UP, DOWN, LEFT, RIGHT]
  • Special events: Death and eating events weighted 5Γ— for balanced training

Training Configuration

  • Loss: Weighted MSE with EDM loss weighting
  • Optimizer: AdamW (lr=1e-4, weight_decay=1e-4)
  • Scheduler: Cosine annealing over 40 epochs
  • Mixed precision: BF16 AMP
  • EMA: Exponential moving average (decay=0.9999)
  • Weighted sampling: 5Γ— weight for death/eating events
  • CFG dropout: 30% for classifier-free guidance
  • Batch size: 512
  • Denoising steps: 3 (Euler method)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support