Snake World Model - Pixel-Space EDM v2
A DIAMOND-style pixel-space EDM (Elucidating the Design Space of Diffusion-Based Generative Models) implementation for world modeling in the Snake game environment. This model predicts future game frames conditioned on previous frames and player actions.
See GitHub repo for train and play code: https://github.com/roastedpotato66/snake-world-modeling
Model Details
Architecture
The model uses a DIAMOND-style UNet architecture working directly in pixel space:
- Input:
(B, 15, 64, 64)- noisy target frame (3 channels) + 4 context frames (12 channels) - Output:
(B, 3, 64, 64)- denoised next frame prediction - Base dimensions: 128 channels, 512 condition dimension
- Resolution: 64Γ64 RGB images
Key Components
UNet Encoder-Decoder
- Encoder: 64Γ64 β 32Γ32 β 16Γ16 β 8Γ8 (3 downsampling blocks)
- Bottleneck: Self-attention at 8Γ8 resolution for global reasoning
- Decoder: 8Γ8 β 16Γ16 β 32Γ32 β 64Γ64 (3 upsampling blocks)
- Skip connections between encoder and decoder
Adaptive Group Normalization
- Conditions normalization on combined action + noise level embeddings
- Enables strong action conditioning throughout the network
EDM Preconditioning
- Preconditioned network output:
c_skip * x_noisy + c_out * network(x) - Stable training with very few denoising steps (only 3 steps needed)
- Preconditioned network output:
Frame Stacking
- 4 previous frames concatenated channel-wise (12 channels total)
- Provides temporal context for prediction
Training Metrics
- Best Epoch: 34
- Best Validation MSE: 0.000137
- Training Loss (final): 0.000798
- CFG Difference: 0.003674
Usage
The simplest way is to directly download the model.pt and create and move it to the output/ folder. Before playing, you will need to generate some data (1k is enough) for initialization. Follow the instructions in the GitHub repository's README.
Interactive Play
Use the provided play script in the GitHub:
python scripts/play_pixel_edm.py \
--model_path model.pt \
--data_dir data/images \
--cfg_scale 2.0 \
--steps 3
Controls:
WASDor Arrow Keys - Move snakeR- Reset with new random seed from dataESC- Quit
Training Details
Dataset
- Format: 64Γ64 RGB images from Snake game
- Context: 4 consecutive frames
- Actions: One-hot encoded [UP, DOWN, LEFT, RIGHT]
- Special events: Death and eating events weighted 5Γ for balanced training
Training Configuration
- Loss: Weighted MSE with EDM loss weighting
- Optimizer: AdamW (lr=1e-4, weight_decay=1e-4)
- Scheduler: Cosine annealing over 40 epochs
- Mixed precision: BF16 AMP
- EMA: Exponential moving average (decay=0.9999)
- Weighted sampling: 5Γ weight for death/eating events
- CFG dropout: 30% for classifier-free guidance
- Batch size: 512
- Denoising steps: 3 (Euler method)