| | from typing import Dict, List, Union |
| | import torch |
| | from transformers import AutoModel |
| | from custom_tokenizer import CustomPhobertTokenizer |
| |
|
| |
|
| | def mean_pooling(model_output, attention_mask): |
| | token_embeddings = model_output[ |
| | 0 |
| | ] |
| | input_mask_expanded = ( |
| | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | ) |
| | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| | input_mask_expanded.sum(1), min=1e-9 |
| | ) |
| |
|
| |
|
| | class PreTrainedPipeline: |
| | def __init__(self, path="."): |
| | self.model = AutoModel.from_pretrained(path) |
| | self.tokenizer = CustomPhobertTokenizer.from_pretrained(path) |
| |
|
| | def __call__(self, inputs: Dict[str, Union[str, List[str]]]) -> List[float]: |
| | """ |
| | Args: |
| | inputs (Dict[str, Union[str, List[str]]]): |
| | a dictionary containing a query sentence and a list of key sentences |
| | """ |
| |
|
| | |
| | sentences = [inputs["source_sentence"]] + inputs["sentences"] |
| |
|
| | |
| | encoded_input = self.tokenizer( |
| | sentences, padding=True, truncation=True, return_tensors="pt" |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | model_output = self.model(**encoded_input) |
| |
|
| | |
| | sentence_embeddings = mean_pooling( |
| | model_output, encoded_input["attention_mask"] |
| | ) |
| |
|
| | |
| | query_embedding = sentence_embeddings[0] |
| | key_embeddings = sentence_embeddings[1:] |
| |
|
| | |
| | cosine_similarities = torch.nn.functional.cosine_similarity( |
| | query_embedding.unsqueeze(0), key_embeddings |
| | ) |
| |
|
| | |
| | scores = cosine_similarities.tolist() |
| |
|
| | return scores |
| |
|
| |
|
| | if __name__ == "__main__": |
| | inputs = { |
| | "source_sentence": "Anh ấy đang là sinh viên năm cuối", |
| | "sentences": [ |
| | "Anh ấy học tại Đại học Bách khoa Hà Nội, chuyên ngành Khoa học máy tính", |
| | "Anh ấy đang làm việc tại nhà máy sản xuất linh kiện điện tử", |
| | "Anh ấy chuẩn bị đi du học nước ngoài", |
| | "Anh ấy sắp mở cửa hàng bán mỹ phẩm", |
| | "Nhà anh ấy có rất nhiều cây cảnh", |
| | ], |
| | } |
| |
|
| | pipeline = PreTrainedPipeline() |
| | res = pipeline(inputs) |
| |
|