Spaces:
Runtime error
Runtime error
| def clean_output(decoded_list): | |
| """Remove duplicates and trim whitespace""" | |
| return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()])) | |
| def preprocess_context(context): | |
| return f"generate question: {context.strip()}" | |
| def get_shap_values(tokenizer, model, prompt): | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| input_ids = inputs["input_ids"] | |
| # Define wrapper prediction function | |
| def f(x): | |
| x = torch.tensor(x).long().to(model.device) # 🔧 convert to LongTensor | |
| with torch.no_grad(): | |
| out = model.generate( | |
| input_ids=x, | |
| max_length=64, | |
| do_sample=False, | |
| num_beams=2 | |
| ) | |
| return np.ones((x.shape[0], 1)) # dummy prediction | |
| # SHAP explainer | |
| explainer = shap.Explainer(f, input_ids.numpy()) | |
| shap_values = explainer(input_ids.numpy()) | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| return shap_values.values[0], tokens | |