Commit
·
fb106d4
1
Parent(s):
9ec9988
up
Browse files- check_gradients_pt_flax.py +179 -0
check_gradients_pt_flax.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from transformers import SpeechEncoderDecoderModel, FlaxSpeechEncoderDecoderModel
|
| 3 |
+
import tempfile
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import optax
|
| 8 |
+
import jax
|
| 9 |
+
from flax.training.common_utils import onehot
|
| 10 |
+
from flax.traverse_util import flatten_dict
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def ids_tensor(shape, vocab_size, rng=None):
|
| 14 |
+
"""Creates a random int32 tensor of the shape within the vocab size."""
|
| 15 |
+
if rng is None:
|
| 16 |
+
rng = random.Random()
|
| 17 |
+
|
| 18 |
+
total_dims = 1
|
| 19 |
+
for dim in shape:
|
| 20 |
+
total_dims *= dim
|
| 21 |
+
|
| 22 |
+
values = []
|
| 23 |
+
for _ in range(total_dims):
|
| 24 |
+
values.append(rng.randint(0, vocab_size - 1))
|
| 25 |
+
|
| 26 |
+
output = np.array(values).reshape(shape)
|
| 27 |
+
|
| 28 |
+
return output
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def random_attention_mask(shape, rng=None):
|
| 32 |
+
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
|
| 33 |
+
# make sure that at least one token is attended to for each batch
|
| 34 |
+
attn_mask[:, -1] = 1
|
| 35 |
+
return attn_mask
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def floats_tensor(shape, scale=1.0, rng=None):
|
| 39 |
+
"""Creates a random float32 tensor"""
|
| 40 |
+
if rng is None:
|
| 41 |
+
rng = random.Random()
|
| 42 |
+
|
| 43 |
+
total_dims = 1
|
| 44 |
+
for dim in shape:
|
| 45 |
+
total_dims *= dim
|
| 46 |
+
|
| 47 |
+
values = []
|
| 48 |
+
for _ in range(total_dims):
|
| 49 |
+
values.append(rng.random() * scale)
|
| 50 |
+
|
| 51 |
+
return np.array(values, dtype=np.float32).reshape(shape)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Shift input ids one token to the right.
|
| 57 |
+
"""
|
| 58 |
+
shifted_input_ids = np.zeros_like(input_ids)
|
| 59 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
| 60 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
| 61 |
+
|
| 62 |
+
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
| 63 |
+
return shifted_input_ids
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 4e-2):
|
| 67 |
+
diff = np.abs((a - b)).max()
|
| 68 |
+
if diff < tol:
|
| 69 |
+
print(f"✅ Difference between Flax and PyTorch is {diff} (< {tol})")
|
| 70 |
+
else:
|
| 71 |
+
print(f"❌ Difference between Flax and PyTorch is {diff} (>= {tol})")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def assert_dict_equal(a: dict, b: dict, tol: float = 4e-2):
|
| 75 |
+
if a.keys() != b.keys():
|
| 76 |
+
print("❌ Dictionary keys for PyTorch and Flax do not match")
|
| 77 |
+
for k in a:
|
| 78 |
+
diff = np.abs((a[k] - b[k])).max()
|
| 79 |
+
if diff < tol:
|
| 80 |
+
print(f"✅ Layer {k} diff is {diff} < {tol}).")
|
| 81 |
+
else:
|
| 82 |
+
print(f"❌ Layer {k} diff is {diff} (>= {tol}).")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def main():
|
| 86 |
+
encoder_id = "hf-internal-testing/tiny-random-wav2vec2"
|
| 87 |
+
decoder_id = "hf-internal-testing/tiny-random-bart"
|
| 88 |
+
|
| 89 |
+
use_decoder_attention_mask = False
|
| 90 |
+
freeze_feature_encoder = False
|
| 91 |
+
|
| 92 |
+
pt_model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id,
|
| 93 |
+
encoder_add_adapter=True)
|
| 94 |
+
|
| 95 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 96 |
+
pt_model.save_pretrained(tmpdirname)
|
| 97 |
+
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
| 98 |
+
|
| 99 |
+
batch_size = 13
|
| 100 |
+
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
|
| 101 |
+
attention_mask = random_attention_mask([batch_size, 512])
|
| 102 |
+
label_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
|
| 103 |
+
decoder_input_ids = shift_tokens_right(input_ids=label_ids, pad_token_id=fx_model.config.decoder.pad_token_id,
|
| 104 |
+
decoder_start_token_id=fx_model.config.decoder.decoder_start_token_id)
|
| 105 |
+
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
| 106 |
+
|
| 107 |
+
fx_inputs = {
|
| 108 |
+
"inputs": input_values,
|
| 109 |
+
"attention_mask": attention_mask,
|
| 110 |
+
"decoder_input_ids": decoder_input_ids,
|
| 111 |
+
}
|
| 112 |
+
if use_decoder_attention_mask:
|
| 113 |
+
fx_inputs["decoder_attention_mask"] = decoder_attention_mask
|
| 114 |
+
|
| 115 |
+
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()}
|
| 116 |
+
pt_inputs["labels"] = torch.tensor(label_ids.tolist())
|
| 117 |
+
|
| 118 |
+
fx_outputs = fx_model(**fx_inputs)
|
| 119 |
+
fx_logits = fx_outputs.logits
|
| 120 |
+
|
| 121 |
+
if freeze_feature_encoder:
|
| 122 |
+
pt_model.freeze_feature_encoder()
|
| 123 |
+
|
| 124 |
+
pt_outputs = pt_model(**pt_inputs)
|
| 125 |
+
pt_logits = pt_outputs.logits
|
| 126 |
+
pt_loss = pt_outputs.loss
|
| 127 |
+
|
| 128 |
+
print("--------------------------Checking logits match--------------------------")
|
| 129 |
+
print(f"Flax logits shape: {fx_logits.shape}, PyTorch logits shape: {pt_logits.shape}")
|
| 130 |
+
assert_almost_equals(fx_logits, pt_logits.detach().numpy())
|
| 131 |
+
|
| 132 |
+
def fx_train_step(fx_model, batch, freeze_feature_encoder=False):
|
| 133 |
+
def compute_loss(params):
|
| 134 |
+
label_ids = batch.pop('label_ids')
|
| 135 |
+
logits = fx_model(**batch, params=params,
|
| 136 |
+
freeze_feature_encoder=freeze_feature_encoder).logits
|
| 137 |
+
vocab_size = logits.shape[-1]
|
| 138 |
+
targets = onehot(label_ids, vocab_size)
|
| 139 |
+
loss = optax.softmax_cross_entropy(logits, targets)
|
| 140 |
+
return loss.mean()
|
| 141 |
+
|
| 142 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
| 143 |
+
loss, grad = grad_fn(fx_model.params)
|
| 144 |
+
return loss, grad
|
| 145 |
+
|
| 146 |
+
fx_inputs["label_ids"] = label_ids
|
| 147 |
+
|
| 148 |
+
fx_loss, fx_grad = fx_train_step(fx_model, fx_inputs, freeze_feature_encoder=freeze_feature_encoder)
|
| 149 |
+
|
| 150 |
+
print("--------------------------Checking losses match--------------------------")
|
| 151 |
+
print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}")
|
| 152 |
+
assert_almost_equals(fx_loss, pt_loss.detach().numpy())
|
| 153 |
+
|
| 154 |
+
pt_loss.backward()
|
| 155 |
+
|
| 156 |
+
pt_grad_dict = {k: v.grad if v.grad is not None else torch.zeros_like(v) for k, v in pt_model.named_parameters()}
|
| 157 |
+
|
| 158 |
+
for k in pt_model.state_dict():
|
| 159 |
+
if k not in pt_grad_dict:
|
| 160 |
+
# set any unused parameters to zero in the grad-dict
|
| 161 |
+
# these won't be compared to the Flax model, but required for loading the PT model from state-dict
|
| 162 |
+
pt_grad_dict[k] = torch.zeros_like(pt_model.state_dict()[k])
|
| 163 |
+
pt_model.state_dict()[k] = pt_grad_dict[k]
|
| 164 |
+
|
| 165 |
+
pt_model.load_state_dict(pt_grad_dict)
|
| 166 |
+
|
| 167 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 168 |
+
pt_model.save_pretrained(tmpdirname)
|
| 169 |
+
pt_grad_model_to_fx = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
| 170 |
+
|
| 171 |
+
pt_grad_to_fx = pt_grad_model_to_fx.params
|
| 172 |
+
fx_grad = flatten_dict(fx_grad)
|
| 173 |
+
pt_grad_to_fx = flatten_dict(pt_grad_to_fx)
|
| 174 |
+
print("--------------------------Checking gradients match--------------------------")
|
| 175 |
+
assert_dict_equal(fx_grad, pt_grad_to_fx)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
main()
|