Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from transformers import GPT2LMHeadModel | |
| def load_model(model_name: str = "snoop2head/Gomoku-GPT2") -> GPT2LMHeadModel: | |
| gpt2 = GPT2LMHeadModel.from_pretrained(model_name) | |
| return gpt2 | |
| BOS_TOKEN_ID = 401 | |
| PAD_TOKEN_ID = 402 | |
| EOS_TOKEN_ID = 403 | |
| def generate_gpt2(model: GPT2LMHeadModel, input_ids: torch.LongTensor) -> list: | |
| """ | |
| input_ids: [batch_size, seq_len] torch.LongTensor | |
| output_ids: [seq_len] list | |
| """ | |
| output_ids = model.generate( | |
| input_ids, | |
| max_length=128, | |
| num_beams=5, | |
| temperature=0.7, | |
| pad_token_id=PAD_TOKEN_ID, | |
| eos_token_id=EOS_TOKEN_ID, | |
| ) | |
| return output_ids.squeeze().tolist() | |
| def change_to_1d_coordinate(board: np.ndarray, x: int, y: int) -> int: | |
| """change 2d coordinate to 1d coordinate""" | |
| return x * board.shape[1] + y | |
| def change_to_2d_coordinate(board: np.ndarray, coordinate: int) -> tuple: | |
| """change 1d coordinate to 2d coordinate""" | |
| return (coordinate // board.shape[1], coordinate % board.shape[1]) | |