drbh
commited on
Commit
·
906ccdf
1
Parent(s):
cf66620
fix: improve example output and allow pushing build
Browse files- .gitignore +0 -1
- gpt_oss_match.py +13 -5
.gitignore
CHANGED
|
@@ -15,5 +15,4 @@ csrc/batch_mm.cu
|
|
| 15 |
torch-ext/yamoe/*.abi3.so
|
| 16 |
|
| 17 |
build-ext
|
| 18 |
-
build
|
| 19 |
exploration
|
|
|
|
| 15 |
torch-ext/yamoe/*.abi3.so
|
| 16 |
|
| 17 |
build-ext
|
|
|
|
| 18 |
exploration
|
gpt_oss_match.py
CHANGED
|
@@ -49,8 +49,15 @@ def main():
|
|
| 49 |
ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP
|
| 50 |
new_moe_cls = yamoe.Yamoe
|
| 51 |
|
| 52 |
-
batch_size, seq_len, hidden_dim =
|
| 53 |
-
num_experts, top_k =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
config = type("Config", (), {})()
|
| 56 |
config.hidden_size = hidden_dim
|
|
@@ -59,6 +66,7 @@ def main():
|
|
| 59 |
config.num_experts_per_tok = top_k
|
| 60 |
ref_moe = ref_moe_cls(config)
|
| 61 |
|
|
|
|
| 62 |
print(ref_moe)
|
| 63 |
|
| 64 |
for p in ref_moe.parameters():
|
|
@@ -91,8 +99,8 @@ def main():
|
|
| 91 |
|
| 92 |
benchmark_forward(ref_moe, x, tag="reference", warmup=10, iters=20)
|
| 93 |
|
| 94 |
-
# Switch to YAMOE
|
| 95 |
-
print("\nYAMOE
|
| 96 |
ref_moe.forward = new_moe_cls.forward.__get__(ref_moe)
|
| 97 |
ref_moe._routing_weights_buffer = None
|
| 98 |
ref_moe._batch_indices_buffer = None
|
|
@@ -117,7 +125,7 @@ def main():
|
|
| 117 |
f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
|
| 118 |
)
|
| 119 |
|
| 120 |
-
benchmark_forward(ref_moe, x, tag="yamoe
|
| 121 |
|
| 122 |
|
| 123 |
if __name__ == "__main__":
|
|
|
|
| 49 |
ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP
|
| 50 |
new_moe_cls = yamoe.Yamoe
|
| 51 |
|
| 52 |
+
batch_size, seq_len, hidden_dim = 1, 1024, 2880
|
| 53 |
+
num_experts, top_k = 32, 4
|
| 54 |
+
|
| 55 |
+
print("\nInput parameters:")
|
| 56 |
+
print(f" Batch size: {batch_size}")
|
| 57 |
+
print(f" Seq len: {seq_len}")
|
| 58 |
+
print(f" Hidden dim: {hidden_dim}")
|
| 59 |
+
print(f" Num experts: {num_experts}")
|
| 60 |
+
print(f" Top-k: {top_k}")
|
| 61 |
|
| 62 |
config = type("Config", (), {})()
|
| 63 |
config.hidden_size = hidden_dim
|
|
|
|
| 66 |
config.num_experts_per_tok = top_k
|
| 67 |
ref_moe = ref_moe_cls(config)
|
| 68 |
|
| 69 |
+
print("\nModel:")
|
| 70 |
print(ref_moe)
|
| 71 |
|
| 72 |
for p in ref_moe.parameters():
|
|
|
|
| 99 |
|
| 100 |
benchmark_forward(ref_moe, x, tag="reference", warmup=10, iters=20)
|
| 101 |
|
| 102 |
+
# Switch to YAMOE forward
|
| 103 |
+
print("\nYAMOE Implementation")
|
| 104 |
ref_moe.forward = new_moe_cls.forward.__get__(ref_moe)
|
| 105 |
ref_moe._routing_weights_buffer = None
|
| 106 |
ref_moe._batch_indices_buffer = None
|
|
|
|
| 125 |
f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
|
| 126 |
)
|
| 127 |
|
| 128 |
+
benchmark_forward(ref_moe, x, tag="yamoe", warmup=10, iters=20)
|
| 129 |
|
| 130 |
|
| 131 |
if __name__ == "__main__":
|