import json import re from typing import List, Optional, Tuple import numpy as np import gradio as gr import spaces import torch from PIL import Image from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info # Qwen2.5-VL 모델 ID MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" def _extract_assistant_content(decoded: str) -> str: """어시스턴트 응답 추출""" if "<|im_start|>assistant" in decoded: content = decoded.split("<|im_start|>assistant")[-1] content = content.replace("<|im_end|>", "").strip() return content return decoded.strip() def _extract_json_block(text: str) -> Optional[str]: """JSON 블록 추출""" match = re.search(r"\{.*\}", text, re.DOTALL) if not match: return None return match.group(0) @spaces.GPU(duration=180) def analyze_medication_image(image: Image.Image) -> Tuple[str, str]: """이미지에서 OCR 추출 후 약 정보 분석""" try: # Qwen2.5-VL 모델 로드 model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype="auto", device_map="auto" ) processor = AutoProcessor.from_pretrained(MODEL_ID) # Step 1: OCR - 이미지에서 텍스트 추출 ocr_messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "이 이미지에 있는 모든 텍스트를 정확하게 추출해주세요. 텍스트만 출력하고 다른 설명은 필요 없습니다."}, ], } ] text = processor.apply_chat_template(ocr_messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(ocr_messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(model.device) with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=2048) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] ocr_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] if not ocr_text or ocr_text.strip() == "": return "텍스트를 찾을 수 없습니다.", "" # Step 2: 약 정보 분석 - OCR 텍스트를 LLM에게 전달 analysis_messages = [ { "role": "user", "content": [ {"type": "text", "text": f"""다음은 약 봉투나 처방전에서 추출한 텍스트입니다: {ocr_text} 위 텍스트에서 약 이름을 찾아서, 각 약에 대해 다음 정보를 **노인과 어린이 모두 쉽게 이해할 수 있도록** 재미있고 친근하게 설명해주세요: 1. **약 이름**: 정확한 약 이름 2. **효능**: 이 약이 무엇을 치료하고 어떻게 도움이 되는지 3. **부작용**: 주의해야 할 부작용들 각 약마다 이모지를 사용하고, 쉬운 단어로 설명해주세요. 할머니 할아버지나 초등학생도 이해할 수 있게 작성해주세요. 마크다운 형식으로 작성해주세요."""}, ], } ] text = processor.apply_chat_template(analysis_messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], images=None, videos=None, padding=True, return_tensors="pt", ) inputs = inputs.to(model.device) with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=3072, temperature=0.7) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] analysis_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return ocr_text.strip(), analysis_text.strip() except Exception as e: raise Exception(f"분석 오류: {str(e)}") def extract_medications_from_text(text: str) -> List[str]: """Stage 2: Qwen2.5로 텍스트에서 약 이름만 추출""" try: messages = [ { "role": "system", "content": "You are a medical text analyzer. Extract only medication names from the given text and return them as a JSON array. Return ONLY valid JSON format." }, { "role": "user", "content": f"Extract all medication names from this text:\n\n{text}\n\nReturn format: {{\"medications\": [\"name1\", \"name2\"]}}" } ] prompt = LLM_TOKENIZER.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = LLM_TOKENIZER(prompt, return_tensors="pt").to(LLM_MODEL.device) with torch.no_grad(): outputs = LLM_MODEL.generate( **inputs, max_new_tokens=512, temperature=0.3, top_p=0.9, do_sample=True, pad_token_id=LLM_TOKENIZER.eos_token_id, ) response = LLM_TOKENIZER.decode(outputs[0], skip_special_tokens=True) # Extract assistant response (Qwen format) if "<|im_start|>assistant" in response: response = response.split("<|im_start|>assistant")[-1] response = response.replace("<|im_end|>", "").strip() # Parse JSON json_match = re.search(r'\{.*?\}', response, re.DOTALL) if json_match: data = json.loads(json_match.group(0)) medications = data.get("medications", []) if isinstance(medications, list) and medications: return [str(m).strip() for m in medications if str(m).strip()] return ["약 이름을 찾지 못했습니다."] except Exception as e: raise Exception(f"LLM 분석 오류: {str(e)}") @spaces.GPU(duration=120) def extract_medication_names(image: Image.Image) -> Tuple[str, List[str]]: """2단계 파이프라인: OCR → LLM 분석""" try: # Stage 1: OCR로 텍스트 추출 extracted_text = extract_text_from_image(image) if not extracted_text: return "", ["텍스트를 추출하지 못했습니다."] # Stage 2: LLM으로 약 이름 추출 medications = extract_medications_from_text(extracted_text) return extracted_text, medications except Exception as e: return "", [f"오류 발생: {str(e)}"] def format_results(extracted_text: str, medications: List[str]) -> Tuple[str, str]: """결과를 포맷팅""" # 추출된 전체 텍스트 text_output = f"### 📄 추출된 텍스트\n\n```\n{extracted_text}\n```" # 약 이름 리스트 if not medications or medications[0].startswith("오류") or medications[0].startswith("약 이름을 찾지") or medications[0].startswith("텍스트를"): med_output = f"### ⚠️ {medications[0] if medications else '약 이름을 찾지 못했습니다.'}" else: med_output = f"### 💊 검출된 약물 ({len(medications)}개)\n\n" for idx, med_name in enumerate(medications, 1): med_output += f"{idx}. **{med_name}**\n" return text_output, med_output def run_analysis(image: Optional[Image.Image], progress=gr.Progress()): """메인 분석 파이프라인: OCR + 약 정보 분석""" if image is None: return "📷 약 봉투나 처방전 사진을 업로드해주세요.", "" progress(0.3, desc="📸 1단계: OCR 텍스트 추출 중...") progress(0.6, desc="🤖 2단계: 약 정보 분석 중...") try: ocr_text, analysis = analyze_medication_image(image) progress(1.0, desc="✅ 완료!") ocr_output = f"### 📄 추출된 텍스트\n\n```\n{ocr_text}\n```" analysis_output = f"### 💊 약 정보 설명\n\n{analysis}" return ocr_output, analysis_output except Exception as e: return f"### ⚠️ 오류 발생\n\n{str(e)}", "" # 심플한 CSS CUSTOM_CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); :root { --primary: #6366f1; --secondary: #8b5cf6; } body { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; } .gradio-container { max-width: 900px !important; margin: auto; background: rgba(255, 255, 255, 0.98); border-radius: 24px; box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.3); padding: 40px; } .hero { text-align: center; padding: 30px 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 20px; color: white; margin-bottom: 30px; } .hero h1 { font-size: 2.5rem; font-weight: 700; margin-bottom: 10px; } .hero p { font-size: 1.1rem; opacity: 0.95; } .upload-section { background: white; border-radius: 16px; padding: 30px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); margin-bottom: 20px; } .result-section { background: white; border-radius: 16px; padding: 30px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); min-height: 200px; } .analyze-btn button { background: linear-gradient(135deg, var(--primary), var(--secondary)) !important; color: white !important; font-weight: 600 !important; font-size: 1.1rem !important; padding: 18px 40px !important; border-radius: 12px !important; border: none !important; box-shadow: 0 10px 20px -5px rgba(99, 102, 241, 0.5) !important; transition: all 0.3s ease !important; } .analyze-btn button:hover { transform: translateY(-2px) !important; box-shadow: 0 15px 30px -5px rgba(99, 102, 241, 0.6) !important; } .gr-image { border-radius: 12px !important; } """ HERO_HTML = """

