Spaces:
Sleeping
Sleeping
| import os | |
| import jsonlines | |
| import argparse | |
| from tqdm import tqdm | |
| import logging | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, matthews_corrcoef | |
| from backend.section_infer_helper.local_llm_helper import local_llm_helper | |
| from backend.section_infer_helper.online_llm_helper import online_llm_helper | |
| INCLUDE_MSG = "no" | |
| BATCH_SIZE = 4 | |
| # overwrite by environment variables | |
| INCLUDE_MSG = os.environ.get("INCLUDE_MSG", INCLUDE_MSG) | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| def main(args): | |
| if args.type == "local": | |
| helper = local_llm_helper | |
| helper.load_model(args.model, args.peft) | |
| elif args.type == "online": | |
| helper = online_llm_helper | |
| helper.load_model(args.model, args.url, args.key) | |
| labels = [] | |
| predicts = [] | |
| input_prompts = [] | |
| output_text = [] | |
| output_probs = [] | |
| inputs = [] | |
| with jsonlines.open(args.data, "r") as reader: | |
| test_data = list(reader) | |
| finished_item = [] | |
| if os.path.exists(args.output): | |
| with jsonlines.open(args.output, "r") as reader: | |
| for i, item in enumerate(reader): | |
| finished_item.append((item["commit_id"], item["file_name"])) | |
| test_data[i] = item | |
| for section in item["sections"]: | |
| labels.append(section["related"]) | |
| predicts.append(section["predict"]) | |
| input_prompts.append(section["input_prompt"]) | |
| output_text.append(section["output_text"]) | |
| output_probs.append(section["conf"]) | |
| for item in test_data: | |
| file_name = item["file_name"] | |
| patch = item["patch"] | |
| if (item["commit_id"],item["file_name"]) in finished_item: | |
| print(f"Skip {item['commit_id']}, {item['file_name']}") | |
| continue | |
| commit_message = item["commit_message"] if INCLUDE_MSG == "yes" else "" | |
| for section in item["sections"]: | |
| section_content = section["section"] | |
| inputs.append(helper.InputData(file_name, patch, section_content, commit_message)) | |
| labels.append(section["related"]) | |
| assert len(labels) == 4088, f"Get {len(labels)} labels" | |
| try: | |
| this_input_prompts, this_output_text, this_output_probs = helper.do_infer(inputs, BATCH_SIZE) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| input_prompts.extend(this_input_prompts) | |
| output_text.extend(this_output_text) | |
| output_probs.extend(this_output_probs) | |
| for result in output_text: | |
| predicts.append("yes" in result.lower()) | |
| # accuracy = accuracy_score(labels, predicts) | |
| # precision = precision_score(labels, predicts) | |
| # recall = recall_score(labels, predicts) | |
| # f1 = f1_score(labels, predicts) | |
| # mcc = matthews_corrcoef(labels, predicts) | |
| # tp, fp, fn, tn = confusion_matrix(labels, predicts).ravel() | |
| # fpr = fp / (fp + tn + 1e-8) | |
| # print("=" * 20) | |
| # print(f"Accuracy: {accuracy}") | |
| # print(f"Precision: {precision}") | |
| # print(f"Recall: {recall}") | |
| # print(f"F1: {f1}") | |
| # print(f"MCC: {mcc}") | |
| # print(f"FPR: {fpr}") | |
| # print("=" * 20) | |
| with jsonlines.open(args.output, "w") as writer: | |
| for item in test_data: | |
| if len(output_text) < len(item["sections"]): | |
| logging.info("Not enough output") | |
| break | |
| for section in item["sections"]: | |
| section["input_prompt"] = input_prompts.pop(0) | |
| section["output_text"] = output_text.pop(0) | |
| section["predict"] = True if predicts.pop(0) else False | |
| section["conf"] = output_probs.pop(0) | |
| writer.write(item) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-d", "--data", type=str, required=True, help="Path to the data file") | |
| parser.add_argument("-t", "--type", type=str, required=True, help="Type of the model", choices=["local", "online"]) | |
| parser.add_argument("-m", "--model", type=str, required=True) | |
| parser.add_argument("-p", "--peft", type=str, help="Path to the PEFT file") | |
| parser.add_argument("-u", "--url", type=str, help="URL of the model") | |
| parser.add_argument("-k", "--key", type=str, help="API key") | |
| parser.add_argument("-o", "--output", type=str, required=True, help="Path to the output file") | |
| args = parser.parse_args() | |
| main(args) | |