Spaces:
Sleeping
Sleeping
| def memory_for_attention_layer(precession: int, | |
| seq_len: int, | |
| batch_size: int, | |
| hidden_size: int, | |
| num_heads: int): | |
| """ | |
| head_dim = hidden_size // num_heads | |
| Model Parameters: | |
| q_proj: (hidden_size, num_heads * head_dim) | |
| k_proj: (hidden_size, num_key_value_heads * head_dim) | |
| v_proj: (hidden_size, num_key_value_heads * head_dim) | |
| o_proj: (hidden_size, hidden_size) | |
| Total parameters = 3 * hidden_size * num_heads * head_dim + hidden_size^2 | |
| Memory required for model parameters = (3 * hidden_size * num_heads * head_dim + hidden_size^2) | |
| Gradients: | |
| Gradients have the same size as the model parameters. | |
| Memory required for gradients = (3 * hidden_size * num_heads * head_dim + hidden_size^2) | |
| Optimizer States: | |
| Assuming Adam optimizer with two states per parameter (momentum and variance). | |
| Memory required for optimizer states = 2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2) | |
| Activations: | |
| query_states: (batch_size, num_heads, q_len, head_dim) | |
| key_states: (batch_size, num_key_value_heads, q_len, head_dim) | |
| value_states: (batch_size, num_key_value_heads, q_len, head_dim) | |
| attn_weights: (batch_size, num_heads, q_len, q_len) | |
| attn_output: (batch_size, q_len, hidden_size) | |
| Total activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size) | |
| Memory required for activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size) | |
| Temporary Memory: | |
| Additional temporary memory for intermediate computations and buffer storage. | |
| Assuming 20% of the total memory as temporary memory. | |
| total_memory = (model_parameters + gradients + optimizer_states + activations) * (1 + temporary_memory_factor) | |
| ((3 * hidden_size * num_heads * head_dim + hidden_size^2) + | |
| (3 * hidden_size * num_heads * head_dim + hidden_size^2) + | |
| 2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2) + | |
| batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)) * (1 + 0.2) | |
| """ | |
| head_dim = hidden_size // num_heads | |
| # Model Memory (3 * hidden_size * num_heads * head_dim + hidden_size^2) | |
| model_memory = 3 * hidden_size * num_heads * head_dim + hidden_size ** 2 | |
| # Gradients = model_memory | |
| gradients = model_memory | |
| # Optimizer | |
| optimizer = 2 * model_memory | |
| # Activation | |
| activation = batch_size * (3 * num_heads * seq_len * head_dim + | |
| num_heads * seq_len ** 2 + | |
| seq_len * hidden_size | |
| ) | |
| total_memory = (model_memory + gradients + optimizer + activation) * precession | |
| return total_memory | |
| def memory_mlp_layer(precession: int, | |
| seq_len: int, | |
| batch_size: int, | |
| hidden_size: int, | |
| intermediate_size: int): | |
| """ | |
| MLP model | |
| gate_proj (hidden_size, intermediate_size) | |
| up_proj (hidden_size, intermediate_size) | |
| down_proj (intermediate_size, hidden_size) | |
| Memory required for gate_proj weights = intermediate_size * hidden_size | |
| Memory required for up_proj weights = intermediate_size * hidden_size | |
| Memory required for down_proj weights = intermediate_size * hidden_size | |
| model memory = 3 * (hidden_size * intermediate_size) | |
| gradient = model_memory | |
| optimizer = 2 * model_memory | |
| activations = batch_size * seq_len * hidden_size + 2 * batch_size * seq_len * intermediate_size | |
| total_memory = 3 * (hidden_size * intermediate_size) + 3 * (hidden_size * intermediate_size) + 6 * (hidden_size * intermediate_size) + batch_size * (2 * intermediate_size + hidden_size) | |
| total_memory = (hidden_size * intermediate_size) * 12 + Batch_size * seq_len * (2 * intermediate_size + hidden_size) | |
| Args: | |
| hidden_size: | |
| intermediate_size: | |
| batch_size: | |
| seq_len: | |
| Returns: | |
| """ | |
| model_memory = 3 * (hidden_size * intermediate_size) | |
| gradient = model_memory | |
| optimizer = 2 * model_memory | |
| activation = batch_size * seq_len * (2 * intermediate_size + hidden_size) | |
| total_memory = (model_memory + gradient + hidden_size + activation) * precession | |
| return total_memory | |
| def memory_moe_mlp(precession: int, | |
| seq_len: int, | |
| batch_size: int, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| num_expert: int, | |
| top_k: int): | |
| # model memory | |
| gat_memory = hidden_size * num_expert | |
| # The result in byte | |
| moe_mlp = memory_mlp_layer(precession, seq_len, batch_size, hidden_size, intermediate_size) * num_expert | |
| # total model memory The result in byte | |
| model_memory = gat_memory * precession + moe_mlp | |
| # optimizer and gradient as before. | |
| # activation | |
| max_memory_activation = ( | |
| (batch_size * seq_len * num_expert * precession) + # Router logits | |
| (batch_size * seq_len * top_k * precession) + # Routing weights | |
| (batch_size * seq_len * top_k * precession) + # Selected experts | |
| (batch_size * seq_len * hidden_size * precession) + # Final hidden states | |
| (batch_size * seq_len * hidden_size * precession) + # Current state (worst-case) | |
| (batch_size * seq_len * hidden_size * precession) # Current hidden states (worst-case) | |
| ) | |
| total_memory = model_memory + model_memory + 2 * model_memory + max_memory_activation | |
| return total_memory | |