Spaces:
Sleeping
Sleeping
| # api.py | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import logging | |
| import time | |
| import shutil | |
| from typing import List, Optional | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from dotenv import load_dotenv | |
| from models import OptimizeRequest, QARequest, AutotuneRequest | |
| # Load environment | |
| load_dotenv() | |
| # Logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("ragmint_mcp_server") | |
| # FastAPI app | |
| app = FastAPI(title="Ragmint MCP Server", version="0.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Directories | |
| DEFAULT_DATA_DIR = "data/docs" | |
| LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl" | |
| os.makedirs(DEFAULT_DATA_DIR, exist_ok=True) | |
| os.makedirs("experiments", exist_ok=True) | |
| # Try importing ragmint modules | |
| try: | |
| from ragmint.autotuner import AutoRAGTuner | |
| from ragmint.qa_generator import generate_validation_qa | |
| from ragmint.explainer import explain_results | |
| from ragmint.leaderboard import Leaderboard | |
| from ragmint.tuner import RAGMint | |
| except Exception as e: | |
| AutoRAGTuner = None | |
| generate_validation_qa = None | |
| explain_results = None | |
| Leaderboard = None | |
| RAGMint = None | |
| _import_error = e | |
| else: | |
| _import_error = None | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "ragmint_imported": _import_error is None, | |
| "import_error": str(_import_error) if _import_error else None, | |
| } | |
| async def upload_docs( | |
| docs_path: str = Form(...), | |
| files: List[UploadFile] = File(...) | |
| ): | |
| os.makedirs(docs_path, exist_ok=True) | |
| saved_files = [] | |
| for file in files: | |
| file_path = os.path.join(docs_path, file.filename) | |
| with open(file_path, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| saved_files.append(file.filename) | |
| return {"status": "ok", "uploaded_files": saved_files, "docs_path": docs_path} | |
| def handle_validation_choice(docs_path: str, validation_choice: Optional[str], llm_model: str) -> Optional[str]: | |
| """Determine which validation QA set to use or generate one.""" | |
| validation_choice = (validation_choice or "").strip() | |
| default_path = os.path.join(docs_path, "validation_qa.json") | |
| if not validation_choice: | |
| if os.path.exists(default_path): | |
| logger.info("Using default validation QA: %s", default_path) | |
| return default_path | |
| return None | |
| if validation_choice.lower() == "generate": | |
| generate_validation_qa( | |
| docs_path=docs_path, | |
| output_path=default_path, | |
| llm_model=llm_model | |
| ) | |
| logger.info("Generated validation QA at: %s", default_path) | |
| return default_path | |
| if os.path.exists(validation_choice) or "/" in validation_choice: | |
| logger.info("Using specified validation dataset: %s", validation_choice) | |
| return validation_choice | |
| logger.warning("Validation choice provided but not found: %s", validation_choice) | |
| return None | |
| def optimize_rag(req: OptimizeRequest): | |
| logger.info("Received optimize_rag request: %s", req.json()) | |
| if RAGMint is None: | |
| raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}") | |
| docs_path = req.docs_path or DEFAULT_DATA_DIR | |
| if not os.path.isdir(docs_path): | |
| raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") | |
| try: | |
| rag = RAGMint( | |
| docs_path=docs_path, | |
| retrievers=req.retriever, | |
| embeddings=req.embedding_model, | |
| rerankers=req.rerankers or ["mmr"], | |
| chunk_sizes=req.chunk_sizes, | |
| overlaps=req.overlaps, | |
| strategies=req.strategy, | |
| ) | |
| validation_set = handle_validation_choice(docs_path, req.validation_choice, | |
| getattr(req, "llm_model", "gemini-2.5-flash-lite")) | |
| start_time = time.time() | |
| best, results = rag.optimize( | |
| validation_set=validation_set, | |
| metric=req.metric, | |
| trials=req.trials, | |
| search_type=req.search_type | |
| ) | |
| elapsed = time.time() - start_time | |
| run_id = f"opt_{int(time.time())}" | |
| corpus_stats = { | |
| "num_docs": len(rag.documents), | |
| "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), | |
| "corpus_size": sum(len(d) for d in rag.documents), | |
| } | |
| if Leaderboard: | |
| lb = Leaderboard() | |
| lb.upload( | |
| run_id=run_id, | |
| best_config=best, | |
| best_score=best.get("faithfulness", best.get("score", 0.0)), | |
| all_results=results, | |
| documents=os.listdir(docs_path), | |
| model=best.get("embedding_model", req.embedding_model), | |
| corpus_stats=corpus_stats, | |
| ) | |
| return { | |
| "status": "finished", | |
| "run_id": run_id, | |
| "elapsed_seconds": elapsed, | |
| "best_config": best, | |
| "results": results, | |
| "corpus_stats": corpus_stats, | |
| } | |
| except Exception as exc: | |
| logger.exception("optimize_rag failed") | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| def autotune_rag(req: AutotuneRequest): | |
| logger.info("Received autotune_rag request: %s", req.json()) | |
| if AutoRAGTuner is None or RAGMint is None: | |
| raise HTTPException(status_code=500, detail=f"Ragmint autotuner/RAGMint imports failed: {_import_error}") | |
| docs_path = req.docs_path or DEFAULT_DATA_DIR | |
| if not os.path.isdir(docs_path): | |
| raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") | |
| try: | |
| start_time = time.time() | |
| tuner = AutoRAGTuner(docs_path=docs_path) | |
| rec = tuner.recommend(embedding_model=req.embedding_model, num_chunk_pairs=req.num_chunk_pairs) | |
| chunk_candidates = tuner.suggest_chunk_sizes( | |
| model_name=rec.get("embedding_model"), | |
| num_pairs=int(req.num_chunk_pairs), | |
| step=20 | |
| ) | |
| chunk_sizes = sorted({c for c, _ in chunk_candidates}) | |
| overlaps = sorted({o for _, o in chunk_candidates}) | |
| rag = RAGMint( | |
| docs_path=docs_path, | |
| retrievers=[rec["retriever"]], | |
| embeddings=[rec["embedding_model"]], | |
| rerankers=["mmr"], | |
| chunk_sizes=chunk_sizes, | |
| overlaps=overlaps, | |
| strategies=[rec["strategy"]], | |
| ) | |
| validation_set = handle_validation_choice(docs_path, req.validation_choice, | |
| getattr(req, "llm_model", "gemini-2.5-flash-lite")) | |
| best, results = rag.optimize( | |
| validation_set=validation_set, | |
| metric=req.metric, | |
| search_type=req.search_type, | |
| trials=req.trials, | |
| ) | |
| elapsed = time.time() - start_time | |
| run_id = f"autotune_{int(time.time())}" | |
| corpus_stats = { | |
| "num_docs": len(rag.documents), | |
| "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), | |
| "corpus_size": sum(len(d) for d in rag.documents), | |
| } | |
| if Leaderboard: | |
| lb = Leaderboard() | |
| lb.upload( | |
| run_id=run_id, | |
| best_config=best, | |
| best_score=best.get("faithfulness", best.get("score", 0.0)), | |
| all_results=results, | |
| documents=os.listdir(docs_path), | |
| model=best.get("embedding_model", rec.get("embedding_model")), | |
| corpus_stats=corpus_stats, | |
| ) | |
| return { | |
| "status": "finished", | |
| "run_id": run_id, | |
| "elapsed_seconds": elapsed, | |
| "recommendation": rec, | |
| "chunk_candidates": chunk_candidates, | |
| "best_config": best, | |
| "results": results, | |
| "corpus_stats": corpus_stats, | |
| } | |
| except Exception as exc: | |
| logger.exception("autotune_rag failed") | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| def generate_validation_qa_endpoint(req: QARequest): | |
| logger.info("Received generate_validation_qa request: %s", req.json()) | |
| if generate_validation_qa is None: | |
| raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}") | |
| try: | |
| out_path = os.path.join(req.docs_path or DEFAULT_DATA_DIR, "validation_qa.json") | |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
| generate_validation_qa( | |
| docs_path=req.docs_path, | |
| output_path=out_path, | |
| llm_model=req.llm_model, | |
| batch_size=req.batch_size, | |
| min_q=req.min_q, | |
| max_q=req.max_q, | |
| ) | |
| with open(out_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| return { | |
| "status": "finished", | |
| "output_path": out_path, | |
| "preview_count": len(data), | |
| "sample": data[:5] | |
| } | |
| except Exception as exc: | |
| logger.exception("generate_validation_qa failed") | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| async def clear_cache(docs_path: str = Form(DEFAULT_DATA_DIR)): | |
| """ | |
| Delete all files inside docs_path but keep the directory. | |
| Useful to reset uploaded documents for RAG runs. | |
| """ | |
| if not os.path.exists(docs_path): | |
| raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") | |
| removed = [] | |
| for root, dirs, files in os.walk(docs_path, topdown=False): | |
| for name in files: | |
| file_path = os.path.join(root, name) | |
| try: | |
| os.remove(file_path) | |
| removed.append(name) | |
| except Exception as e: | |
| logger.error(f"Failed to remove {file_path}: {e}") | |
| for name in dirs: | |
| dir_path = os.path.join(root, name) | |
| try: | |
| shutil.rmtree(dir_path) | |
| removed.append(f"{name}/") | |
| except Exception as e: | |
| logger.error(f"Failed to remove {dir_path}: {e}") | |
| return { | |
| "status": "cleared", | |
| "docs_path": docs_path, | |
| "removed_items": removed, | |
| "total_removed": len(removed), | |
| } | |
| def start_api(): | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") | |