💊 우리 가족 약 도우미

약봉투/처방전 사진에서 약 정보를 쉽고 재미있게 알려드려요!

""" # Gradio 인터페이스 with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: gr.HTML(HERO_HTML) with gr.Column(elem_classes=["upload-section"]): gr.Markdown("### 📸 사진 업로드") image_input = gr.Image(type="pil", label="약봉투 또는 처방전 사진", height=350) analyze_button = gr.Button("🔍 약 정보 분석하기", elem_classes=["analyze-btn"], size="lg") with gr.Row(): with gr.Column(elem_classes=["result-section"]): gr.Markdown("### 📋 1단계: 추출된 텍스트") ocr_output = gr.Markdown("OCR로 추출된 텍스트가 여기 표시됩니다.") with gr.Column(elem_classes=["result-section"]): gr.Markdown("### 📋 2단계: 쉬운 약 설명") analysis_output = gr.Markdown("노인과 어린이도 이해하기 쉬운 약 정보가 여기 표시됩니다.") analyze_button.click( run_analysis, inputs=image_input, outputs=[ocr_output, analysis_output], ) gr.Markdown(""" --- **ℹ️ 사용 방법** 1. 약 봉투나 처방전 사진을 업로드하세요 2. '약 정보 분석하기' 버튼을 클릭하세요 3. 왼쪽에는 추출된 텍스트, 오른쪽에는 쉬운 설명이 나타납니다! **⚠️ 주의사항** - 이 앱은 참고용이며, 실제 복약은 반드시 의사나 약사의 지시를 따르세요 - AI가 생성한 정보이므로 정확하지 않을 수 있습니다 **🤖 기술 스택** - Qwen2.5-VL-7B-Instruct (OCR + 약 정보 분석) """) if __name__ == "__main__": demo.queue().launch()