drbh commited on
Commit
906ccdf
·
1 Parent(s): cf66620

fix: improve example output and allow pushing build

Browse files
Files changed (2) hide show
  1. .gitignore +0 -1
  2. gpt_oss_match.py +13 -5
.gitignore CHANGED
@@ -15,5 +15,4 @@ csrc/batch_mm.cu
15
  torch-ext/yamoe/*.abi3.so
16
 
17
  build-ext
18
- build
19
  exploration
 
15
  torch-ext/yamoe/*.abi3.so
16
 
17
  build-ext
 
18
  exploration
gpt_oss_match.py CHANGED
@@ -49,8 +49,15 @@ def main():
49
  ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP
50
  new_moe_cls = yamoe.Yamoe
51
 
52
- batch_size, seq_len, hidden_dim = 4, 1024, 2880
53
- num_experts, top_k = 8, 2
 
 
 
 
 
 
 
54
 
55
  config = type("Config", (), {})()
56
  config.hidden_size = hidden_dim
@@ -59,6 +66,7 @@ def main():
59
  config.num_experts_per_tok = top_k
60
  ref_moe = ref_moe_cls(config)
61
 
 
62
  print(ref_moe)
63
 
64
  for p in ref_moe.parameters():
@@ -91,8 +99,8 @@ def main():
91
 
92
  benchmark_forward(ref_moe, x, tag="reference", warmup=10, iters=20)
93
 
94
- # Switch to YAMOE-backed forward
95
- print("\nYAMOE-backed Implementation")
96
  ref_moe.forward = new_moe_cls.forward.__get__(ref_moe)
97
  ref_moe._routing_weights_buffer = None
98
  ref_moe._batch_indices_buffer = None
@@ -117,7 +125,7 @@ def main():
117
  f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
118
  )
119
 
120
- benchmark_forward(ref_moe, x, tag="yamoe-backed", warmup=10, iters=20)
121
 
122
 
123
  if __name__ == "__main__":
 
49
  ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP
50
  new_moe_cls = yamoe.Yamoe
51
 
52
+ batch_size, seq_len, hidden_dim = 1, 1024, 2880
53
+ num_experts, top_k = 32, 4
54
+
55
+ print("\nInput parameters:")
56
+ print(f" Batch size: {batch_size}")
57
+ print(f" Seq len: {seq_len}")
58
+ print(f" Hidden dim: {hidden_dim}")
59
+ print(f" Num experts: {num_experts}")
60
+ print(f" Top-k: {top_k}")
61
 
62
  config = type("Config", (), {})()
63
  config.hidden_size = hidden_dim
 
66
  config.num_experts_per_tok = top_k
67
  ref_moe = ref_moe_cls(config)
68
 
69
+ print("\nModel:")
70
  print(ref_moe)
71
 
72
  for p in ref_moe.parameters():
 
99
 
100
  benchmark_forward(ref_moe, x, tag="reference", warmup=10, iters=20)
101
 
102
+ # Switch to YAMOE forward
103
+ print("\nYAMOE Implementation")
104
  ref_moe.forward = new_moe_cls.forward.__get__(ref_moe)
105
  ref_moe._routing_weights_buffer = None
106
  ref_moe._batch_indices_buffer = None
 
125
  f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
126
  )
127
 
128
+ benchmark_forward(ref_moe, x, tag="yamoe", warmup=10, iters=20)
129
 
130
 
131
  if __name__ == "__main__":