Spaces:
Runtime error
Runtime error
Improvement in the display of the graph axes labels. Generalization of rankSent class. Minor fixes.
Browse files- modules/module_BiasExplorer.py +22 -10
- modules/module_connection.py +12 -10
- modules/module_rankSents.py +29 -25
- modules/utils.py +64 -3
modules/module_BiasExplorer.py
CHANGED
|
@@ -5,7 +5,7 @@ import seaborn as sns
|
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
from sklearn.decomposition import PCA
|
| 7 |
from typing import List, Dict, Tuple, Optional, Any
|
| 8 |
-
from modules.utils import normalize, cosine_similarity, project_params, take_two_sides_extreme_sorted
|
| 9 |
|
| 10 |
__all__ = ['WordBiasExplorer', 'WEBiasExplorer2Spaces', 'WEBiasExplorer4Spaces']
|
| 11 |
|
|
@@ -371,9 +371,14 @@ class WEBiasExplorer2Spaces(WordBiasExplorer):
|
|
| 371 |
plt.xticks(np.arange(-most_extream_projection,
|
| 372 |
most_extream_projection + axis_projection_step,
|
| 373 |
axis_projection_step))
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
plt.xlabel(xlabel)
|
| 379 |
plt.ylabel('Words')
|
|
@@ -515,13 +520,20 @@ class WEBiasExplorer4Spaces(WordBiasExplorer):
|
|
| 515 |
for _, row in (projections_df.iterrows()):
|
| 516 |
ax.annotate(
|
| 517 |
row['word'], (row['projection_x'], row['projection_y']))
|
| 518 |
-
x_label = 'β {} {} {} β'.format(name_left,
|
| 519 |
-
' ' * 20,
|
| 520 |
-
name_right)
|
| 521 |
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
plt.xlabel(x_label)
|
| 527 |
ax.xaxis.set_label_position('bottom')
|
|
|
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
from sklearn.decomposition import PCA
|
| 7 |
from typing import List, Dict, Tuple, Optional, Any
|
| 8 |
+
from modules.utils import normalize, cosine_similarity, project_params, take_two_sides_extreme_sorted, axes_labels_format
|
| 9 |
|
| 10 |
__all__ = ['WordBiasExplorer', 'WEBiasExplorer2Spaces', 'WEBiasExplorer4Spaces']
|
| 11 |
|
|
|
|
| 371 |
plt.xticks(np.arange(-most_extream_projection,
|
| 372 |
most_extream_projection + axis_projection_step,
|
| 373 |
axis_projection_step))
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
xlabel = axes_labels_format(
|
| 377 |
+
left=self.negative_end,
|
| 378 |
+
right=self.positive_end,
|
| 379 |
+
sep=' ' * 20,
|
| 380 |
+
word_wrap=3
|
| 381 |
+
)
|
| 382 |
|
| 383 |
plt.xlabel(xlabel)
|
| 384 |
plt.ylabel('Words')
|
|
|
|
| 520 |
for _, row in (projections_df.iterrows()):
|
| 521 |
ax.annotate(
|
| 522 |
row['word'], (row['projection_x'], row['projection_y']))
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
+
|
| 525 |
+
x_label = axes_labels_format(
|
| 526 |
+
left=name_left,
|
| 527 |
+
right=name_right,
|
| 528 |
+
sep=' ' * 20,
|
| 529 |
+
word_wrap=3
|
| 530 |
+
)
|
| 531 |
+
y_label = axes_labels_format(
|
| 532 |
+
left=name_top,
|
| 533 |
+
right=name_bottom,
|
| 534 |
+
sep=' ' * 20,
|
| 535 |
+
word_wrap=3
|
| 536 |
+
)
|
| 537 |
|
| 538 |
plt.xlabel(x_label)
|
| 539 |
ax.xaxis.set_label_position('bottom')
|
modules/module_connection.py
CHANGED
|
@@ -422,11 +422,12 @@ class PhraseBiasExplorerConnector(Connector):
|
|
| 422 |
def rank_sentence_options(
|
| 423 |
self,
|
| 424 |
sent: str,
|
| 425 |
-
|
| 426 |
banned_word_list: str,
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
|
|
|
| 430 |
) -> Tuple:
|
| 431 |
|
| 432 |
sent = " ".join(sent.strip().replace("*"," * ").split())
|
|
@@ -435,7 +436,7 @@ class PhraseBiasExplorerConnector(Connector):
|
|
| 435 |
if err:
|
| 436 |
return err, "", ""
|
| 437 |
|
| 438 |
-
|
| 439 |
banned_word_list = self.parse_words(banned_word_list)
|
| 440 |
|
| 441 |
# Save inputs in logs file
|
|
@@ -443,16 +444,17 @@ class PhraseBiasExplorerConnector(Connector):
|
|
| 443 |
self.logs_file_name,
|
| 444 |
self.headers,
|
| 445 |
sent,
|
| 446 |
-
|
| 447 |
)
|
| 448 |
|
| 449 |
all_plls_scores = self.phrase_bias_explorer.rank(
|
| 450 |
sent,
|
| 451 |
-
|
| 452 |
banned_word_list,
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
|
|
|
| 456 |
)
|
| 457 |
|
| 458 |
all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
|
|
|
|
| 422 |
def rank_sentence_options(
|
| 423 |
self,
|
| 424 |
sent: str,
|
| 425 |
+
interest_word_list: str,
|
| 426 |
banned_word_list: str,
|
| 427 |
+
exclude_articles: bool,
|
| 428 |
+
exclude_prepositions: bool,
|
| 429 |
+
exclude_conjunctions: bool,
|
| 430 |
+
n_predictions: int=5
|
| 431 |
) -> Tuple:
|
| 432 |
|
| 433 |
sent = " ".join(sent.strip().replace("*"," * ").split())
|
|
|
|
| 436 |
if err:
|
| 437 |
return err, "", ""
|
| 438 |
|
| 439 |
+
interest_word_list = self.parse_words(interest_word_list)
|
| 440 |
banned_word_list = self.parse_words(banned_word_list)
|
| 441 |
|
| 442 |
# Save inputs in logs file
|
|
|
|
| 444 |
self.logs_file_name,
|
| 445 |
self.headers,
|
| 446 |
sent,
|
| 447 |
+
interest_word_list
|
| 448 |
)
|
| 449 |
|
| 450 |
all_plls_scores = self.phrase_bias_explorer.rank(
|
| 451 |
sent,
|
| 452 |
+
interest_word_list,
|
| 453 |
banned_word_list,
|
| 454 |
+
exclude_articles,
|
| 455 |
+
exclude_prepositions,
|
| 456 |
+
exclude_conjunctions,
|
| 457 |
+
n_predictions
|
| 458 |
)
|
| 459 |
|
| 460 |
all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
|
modules/module_rankSents.py
CHANGED
|
@@ -66,13 +66,14 @@ class RankSents:
|
|
| 66 |
|
| 67 |
return self.errorManager.process(out_msj)
|
| 68 |
|
| 69 |
-
def
|
| 70 |
self,
|
|
|
|
| 71 |
sent: str,
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
) -> List[str]:
|
| 77 |
|
| 78 |
sent_masked = sent.replace("*", self.tokenizer.mask_token)
|
|
@@ -80,7 +81,8 @@ class RankSents:
|
|
| 80 |
sent_masked,
|
| 81 |
add_special_tokens=True,
|
| 82 |
return_tensors='pt',
|
| 83 |
-
return_attention_mask=True,
|
|
|
|
| 84 |
)
|
| 85 |
|
| 86 |
tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item()
|
|
@@ -94,26 +96,26 @@ class RankSents:
|
|
| 94 |
probabilities = outputs[tk_position_mask]
|
| 95 |
first_tk_id = torch.argsort(probabilities, descending=True)
|
| 96 |
|
| 97 |
-
|
| 98 |
for tk_id in first_tk_id:
|
| 99 |
tk_string = self.tokenizer.decode([tk_id])
|
| 100 |
|
| 101 |
-
tk_is_banned = tk_string in
|
| 102 |
tk_is_punctuation = not tk_string.isalnum()
|
| 103 |
tk_is_substring = tk_string.startswith("##")
|
| 104 |
tk_is_special = (tk_string in self.tokenizer.all_special_tokens)
|
| 105 |
|
| 106 |
-
if
|
| 107 |
tk_is_article = tk_string in self.articles
|
| 108 |
else:
|
| 109 |
tk_is_article = False
|
| 110 |
|
| 111 |
-
if
|
| 112 |
tk_is_prepositions = tk_string in self.prepositions
|
| 113 |
else:
|
| 114 |
tk_is_prepositions = False
|
| 115 |
|
| 116 |
-
if
|
| 117 |
tk_is_conjunctions = tk_string in self.conjunctions
|
| 118 |
else:
|
| 119 |
tk_is_conjunctions = False
|
|
@@ -128,39 +130,41 @@ class RankSents:
|
|
| 128 |
tk_is_conjunctions
|
| 129 |
])
|
| 130 |
|
| 131 |
-
if predictions_is_dessire and len(
|
| 132 |
-
|
| 133 |
|
| 134 |
-
elif len(
|
| 135 |
break
|
| 136 |
|
| 137 |
-
return
|
| 138 |
|
| 139 |
def rank(self,
|
| 140 |
sent: str,
|
| 141 |
-
|
| 142 |
banned_word_list: List[str]=[],
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
| 146 |
) -> Dict[str, float]:
|
| 147 |
|
| 148 |
err = self.errorChecking(sent)
|
| 149 |
if err:
|
| 150 |
raise Exception(err)
|
| 151 |
|
| 152 |
-
if not
|
| 153 |
-
|
|
|
|
| 154 |
sent,
|
| 155 |
banned_word_list,
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
)
|
| 160 |
|
| 161 |
sent_list = []
|
| 162 |
sent_list2print = []
|
| 163 |
-
for word in
|
| 164 |
sent_list.append(sent.replace("*", "<"+word+">"))
|
| 165 |
sent_list2print.append(sent.replace("*", "<"+word+">"))
|
| 166 |
|
|
|
|
| 66 |
|
| 67 |
return self.errorManager.process(out_msj)
|
| 68 |
|
| 69 |
+
def getTopPredictions(
|
| 70 |
self,
|
| 71 |
+
n: int,
|
| 72 |
sent: str,
|
| 73 |
+
banned_word_list: List[str],
|
| 74 |
+
exclude_articles: bool,
|
| 75 |
+
exclude_prepositions: bool,
|
| 76 |
+
exclude_conjunctions: bool,
|
| 77 |
) -> List[str]:
|
| 78 |
|
| 79 |
sent_masked = sent.replace("*", self.tokenizer.mask_token)
|
|
|
|
| 81 |
sent_masked,
|
| 82 |
add_special_tokens=True,
|
| 83 |
return_tensors='pt',
|
| 84 |
+
return_attention_mask=True,
|
| 85 |
+
truncation=True
|
| 86 |
)
|
| 87 |
|
| 88 |
tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item()
|
|
|
|
| 96 |
probabilities = outputs[tk_position_mask]
|
| 97 |
first_tk_id = torch.argsort(probabilities, descending=True)
|
| 98 |
|
| 99 |
+
top_tks_pred = []
|
| 100 |
for tk_id in first_tk_id:
|
| 101 |
tk_string = self.tokenizer.decode([tk_id])
|
| 102 |
|
| 103 |
+
tk_is_banned = tk_string in banned_word_list
|
| 104 |
tk_is_punctuation = not tk_string.isalnum()
|
| 105 |
tk_is_substring = tk_string.startswith("##")
|
| 106 |
tk_is_special = (tk_string in self.tokenizer.all_special_tokens)
|
| 107 |
|
| 108 |
+
if exclude_articles:
|
| 109 |
tk_is_article = tk_string in self.articles
|
| 110 |
else:
|
| 111 |
tk_is_article = False
|
| 112 |
|
| 113 |
+
if exclude_prepositions:
|
| 114 |
tk_is_prepositions = tk_string in self.prepositions
|
| 115 |
else:
|
| 116 |
tk_is_prepositions = False
|
| 117 |
|
| 118 |
+
if exclude_conjunctions:
|
| 119 |
tk_is_conjunctions = tk_string in self.conjunctions
|
| 120 |
else:
|
| 121 |
tk_is_conjunctions = False
|
|
|
|
| 130 |
tk_is_conjunctions
|
| 131 |
])
|
| 132 |
|
| 133 |
+
if predictions_is_dessire and len(top_tks_pred) < n:
|
| 134 |
+
top_tks_pred.append(tk_string)
|
| 135 |
|
| 136 |
+
elif len(top_tks_pred) >= n:
|
| 137 |
break
|
| 138 |
|
| 139 |
+
return top_tks_pred
|
| 140 |
|
| 141 |
def rank(self,
|
| 142 |
sent: str,
|
| 143 |
+
interest_word_list: List[str]=[],
|
| 144 |
banned_word_list: List[str]=[],
|
| 145 |
+
exclude_articles: bool=False,
|
| 146 |
+
exclude_prepositions: bool=False,
|
| 147 |
+
exclude_conjunctions: bool=False,
|
| 148 |
+
n_predictions: int=5
|
| 149 |
) -> Dict[str, float]:
|
| 150 |
|
| 151 |
err = self.errorChecking(sent)
|
| 152 |
if err:
|
| 153 |
raise Exception(err)
|
| 154 |
|
| 155 |
+
if not interest_word_list:
|
| 156 |
+
interest_word_list = self.getTopPredictions(
|
| 157 |
+
n_predictions,
|
| 158 |
sent,
|
| 159 |
banned_word_list,
|
| 160 |
+
exclude_articles,
|
| 161 |
+
exclude_prepositions,
|
| 162 |
+
exclude_conjunctions
|
| 163 |
)
|
| 164 |
|
| 165 |
sent_list = []
|
| 166 |
sent_list2print = []
|
| 167 |
+
for word in interest_word_list:
|
| 168 |
sent_list.append(sent.replace("*", "<"+word+">"))
|
| 169 |
sent_list2print.append(sent.replace("*", "<"+word+">"))
|
| 170 |
|
modules/utils.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
-
from datetime import datetime
|
| 4 |
import pytz
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class DateLogs:
|
| 8 |
def __init__(
|
| 9 |
self,
|
| 10 |
-
zone: str="America/Argentina/Cordoba"
|
| 11 |
) -> None:
|
| 12 |
|
| 13 |
self.time_zone = pytz.timezone(zone)
|
|
@@ -80,4 +82,63 @@ def cosine_similarity(
|
|
| 80 |
v_norm = np.linalg.norm(v)
|
| 81 |
u_norm = np.linalg.norm(u)
|
| 82 |
similarity = v @ u / (v_norm * u_norm)
|
| 83 |
-
return similarity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
|
|
|
| 3 |
import pytz
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
|
| 8 |
|
| 9 |
class DateLogs:
|
| 10 |
def __init__(
|
| 11 |
self,
|
| 12 |
+
zone: str = "America/Argentina/Cordoba"
|
| 13 |
) -> None:
|
| 14 |
|
| 15 |
self.time_zone = pytz.timezone(zone)
|
|
|
|
| 82 |
v_norm = np.linalg.norm(v)
|
| 83 |
u_norm = np.linalg.norm(u)
|
| 84 |
similarity = v @ u / (v_norm * u_norm)
|
| 85 |
+
return similarity
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def axes_labels_format(
|
| 89 |
+
left: str,
|
| 90 |
+
right: str,
|
| 91 |
+
sep: str,
|
| 92 |
+
word_wrap: int = 4
|
| 93 |
+
) -> str:
|
| 94 |
+
|
| 95 |
+
def sparse(
|
| 96 |
+
word: str,
|
| 97 |
+
max_len: int
|
| 98 |
+
) -> str:
|
| 99 |
+
|
| 100 |
+
diff = max_len-len(word)
|
| 101 |
+
rest = diff if diff > 0 else 0
|
| 102 |
+
return word+" "*rest
|
| 103 |
+
|
| 104 |
+
def gen_block(
|
| 105 |
+
list_: List[str],
|
| 106 |
+
n_rows:int,
|
| 107 |
+
n_cols:int
|
| 108 |
+
) -> List[str]:
|
| 109 |
+
|
| 110 |
+
block = []
|
| 111 |
+
block_row = []
|
| 112 |
+
for r in range(n_rows):
|
| 113 |
+
for c in range(n_cols):
|
| 114 |
+
i = r * n_cols + c
|
| 115 |
+
w = list_[i] if i <= len(list_) - 1 else ""
|
| 116 |
+
block_row.append(w)
|
| 117 |
+
if (i+1) % n_cols == 0:
|
| 118 |
+
block.append(block_row)
|
| 119 |
+
block_row = []
|
| 120 |
+
return block
|
| 121 |
+
|
| 122 |
+
# Transform 'string' to list of string
|
| 123 |
+
l_list = [word.strip() for word in left.split(",") if word.strip() != ""]
|
| 124 |
+
r_list = [word.strip() for word in right.split(",") if word.strip() != ""]
|
| 125 |
+
|
| 126 |
+
# Get longest word, and longest_list
|
| 127 |
+
longest_list = max(len(l_list), len(r_list))
|
| 128 |
+
longest_word = len(max( max(l_list, key=len), max(r_list, key=len)))
|
| 129 |
+
|
| 130 |
+
# Creation of word blocks for each list
|
| 131 |
+
n_rows = (longest_list // word_wrap) if longest_list % word_wrap == 0 else (longest_list // word_wrap) + 1
|
| 132 |
+
n_cols = word_wrap
|
| 133 |
+
|
| 134 |
+
l_block = gen_block(l_list, n_rows, n_cols)
|
| 135 |
+
r_block = gen_block(r_list, n_rows, n_cols)
|
| 136 |
+
|
| 137 |
+
# Transform list of list to sparse string
|
| 138 |
+
labels = ""
|
| 139 |
+
for i,(l,r) in enumerate(zip(l_block, r_block)):
|
| 140 |
+
line = ' '.join([sparse(w, longest_word) for w in l]) + sep + \
|
| 141 |
+
' '.join([sparse(w, longest_word) for w in r])
|
| 142 |
+
labels += f"β {line} β\n" if i==0 else f" {line} \n"
|
| 143 |
+
|
| 144 |
+
return labels
|