yogkul2000 commited on
Commit
ee17189
·
verified ·
1 Parent(s): e9514e6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +214 -278
README.md CHANGED
@@ -1,304 +1,240 @@
1
  ---
2
  license: apache-2.0
3
  ---
4
- # VideoSAVi Checkpoints
5
 
6
  This repository contains the weights for the **VideoSAVi (Self-Aligned Video Language Model)** introduced in the paper [VideoSAVi: Self-Aligned Video Language Models without Human Supervision](https://arxiv.org/abs/2412.00624).
7
-
8
- ## Model Overview
9
-
10
- VideoSAVi is a novel self-training pipeline designed to improve video-language understanding tasks without requiring extensive human annotations or proprietary models. By leveraging self-generated synthetic preference data, VideoSAVi achieves state-of-the-art performance on multiple benchmarks, including multi-choice QA, open-ended QA, and temporal reasoning tasks.
11
-
12
  - **Project Page:** [https://people-robots.github.io/VideoSAVi/](https://people-robots.github.io/VideoSAVi/)
13
 
14
  ## Usage Instructions
 
15
 
16
- Please refer to [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT) for installing all the requirements. We provide sample inference code below.
 
17
 
18
- ```bash
19
- import math
20
- import os
21
  import argparse
22
  import json
23
- import torch
24
- from tqdm import tqdm
25
  import numpy as np
26
- import cv2
27
- import base64
28
  from decord import VideoReader, cpu
29
  from PIL import Image
30
- from transformers import AutoConfig
31
-
32
- from llava.conversation import conv_templates, SeparatorStyle
33
- from llava.constants import (
34
- IMAGE_TOKEN_INDEX,
35
- DEFAULT_IMAGE_TOKEN,
36
- DEFAULT_IM_START_TOKEN,
37
- DEFAULT_IM_END_TOKEN,
38
- )
39
- from llava.mm_utils import (
40
- process_anyres_image,
41
- tokenizer_image_token,
42
- get_model_name_from_path,
43
- KeywordsStoppingCriteria,
44
- )
45
- from llava.model.builder import load_pretrained_model
46
- from llava.train.train import smart_tokenizer_and_embedding_resize
47
-
48
-
49
- class VideoProcessor:
50
- def __init__(self, args):
51
- self.args = args
52
-
53
- def load_video(self, video_path):
54
- """Load and process video frames."""
55
- if self.args.frame_count == 0:
56
- return np.zeros((1, 336, 336, 3))
57
-
58
- vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
59
- total_frames = len(vr)
60
- fps = round(vr.get_avg_fps())
61
- video_duration = total_frames / fps
62
-
63
- frame_indices = [i for i in range(0, total_frames, fps)]
64
- frame_times = [i / fps for i in frame_indices]
65
-
66
- if len(frame_indices) > self.args.frame_count or self.args.force_sample:
67
- uniform_samples = np.linspace(
68
- 0, total_frames - 1, self.args.frame_count, dtype=int
69
- )
70
- frame_indices = uniform_samples.tolist()
71
- frame_times = [i / fps for i in frame_indices]
72
-
73
- frame_times_str = ",".join([f"{t:.2f}s" for t in frame_times])
74
- frames = vr.get_batch(frame_indices).asnumpy()
75
-
76
- return frames, frame_times_str, video_duration
77
-
78
- @staticmethod
79
- def load_video_base64(path):
80
- """Convert video frames to base64 encoding."""
81
- video = cv2.VideoCapture(path)
82
- base64_frames = []
83
-
84
- while video.isOpened():
85
- success, frame = video.read()
86
- if not success:
87
- break
88
- _, buffer = cv2.imencode(".jpg", frame)
89
- base64_frames.append(base64.b64encode(buffer).decode("utf-8"))
90
-
91
- video.release()
92
- return base64_frames
93
-
94
-
95
- class VideoInference:
96
- def __init__(self, model, tokenizer, device="cuda"):
97
- self.model = model
98
- self.tokenizer = tokenizer
99
- self.device = device
100
-
101
- def generate_response(self, video, prompt):
102
- """Generate model response for video input."""
103
- if self.model.config.mm_use_im_start_end:
104
- prompt = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{prompt}"
105
- else:
106
- prompt = f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
107
-
108
- conv = conv_templates["qwen_2"].copy()
109
- conv.append_message(conv.roles[0], prompt)
110
- conv.append_message(conv.roles[1], None)
111
- full_prompt = conv.get_prompt()
112
-
113
- input_ids = (
114
- tokenizer_image_token(
115
- full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
116
- )
117
- .unsqueeze(0)
118
- .to(self.device)
119
- )
120
-
121
- if self.tokenizer.pad_token_id is None:
122
- if "qwen" in self.tokenizer.name_or_path.lower():
123
- self.tokenizer.pad_token_id = 151643
124
-
125
- attention_mask = (
126
- input_ids.ne(self.tokenizer.pad_token_id).long().to(self.device)
127
- )
128
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
129
- stopping_criteria = KeywordsStoppingCriteria(
130
- [stop_str], self.tokenizer, input_ids
131
- )
132
 
133
- try:
134
- with torch.inference_mode():
135
- output_ids = self.model.generate(
136
- inputs=input_ids,
137
- images=video,
138
- attention_mask=attention_mask,
139
- modalities="video",
140
- do_sample=True,
141
- temperature=0.2,
142
- max_new_tokens=256,
143
- use_cache=True,
144
- stopping_criteria=[stopping_criteria],
145
- )
146
-
147
- generated_text = self.tokenizer.batch_decode(
148
- output_ids, skip_special_tokens=True
149
- )[0].strip()
150
-
151
- if generated_text.endswith(stop_str):
152
- generated_text = generated_text[: -len(stop_str)].strip()
153
-
154
- return generated_text
155
-
156
- except Exception as e:
157
- print(f"Generation error: {str(e)}")
158
- return "Can you describe another aspect of the video?"
159
-
160
-
161
- ANSWER_PROMPTS = {
162
- "multi-choice": "\nPlease directly give the best option:",
163
- "yes_no": "\nPlease answer yes or no:",
164
- "caption_matching": "\nPlease directly give the best option:",
165
- "captioning": "",
166
- }
167
-
168
-
169
- def setup_model(model_path, args):
170
- """Setup the model, tokenizer and processors."""
171
- model_name = get_model_name_from_path(model_path)
172
-
173
- if args.overwrite:
174
- config = {
175
- "mm_spatial_pool_mode": args.mm_spatial_pool_mode,
176
- "mm_spatial_pool_stride": args.mm_spatial_pool_stride,
177
- "mm_newline_position": args.mm_newline_position,
178
- }
179
-
180
- cfg_pretrained = AutoConfig.from_pretrained(model_path)
181
-
182
- if "qwen" not in model_path.lower():
183
- if "224" in cfg_pretrained.mm_vision_tower:
184
- min_tokens = (
185
- args.frame_count * (16 // args.mm_spatial_pool_stride) ** 2 + 1000
186
- )
187
- else:
188
- min_tokens = (
189
- args.frame_count * (24 // args.mm_spatial_pool_stride) ** 2 + 1000
190
- )
191
-
192
- scaling = math.ceil(min_tokens / 4096)
193
- if scaling >= 2:
194
- if "vicuna" in cfg_pretrained._name_or_path.lower():
195
- config["rope_scaling"] = {
196
- "factor": float(scaling),
197
- "type": "linear",
198
- }
199
- config["max_sequence_length"] = 4096 * scaling
200
- config["tokenizer_model_max_length"] = 4096 * scaling
201
-
202
- return load_pretrained_model(
203
- model_path, args.model_base, model_name, overwrite_config=config
204
- )
205
-
206
- return load_pretrained_model(model_path, args.model_base, model_name)
207
 
208
 
209
- def main():
210
- parser = argparse.ArgumentParser(description="Video LLM Processing")
211
- parser.add_argument(
212
- "--video_dir", required=True, help="Directory containing video files"
213
- )
214
- parser.add_argument(
215
- "--output_dir", required=True, help="Directory for output predictions"
216
  )
217
- parser.add_argument("--model_path", required=True, help="Path to the model")
218
- parser.add_argument(
219
- "--questions_dir", required=True, help="Directory containing question files"
220
- )
221
- parser.add_argument(
222
- "--task_type",
223
- default="multi-choice",
224
- choices=["multi-choice", "captioning", "caption_matching", "yes_no"],
225
- )
226
- parser.add_argument("--frame_count", type=int, default=4)
227
- parser.add_argument(
228
- "--overwrite", type=lambda x: str(x).lower() == "true", default=True
229
- )
230
- parser.add_argument(
231
- "--force_sample", type=lambda x: str(x).lower() == "true", default=False
232
- )
233
- parser.add_argument("--model_base", default=None)
234
- parser.add_argument("--model_max_length", type=int, default=2048)
235
- parser.add_argument("--mm_spatial_pool_stride", type=int, default=4)
236
- parser.add_argument("--mm_spatial_pool_out_channels", type=int, default=1024)
237
- parser.add_argument("--mm_spatial_pool_mode", type=str, default="average")
238
- parser.add_argument("--mm_newline_position", type=str, default="no_token")
239
-
240
- args = parser.parse_args()
241
-
242
- # Ensure output directory exists
243
- os.makedirs(args.output_dir, exist_ok=True)
244
-
245
- # Setup model and processors
246
- tokenizer, model, image_processor, context_len = setup_model(args.model_path, args)
247
- model = model.to("cuda")
248
-
249
- # Load questions
250
- question_file = os.path.join(args.questions_dir, f"{args.task_type}.json")
251
- with open(question_file, "r") as f:
252
- questions = json.load(f)
253
-
254
- # Initialize or load predictions
255
- pred_file = os.path.join(args.output_dir, f"{args.task_type}.json")
256
- if os.path.isfile(pred_file):
257
- with open(pred_file, "r") as f:
258
- predictions = json.load(f)
259
  else:
260
- predictions = {}
261
-
262
- # Setup processors
263
- video_processor = VideoProcessor(args)
264
- inference_engine = VideoInference(model, tokenizer)
265
-
266
- # Process videos
267
- for video_id, data in tqdm(questions.items()):
268
- if video_id not in predictions:
269
- predictions[video_id] = {}
270
- video_path = os.path.join(args.video_dir, f"{video_id}.mp4")
271
-
272
- for dimension, question_list in data.items():
273
- predictions[video_id][dimension] = []
274
- for question in question_list:
275
- prompt = question["question"] + ANSWER_PROMPTS[args.task_type]
276
- video, _, _ = video_processor.load_video(video_path)
277
- video = (
278
- image_processor.preprocess(video, return_tensors="pt")[
279
- "pixel_values"
280
- ]
281
- .half()
282
- .cuda()
283
- )
284
- video = [video]
285
-
286
- prediction = inference_engine.generate_response(video, prompt)
287
- predictions[video_id][dimension].append(
288
- {
289
- "question": question["question"],
290
- "answer": question["answer"],
291
- "prediction": prediction,
292
- }
293
- )
294
-
295
- # Save predictions after each video
296
- with open(pred_file, "w") as f:
297
- json.dump(predictions, f, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
 
300
  if __name__ == "__main__":
301
  main()
302
-
303
  ```
304
 
 
1
  ---
2
  license: apache-2.0
3
  ---
 
4
 
5
  This repository contains the weights for the **VideoSAVi (Self-Aligned Video Language Model)** introduced in the paper [VideoSAVi: Self-Aligned Video Language Models without Human Supervision](https://arxiv.org/abs/2412.00624).
 
 
 
 
 
6
  - **Project Page:** [https://people-robots.github.io/VideoSAVi/](https://people-robots.github.io/VideoSAVi/)
7
 
8
  ## Usage Instructions
9
+ We provide sample inference code below.
10
 
11
+ ```python
12
+ #!/usr/bin/env python3
13
 
 
 
 
14
  import argparse
15
  import json
 
 
16
  import numpy as np
17
+ import torch
18
+ import torchvision.transforms as T
19
  from decord import VideoReader, cpu
20
  from PIL import Image
21
+ from torchvision.transforms.functional import InterpolationMode
22
+ from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
25
+ IMAGENET_STD = (0.229, 0.224, 0.225)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
+ def build_transform(input_size):
29
+ transform = T.Compose(
30
+ [T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)]
 
 
 
 
31
  )
32
+ return transform
33
+
34
+
35
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
36
+ if bound:
37
+ start, end = bound[0], bound[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  else:
39
+ start, end = -100000, 100000
40
+ start_idx = max(first_idx, round(start * fps))
41
+ end_idx = min(round(end * fps), max_frame)
42
+ seg_size = float(end_idx - start_idx) / num_segments
43
+ frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
44
+ return frame_indices
45
+
46
+
47
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
48
+ orig_width, orig_height = image.size
49
+ aspect_ratio = orig_width / orig_height
50
+
51
+ # calculate the existing image aspect ratio
52
+ target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num)
53
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
54
+
55
+ # find the closest aspect ratio to the target
56
+ best_ratio_diff = float("inf")
57
+ best_ratio = (1, 1)
58
+ area = orig_width * orig_height
59
+ for ratio in target_ratios:
60
+ target_aspect_ratio = ratio[0] / ratio[1]
61
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
62
+ if ratio_diff < best_ratio_diff:
63
+ best_ratio_diff = ratio_diff
64
+ best_ratio = ratio
65
+ elif ratio_diff == best_ratio_diff:
66
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
67
+ best_ratio = ratio
68
+
69
+ # calculate the target width and height
70
+ target_width = image_size * best_ratio[0]
71
+ target_height = image_size * best_ratio[1]
72
+ blocks = best_ratio[0] * best_ratio[1]
73
+
74
+ # resize the image
75
+ resized_img = image.resize((target_width, target_height))
76
+ processed_images = []
77
+ for i in range(blocks):
78
+ box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size)
79
+ # split the image
80
+ split_img = resized_img.crop(box)
81
+ processed_images.append(split_img)
82
+ assert len(processed_images) == blocks
83
+ if use_thumbnail and len(processed_images) != 1:
84
+ thumbnail_img = image.resize((image_size, image_size))
85
+ processed_images.append(thumbnail_img)
86
+ return processed_images
87
+
88
+
89
+ def load_video(video_path, bound=None, input_size=448, max_num=12, num_segments=8):
90
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
91
+ max_frame = len(vr) - 1
92
+ fps = float(vr.get_avg_fps())
93
+
94
+ pixel_values_list, num_patches_list = [], []
95
+ transform = build_transform(input_size=input_size)
96
+ frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
97
+
98
+ for frame_index in frame_indices:
99
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
100
+ img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
101
+ pixel_values = [transform(tile) for tile in img]
102
+ pixel_values = torch.stack(pixel_values)
103
+ num_patches_list.append(pixel_values.shape[0])
104
+ pixel_values_list.append(pixel_values)
105
+
106
+ pixel_values = torch.cat(pixel_values_list)
107
+ return pixel_values, num_patches_list
108
+
109
+
110
+ def parse_args():
111
+ parser = argparse.ArgumentParser(description="Inference Script")
112
+
113
+ parser.add_argument("--video_path", type=str, required=True, help="Path to the input video file")
114
+ parser.add_argument("--model_path", type=str, default="yogkul2000/VideoSAVi", help="Path to the VideoSAVi model")
115
+
116
+ parser.add_argument("--num_segments", type=int, default=8, help="Number of video segments to sample (default: 8)")
117
+
118
+ parser.add_argument("--max_patches", type=int, default=12, help="Maximum patches per frame (default: 12)")
119
+
120
+ parser.add_argument("--input_size", type=int, default=448, help="Input image size (default: 448)")
121
+
122
+ parser.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum number of tokens to generate (default: 1024)")
123
+
124
+ parser.add_argument("--do_sample", action="store_true", default=False, help="Whether to use sampling for generation")
125
+
126
+ parser.add_argument("--temperature", type=float, default=0, help="Sampling temperature)")
127
+
128
+ parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling parameter (default: 1.0)")
129
+
130
+ parser.add_argument("--question", type=str, default="What is happening in this video?", help="Question to ask about the video")
131
+
132
+ parser.add_argument("--output_file", type=str, default=None, help="Optional output file to save results")
133
+
134
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference (default: cuda)")
135
+
136
+ parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"], help="Torch dtype for model (default: bfloat16)")
137
+
138
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
139
+
140
+ parser.add_argument("--no_follow_up", action="store_true", help="Skip follow-up question")
141
+
142
+ return parser.parse_args()
143
+
144
+
145
+ def main():
146
+ args = parse_args()
147
+
148
+ if args.verbose:
149
+ print(f"Loading model from: {args.model_path}")
150
+ print(f"Processing video: {args.video_path}")
151
+ print(f"Video segments: {args.num_segments}")
152
+ print(f"Max patches per frame: {args.max_patches}")
153
+
154
+ torch_dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
155
+ torch_dtype = torch_dtype_map[args.torch_dtype]
156
+
157
+ try:
158
+ model = AutoModel.from_pretrained(args.model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_flash_attn=True, trust_remote_code=True).eval()
159
+
160
+ if args.device == "cuda" and torch.cuda.is_available():
161
+ model = model.cuda()
162
+
163
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, use_fast=False)
164
+
165
+ if args.verbose:
166
+ print("Model and tokenizer loaded successfully!")
167
+
168
+ except Exception as e:
169
+ print(f"Error loading model: {e}")
170
+ return
171
+
172
+ try:
173
+ if args.verbose:
174
+ print("Loading and processing video...")
175
+
176
+ pixel_values, num_patches_list = load_video(args.video_path, num_segments=args.num_segments, max_num=args.max_patches, input_size=args.input_size)
177
+
178
+ pixel_values = pixel_values.to(torch_dtype)
179
+ if args.device == "cuda" and torch.cuda.is_available():
180
+ pixel_values = pixel_values.cuda()
181
+
182
+ if args.verbose:
183
+ print(f"Video processed: {len(num_patches_list)} segments, {pixel_values.shape[0]} total patches")
184
+
185
+ except Exception as e:
186
+ print(f"Error processing video: {e}")
187
+ return
188
+
189
+ # Create video prefix for frames
190
+ video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
191
+
192
+ # Generation config
193
+ generation_config = {"max_new_tokens": args.max_new_tokens, "do_sample": args.do_sample, "temperature": args.temperature, "top_p": args.top_p}
194
+
195
+ results = {}
196
+
197
+ try:
198
+ question = video_prefix + args.question
199
+
200
+ if args.verbose:
201
+ print(f"\nAsking question: {args.question}")
202
+
203
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config, num_patches_list=num_patches_list, history=None, return_history=True)
204
+
205
+ print(f"\nUser: {args.question}")
206
+ print(f"VideoSAVi: {response}")
207
+
208
+ results["question_1"] = {"question": args.question, "response": response}
209
+
210
+ # Clear GPU cache
211
+ if args.device == "cuda" and torch.cuda.is_available():
212
+ torch.cuda.empty_cache()
213
+
214
+ except Exception as e:
215
+ print(f"Error during first inference: {e}")
216
+ return
217
+
218
+ # Save results if output file specified
219
+ if args.output_file:
220
+ try:
221
+ results["video_path"] = args.video_path
222
+ results["model_path"] = args.model_path
223
+ results["config"] = {"num_segments": args.num_segments, "max_patches": args.max_patches, "input_size": args.input_size, "generation_config": generation_config}
224
+
225
+ with open(args.output_file, "w") as f:
226
+ json.dump(results, f, indent=2)
227
+
228
+ print(f"\nResults saved to: {args.output_file}")
229
+
230
+ except Exception as e:
231
+ print(f"Error saving results: {e}")
232
+
233
+ if args.verbose:
234
+ print("\nInference completed successfully!")
235
 
236
 
237
  if __name__ == "__main__":
238
  main()
 
239
  ```
240