Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import transformers | |
| import os | |
| from PIL import Image | |
| import spaces | |
| def process_vision_info(messages): | |
| image_inputs = [] | |
| video_inputs = [] | |
| for message in messages: | |
| if message["role"] == "user": | |
| content = message["content"] | |
| for item in content: | |
| if item["type"] == "image": | |
| image_inputs.append(item["image"]) | |
| elif item["type"] == "video": | |
| video_inputs.append(item["video"]) | |
| return image_inputs, video_inputs | |
| print("Loading text model (Qwen/Qwen2.5-7B)...") | |
| text_model_loaded = False | |
| text_model_error = "" | |
| try: | |
| text_model = transformers.AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen2.5-7B", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| text_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B") | |
| text_model_loaded = True | |
| print("Text model loaded successfully.") | |
| except Exception as e: | |
| text_model_error = str(e) | |
| print(f"Error loading text model: {text_model_error}") | |
| text_model, text_tokenizer = None, None | |
| print("Loading Vision-Language model (Qwen/Qwen2.5-VL-7B-Instruct)...") | |
| vl_model_loaded = False | |
| vl_model_error = "" | |
| try: | |
| vl_model = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen2.5-VL-7B-Instruct", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| vl_processor = transformers.AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") | |
| vl_model_loaded = True | |
| print("Vision-Language model loaded successfully.") | |
| except Exception as e: | |
| vl_model_error = str(e) | |
| print(f"Error loading Vision-Language model: {vl_model_error}") | |
| vl_model, vl_processor = None, None | |
| def visualize_text_token_probabilities(text: str): | |
| if not text_model_loaded: | |
| return [(f"Text Model failed to load: {text_model_error}", None)] | |
| if not text or not text.strip(): | |
| return [("Please enter some text to analyze.", None)] | |
| try: | |
| inputs = text_tokenizer([text], return_tensors="pt").to(text_model.device) | |
| input_ids = inputs.input_ids | |
| if input_ids.shape[1] < 2: | |
| token = text_tokenizer.decode(input_ids[0]) | |
| return [(token, None)] | |
| inp = input_ids[:, :-1] | |
| outp = input_ids[:, 1:].unsqueeze(-1) | |
| with torch.no_grad(): | |
| logits = text_model(inp).logits.float() | |
| all_probs = torch.softmax(logits, dim=-1) | |
| chosen_probs = torch.gather(all_probs, dim=2, index=outp).squeeze(-1).cpu().numpy()[0] | |
| highlighted_data = [] | |
| outp_tokens = input_ids[0, 1:].cpu().tolist() | |
| first_token_str = text_tokenizer.decode([input_ids[0, 0].item()]) | |
| highlighted_data.append((first_token_str, None)) | |
| for token_id, prob in zip(outp_tokens, chosen_probs): | |
| token_str = text_tokenizer.decode([token_id]) | |
| highlighted_data.append((token_str, float(prob))) | |
| return highlighted_data | |
| except Exception as e: | |
| print(f"An error occurred during text processing: {e}") | |
| return [(f"An error occurred: {str(e)}", None)] | |
| def generate_and_visualize_vl_probabilities(image, prompt: str): | |
| if not vl_model_loaded: | |
| return [(f"Vision-Language Model failed to load: {vl_model_error}", None)] | |
| if image is None or not prompt or not prompt.strip(): | |
| return [("Please upload an image and provide a text prompt.", None)] | |
| try: | |
| messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt.strip()}]}] | |
| text = vl_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| image_inputs, _ = process_vision_info(messages) | |
| inputs = vl_processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(vl_model.device) | |
| with torch.no_grad(): | |
| generated_ids = vl_model.generate(**inputs, max_new_tokens=512) | |
| input_token_len = inputs.input_ids.shape[1] | |
| if generated_ids.shape[1] <= input_token_len: | |
| return [("Model did not generate any new tokens.", None)] | |
| original_mask = inputs.attention_mask | |
| num_generated_tokens = generated_ids.shape[1] - input_token_len | |
| generated_mask = torch.ones( | |
| (1, num_generated_tokens), | |
| dtype=original_mask.dtype, | |
| device=original_mask.device | |
| ) | |
| full_attention_mask = torch.cat([original_mask, generated_mask], dim=1) | |
| with torch.no_grad(): | |
| outputs = vl_model( | |
| input_ids=generated_ids, | |
| pixel_values=inputs.get('pixel_values'), | |
| image_grid_thw=inputs.get('image_grid_thw'), | |
| attention_mask=full_attention_mask | |
| ) | |
| logits = outputs.logits.float() | |
| logits_of_generated_part = logits[:, input_token_len - 1:-1, :] | |
| labels_of_generated_part = generated_ids[:, input_token_len:] | |
| all_probs = torch.softmax(logits_of_generated_part, dim=-1) | |
| chosen_probs = torch.gather(all_probs, 2, labels_of_generated_part.unsqueeze(-1)).squeeze(-1) | |
| generated_token_ids_only = generated_ids[0, input_token_len:] | |
| probs_list = chosen_probs[0].cpu().tolist() | |
| highlighted_data = [] | |
| for token_id, prob in zip(generated_token_ids_only.tolist(), probs_list): | |
| token_str = vl_processor.decode([token_id]) | |
| highlighted_data.append((token_str, float(prob))) | |
| if not highlighted_data: | |
| return [("Model did not generate any new tokens.", None)] | |
| return highlighted_data | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| print(f"An error occurred during VL processing: {e}") | |
| return [(f"An error occurred: {str(e)}", None)] | |
| text_en_example = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. | |
| The assistant first thinks about the reasoning process in the mind and then provides the user | |
| with the answer. The reasoning process and answer are enclosed within <think> </think> and | |
| <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> | |
| <answer> answer here </answer>. User: What is 7 * 6? Assistant: <think> First, the user asked: "what is 7 * 6?" That's a multiplication problem. I need to calculate the product of 7 and 6. | |
| I know my multiplication tables. 7 times 6 is 42. I can double-check: 7 Γ 6 means adding 7 six times: 7 + 7 + 7 + 7 + 7 + 7. Let's add that up: 7+7=14, 14+7=21, 21+7=28, 28+7=35, 35+7=42. Yes, that's 42. | |
| I think that's fine. </think> <answer> 7 multiplied by 6 equals **42**. | |
| If you have any more math questions or need an explanation, feel free to ask! π </answer>""" | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Qwen2.5 Token Visualizer") as demo: | |
| gr.Markdown( | |
| """ | |
| # Qwen2.5 Series Token Probability Visualizer | |
| This tool visualizes token probabilities for both text and vision-language models from the Qwen2.5 series. | |
| The color of each token represents its conditional probability. | |
| **<span style="color:red">Red</span> means high probability** (the model was confident), and **<span style="color:black">White</span> means low probability** (the model was surprised). | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("Text Model (Qwen2.5-7B)"): | |
| gr.Markdown("### Analyze Probabilities of Given Text") | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Input Text", lines=15, value=text_en_example, | |
| placeholder="Enter text here to analyze..." | |
| ) | |
| with gr.Row(): | |
| text_submit_btn = gr.Button("Visualize Probabilities", variant="primary") | |
| text_output_highlight = gr.HighlightedText( | |
| label="Token Probabilities (High: Red, Low: White)", show_legend=True, | |
| combine_adjacent=False, | |
| ) | |
| gr.Examples( | |
| examples=[[text_en_example]], inputs=text_input, outputs=text_output_highlight, | |
| fn=visualize_text_token_probabilities, cache_examples=False | |
| ) | |
| text_submit_btn.click( | |
| fn=visualize_text_token_probabilities, inputs=text_input, outputs=text_output_highlight, | |
| api_name="visualize_text" | |
| ) | |
| with gr.TabItem("Vision-Language Model (Qwen2.5-VL-7B-Instruct)"): | |
| gr.Markdown("### Generate Text from Image and Visualize Probabilities") | |
| with gr.Row(): | |
| with gr.Column(): | |
| vl_image_input = gr.Image(type="pil", label="Upload Image") | |
| vl_text_input = gr.Textbox(label="Your Question", placeholder="e.g., Describe this image.") | |
| vl_submit_btn = gr.Button("Generate and Visualize", variant="primary") | |
| with gr.Column(): | |
| vl_output_highlight = gr.HighlightedText( | |
| label="Generated Token Probabilities (High: Red, Low: White)", show_legend=True, | |
| combine_adjacent=False, | |
| ) | |
| gr.Examples( | |
| examples=[["demo.jpeg", "Describe this image in detail."]], | |
| inputs=[vl_image_input, vl_text_input], | |
| outputs=vl_output_highlight, | |
| fn=generate_and_visualize_vl_probabilities, | |
| cache_examples=False | |
| ) | |
| vl_submit_btn.click( | |
| fn=generate_and_visualize_vl_probabilities, inputs=[vl_image_input, vl_text_input], | |
| outputs=vl_output_highlight, api_name="visualize_vl_generation" | |
| ) | |
| if __name__ == "__main__": | |
| if not os.path.exists("demo.jpeg"): | |
| try: | |
| from PIL import Image, ImageDraw, ImageFont | |
| img = Image.new('RGB', (400, 200), color = (73, 109, 137)) | |
| d = ImageDraw.Draw(img) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 20) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| d.text((10,10), "This is a demo image for Gradio.", font=font, fill=(255,255,0)) | |
| img.save("demo.jpeg") | |
| print("Created a dummy 'demo.jpeg' for the example.") | |
| except Exception as e: | |
| print(f"Could not create a dummy image: {e}") | |
| demo.queue().launch(share=True) |