drbh commited on
Commit
b992f14
·
1 Parent(s): 281d8ba

feat: adjust reference impl

Browse files
build/torch27-cxx11-cu118-x86_64-linux/yamoe/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ from . import reference
3
+
4
+ gather = ops.gather
5
+ scatter = ops.scatter
6
+ sort = ops.sort
7
+ bincount_cumsum = ops.bincount_cumsum
8
+ batch_mm = ops.batch_mm
9
+ experts = ops.experts
10
+
11
+ __all__ = [
12
+ "shuffle",
13
+ "gather",
14
+ "scatter",
15
+ "sort",
16
+ "bincount_cumsum",
17
+ "batch_mm",
18
+ "experts",
19
+ # Export the reference implementation
20
+ "reference",
21
+ ]
build/torch28-cxx11-cu129-x86_64-linux/yamoe/reference.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class GptOssExperts(nn.Module):
5
+ def __init__(self, config):
6
+ super().__init__()
7
+ self.intermediate_size = config.intermediate_size
8
+ self.num_experts = config.num_local_experts
9
+ self.hidden_size = config.hidden_size
10
+ self.expert_dim = self.intermediate_size
11
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
12
+ self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
13
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
14
+ self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
15
+ self.alpha = 1.702
16
+ self.limit = 7.0
17
+
18
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
19
+ """
20
+ When training is is more efficient to just loop over the experts and compute the output for each expert
21
+ as otherwise the memory would explode.
22
+
23
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
24
+
25
+ Args:
26
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
27
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
28
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
29
+ Returns:
30
+ torch.Tensor
31
+ """
32
+
33
+ # import ipdb; ipdb.set_trace()
34
+
35
+ batch_size = hidden_states.shape[0]
36
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
37
+ num_experts = routing_weights.shape[1]
38
+ if self.training:
39
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
40
+ with torch.no_grad():
41
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
42
+ expert_mask = expert_mask.permute(2, 1, 0)
43
+ # we sum on the top_k and on the sequence lenght to get which experts
44
+ # are hit this time around
45
+ expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
46
+ for expert_idx in expert_hitted[:]:
47
+ with torch.no_grad():
48
+ _, token_idx = torch.where(expert_mask[expert_idx[0]])
49
+ current_state = hidden_states[token_idx]
50
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
51
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
52
+ gate = gate.clamp(min=None, max=self.limit)
53
+ up = up.clamp(min=-self.limit, max=self.limit)
54
+ glu = gate * torch.sigmoid(gate * self.alpha)
55
+ gated_output = (up + 1) * glu
56
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
57
+ weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
58
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
59
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
60
+ else:
61
+ hidden_states = hidden_states.repeat(num_experts, 1)
62
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
63
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
64
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
65
+ gate = gate.clamp(min=None, max=self.limit)
66
+ up = up.clamp(min=-self.limit, max=self.limit)
67
+ glu = gate * torch.sigmoid(gate * self.alpha)
68
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
69
+ next_states = next_states + self.down_proj_bias[..., None, :]
70
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
71
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
72
+ next_states = next_states.sum(dim=0)
73
+ return next_states