|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import sys |
|
|
import time |
|
|
from kernels import get_kernel, get_local_kernel |
|
|
from pathlib import Path |
|
|
|
|
|
load_method = 3 |
|
|
|
|
|
if load_method == 1: |
|
|
sys.path.insert(0, "./torch-ext") |
|
|
import yamoe |
|
|
elif load_method == 2: |
|
|
yamoe = get_local_kernel(Path("result"), "yamoe") |
|
|
elif load_method == 3: |
|
|
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0") |
|
|
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
def zero_grads(m): |
|
|
for p in m.parameters(): |
|
|
if p.grad is not None: |
|
|
p.grad = None |
|
|
|
|
|
|
|
|
def benchmark_backward(model, x, tag: str, iters: int = 10, warmup: int = 10): |
|
|
x_local = x.detach().clone().requires_grad_(True) |
|
|
x_local.retain_grad() |
|
|
|
|
|
|
|
|
for _ in range(warmup): |
|
|
out = model(x_local) |
|
|
out = out[0] if isinstance(out, tuple) else out |
|
|
loss = out.mean() |
|
|
zero_grads(model) |
|
|
if x_local.grad is not None: |
|
|
x_local.grad = None |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
torch.cuda.synchronize() |
|
|
start = time.perf_counter() |
|
|
for _ in range(iters): |
|
|
out = model(x_local) |
|
|
out = out[0] if isinstance(out, tuple) else out |
|
|
loss = out.mean() |
|
|
zero_grads(model) |
|
|
if x_local.grad is not None: |
|
|
x_local.grad = None |
|
|
loss.backward() |
|
|
torch.cuda.synchronize() |
|
|
bwd_ms = (time.perf_counter() - start) * 1e3 / iters |
|
|
peak_mem = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
|
|
|
print(f"[{tag}] backward: {bwd_ms:.2f} ms | peak mem: {peak_mem:.2f} GB") |
|
|
return bwd_ms |
|
|
|
|
|
|
|
|
def print_grad_norms(m, x, tag): |
|
|
print(f"\n[{tag}] Gradient norms:") |
|
|
xg = x.grad.norm().item() if x.grad is not None else 0.0 |
|
|
print(f" input grad: {xg:.6f}") |
|
|
for name, p in m.named_parameters(): |
|
|
if p.grad is None: |
|
|
print(f" {name}: None") |
|
|
else: |
|
|
print(f" {name}: {p.grad.norm().item():.6f}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP |
|
|
new_moe_cls = yamoe.Yamoe |
|
|
|
|
|
batch_size, seq_len, hidden_dim = 4, 1024, 2880 |
|
|
num_experts, top_k = 8, 2 |
|
|
|
|
|
config = type("Config", (), {})() |
|
|
config.hidden_size = hidden_dim |
|
|
config.intermediate_size = hidden_dim |
|
|
config.num_local_experts = num_experts |
|
|
config.num_experts_per_tok = top_k |
|
|
ref_moe = ref_moe_cls(config) |
|
|
|
|
|
print(ref_moe) |
|
|
|
|
|
for p in ref_moe.parameters(): |
|
|
if p.dim() > 1: |
|
|
torch.nn.init.xavier_uniform_(p) |
|
|
else: |
|
|
torch.nn.init.zeros_(p) |
|
|
|
|
|
x = torch.randn(batch_size, seq_len, hidden_dim, device="cuda") |
|
|
ref_moe = ref_moe.cuda() |
|
|
|
|
|
|
|
|
print("\nReference Implementation Backward") |
|
|
|
|
|
|
|
|
print(" Warming up...") |
|
|
x_warmup = x.detach().requires_grad_(True) |
|
|
for _ in range(3): |
|
|
out = ref_moe(x_warmup) |
|
|
out = out[0] if isinstance(out, tuple) else out |
|
|
loss = out.mean() |
|
|
zero_grads(ref_moe) |
|
|
if x_warmup.grad is not None: |
|
|
x_warmup.grad = None |
|
|
loss.backward() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
x_ref = x.detach().requires_grad_(True) |
|
|
x_ref.retain_grad() |
|
|
ref_output = ref_moe(x_ref) |
|
|
out = ref_output[0] if isinstance(ref_output, tuple) else ref_output |
|
|
print(f" Input shape: {x_ref.shape}") |
|
|
print(f" Output shape: {out.shape}") |
|
|
print( |
|
|
f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}" |
|
|
) |
|
|
|
|
|
loss = out.mean() |
|
|
zero_grads(ref_moe) |
|
|
loss.backward() |
|
|
print_grad_norms(ref_moe, x_ref, "reference") |
|
|
|
|
|
benchmark_backward(ref_moe, x, tag="reference", warmup=10, iters=20) |
|
|
|
|
|
|
|
|
print("\nYAMOE-backed Implementation Backward") |
|
|
ref_moe.forward = new_moe_cls.forward.__get__(ref_moe) |
|
|
ref_moe._routing_weights_buffer = None |
|
|
ref_moe._batch_indices_buffer = None |
|
|
ref_moe._last_batch_seq = None |
|
|
ref_moe._last_num_experts = None |
|
|
ref_moe.enable_router_grads = True |
|
|
ref_moe.num_experts = num_experts |
|
|
ref_moe.top_k = top_k |
|
|
|
|
|
|
|
|
print(" Warming up...") |
|
|
x_warmup = x.detach().requires_grad_(True) |
|
|
for _ in range(3): |
|
|
out = ref_moe(x_warmup) |
|
|
out = out[0] if isinstance(out, tuple) else out |
|
|
loss = out.mean() |
|
|
zero_grads(ref_moe) |
|
|
if x_warmup.grad is not None: |
|
|
x_warmup.grad = None |
|
|
loss.backward() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
x_cuda = x.detach().requires_grad_(True) |
|
|
x_cuda.retain_grad() |
|
|
cuda_output = ref_moe(x_cuda) |
|
|
out = cuda_output[0] if isinstance(cuda_output, tuple) else cuda_output |
|
|
print(f" Input shape: {x_cuda.shape}") |
|
|
print(f" Output shape: {out.shape}") |
|
|
print( |
|
|
f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}" |
|
|
) |
|
|
|
|
|
loss = out.mean() |
|
|
zero_grads(ref_moe) |
|
|
loss.backward() |
|
|
print_grad_norms(ref_moe, x_cuda, "yamoe-backed") |
|
|
|
|
|
benchmark_backward(ref_moe, x, tag="yamoe-backed", warmup=10, iters=20) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|