#!/usr/bin/env python3 """ Maaza Nano-Orchestrator 9.6M - Custom Transformer Architecture TRUE 9.6M parameters from scratch. Architecture: vocab_size: 8000 hidden_size: 256 num_layers: 6 num_heads: 4 intermediate_size: 512 max_position: 512 Param breakdown: Embeddings: 8000 × 256 = 2.0M Per layer: ~0.8M × 6 = 4.8M Output head: 8000 × 256 = 2.0M + ~0.8M layernorm/etc Total: ~9.6M ✓ """ import math import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Tuple # ============================================================================ # MODEL CONFIGURATION # ============================================================================ @dataclass class MaazaNanoConfig: """Configuration for Maaza Nano 9.6M model. Param breakdown for 9.6M target: Embeddings: 8000 × 320 = 2.56M Per layer: ~1.0M × 7 = 7.04M Output (tied): 0 Total: ~9.60M = 9.6M ✓ """ vocab_size: int = 8000 hidden_size: int = 320 num_layers: int = 7 num_heads: int = 8 # 320 / 8 = 40 dim per head intermediate_size: int = 620 # tuned for 9.6M exact max_position_embeddings: int = 512 dropout: float = 0.1 layer_norm_eps: float = 1e-6 rope_theta: float = 10000.0 tie_word_embeddings: bool = True def __post_init__(self): assert self.hidden_size % self.num_heads == 0 self.head_dim = self.hidden_size // self.num_heads # ============================================================================ # ROTARY POSITIONAL EMBEDDING (RoPE) # ============================================================================ class RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE) - efficient positional encoding.""" def __init__(self, dim: int, max_position: int = 512, theta: float = 10000.0): super().__init__() self.dim = dim self.max_position = max_position self.theta = theta # Precompute inverse frequencies inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) # Precompute cos/sin for all positions self._build_cache(max_position) def _build_cache(self, seq_len: int): positions = torch.arange(seq_len, dtype=torch.float32) freqs = torch.einsum("i,j->ij", positions, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos()) self.register_buffer("sin_cached", emb.sin()) def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: if seq_len > self.max_position: self._build_cache(seq_len) return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims.""" x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary position embedding to query and key tensors.""" # Expand cos/sin for batch and heads cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim] sin = sin.unsqueeze(0).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # ============================================================================ # ATTENTION # ============================================================================ class MaazaAttention(nn.Module): """Multi-head attention with RoPE.""" def __init__(self, config: MaazaNanoConfig): super().__init__() self.config = config self.num_heads = config.num_heads self.head_dim = config.head_dim self.scale = self.head_dim ** -0.5 self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.rotary_emb = RotaryEmbedding( dim=self.head_dim, max_position=config.max_position_embeddings, theta=config.rope_theta ) self.dropout = nn.Dropout(config.dropout) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, seq_len, _ = hidden_states.shape # Project to Q, K, V q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) # Reshape for multi-head attention q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Apply RoPE cos, sin = self.rotary_emb(seq_len) q, k = apply_rotary_pos_emb(q, k, cos, sin) # Attention scores attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale # Apply causal mask (always needed for autoregressive generation) causal_mask = torch.triu( torch.ones(seq_len, seq_len, dtype=torch.bool, device=hidden_states.device), diagonal=1 ) attn_weights = attn_weights.masked_fill(causal_mask, float("-inf")) # Also apply padding mask if provided if attention_mask is not None: attn_weights = attn_weights + attention_mask # Softmax and dropout attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_weights = self.dropout(attn_weights) # Apply attention to values attn_output = torch.matmul(attn_weights, v) # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) attn_output = self.o_proj(attn_output) return attn_output # ============================================================================ # FEEDFORWARD # ============================================================================ class MaazaMLP(nn.Module): """Feedforward network with SwiGLU activation.""" def __init__(self, config: MaazaNanoConfig): super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: # SwiGLU activation gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) return self.dropout(self.down_proj(gate * up)) # ============================================================================ # TRANSFORMER LAYER # ============================================================================ class MaazaLayer(nn.Module): """Single transformer layer with pre-norm.""" def __init__(self, config: MaazaNanoConfig): super().__init__() self.attention = MaazaAttention(config) self.mlp = MaazaMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Pre-norm attention with residual residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.attention(hidden_states, attention_mask) hidden_states = residual + hidden_states # Pre-norm MLP with residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states # ============================================================================ # FULL MODEL # ============================================================================ class MaazaNanoModel(nn.Module): """Maaza Nano 9.6M - Tool Routing Transformer.""" def __init__(self, config: MaazaNanoConfig): super().__init__() self.config = config # Token embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # Transformer layers self.layers = nn.ModuleList([ MaazaLayer(config) for _ in range(config.num_layers) ]) # Final layer norm self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Output projection (tied with embeddings if configured) if config.tie_word_embeddings: self.lm_head = None else: self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize weights.""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.ones_(module.weight) torch.nn.init.zeros_(module.bias) def get_input_embeddings(self): return self.embed_tokens def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ) -> dict: # Get embeddings hidden_states = self.embed_tokens(input_ids) # Create attention mask if needed if attention_mask is not None: # Convert attention mask to additive mask attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(hidden_states.dtype).min # Pass through layers for layer in self.layers: hidden_states = layer(hidden_states, attention_mask) # Final norm hidden_states = self.norm(hidden_states) # Compute logits if self.lm_head is not None: logits = self.lm_head(hidden_states) else: # Tied embeddings logits = F.linear(hidden_states, self.embed_tokens.weight) # Compute loss if labels provided loss = None if labels is not None: # Shift for next token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100 ) return {"loss": loss, "logits": logits, "hidden_states": hidden_states} @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 128, temperature: float = 0.3, top_p: float = 0.9, repetition_penalty: float = 1.2, eos_token_id: int = 3, # <|eos|> ) -> torch.Tensor: """Generate tokens autoregressively.""" self.eval() for _ in range(max_new_tokens): # Forward pass outputs = self(input_ids) logits = outputs["logits"][:, -1, :] # Last token logits # Apply repetition penalty if repetition_penalty != 1.0: for i in range(input_ids.size(0)): for token_id in set(input_ids[i].tolist()): logits[i, token_id] /= repetition_penalty # Apply temperature logits = logits / temperature # Top-p sampling sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 for i in range(logits.size(0)): indices_to_remove = sorted_indices[i, sorted_indices_to_remove[i]] logits[i, indices_to_remove] = float("-inf") # Sample next token probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to sequence input_ids = torch.cat([input_ids, next_token], dim=-1) # Check for EOS if (next_token == eos_token_id).all(): break return input_ids def count_parameters(self) -> dict: """Count parameters by component.""" counts = {} # Embeddings counts["embeddings"] = sum(p.numel() for p in self.embed_tokens.parameters()) # Layers layer_params = sum(p.numel() for layer in self.layers for p in layer.parameters()) counts["layers"] = layer_params # Norm counts["norm"] = sum(p.numel() for p in self.norm.parameters()) # LM head (if not tied) if self.lm_head is not None: counts["lm_head"] = sum(p.numel() for p in self.lm_head.parameters()) else: counts["lm_head"] = 0 # Tied with embeddings counts["total"] = sum(counts.values()) return counts def create_model(vocab_size: int = 8000) -> MaazaNanoModel: """Create Maaza Nano 9.6M model.""" config = MaazaNanoConfig(vocab_size=vocab_size) model = MaazaNanoModel(config) return model if __name__ == "__main__": print("=" * 60) print("Maaza Nano-Orchestrator 9.6M - Architecture Verification") print("=" * 60) # Create model model = create_model() # Count parameters param_counts = model.count_parameters() print("\nParameter counts:") for name, count in param_counts.items(): print(f" {name:20s}: {count:,} ({count/1e6:.2f}M)") # Target verification total = param_counts["total"] target = 9.6e6 diff = abs(total - target) / target * 100 print(f"\nTarget: 9.6M") print(f"Actual: {total/1e6:.2f}M") print(f"Diff: {diff:.1f}%") if diff < 10: print("\n✓ Model architecture verified!") else: print(f"\n✗ Model size off by {diff:.1f}% - adjust config") # Test forward pass print("\n" + "=" * 60) print("Testing forward pass...") batch = torch.randint(0, 8000, (2, 64)) # Batch of 2, seq len 64 outputs = model(batch) print(f" Input shape: {batch.shape}") print(f" Output shape: {outputs['logits'].shape}") print(f" Hidden shape: {outputs['hidden_states'].shape}") # Test generation print("\nTesting generation...") prompt = torch.randint(0, 8000, (1, 10)) generated = model.generate(prompt, max_new_tokens=20) print(f" Prompt length: {prompt.shape[1]}") print(f" Generated length: {generated.shape[1]}") # Memory estimate print("\n" + "=" * 60) print("Memory estimates:") fp32_bytes = total * 4 fp16_bytes = total * 2 int8_bytes = total * 1 print(f" FP32: {fp32_bytes / 1e6:.1f} MB") print(f" FP16: {fp16_bytes / 1e6:.1f} MB") print(f" INT8: {int8_bytes / 1e6:.1f} MB (quantized)") print("\n✓ Model ready for training!") print(f"Next step: python train.py --dataset dataset.jsonl --tokenizer tokenizer.json")