Commit
·
7515eca
1
Parent(s):
c8028be
big update
Browse files- modeling_img2html.py +10 -12
modeling_img2html.py
CHANGED
|
@@ -162,7 +162,7 @@ def expand_inputs_for_generation(
|
|
| 162 |
input_ids = input_ids.index_select(0, expanded_return_idx)
|
| 163 |
model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
|
| 164 |
model_kwargs["image_hidden_states"] = model_kwargs.get("image_hidden_states", None)
|
| 165 |
-
model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
|
| 166 |
|
| 167 |
if "token_type_ids" in model_kwargs:
|
| 168 |
token_type_ids = model_kwargs["token_type_ids"]
|
|
@@ -180,9 +180,7 @@ def expand_inputs_for_generation(
|
|
| 180 |
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
|
| 181 |
|
| 182 |
elif model_kwargs["image_hidden_states"] is not None:
|
| 183 |
-
model_kwargs["image_hidden_states"] = model_kwargs["image_hidden_states"].index_select(
|
| 184 |
-
0, expanded_return_idx
|
| 185 |
-
)
|
| 186 |
|
| 187 |
return input_ids, model_kwargs
|
| 188 |
|
|
@@ -205,10 +203,10 @@ def update_model_kwargs_for_generation(outputs, model_kwargs):
|
|
| 205 |
model_kwargs["attention_mask"] = torch.cat(
|
| 206 |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 207 |
)
|
| 208 |
-
if "image_attention_mask" in model_kwargs:
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
|
| 213 |
# Get the precomputed image_hidden_states
|
| 214 |
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
|
@@ -236,7 +234,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
|
| 236 |
|
| 237 |
pixel_values = kwargs.get("pixel_values", None)
|
| 238 |
image_hidden_states = kwargs.get("image_hidden_states", None)
|
| 239 |
-
image_attention_mask = kwargs.get("image_attention_mask", None)
|
| 240 |
|
| 241 |
return {
|
| 242 |
"input_ids": input_ids,
|
|
@@ -247,7 +245,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
|
| 247 |
"token_type_ids": token_type_ids,
|
| 248 |
"pixel_values": pixel_values,
|
| 249 |
"image_hidden_states": image_hidden_states,
|
| 250 |
-
"image_attention_mask": image_attention_mask,
|
| 251 |
}
|
| 252 |
|
| 253 |
|
|
@@ -1373,7 +1371,6 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
| 1373 |
input_ids: torch.LongTensor = None,
|
| 1374 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1375 |
image_hidden_states: Optional[torch.Tensor] = None,
|
| 1376 |
-
num_images: Optional[int] = None,
|
| 1377 |
):
|
| 1378 |
"""
|
| 1379 |
This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
|
|
@@ -1496,6 +1493,8 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
| 1496 |
|
| 1497 |
if self.config.use_resampler:
|
| 1498 |
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
|
|
|
|
|
|
| 1499 |
|
| 1500 |
if past_key_values is None:
|
| 1501 |
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
|
@@ -1504,7 +1503,6 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
| 1504 |
input_ids=input_ids,
|
| 1505 |
inputs_embeds=inputs_embeds,
|
| 1506 |
image_hidden_states=image_hidden_states,
|
| 1507 |
-
num_images=num_images,
|
| 1508 |
)
|
| 1509 |
inputs_embeds = new_inp["inputs_embeds"]
|
| 1510 |
|
|
|
|
| 162 |
input_ids = input_ids.index_select(0, expanded_return_idx)
|
| 163 |
model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
|
| 164 |
model_kwargs["image_hidden_states"] = model_kwargs.get("image_hidden_states", None)
|
| 165 |
+
# model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
|
| 166 |
|
| 167 |
if "token_type_ids" in model_kwargs:
|
| 168 |
token_type_ids = model_kwargs["token_type_ids"]
|
|
|
|
| 180 |
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
|
| 181 |
|
| 182 |
elif model_kwargs["image_hidden_states"] is not None:
|
| 183 |
+
model_kwargs["image_hidden_states"] = model_kwargs["image_hidden_states"].index_select(0, expanded_return_idx)
|
|
|
|
|
|
|
| 184 |
|
| 185 |
return input_ids, model_kwargs
|
| 186 |
|
|
|
|
| 203 |
model_kwargs["attention_mask"] = torch.cat(
|
| 204 |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 205 |
)
|
| 206 |
+
# if "image_attention_mask" in model_kwargs:
|
| 207 |
+
# image_attention_mask = model_kwargs["image_attention_mask"]
|
| 208 |
+
# last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
|
| 209 |
+
# model_kwargs["image_attention_mask"] = last_mask
|
| 210 |
|
| 211 |
# Get the precomputed image_hidden_states
|
| 212 |
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
|
|
|
| 234 |
|
| 235 |
pixel_values = kwargs.get("pixel_values", None)
|
| 236 |
image_hidden_states = kwargs.get("image_hidden_states", None)
|
| 237 |
+
# image_attention_mask = kwargs.get("image_attention_mask", None)
|
| 238 |
|
| 239 |
return {
|
| 240 |
"input_ids": input_ids,
|
|
|
|
| 245 |
"token_type_ids": token_type_ids,
|
| 246 |
"pixel_values": pixel_values,
|
| 247 |
"image_hidden_states": image_hidden_states,
|
| 248 |
+
# "image_attention_mask": image_attention_mask,
|
| 249 |
}
|
| 250 |
|
| 251 |
|
|
|
|
| 1371 |
input_ids: torch.LongTensor = None,
|
| 1372 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1373 |
image_hidden_states: Optional[torch.Tensor] = None,
|
|
|
|
| 1374 |
):
|
| 1375 |
"""
|
| 1376 |
This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
|
|
|
|
| 1493 |
|
| 1494 |
if self.config.use_resampler:
|
| 1495 |
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
| 1496 |
+
elif image_hidden_states is not None:
|
| 1497 |
+
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
| 1498 |
|
| 1499 |
if past_key_values is None:
|
| 1500 |
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
|
|
|
| 1503 |
input_ids=input_ids,
|
| 1504 |
inputs_embeds=inputs_embeds,
|
| 1505 |
image_hidden_states=image_hidden_states,
|
|
|
|
| 1506 |
)
|
| 1507 |
inputs_embeds = new_inp["inputs_embeds"]
|
| 1508 |
|