Update modeling_mmMamba_embedding.py
Browse files
modeling_mmMamba_embedding.py
CHANGED
|
@@ -14,52 +14,30 @@
|
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
import math
|
| 17 |
-
import
|
| 18 |
-
import threading
|
| 19 |
-
import warnings
|
| 20 |
-
from typing import List, Optional, Tuple, Union
|
| 21 |
-
from functools import partial
|
| 22 |
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
| 25 |
import torch.utils.checkpoint
|
| 26 |
from einops import rearrange
|
| 27 |
from torch import nn
|
| 28 |
-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 29 |
from transformers.activations import ACT2FN
|
| 30 |
-
|
| 31 |
-
BaseModelOutputWithPast,
|
| 32 |
-
CausalLMOutputWithPast,
|
| 33 |
-
SequenceClassifierOutputWithPast,
|
| 34 |
-
)
|
| 35 |
from transformers.modeling_utils import PreTrainedModel
|
| 36 |
-
from transformers.
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
logging,
|
| 41 |
-
replace_return_docstrings,
|
| 42 |
-
)
|
| 43 |
-
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
|
| 44 |
-
import copy
|
| 45 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 46 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 47 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 48 |
-
|
| 49 |
-
import time
|
| 50 |
from timm.models.layers import DropPath
|
| 51 |
|
| 52 |
compute_ARank = False # [ARank] Set this to True to compute attention rank
|
| 53 |
|
| 54 |
-
try:
|
| 55 |
-
from transformers.generation.streamers import BaseStreamer
|
| 56 |
-
except: # noqa # pylint: disable=bare-except
|
| 57 |
-
BaseStreamer = None
|
| 58 |
-
|
| 59 |
from .configuration_mmMamba_embedding import mmMambaEmbeddingConfig
|
| 60 |
|
| 61 |
-
import time
|
| 62 |
-
|
| 63 |
from .configuration_mmMamba import mmMambaConfig
|
| 64 |
|
| 65 |
try:
|
|
|
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
import math
|
| 17 |
+
from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn.functional as F
|
| 21 |
import torch.utils.checkpoint
|
| 22 |
from einops import rearrange
|
| 23 |
from torch import nn
|
|
|
|
| 24 |
from transformers.activations import ACT2FN
|
| 25 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from transformers.modeling_utils import PreTrainedModel
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
from fused_norm_gate import FusedRMSNormSwishGate
|
| 30 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 32 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 33 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 34 |
+
|
|
|
|
| 35 |
from timm.models.layers import DropPath
|
| 36 |
|
| 37 |
compute_ARank = False # [ARank] Set this to True to compute attention rank
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
from .configuration_mmMamba_embedding import mmMambaEmbeddingConfig
|
| 40 |
|
|
|
|
|
|
|
| 41 |
from .configuration_mmMamba import mmMambaConfig
|
| 42 |
|
| 43 |
try:
|