| from transformers import BertConfig, BertModel |
| import torch |
| import re |
| from torch.utils.data import DataLoader, Dataset |
| from sklearn.model_selection import train_test_split, cross_validate |
| import pytorch_lightning as pl |
| import pandas as pd |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
| from torch.optim import AdamW |
| from sklearn.metrics import f1_score |
|
|
| MAX_LEN = 96 |
| PAD_ID = 0 |
|
|
| config = BertConfig( |
| vocab_size=40, |
| hidden_size=64, |
| num_hidden_layers=4, |
| num_attention_heads=4, |
| intermediate_size=256, |
| max_position_embeddings=MAX_LEN, |
| type_vocab_size=4 |
| ) |
|
|
|
|
|
|
| class MyDataset(Dataset): |
| def __init__(self, df, char2idx, label2idx, is_train=True): |
| super().__init__() |
| print(char2idx) |
| print(label2idx) |
| self.is_train = is_train |
| self.dataset = get_dataset3(df, char2idx, label2idx, is_train=is_train) |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| return self.dataset[idx] |
|
|
|
|
| def collate_fn(self, batch): |
| collated = { |
| "input_ids": torch.IntTensor([(x[0] if self.is_train else x)["input_ids"] for x in batch]), |
| "attention_mask": torch.Tensor([(x[0] if self.is_train else x)["attention_mask"] for x in batch]), |
| "token_type_ids": torch.IntTensor([(x[0] if self.is_train else x)["token_type_ids"] for x in batch]) |
| } |
| if self.is_train: |
| collated = collated, torch.IntTensor([x[1] for x in batch]) |
|
|
| return collated |
|
|
|
|
| def get_preprocessed_dfs(folder): |
| df = pd.read_csv(f"{folder}/train_data.csv").drop_duplicates() |
| df.loc[:, "Tag"] = df.Tag.apply(lambda x: "CAUS_2" if x.startswith("CAUS_") and x != "CAUS_1" else x) |
|
|
| cats = ['FUT_INDF_3PLF', 'FUT_INDF_NEG', 'PST_INDF_PS', 'PCP_FUT_NEG', 'PCP_FUT_DEF', 'PRES_CONT', 'PRES_2SGF', 'POSS_2SGF', 'POSS_2PLF', 'NUM_APPR3', 'NUM_APPR2', 'NUM_APPR1', 'ADVV_CONT', 'ADJECTIVE', 'PST_ITER', 'PST_INDF', 'PST_EVID', 'PRES_PST', 'POSS_3SG', 'POSS_3PL', 'POSS_2SG', 'POSS_2PL', 'POSS_1SG', 'POSS_1PL', 'NUM_COLL', 'FUT_INDF', 'ADVV_SUC', 'ADVV_NEG', 'ADVV_INT', 'ADVV_ACC', 'PST_DEF', 'NUM_ORD', 'NUMERAL', 'IMP_SGF', 'IMP_PLF', 'FUT_DEF', 'PREC_1', 'PCP_PS', 'PCP_PR', 'JUS_SG', 'JUS_PL', 'IMP_SG', 'IMP_PL', 'HOR_SG', 'HOR_PL', 'DESIDE', 'CAUS_2', 'CAUS_1', 'INF_5', 'INF_4', 'INF_3', 'INF_2', 'INF_1', 'VERB', 'REFL', 'RECP', 'PRES', 'PREM', 'PERS', 'PASS', 'COND', 'COMP', '2SGF', '2PLF', 'SUC', 'OPT', 'NOM', 'NEG', 'NEG', 'LOC', 'INT', 'GEN', 'DAT', 'ACT', 'ACC', 'ABL', '3SG', '3PL', '2SG', '2PL', '1SG', '1PL', 'SG', 'PL'] |
| cats = sorted([x.lower() for x in cats], key=lambda x: (len(x), x), reverse=True) |
|
|
| for col in df.columns: |
| df.loc[:, col] = df[col].apply(lambda x: x.strip().lower()) |
|
|
| def tag2list(t): |
| res = [] |
| for c in cats: |
| if c in t: |
| res.append(c) |
| t = t.replace(c, "") |
| return res |
|
|
| df.loc[:, "Tag"] = df.Tag.apply(tag2list) |
|
|
| tdf = pd.read_csv(f"{folder}/test_data.csv") |
| tdf.pop("Tag") |
| for col in tdf.columns: |
| tdf.loc[:, col] = tdf[col].apply(lambda x: x.strip().lower()) |
|
|
| return {"train": df.rename(columns={x: x.lower() for x in df.columns}), "test": tdf.rename(columns={x: x.lower() for x in tdf.columns})} |
|
|
| def get_preprocessed_dfs2(folder): |
| df = pd.read_csv(f"{folder}/train_data.csv").drop_duplicates() |
| df.loc[:, "Tag"] = df.Tag.apply(lambda x: "CAUS_2" if x.startswith("CAUS_") and x != "CAUS_1" else x) |
|
|
| for col in df.columns: |
| df.loc[:, col] = df[col].apply(lambda x: x.strip().lower()) |
|
|
| tdf = pd.read_csv(f"{folder}/test_data.csv") |
| tdf.pop("Tag") |
| for col in tdf.columns: |
| tdf.loc[:, col] = tdf[col].apply(lambda x: x.strip().lower()) |
|
|
| return {"train": df.rename(columns={x: x.lower() for x in df.columns}), "test": tdf.rename(columns={x: x.lower() for x in tdf.columns})} |
|
|
| def get_splits(df, test_size=0.2): |
| unique_roots = df.root.unique() |
| print("unique roots", len(unique_roots)) |
| train, validation = train_test_split(unique_roots, test_size=test_size, random_state=2023) |
| print("unique train roots", len(train)) |
| print("unique validation roots", len(validation)) |
| train_df = df[df.root.isin(train)] |
| validation_df = df[df.root.isin(validation)] |
|
|
| return train_df, validation_df |
|
|
| def get_char2idx(all_splits, special_chars=("<pad>", "<s>", "</s>")): |
| charset = set() |
| for split, df in all_splits.items(): |
| charset = charset.union("".join(df.apply(lambda r: r.root + r.affix, axis=1))) |
| return {x: i for i, x in enumerate(list(special_chars) + sorted(charset))} |
|
|
| def get_dataset(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): |
| pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} |
|
|
| result = [] |
|
|
| for r in split.itertuples(): |
|
|
| input_ids = [char2idx["<s>"], pos2idx[r.pos_word], pos2idx[r.pos_root]] |
| attention_mask = [1, 1, 1] |
| token_type_ids = [0, 0, 0] |
|
|
| |
| for c in r.word: |
| input_ids.append(char2idx[c]) |
| attention_mask.append(1) |
| token_type_ids.append(1) |
|
|
| for c in r.root: |
| input_ids.append(char2idx[c]) |
| attention_mask.append(1) |
| token_type_ids.append(2) |
|
|
| for c in r.affix: |
| input_ids.append(char2idx[c]) |
| attention_mask.append(1) |
| token_type_ids.append(3) |
|
|
| input_ids.append(char2idx["</s>"]) |
| attention_mask.append(1) |
| token_type_ids.append(3) |
|
|
| input_ids = input_ids[:MAX_LEN] |
| attention_mask = attention_mask[:MAX_LEN] |
| token_type_ids = token_type_ids[:MAX_LEN] |
|
|
|
|
| for _ in range(MAX_LEN - len(input_ids)): |
| input_ids.append(char2idx["<pad>"]) |
| attention_mask.append(0) |
| token_type_ids.append(3) |
|
|
| result.append( |
| { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "token_type_ids": token_type_ids, |
| } |
| ) |
|
|
| if is_train: |
| result[-1] = (result[-1], [0 for _ in range(len(label2idx))]) |
| for tag in r.tag: |
| result[-1][-1][label2idx[tag]] = 1 |
| |
|
|
| return result |
|
|
| def get_dataset3(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): |
| pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} |
|
|
| result = [] |
|
|
| for xs, r in enumerate(split.itertuples()): |
|
|
| input_ids = [char2idx["<s>"], pos2idx[r.pos_root]] |
| attention_mask = [1, 1] |
| token_type_ids = [0, 0] |
|
|
| for c in r.root: |
| input_ids.append(char2idx[c]) |
| attention_mask.append(1) |
| token_type_ids.append(1) |
|
|
| for c in r.affix: |
| input_ids.append(char2idx[c]) |
| attention_mask.append(1) |
| token_type_ids.append(2) |
|
|
| input_ids.append(char2idx["</s>"]) |
| attention_mask.append(1) |
| token_type_ids.append(2) |
|
|
| input_ids = input_ids[:MAX_LEN] |
| attention_mask = attention_mask[:MAX_LEN] |
| token_type_ids = token_type_ids[:MAX_LEN] |
|
|
|
|
| for _ in range(MAX_LEN - len(input_ids)): |
| input_ids.append(char2idx["<pad>"]) |
| attention_mask.append(0) |
| token_type_ids.append(2) |
|
|
| result.append( |
| { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "token_type_ids": token_type_ids, |
| } |
| ) |
|
|
| if is_train: |
| result[-1] = (result[-1], label2idx[r.tag]) |
|
|
| if xs + 1 % 1000 == 0: |
| print(input_ids) |
| print(attention_mask) |
| print(token_type_ids) |
|
|
| return result |
|
|
| def get_dataset2(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): |
| pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} |
|
|
| result = [] |
|
|
| for xs, r in enumerate(split.itertuples()): |
|
|
| input_ids = [char2idx["<s>"], pos2idx[r.pos_word], pos2idx[r.pos_root]] |
| attention_mask = [1, 1, 1] |
| token_type_ids = [0, 0, 0] |
|
|
| |
| for c in r.word: |
| input_ids.append(char2idx[c]) |
| attention_mask.append(1) |
| token_type_ids.append(1) |
|
|
| for c in r.root: |
| input_ids.append(char2idx[c]) |
| attention_mask.append(1) |
| token_type_ids.append(2) |
|
|
| for c in r.affix: |
| input_ids.append(char2idx[c]) |
| attention_mask.append(1) |
| token_type_ids.append(3) |
|
|
| input_ids.append(char2idx["</s>"]) |
| attention_mask.append(1) |
| token_type_ids.append(3) |
|
|
| input_ids = input_ids[:MAX_LEN] |
| attention_mask = attention_mask[:MAX_LEN] |
| token_type_ids = token_type_ids[:MAX_LEN] |
|
|
|
|
| for _ in range(MAX_LEN - len(input_ids)): |
| input_ids.append(char2idx["<pad>"]) |
| attention_mask.append(0) |
| token_type_ids.append(3) |
|
|
| result.append( |
| { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "token_type_ids": token_type_ids, |
| } |
| ) |
|
|
| if is_train: |
| result[-1] = (result[-1], label2idx[r.tag]) |
|
|
| if xs + 1 % 10000 == 0: |
| print(input_ids) |
| print(attention_mask) |
| print(token_type_ids) |
|
|
|
|
| return result |
|
|
| def train_model(epochs=100, batch_size=400, data_folder="../Downloads/"): |
| dfs = get_preprocessed_dfs2(data_folder) |
| train, val = get_splits(dfs["train"]) |
| char2idx = get_char2idx(dfs) |
| |
| label2idx = {l: i for i, l in enumerate(dfs["train"].tag.unique())} |
|
|
| model = MyModel2(config, label2idx, char2idx, 0.5) |
| checkpoint_callback = ModelCheckpoint( |
| dirpath="fmicro_weights", |
| save_top_k=3, |
| monitor="fmicro", |
| mode="max", |
| filename="{epoch}-{step}", |
| ) |
| trainer = pl.Trainer( |
| deterministic=True, |
| max_epochs=epochs, |
| callbacks=[EarlyStopping(monitor="fmicro", mode="max"), checkpoint_callback], |
| log_every_n_steps=30, |
| ) |
|
|
| train_dataset = MyDataset(train, char2idx, label2idx) |
| validation_dataset = MyDataset(val, char2idx, label2idx) |
| trainer.fit(model, DataLoader(train_dataset, batch_size=400, collate_fn=train_dataset.collate_fn), DataLoader(validation_dataset, batch_size=400, collate_fn=validation_dataset.collate_fn)) |
|
|
| best_model_path = [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)][0].best_model_path |
|
|
| model.load_state_dict(torch.load(best_model_path)["state_dict"]) |
|
|
| return model, train, val, dfs["test"] |
|
|
|
|
| class MyModel(pl.LightningModule): |
| def __init__(self, config, label2idx, threshold, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.threshold = threshold |
| self.char2idx = char2idx |
| self.label2idx = label2idx |
| self.idx2label = {i: l for l, i in label2idx.items()} |
| self.bert = BertModel(config) |
| self.dropout = torch.nn.Dropout(0.3) |
| self.proj = torch.nn.Linear(config.hidden_size, len(label2idx)) |
|
|
|
|
| def common_step(self, batch): |
| X, _ = batch |
| hidden = self.bert(**X)[1] |
| return self.proj(self.dropout(hidden)) |
|
|
| def training_step(self, batch, batch_idx): |
| |
| logits = self.common_step(batch) |
| loss = torch.nn.BCEWithLogitsLoss()(logits, batch[1].float()) |
| self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) |
|
|
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| |
| |
| logits = self.common_step(batch) |
| |
| |
| loss = torch.nn.BCEWithLogitsLoss()(logits, batch[1].float()) |
| self.log("loss", loss.mean(), prog_bar=True, on_epoch=True) |
|
|
| return logits, loss |
|
|
| def test_step(self, batch, batch_idx): |
| return self.common_step((batch, [])) |
|
|
| def forward(self, batch, batch_idx): |
| return self.common_step((batch, [])) |
| |
| def configure_optimizers(self): |
| return AdamW(params=self.parameters()) |
|
|
| class MyModel2(pl.LightningModule): |
| def __init__(self, config, label2idx, char2idx, threshold, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.threshold = threshold |
| self.char2idx = char2idx |
| self.fscore = 0.0 |
| self.label2idx = label2idx |
| self.idx2label = {i: l for l, i in label2idx.items()} |
| self.bert = BertModel(config) |
| self.dropout = torch.nn.Dropout(0.3) |
| self.proj = torch.nn.Linear(config.hidden_size, len(label2idx)) |
|
|
|
|
| def common_step(self, batch): |
| X, _ = batch |
| hidden = self.bert(**X)[1] |
| return self.proj(self.dropout(hidden)) |
|
|
| def training_step(self, batch, batch_idx): |
| |
| logits = self.common_step(batch) |
| loss = torch.nn.CrossEntropyLoss()(logits.view(-1, len(self.label2idx)), batch[1].view(-1).long()) |
| self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) |
|
|
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| |
| |
| logits = self.common_step(batch) |
| |
| |
| loss = torch.nn.CrossEntropyLoss()(logits.view(-1, len(self.label2idx)), batch[1].view(-1).long()) |
| for p in logits: |
| self.predos.append(self.idx2label[p.argmax().cpu().item()]) |
| for t in batch[1]: |
| self.trues.append(self.idx2label[t.cpu().item()]) |
| self.log("loss", loss.mean(), prog_bar=True, on_epoch=True) |
| self.log("fmicro", self.fscore, prog_bar=True, on_epoch=True) |
|
|
| return logits, loss |
|
|
| def on_validation_start(self): |
| self.predos = [] |
| self.trues = [] |
|
|
| def on_validation_end(self): |
| self.fscore = f1_score(self.trues, self.predos, average="micro") |
|
|
| def test_step(self, batch, batch_idx): |
| return self.common_step((batch, [])) |
|
|
| def forward(self, batch, batch_idx): |
| return self.common_step((batch, [])) |
| |
| def configure_optimizers(self): |
| return AdamW(params=self.parameters()) |
|
|
| def predict(self, dataloader): |
| pass |
|
|
|
|