{ "cells": [ { "cell_type": "markdown", "id": "ccc71090-4881-459a-a021-8476ca943393", "metadata": { "libroFormatter": "formatter-string", "trusted": true }, "source": [ "# 分别加载SaProt和ChemBerta模型,查看输入输格式" ] }, { "cell_type": "code", "execution_count": null, "id": "0b93a9b5-102b-4607-9e7d-3b0760ed41fe", "metadata": { "isLargeOutputDisplay": true, "libroFormatter": "formatter-string", "trusted": true }, "outputs": [], "source": [ "import torch\n", "def generate(model, tokenizer, input_seq):\n", " tokens = tokenizer.tokenize(input_seq)\n", " \n", " inputs = tokenizer(input_seq, return_tensors=\"pt\")\n", " inputs = {k: v.to(device) for k, v in inputs.items()}\n", "\n", " outputs = model(**inputs)\n", "\n", " logits = outputs.logits # shape: [1, 序列长度, 词表大小]\n", " # 去掉开头和结尾的占位符\n", " # logits = logits[:, 1:-1, :]\n", " predictions = torch.argmax(logits, dim=-1) # shape: [1, 序列长度]\n", " decoded_sequence = tokenizer.decode(predictions[0], skip_special_tokens=True)\n", "\n", " print('='*100)\n", " print(f'Input Token({len(tokens)}): ', tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))\n", " print(f'Output Token: ', tokenizer.tokenize(decoded_sequence))\n", " print('Output Logits Shape: ', outputs.logits.shape)\n", "\n", " print('-'*100)\n", " print('Input Seq: ', input_seq)\n", " print('Predict Seq: ', decoded_sequence)\n", " return logits, decoded_sequence" ] }, { "cell_type": "code", "execution_count": null, "id": "0a43e1a8-7e57-4528-8e8c-73e2b436130c", "metadata": { "isLargeOutputDisplay": true, "libroFormatter": "formatter-string", "trusted": true }, "outputs": [], "source": [ "from transformers import EsmTokenizer, EsmForMaskedLM\n", "saprot_model_path = \"./SaProt_650M_AF2\"\n", "device = \"cuda\"\n", "\n", "saprot_tokenizer = EsmTokenizer.from_pretrained(saprot_model_path)\n", "saprot_model = EsmForMaskedLM.from_pretrained(saprot_model_path).to(device)\n", "\n", "print(\"--- Running SaProt_650M_AF2 Inference ---\")\n", "logits, decoded_sequence = generate(saprot_model, saprot_tokenizer, \"M#EvVpQpL#VyQdYaKv\")\n", "logits, decoded_sequence = generate(saprot_model, saprot_tokenizer, \"M#\")" ] }, { "cell_type": "code", "execution_count": null, "id": "85ff7812-c7f7-4429-867a-0f6c645daa81", "metadata": { "isLargeOutputDisplay": true, "libroFormatter": "formatter-string", "trusted": true }, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForMaskedLM\n", "chem_model_path = \"./ChemBERTa-zinc-base-v1\" \n", "device = \"cuda\"\n", "\n", "chem_tokenizer = AutoTokenizer.from_pretrained(chem_model_path)\n", "chem_model = AutoModelForMaskedLM.from_pretrained(chem_model_path).to(device)\n", "\n", "smiles_seq = \"CC(=O)OC1=CC=CC=C1C(=O)O\" \n", "\n", "print(\"--- Running ChemBERTa Inference ---\")\n", "logits, decoded_sequence = generate(chem_model, chem_tokenizer, smiles_seq)" ] }, { "cell_type": "code", "execution_count": null, "id": "d7ba541f-6588-4527-a438-dbec741a0ad2", "metadata": { "isLargeOutputDisplay": true, "libroFormatter": "formatter-string", "trusted": true }, "outputs": [], "source": [] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }