Upload 44 files
Browse files- .gitattributes +6 -0
- src/binary_model/checkpoint-400/config.json +35 -0
- src/binary_model/checkpoint-400/model.safetensors +3 -0
- src/binary_model/checkpoint-400/optimizer.pt +3 -0
- src/binary_model/checkpoint-400/rng_state.pth +3 -0
- src/binary_model/checkpoint-400/scheduler.pt +3 -0
- src/binary_model/checkpoint-400/trainer_state.json +734 -0
- src/binary_model/checkpoint-400/training_args.bin +3 -0
- src/binary_model/runs/Nov19_19-55-03_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763571303.vadim-HP-Laptop-15s-eq1xxx.128002.0 +3 -0
- src/category_model/checkpoint-400/config.json +39 -0
- src/category_model/checkpoint-400/model.safetensors +3 -0
- src/category_model/checkpoint-400/optimizer.pt +3 -0
- src/category_model/checkpoint-400/rng_state.pth +3 -0
- src/category_model/checkpoint-400/scheduler.pt +3 -0
- src/category_model/checkpoint-400/trainer_state.json +734 -0
- src/category_model/checkpoint-400/training_args.bin +3 -0
- src/category_model/runs/Nov19_16-34-43_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763559283.vadim-HP-Laptop-15s-eq1xxx.119293.0 +3 -0
- src/category_model/runs/Nov19_16-46-52_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763560013.vadim-HP-Laptop-15s-eq1xxx.120060.0 +3 -0
- src/ml_binary.joblib +3 -0
- src/ml_category.joblib +3 -0
- src/ml_categorys.joblib +3 -0
- src/multilabel_model/checkpoint-700/config.json +39 -0
- src/multilabel_model/checkpoint-700/model.safetensors +3 -0
- src/multilabel_model/checkpoint-700/optimizer.pt +3 -0
- src/multilabel_model/checkpoint-700/rng_state.pth +3 -0
- src/multilabel_model/checkpoint-700/scheduler.pt +3 -0
- src/multilabel_model/checkpoint-700/trainer_state.json +41 -0
- src/multilabel_model/checkpoint-700/training_args.bin +3 -0
- src/multilabel_model/runs/Nov19_18-32-50_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763566371.vadim-HP-Laptop-15s-eq1xxx.124852.0 +3 -0
- src/multilabel_model/runs/Nov19_18-43-43_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763567023.vadim-HP-Laptop-15s-eq1xxx.125134.0 +3 -0
- src/multilabel_model/runs/Nov19_18-50-20_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763567421.vadim-HP-Laptop-15s-eq1xxx.125341.0 +3 -0
- src/multilabel_model/runs/Nov19_18-56-03_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763567764.vadim-HP-Laptop-15s-eq1xxx.125471.0 +3 -0
- src/multilabel_model/runs/Nov19_19-12-54_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763568775.vadim-HP-Laptop-15s-eq1xxx.125830.0 +3 -0
- src/multilabel_model/runs/Nov19_19-19-23_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763569164.vadim-HP-Laptop-15s-eq1xxx.126011.0 +3 -0
- src/multilabel_model/runs/Nov19_19-20-09_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763569210.vadim-HP-Laptop-15s-eq1xxx.126088.0 +3 -0
- src/nn_binary.keras +3 -0
- src/nn_category.keras +3 -0
- src/nn_categorys.keras +3 -0
- src/nn_vectorizer_binary.keras +3 -0
- src/nn_vectorizer_category.keras +3 -0
- src/nn_vectorizer_categorys.keras +3 -0
- src/streamlit_app.py +553 -38
- src/use_ml.py +76 -0
- src/use_nn.py +56 -0
- src/use_transformer.py +86 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
src/nn_binary.keras filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
src/nn_category.keras filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
src/nn_categorys.keras filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
src/nn_vectorizer_binary.keras filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
src/nn_vectorizer_category.keras filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
src/nn_vectorizer_categorys.keras filter=lfs diff=lfs merge=lfs -text
|
src/binary_model/checkpoint-400/config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"emb_size": 312,
|
| 9 |
+
"gradient_checkpointing": false,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"hidden_size": 312,
|
| 13 |
+
"id2label": {
|
| 14 |
+
"0": "negative",
|
| 15 |
+
"1": "positive"
|
| 16 |
+
},
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 600,
|
| 19 |
+
"label2id": {
|
| 20 |
+
"negative": 0,
|
| 21 |
+
"positive": 1
|
| 22 |
+
},
|
| 23 |
+
"layer_norm_eps": 1e-12,
|
| 24 |
+
"max_position_embeddings": 512,
|
| 25 |
+
"model_type": "bert",
|
| 26 |
+
"num_attention_heads": 12,
|
| 27 |
+
"num_hidden_layers": 3,
|
| 28 |
+
"pad_token_id": 0,
|
| 29 |
+
"position_embedding_type": "absolute",
|
| 30 |
+
"problem_type": "single_label_classification",
|
| 31 |
+
"transformers_version": "4.57.0",
|
| 32 |
+
"type_vocab_size": 2,
|
| 33 |
+
"use_cache": true,
|
| 34 |
+
"vocab_size": 29564
|
| 35 |
+
}
|
src/binary_model/checkpoint-400/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:452079fe712b0a898545d8ec8d36ae4dbca4e4f34a67ee61bce99be89efb0276
|
| 3 |
+
size 47145624
|
src/binary_model/checkpoint-400/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98b2f3d51ee4fb8db44f65ab5aa0dfcb6a3a42faac9a58970d809a7fac691ff5
|
| 3 |
+
size 94323147
|
src/binary_model/checkpoint-400/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:987e2565b4d0e8df1d1d7fabe9aae58ea75b005c720faa6553599a28da5eb789
|
| 3 |
+
size 14455
|
src/binary_model/checkpoint-400/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2a13aedfb10730f658c90eb03b45908f32db49971da0e67d91a64b15a963525
|
| 3 |
+
size 1465
|
src/binary_model/checkpoint-400/trainer_state.json
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 100.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 400,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 1.0,
|
| 14 |
+
"grad_norm": 0.7171610593795776,
|
| 15 |
+
"learning_rate": 1.985e-05,
|
| 16 |
+
"loss": 0.6936,
|
| 17 |
+
"step": 4
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 2.0,
|
| 21 |
+
"grad_norm": 0.931102991104126,
|
| 22 |
+
"learning_rate": 1.9650000000000003e-05,
|
| 23 |
+
"loss": 0.6903,
|
| 24 |
+
"step": 8
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 3.0,
|
| 28 |
+
"grad_norm": 0.9753431677818298,
|
| 29 |
+
"learning_rate": 1.9450000000000002e-05,
|
| 30 |
+
"loss": 0.691,
|
| 31 |
+
"step": 12
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 4.0,
|
| 35 |
+
"grad_norm": 0.7282153964042664,
|
| 36 |
+
"learning_rate": 1.925e-05,
|
| 37 |
+
"loss": 0.6909,
|
| 38 |
+
"step": 16
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 5.0,
|
| 42 |
+
"grad_norm": 0.7919716238975525,
|
| 43 |
+
"learning_rate": 1.9050000000000002e-05,
|
| 44 |
+
"loss": 0.6853,
|
| 45 |
+
"step": 20
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 6.0,
|
| 49 |
+
"grad_norm": 2.4427859783172607,
|
| 50 |
+
"learning_rate": 1.885e-05,
|
| 51 |
+
"loss": 0.6789,
|
| 52 |
+
"step": 24
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 7.0,
|
| 56 |
+
"grad_norm": 0.7891402840614319,
|
| 57 |
+
"learning_rate": 1.8650000000000003e-05,
|
| 58 |
+
"loss": 0.6753,
|
| 59 |
+
"step": 28
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 8.0,
|
| 63 |
+
"grad_norm": 0.6697407960891724,
|
| 64 |
+
"learning_rate": 1.845e-05,
|
| 65 |
+
"loss": 0.6762,
|
| 66 |
+
"step": 32
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 9.0,
|
| 70 |
+
"grad_norm": 0.7295678853988647,
|
| 71 |
+
"learning_rate": 1.825e-05,
|
| 72 |
+
"loss": 0.6714,
|
| 73 |
+
"step": 36
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 10.0,
|
| 77 |
+
"grad_norm": 0.7619945406913757,
|
| 78 |
+
"learning_rate": 1.805e-05,
|
| 79 |
+
"loss": 0.6697,
|
| 80 |
+
"step": 40
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 11.0,
|
| 84 |
+
"grad_norm": 1.8652944564819336,
|
| 85 |
+
"learning_rate": 1.785e-05,
|
| 86 |
+
"loss": 0.6609,
|
| 87 |
+
"step": 44
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 12.0,
|
| 91 |
+
"grad_norm": 0.9776553511619568,
|
| 92 |
+
"learning_rate": 1.7650000000000002e-05,
|
| 93 |
+
"loss": 0.6554,
|
| 94 |
+
"step": 48
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 13.0,
|
| 98 |
+
"grad_norm": 0.8226175308227539,
|
| 99 |
+
"learning_rate": 1.7450000000000004e-05,
|
| 100 |
+
"loss": 0.6505,
|
| 101 |
+
"step": 52
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 14.0,
|
| 105 |
+
"grad_norm": 1.9432940483093262,
|
| 106 |
+
"learning_rate": 1.7250000000000003e-05,
|
| 107 |
+
"loss": 0.6451,
|
| 108 |
+
"step": 56
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 15.0,
|
| 112 |
+
"grad_norm": 1.1705070734024048,
|
| 113 |
+
"learning_rate": 1.705e-05,
|
| 114 |
+
"loss": 0.6293,
|
| 115 |
+
"step": 60
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 16.0,
|
| 119 |
+
"grad_norm": 1.1913769245147705,
|
| 120 |
+
"learning_rate": 1.6850000000000003e-05,
|
| 121 |
+
"loss": 0.6219,
|
| 122 |
+
"step": 64
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 17.0,
|
| 126 |
+
"grad_norm": 1.1586858034133911,
|
| 127 |
+
"learning_rate": 1.665e-05,
|
| 128 |
+
"loss": 0.6151,
|
| 129 |
+
"step": 68
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 18.0,
|
| 133 |
+
"grad_norm": 1.3686275482177734,
|
| 134 |
+
"learning_rate": 1.645e-05,
|
| 135 |
+
"loss": 0.6057,
|
| 136 |
+
"step": 72
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 19.0,
|
| 140 |
+
"grad_norm": 1.2270820140838623,
|
| 141 |
+
"learning_rate": 1.6250000000000002e-05,
|
| 142 |
+
"loss": 0.5921,
|
| 143 |
+
"step": 76
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 20.0,
|
| 147 |
+
"grad_norm": 2.155693531036377,
|
| 148 |
+
"learning_rate": 1.605e-05,
|
| 149 |
+
"loss": 0.5771,
|
| 150 |
+
"step": 80
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 21.0,
|
| 154 |
+
"grad_norm": 1.8586078882217407,
|
| 155 |
+
"learning_rate": 1.5850000000000002e-05,
|
| 156 |
+
"loss": 0.562,
|
| 157 |
+
"step": 84
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 22.0,
|
| 161 |
+
"grad_norm": 2.844381809234619,
|
| 162 |
+
"learning_rate": 1.565e-05,
|
| 163 |
+
"loss": 0.5483,
|
| 164 |
+
"step": 88
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 23.0,
|
| 168 |
+
"grad_norm": 1.532677412033081,
|
| 169 |
+
"learning_rate": 1.545e-05,
|
| 170 |
+
"loss": 0.5323,
|
| 171 |
+
"step": 92
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 24.0,
|
| 175 |
+
"grad_norm": 2.501610040664673,
|
| 176 |
+
"learning_rate": 1.525e-05,
|
| 177 |
+
"loss": 0.5106,
|
| 178 |
+
"step": 96
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 25.0,
|
| 182 |
+
"grad_norm": 3.366448402404785,
|
| 183 |
+
"learning_rate": 1.505e-05,
|
| 184 |
+
"loss": 0.5028,
|
| 185 |
+
"step": 100
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 26.0,
|
| 189 |
+
"grad_norm": 2.540175199508667,
|
| 190 |
+
"learning_rate": 1.4850000000000002e-05,
|
| 191 |
+
"loss": 0.4743,
|
| 192 |
+
"step": 104
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 27.0,
|
| 196 |
+
"grad_norm": 2.1043853759765625,
|
| 197 |
+
"learning_rate": 1.4650000000000002e-05,
|
| 198 |
+
"loss": 0.4676,
|
| 199 |
+
"step": 108
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 28.0,
|
| 203 |
+
"grad_norm": 2.4121694564819336,
|
| 204 |
+
"learning_rate": 1.4450000000000002e-05,
|
| 205 |
+
"loss": 0.4418,
|
| 206 |
+
"step": 112
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"epoch": 29.0,
|
| 210 |
+
"grad_norm": 1.871059775352478,
|
| 211 |
+
"learning_rate": 1.425e-05,
|
| 212 |
+
"loss": 0.4188,
|
| 213 |
+
"step": 116
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"epoch": 30.0,
|
| 217 |
+
"grad_norm": 3.22082257270813,
|
| 218 |
+
"learning_rate": 1.4050000000000001e-05,
|
| 219 |
+
"loss": 0.3973,
|
| 220 |
+
"step": 120
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"epoch": 31.0,
|
| 224 |
+
"grad_norm": 2.0184738636016846,
|
| 225 |
+
"learning_rate": 1.3850000000000001e-05,
|
| 226 |
+
"loss": 0.3767,
|
| 227 |
+
"step": 124
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"epoch": 32.0,
|
| 231 |
+
"grad_norm": 1.8004070520401,
|
| 232 |
+
"learning_rate": 1.3650000000000001e-05,
|
| 233 |
+
"loss": 0.3687,
|
| 234 |
+
"step": 128
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"epoch": 33.0,
|
| 238 |
+
"grad_norm": 2.161533832550049,
|
| 239 |
+
"learning_rate": 1.3450000000000002e-05,
|
| 240 |
+
"loss": 0.3419,
|
| 241 |
+
"step": 132
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"epoch": 34.0,
|
| 245 |
+
"grad_norm": 2.215999126434326,
|
| 246 |
+
"learning_rate": 1.325e-05,
|
| 247 |
+
"loss": 0.3259,
|
| 248 |
+
"step": 136
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"epoch": 35.0,
|
| 252 |
+
"grad_norm": 1.8289316892623901,
|
| 253 |
+
"learning_rate": 1.305e-05,
|
| 254 |
+
"loss": 0.2965,
|
| 255 |
+
"step": 140
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"epoch": 36.0,
|
| 259 |
+
"grad_norm": 1.7603213787078857,
|
| 260 |
+
"learning_rate": 1.285e-05,
|
| 261 |
+
"loss": 0.2784,
|
| 262 |
+
"step": 144
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"epoch": 37.0,
|
| 266 |
+
"grad_norm": 1.9211527109146118,
|
| 267 |
+
"learning_rate": 1.2650000000000001e-05,
|
| 268 |
+
"loss": 0.2624,
|
| 269 |
+
"step": 148
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"epoch": 38.0,
|
| 273 |
+
"grad_norm": 1.7408591508865356,
|
| 274 |
+
"learning_rate": 1.2450000000000003e-05,
|
| 275 |
+
"loss": 0.2301,
|
| 276 |
+
"step": 152
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"epoch": 39.0,
|
| 280 |
+
"grad_norm": 1.8422377109527588,
|
| 281 |
+
"learning_rate": 1.2250000000000001e-05,
|
| 282 |
+
"loss": 0.2316,
|
| 283 |
+
"step": 156
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"epoch": 40.0,
|
| 287 |
+
"grad_norm": 2.905261754989624,
|
| 288 |
+
"learning_rate": 1.2050000000000002e-05,
|
| 289 |
+
"loss": 0.2066,
|
| 290 |
+
"step": 160
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"epoch": 41.0,
|
| 294 |
+
"grad_norm": 1.5432759523391724,
|
| 295 |
+
"learning_rate": 1.1850000000000002e-05,
|
| 296 |
+
"loss": 0.2084,
|
| 297 |
+
"step": 164
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"epoch": 42.0,
|
| 301 |
+
"grad_norm": 1.6602318286895752,
|
| 302 |
+
"learning_rate": 1.1650000000000002e-05,
|
| 303 |
+
"loss": 0.1901,
|
| 304 |
+
"step": 168
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"epoch": 43.0,
|
| 308 |
+
"grad_norm": 1.7276387214660645,
|
| 309 |
+
"learning_rate": 1.145e-05,
|
| 310 |
+
"loss": 0.1635,
|
| 311 |
+
"step": 172
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"epoch": 44.0,
|
| 315 |
+
"grad_norm": 3.0626723766326904,
|
| 316 |
+
"learning_rate": 1.125e-05,
|
| 317 |
+
"loss": 0.1493,
|
| 318 |
+
"step": 176
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"epoch": 45.0,
|
| 322 |
+
"grad_norm": 1.6950130462646484,
|
| 323 |
+
"learning_rate": 1.1050000000000001e-05,
|
| 324 |
+
"loss": 0.128,
|
| 325 |
+
"step": 180
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"epoch": 46.0,
|
| 329 |
+
"grad_norm": 1.41054105758667,
|
| 330 |
+
"learning_rate": 1.0850000000000001e-05,
|
| 331 |
+
"loss": 0.1241,
|
| 332 |
+
"step": 184
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"epoch": 47.0,
|
| 336 |
+
"grad_norm": 1.694176435470581,
|
| 337 |
+
"learning_rate": 1.065e-05,
|
| 338 |
+
"loss": 0.126,
|
| 339 |
+
"step": 188
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"epoch": 48.0,
|
| 343 |
+
"grad_norm": 1.3726774454116821,
|
| 344 |
+
"learning_rate": 1.045e-05,
|
| 345 |
+
"loss": 0.1127,
|
| 346 |
+
"step": 192
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"epoch": 49.0,
|
| 350 |
+
"grad_norm": 2.0337917804718018,
|
| 351 |
+
"learning_rate": 1.025e-05,
|
| 352 |
+
"loss": 0.1056,
|
| 353 |
+
"step": 196
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"epoch": 50.0,
|
| 357 |
+
"grad_norm": 1.3560911417007446,
|
| 358 |
+
"learning_rate": 1.005e-05,
|
| 359 |
+
"loss": 0.0995,
|
| 360 |
+
"step": 200
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"epoch": 51.0,
|
| 364 |
+
"grad_norm": 1.0479848384857178,
|
| 365 |
+
"learning_rate": 9.85e-06,
|
| 366 |
+
"loss": 0.0848,
|
| 367 |
+
"step": 204
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"epoch": 52.0,
|
| 371 |
+
"grad_norm": 0.9078042507171631,
|
| 372 |
+
"learning_rate": 9.65e-06,
|
| 373 |
+
"loss": 0.0789,
|
| 374 |
+
"step": 208
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"epoch": 53.0,
|
| 378 |
+
"grad_norm": 1.6278938055038452,
|
| 379 |
+
"learning_rate": 9.450000000000001e-06,
|
| 380 |
+
"loss": 0.077,
|
| 381 |
+
"step": 212
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"epoch": 54.0,
|
| 385 |
+
"grad_norm": 1.9590917825698853,
|
| 386 |
+
"learning_rate": 9.250000000000001e-06,
|
| 387 |
+
"loss": 0.0807,
|
| 388 |
+
"step": 216
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"epoch": 55.0,
|
| 392 |
+
"grad_norm": 1.2972891330718994,
|
| 393 |
+
"learning_rate": 9.050000000000001e-06,
|
| 394 |
+
"loss": 0.0614,
|
| 395 |
+
"step": 220
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"epoch": 56.0,
|
| 399 |
+
"grad_norm": 0.8540873527526855,
|
| 400 |
+
"learning_rate": 8.85e-06,
|
| 401 |
+
"loss": 0.0606,
|
| 402 |
+
"step": 224
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"epoch": 57.0,
|
| 406 |
+
"grad_norm": 0.6654326319694519,
|
| 407 |
+
"learning_rate": 8.65e-06,
|
| 408 |
+
"loss": 0.0551,
|
| 409 |
+
"step": 228
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"epoch": 58.0,
|
| 413 |
+
"grad_norm": 0.9245683550834656,
|
| 414 |
+
"learning_rate": 8.45e-06,
|
| 415 |
+
"loss": 0.054,
|
| 416 |
+
"step": 232
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"epoch": 59.0,
|
| 420 |
+
"grad_norm": 0.5625425577163696,
|
| 421 |
+
"learning_rate": 8.25e-06,
|
| 422 |
+
"loss": 0.0496,
|
| 423 |
+
"step": 236
|
| 424 |
+
},
|
| 425 |
+
{
|
| 426 |
+
"epoch": 60.0,
|
| 427 |
+
"grad_norm": 0.664634644985199,
|
| 428 |
+
"learning_rate": 8.050000000000001e-06,
|
| 429 |
+
"loss": 0.0493,
|
| 430 |
+
"step": 240
|
| 431 |
+
},
|
| 432 |
+
{
|
| 433 |
+
"epoch": 61.0,
|
| 434 |
+
"grad_norm": 0.5101817846298218,
|
| 435 |
+
"learning_rate": 7.850000000000001e-06,
|
| 436 |
+
"loss": 0.0442,
|
| 437 |
+
"step": 244
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"epoch": 62.0,
|
| 441 |
+
"grad_norm": 0.5927309393882751,
|
| 442 |
+
"learning_rate": 7.650000000000001e-06,
|
| 443 |
+
"loss": 0.0423,
|
| 444 |
+
"step": 248
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"epoch": 63.0,
|
| 448 |
+
"grad_norm": 0.7394993305206299,
|
| 449 |
+
"learning_rate": 7.450000000000001e-06,
|
| 450 |
+
"loss": 0.0434,
|
| 451 |
+
"step": 252
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"epoch": 64.0,
|
| 455 |
+
"grad_norm": 0.653026819229126,
|
| 456 |
+
"learning_rate": 7.25e-06,
|
| 457 |
+
"loss": 0.0373,
|
| 458 |
+
"step": 256
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"epoch": 65.0,
|
| 462 |
+
"grad_norm": 0.4957493543624878,
|
| 463 |
+
"learning_rate": 7.05e-06,
|
| 464 |
+
"loss": 0.0345,
|
| 465 |
+
"step": 260
|
| 466 |
+
},
|
| 467 |
+
{
|
| 468 |
+
"epoch": 66.0,
|
| 469 |
+
"grad_norm": 0.6404949426651001,
|
| 470 |
+
"learning_rate": 6.850000000000001e-06,
|
| 471 |
+
"loss": 0.0347,
|
| 472 |
+
"step": 264
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"epoch": 67.0,
|
| 476 |
+
"grad_norm": 0.4832979440689087,
|
| 477 |
+
"learning_rate": 6.650000000000001e-06,
|
| 478 |
+
"loss": 0.0318,
|
| 479 |
+
"step": 268
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"epoch": 68.0,
|
| 483 |
+
"grad_norm": 0.5346927046775818,
|
| 484 |
+
"learning_rate": 6.450000000000001e-06,
|
| 485 |
+
"loss": 0.0334,
|
| 486 |
+
"step": 272
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"epoch": 69.0,
|
| 490 |
+
"grad_norm": 0.46299833059310913,
|
| 491 |
+
"learning_rate": 6.25e-06,
|
| 492 |
+
"loss": 0.0329,
|
| 493 |
+
"step": 276
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"epoch": 70.0,
|
| 497 |
+
"grad_norm": 0.39228323101997375,
|
| 498 |
+
"learning_rate": 6.0500000000000005e-06,
|
| 499 |
+
"loss": 0.0299,
|
| 500 |
+
"step": 280
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"epoch": 71.0,
|
| 504 |
+
"grad_norm": 0.4643970727920532,
|
| 505 |
+
"learning_rate": 5.85e-06,
|
| 506 |
+
"loss": 0.0297,
|
| 507 |
+
"step": 284
|
| 508 |
+
},
|
| 509 |
+
{
|
| 510 |
+
"epoch": 72.0,
|
| 511 |
+
"grad_norm": 0.4702988862991333,
|
| 512 |
+
"learning_rate": 5.65e-06,
|
| 513 |
+
"loss": 0.0292,
|
| 514 |
+
"step": 288
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"epoch": 73.0,
|
| 518 |
+
"grad_norm": 0.4042636752128601,
|
| 519 |
+
"learning_rate": 5.450000000000001e-06,
|
| 520 |
+
"loss": 0.0276,
|
| 521 |
+
"step": 292
|
| 522 |
+
},
|
| 523 |
+
{
|
| 524 |
+
"epoch": 74.0,
|
| 525 |
+
"grad_norm": 0.49854159355163574,
|
| 526 |
+
"learning_rate": 5.2500000000000006e-06,
|
| 527 |
+
"loss": 0.0285,
|
| 528 |
+
"step": 296
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"epoch": 75.0,
|
| 532 |
+
"grad_norm": 0.33747512102127075,
|
| 533 |
+
"learning_rate": 5.050000000000001e-06,
|
| 534 |
+
"loss": 0.0259,
|
| 535 |
+
"step": 300
|
| 536 |
+
},
|
| 537 |
+
{
|
| 538 |
+
"epoch": 76.0,
|
| 539 |
+
"grad_norm": 0.5222832560539246,
|
| 540 |
+
"learning_rate": 4.85e-06,
|
| 541 |
+
"loss": 0.027,
|
| 542 |
+
"step": 304
|
| 543 |
+
},
|
| 544 |
+
{
|
| 545 |
+
"epoch": 77.0,
|
| 546 |
+
"grad_norm": 0.3840760588645935,
|
| 547 |
+
"learning_rate": 4.65e-06,
|
| 548 |
+
"loss": 0.0257,
|
| 549 |
+
"step": 308
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"epoch": 78.0,
|
| 553 |
+
"grad_norm": 0.3676559627056122,
|
| 554 |
+
"learning_rate": 4.450000000000001e-06,
|
| 555 |
+
"loss": 0.0253,
|
| 556 |
+
"step": 312
|
| 557 |
+
},
|
| 558 |
+
{
|
| 559 |
+
"epoch": 79.0,
|
| 560 |
+
"grad_norm": 0.3206919729709625,
|
| 561 |
+
"learning_rate": 4.25e-06,
|
| 562 |
+
"loss": 0.0245,
|
| 563 |
+
"step": 316
|
| 564 |
+
},
|
| 565 |
+
{
|
| 566 |
+
"epoch": 80.0,
|
| 567 |
+
"grad_norm": 0.38936108350753784,
|
| 568 |
+
"learning_rate": 4.05e-06,
|
| 569 |
+
"loss": 0.0246,
|
| 570 |
+
"step": 320
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"epoch": 81.0,
|
| 574 |
+
"grad_norm": 1.3330600261688232,
|
| 575 |
+
"learning_rate": 3.85e-06,
|
| 576 |
+
"loss": 0.0245,
|
| 577 |
+
"step": 324
|
| 578 |
+
},
|
| 579 |
+
{
|
| 580 |
+
"epoch": 82.0,
|
| 581 |
+
"grad_norm": 0.3317999839782715,
|
| 582 |
+
"learning_rate": 3.65e-06,
|
| 583 |
+
"loss": 0.0225,
|
| 584 |
+
"step": 328
|
| 585 |
+
},
|
| 586 |
+
{
|
| 587 |
+
"epoch": 83.0,
|
| 588 |
+
"grad_norm": 0.35797789692878723,
|
| 589 |
+
"learning_rate": 3.45e-06,
|
| 590 |
+
"loss": 0.0237,
|
| 591 |
+
"step": 332
|
| 592 |
+
},
|
| 593 |
+
{
|
| 594 |
+
"epoch": 84.0,
|
| 595 |
+
"grad_norm": 0.3166642189025879,
|
| 596 |
+
"learning_rate": 3.2500000000000002e-06,
|
| 597 |
+
"loss": 0.0233,
|
| 598 |
+
"step": 336
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
"epoch": 85.0,
|
| 602 |
+
"grad_norm": 0.3116203248500824,
|
| 603 |
+
"learning_rate": 3.05e-06,
|
| 604 |
+
"loss": 0.0235,
|
| 605 |
+
"step": 340
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"epoch": 86.0,
|
| 609 |
+
"grad_norm": 0.3509286940097809,
|
| 610 |
+
"learning_rate": 2.85e-06,
|
| 611 |
+
"loss": 0.0221,
|
| 612 |
+
"step": 344
|
| 613 |
+
},
|
| 614 |
+
{
|
| 615 |
+
"epoch": 87.0,
|
| 616 |
+
"grad_norm": 0.33957698941230774,
|
| 617 |
+
"learning_rate": 2.6500000000000005e-06,
|
| 618 |
+
"loss": 0.0219,
|
| 619 |
+
"step": 348
|
| 620 |
+
},
|
| 621 |
+
{
|
| 622 |
+
"epoch": 88.0,
|
| 623 |
+
"grad_norm": 0.36599016189575195,
|
| 624 |
+
"learning_rate": 2.4500000000000003e-06,
|
| 625 |
+
"loss": 0.0219,
|
| 626 |
+
"step": 352
|
| 627 |
+
},
|
| 628 |
+
{
|
| 629 |
+
"epoch": 89.0,
|
| 630 |
+
"grad_norm": 0.30192670226097107,
|
| 631 |
+
"learning_rate": 2.25e-06,
|
| 632 |
+
"loss": 0.0215,
|
| 633 |
+
"step": 356
|
| 634 |
+
},
|
| 635 |
+
{
|
| 636 |
+
"epoch": 90.0,
|
| 637 |
+
"grad_norm": 0.4861908257007599,
|
| 638 |
+
"learning_rate": 2.05e-06,
|
| 639 |
+
"loss": 0.0216,
|
| 640 |
+
"step": 360
|
| 641 |
+
},
|
| 642 |
+
{
|
| 643 |
+
"epoch": 91.0,
|
| 644 |
+
"grad_norm": 0.43383175134658813,
|
| 645 |
+
"learning_rate": 1.85e-06,
|
| 646 |
+
"loss": 0.0211,
|
| 647 |
+
"step": 364
|
| 648 |
+
},
|
| 649 |
+
{
|
| 650 |
+
"epoch": 92.0,
|
| 651 |
+
"grad_norm": 0.32720497250556946,
|
| 652 |
+
"learning_rate": 1.6500000000000003e-06,
|
| 653 |
+
"loss": 0.0218,
|
| 654 |
+
"step": 368
|
| 655 |
+
},
|
| 656 |
+
{
|
| 657 |
+
"epoch": 93.0,
|
| 658 |
+
"grad_norm": 0.36105918884277344,
|
| 659 |
+
"learning_rate": 1.45e-06,
|
| 660 |
+
"loss": 0.0212,
|
| 661 |
+
"step": 372
|
| 662 |
+
},
|
| 663 |
+
{
|
| 664 |
+
"epoch": 94.0,
|
| 665 |
+
"grad_norm": 0.3829093277454376,
|
| 666 |
+
"learning_rate": 1.25e-06,
|
| 667 |
+
"loss": 0.0202,
|
| 668 |
+
"step": 376
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"epoch": 95.0,
|
| 672 |
+
"grad_norm": 0.3548564016819,
|
| 673 |
+
"learning_rate": 1.0500000000000001e-06,
|
| 674 |
+
"loss": 0.0216,
|
| 675 |
+
"step": 380
|
| 676 |
+
},
|
| 677 |
+
{
|
| 678 |
+
"epoch": 96.0,
|
| 679 |
+
"grad_norm": 0.52253657579422,
|
| 680 |
+
"learning_rate": 8.500000000000001e-07,
|
| 681 |
+
"loss": 0.0211,
|
| 682 |
+
"step": 384
|
| 683 |
+
},
|
| 684 |
+
{
|
| 685 |
+
"epoch": 97.0,
|
| 686 |
+
"grad_norm": 0.29113584756851196,
|
| 687 |
+
"learning_rate": 6.5e-07,
|
| 688 |
+
"loss": 0.0216,
|
| 689 |
+
"step": 388
|
| 690 |
+
},
|
| 691 |
+
{
|
| 692 |
+
"epoch": 98.0,
|
| 693 |
+
"grad_norm": 0.35965240001678467,
|
| 694 |
+
"learning_rate": 4.5000000000000003e-07,
|
| 695 |
+
"loss": 0.0209,
|
| 696 |
+
"step": 392
|
| 697 |
+
},
|
| 698 |
+
{
|
| 699 |
+
"epoch": 99.0,
|
| 700 |
+
"grad_norm": 0.2798146605491638,
|
| 701 |
+
"learning_rate": 2.5000000000000004e-07,
|
| 702 |
+
"loss": 0.0208,
|
| 703 |
+
"step": 396
|
| 704 |
+
},
|
| 705 |
+
{
|
| 706 |
+
"epoch": 100.0,
|
| 707 |
+
"grad_norm": 0.30020079016685486,
|
| 708 |
+
"learning_rate": 5.0000000000000004e-08,
|
| 709 |
+
"loss": 0.0208,
|
| 710 |
+
"step": 400
|
| 711 |
+
}
|
| 712 |
+
],
|
| 713 |
+
"logging_steps": 500,
|
| 714 |
+
"max_steps": 400,
|
| 715 |
+
"num_input_tokens_seen": 0,
|
| 716 |
+
"num_train_epochs": 100,
|
| 717 |
+
"save_steps": 500,
|
| 718 |
+
"stateful_callbacks": {
|
| 719 |
+
"TrainerControl": {
|
| 720 |
+
"args": {
|
| 721 |
+
"should_epoch_stop": false,
|
| 722 |
+
"should_evaluate": false,
|
| 723 |
+
"should_log": false,
|
| 724 |
+
"should_save": true,
|
| 725 |
+
"should_training_stop": true
|
| 726 |
+
},
|
| 727 |
+
"attributes": {}
|
| 728 |
+
}
|
| 729 |
+
},
|
| 730 |
+
"total_flos": 23228751974400.0,
|
| 731 |
+
"train_batch_size": 16,
|
| 732 |
+
"trial_name": null,
|
| 733 |
+
"trial_params": null
|
| 734 |
+
}
|
src/binary_model/checkpoint-400/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f896b886afa5b5445f6670dc5555187301baf60f22ffc418da5b22535c2d9b7
|
| 3 |
+
size 5841
|
src/binary_model/runs/Nov19_19-55-03_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763571303.vadim-HP-Laptop-15s-eq1xxx.128002.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:53658ae0abd1e72ffd984ba81c8e70ba0b924e0893117cf404b99e33198dea5b
|
| 3 |
+
size 26486
|
src/category_model/checkpoint-400/config.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"emb_size": 312,
|
| 9 |
+
"gradient_checkpointing": false,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"hidden_size": 312,
|
| 13 |
+
"id2label": {
|
| 14 |
+
"0": "\u043a\u0443\u043b\u044c\u0442\u0443\u0440\u0430",
|
| 15 |
+
"1": "\u043f\u043e\u043b\u0438\u0442\u0438\u043a\u0430",
|
| 16 |
+
"2": "\u0441\u043f\u043e\u0440\u0442",
|
| 17 |
+
"3": "\u044d\u043a\u043e\u043d\u043e\u043c\u0438\u043a\u0430"
|
| 18 |
+
},
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"intermediate_size": 600,
|
| 21 |
+
"label2id": {
|
| 22 |
+
"\u043a\u0443\u043b\u044c\u0442\u0443\u0440\u0430": 0,
|
| 23 |
+
"\u043f\u043e\u043b\u0438\u0442\u0438\u043a\u0430": 1,
|
| 24 |
+
"\u0441\u043f\u043e\u0440\u0442": 2,
|
| 25 |
+
"\u044d\u043a\u043e\u043d\u043e\u043c\u0438\u043a\u0430": 3
|
| 26 |
+
},
|
| 27 |
+
"layer_norm_eps": 1e-12,
|
| 28 |
+
"max_position_embeddings": 512,
|
| 29 |
+
"model_type": "bert",
|
| 30 |
+
"num_attention_heads": 12,
|
| 31 |
+
"num_hidden_layers": 3,
|
| 32 |
+
"pad_token_id": 0,
|
| 33 |
+
"position_embedding_type": "absolute",
|
| 34 |
+
"problem_type": "single_label_classification",
|
| 35 |
+
"transformers_version": "4.57.0",
|
| 36 |
+
"type_vocab_size": 2,
|
| 37 |
+
"use_cache": true,
|
| 38 |
+
"vocab_size": 29564
|
| 39 |
+
}
|
src/category_model/checkpoint-400/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d171659daf0b9242a1355e1f748db5f1fcb61a51537acafa86dd40ee74ed82f
|
| 3 |
+
size 47148128
|
src/category_model/checkpoint-400/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:87a1c7b428390f59facee41bc1170aabd0e0fcd483460b5361b4b834c02b1f88
|
| 3 |
+
size 94328139
|
src/category_model/checkpoint-400/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:80f4b2228e39d06e9189b284c4d5fbf902b5c3de450a889d8a8ff7c84c225c15
|
| 3 |
+
size 14455
|
src/category_model/checkpoint-400/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2a13aedfb10730f658c90eb03b45908f32db49971da0e67d91a64b15a963525
|
| 3 |
+
size 1465
|
src/category_model/checkpoint-400/trainer_state.json
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 100.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 400,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 1.0,
|
| 14 |
+
"grad_norm": 3.455427408218384,
|
| 15 |
+
"learning_rate": 1.985e-05,
|
| 16 |
+
"loss": 1.372,
|
| 17 |
+
"step": 4
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 2.0,
|
| 21 |
+
"grad_norm": 2.609254837036133,
|
| 22 |
+
"learning_rate": 1.9650000000000003e-05,
|
| 23 |
+
"loss": 1.3635,
|
| 24 |
+
"step": 8
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 3.0,
|
| 28 |
+
"grad_norm": 3.7442009449005127,
|
| 29 |
+
"learning_rate": 1.9450000000000002e-05,
|
| 30 |
+
"loss": 1.3225,
|
| 31 |
+
"step": 12
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 4.0,
|
| 35 |
+
"grad_norm": 3.725454568862915,
|
| 36 |
+
"learning_rate": 1.925e-05,
|
| 37 |
+
"loss": 1.3125,
|
| 38 |
+
"step": 16
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 5.0,
|
| 42 |
+
"grad_norm": 3.681874990463257,
|
| 43 |
+
"learning_rate": 1.9050000000000002e-05,
|
| 44 |
+
"loss": 1.2952,
|
| 45 |
+
"step": 20
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 6.0,
|
| 49 |
+
"grad_norm": 4.666588306427002,
|
| 50 |
+
"learning_rate": 1.885e-05,
|
| 51 |
+
"loss": 1.3249,
|
| 52 |
+
"step": 24
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 7.0,
|
| 56 |
+
"grad_norm": 2.378511428833008,
|
| 57 |
+
"learning_rate": 1.8650000000000003e-05,
|
| 58 |
+
"loss": 1.28,
|
| 59 |
+
"step": 28
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 8.0,
|
| 63 |
+
"grad_norm": 4.4941725730896,
|
| 64 |
+
"learning_rate": 1.845e-05,
|
| 65 |
+
"loss": 1.2995,
|
| 66 |
+
"step": 32
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 9.0,
|
| 70 |
+
"grad_norm": 3.8662474155426025,
|
| 71 |
+
"learning_rate": 1.825e-05,
|
| 72 |
+
"loss": 1.2559,
|
| 73 |
+
"step": 36
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 10.0,
|
| 77 |
+
"grad_norm": 3.9468913078308105,
|
| 78 |
+
"learning_rate": 1.805e-05,
|
| 79 |
+
"loss": 1.2227,
|
| 80 |
+
"step": 40
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 11.0,
|
| 84 |
+
"grad_norm": 3.712695360183716,
|
| 85 |
+
"learning_rate": 1.785e-05,
|
| 86 |
+
"loss": 1.2797,
|
| 87 |
+
"step": 44
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 12.0,
|
| 91 |
+
"grad_norm": 2.5063178539276123,
|
| 92 |
+
"learning_rate": 1.7650000000000002e-05,
|
| 93 |
+
"loss": 1.232,
|
| 94 |
+
"step": 48
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 13.0,
|
| 98 |
+
"grad_norm": 5.289675712585449,
|
| 99 |
+
"learning_rate": 1.7450000000000004e-05,
|
| 100 |
+
"loss": 1.2657,
|
| 101 |
+
"step": 52
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 14.0,
|
| 105 |
+
"grad_norm": 3.8385415077209473,
|
| 106 |
+
"learning_rate": 1.7250000000000003e-05,
|
| 107 |
+
"loss": 1.179,
|
| 108 |
+
"step": 56
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 15.0,
|
| 112 |
+
"grad_norm": 4.603259086608887,
|
| 113 |
+
"learning_rate": 1.705e-05,
|
| 114 |
+
"loss": 1.2584,
|
| 115 |
+
"step": 60
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 16.0,
|
| 119 |
+
"grad_norm": 3.869927406311035,
|
| 120 |
+
"learning_rate": 1.6850000000000003e-05,
|
| 121 |
+
"loss": 1.1632,
|
| 122 |
+
"step": 64
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 17.0,
|
| 126 |
+
"grad_norm": 3.1474037170410156,
|
| 127 |
+
"learning_rate": 1.665e-05,
|
| 128 |
+
"loss": 1.1961,
|
| 129 |
+
"step": 68
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 18.0,
|
| 133 |
+
"grad_norm": 5.3139424324035645,
|
| 134 |
+
"learning_rate": 1.645e-05,
|
| 135 |
+
"loss": 1.1511,
|
| 136 |
+
"step": 72
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 19.0,
|
| 140 |
+
"grad_norm": 7.552919387817383,
|
| 141 |
+
"learning_rate": 1.6250000000000002e-05,
|
| 142 |
+
"loss": 1.1045,
|
| 143 |
+
"step": 76
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 20.0,
|
| 147 |
+
"grad_norm": 3.9046547412872314,
|
| 148 |
+
"learning_rate": 1.605e-05,
|
| 149 |
+
"loss": 1.1123,
|
| 150 |
+
"step": 80
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 21.0,
|
| 154 |
+
"grad_norm": 4.048532962799072,
|
| 155 |
+
"learning_rate": 1.5850000000000002e-05,
|
| 156 |
+
"loss": 1.0915,
|
| 157 |
+
"step": 84
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 22.0,
|
| 161 |
+
"grad_norm": 4.49423885345459,
|
| 162 |
+
"learning_rate": 1.565e-05,
|
| 163 |
+
"loss": 1.0988,
|
| 164 |
+
"step": 88
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 23.0,
|
| 168 |
+
"grad_norm": 7.544517517089844,
|
| 169 |
+
"learning_rate": 1.545e-05,
|
| 170 |
+
"loss": 1.1344,
|
| 171 |
+
"step": 92
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 24.0,
|
| 175 |
+
"grad_norm": 5.467407703399658,
|
| 176 |
+
"learning_rate": 1.525e-05,
|
| 177 |
+
"loss": 1.1133,
|
| 178 |
+
"step": 96
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 25.0,
|
| 182 |
+
"grad_norm": 3.8020665645599365,
|
| 183 |
+
"learning_rate": 1.505e-05,
|
| 184 |
+
"loss": 1.0526,
|
| 185 |
+
"step": 100
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 26.0,
|
| 189 |
+
"grad_norm": 4.590837478637695,
|
| 190 |
+
"learning_rate": 1.4850000000000002e-05,
|
| 191 |
+
"loss": 1.0252,
|
| 192 |
+
"step": 104
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 27.0,
|
| 196 |
+
"grad_norm": 5.339500427246094,
|
| 197 |
+
"learning_rate": 1.4650000000000002e-05,
|
| 198 |
+
"loss": 0.9814,
|
| 199 |
+
"step": 108
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 28.0,
|
| 203 |
+
"grad_norm": 3.674020528793335,
|
| 204 |
+
"learning_rate": 1.4450000000000002e-05,
|
| 205 |
+
"loss": 0.9983,
|
| 206 |
+
"step": 112
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"epoch": 29.0,
|
| 210 |
+
"grad_norm": 7.431573867797852,
|
| 211 |
+
"learning_rate": 1.425e-05,
|
| 212 |
+
"loss": 0.9252,
|
| 213 |
+
"step": 116
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"epoch": 30.0,
|
| 217 |
+
"grad_norm": 6.031501770019531,
|
| 218 |
+
"learning_rate": 1.4050000000000001e-05,
|
| 219 |
+
"loss": 1.0251,
|
| 220 |
+
"step": 120
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"epoch": 31.0,
|
| 224 |
+
"grad_norm": 6.9349799156188965,
|
| 225 |
+
"learning_rate": 1.3850000000000001e-05,
|
| 226 |
+
"loss": 0.8766,
|
| 227 |
+
"step": 124
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"epoch": 32.0,
|
| 231 |
+
"grad_norm": 3.9678494930267334,
|
| 232 |
+
"learning_rate": 1.3650000000000001e-05,
|
| 233 |
+
"loss": 0.9679,
|
| 234 |
+
"step": 128
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"epoch": 33.0,
|
| 238 |
+
"grad_norm": 7.587587356567383,
|
| 239 |
+
"learning_rate": 1.3450000000000002e-05,
|
| 240 |
+
"loss": 0.9791,
|
| 241 |
+
"step": 132
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"epoch": 34.0,
|
| 245 |
+
"grad_norm": 6.44096040725708,
|
| 246 |
+
"learning_rate": 1.325e-05,
|
| 247 |
+
"loss": 0.8987,
|
| 248 |
+
"step": 136
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"epoch": 35.0,
|
| 252 |
+
"grad_norm": 5.321375846862793,
|
| 253 |
+
"learning_rate": 1.305e-05,
|
| 254 |
+
"loss": 0.8359,
|
| 255 |
+
"step": 140
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"epoch": 36.0,
|
| 259 |
+
"grad_norm": 4.376260757446289,
|
| 260 |
+
"learning_rate": 1.285e-05,
|
| 261 |
+
"loss": 0.8474,
|
| 262 |
+
"step": 144
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"epoch": 37.0,
|
| 266 |
+
"grad_norm": 5.06814432144165,
|
| 267 |
+
"learning_rate": 1.2650000000000001e-05,
|
| 268 |
+
"loss": 0.8176,
|
| 269 |
+
"step": 148
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"epoch": 38.0,
|
| 273 |
+
"grad_norm": 4.7853899002075195,
|
| 274 |
+
"learning_rate": 1.2450000000000003e-05,
|
| 275 |
+
"loss": 0.8357,
|
| 276 |
+
"step": 152
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"epoch": 39.0,
|
| 280 |
+
"grad_norm": 3.8893511295318604,
|
| 281 |
+
"learning_rate": 1.2250000000000001e-05,
|
| 282 |
+
"loss": 0.7759,
|
| 283 |
+
"step": 156
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"epoch": 40.0,
|
| 287 |
+
"grad_norm": 4.117180824279785,
|
| 288 |
+
"learning_rate": 1.2050000000000002e-05,
|
| 289 |
+
"loss": 0.776,
|
| 290 |
+
"step": 160
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"epoch": 41.0,
|
| 294 |
+
"grad_norm": 3.4978015422821045,
|
| 295 |
+
"learning_rate": 1.1850000000000002e-05,
|
| 296 |
+
"loss": 0.8084,
|
| 297 |
+
"step": 164
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"epoch": 42.0,
|
| 301 |
+
"grad_norm": 3.2947819232940674,
|
| 302 |
+
"learning_rate": 1.1650000000000002e-05,
|
| 303 |
+
"loss": 0.7224,
|
| 304 |
+
"step": 168
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"epoch": 43.0,
|
| 308 |
+
"grad_norm": 7.838773250579834,
|
| 309 |
+
"learning_rate": 1.145e-05,
|
| 310 |
+
"loss": 0.7542,
|
| 311 |
+
"step": 172
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"epoch": 44.0,
|
| 315 |
+
"grad_norm": 5.788562297821045,
|
| 316 |
+
"learning_rate": 1.125e-05,
|
| 317 |
+
"loss": 0.733,
|
| 318 |
+
"step": 176
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"epoch": 45.0,
|
| 322 |
+
"grad_norm": 6.079436779022217,
|
| 323 |
+
"learning_rate": 1.1050000000000001e-05,
|
| 324 |
+
"loss": 0.7207,
|
| 325 |
+
"step": 180
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"epoch": 46.0,
|
| 329 |
+
"grad_norm": 5.334987640380859,
|
| 330 |
+
"learning_rate": 1.0850000000000001e-05,
|
| 331 |
+
"loss": 0.6308,
|
| 332 |
+
"step": 184
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"epoch": 47.0,
|
| 336 |
+
"grad_norm": 3.619874954223633,
|
| 337 |
+
"learning_rate": 1.065e-05,
|
| 338 |
+
"loss": 0.6654,
|
| 339 |
+
"step": 188
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"epoch": 48.0,
|
| 343 |
+
"grad_norm": 4.317775726318359,
|
| 344 |
+
"learning_rate": 1.045e-05,
|
| 345 |
+
"loss": 0.5905,
|
| 346 |
+
"step": 192
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"epoch": 49.0,
|
| 350 |
+
"grad_norm": 4.5337748527526855,
|
| 351 |
+
"learning_rate": 1.025e-05,
|
| 352 |
+
"loss": 0.6396,
|
| 353 |
+
"step": 196
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"epoch": 50.0,
|
| 357 |
+
"grad_norm": 3.27056884765625,
|
| 358 |
+
"learning_rate": 1.005e-05,
|
| 359 |
+
"loss": 0.5575,
|
| 360 |
+
"step": 200
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"epoch": 51.0,
|
| 364 |
+
"grad_norm": 5.867713928222656,
|
| 365 |
+
"learning_rate": 9.85e-06,
|
| 366 |
+
"loss": 0.5961,
|
| 367 |
+
"step": 204
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"epoch": 52.0,
|
| 371 |
+
"grad_norm": 4.240111351013184,
|
| 372 |
+
"learning_rate": 9.65e-06,
|
| 373 |
+
"loss": 0.5803,
|
| 374 |
+
"step": 208
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"epoch": 53.0,
|
| 378 |
+
"grad_norm": 8.360318183898926,
|
| 379 |
+
"learning_rate": 9.450000000000001e-06,
|
| 380 |
+
"loss": 0.6097,
|
| 381 |
+
"step": 212
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"epoch": 54.0,
|
| 385 |
+
"grad_norm": 5.395203590393066,
|
| 386 |
+
"learning_rate": 9.250000000000001e-06,
|
| 387 |
+
"loss": 0.5235,
|
| 388 |
+
"step": 216
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"epoch": 55.0,
|
| 392 |
+
"grad_norm": 8.306116104125977,
|
| 393 |
+
"learning_rate": 9.050000000000001e-06,
|
| 394 |
+
"loss": 0.6212,
|
| 395 |
+
"step": 220
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"epoch": 56.0,
|
| 399 |
+
"grad_norm": 4.548465728759766,
|
| 400 |
+
"learning_rate": 8.85e-06,
|
| 401 |
+
"loss": 0.4939,
|
| 402 |
+
"step": 224
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"epoch": 57.0,
|
| 406 |
+
"grad_norm": 5.0567755699157715,
|
| 407 |
+
"learning_rate": 8.65e-06,
|
| 408 |
+
"loss": 0.492,
|
| 409 |
+
"step": 228
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"epoch": 58.0,
|
| 413 |
+
"grad_norm": 3.3125669956207275,
|
| 414 |
+
"learning_rate": 8.45e-06,
|
| 415 |
+
"loss": 0.4867,
|
| 416 |
+
"step": 232
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"epoch": 59.0,
|
| 420 |
+
"grad_norm": 9.607614517211914,
|
| 421 |
+
"learning_rate": 8.25e-06,
|
| 422 |
+
"loss": 0.4979,
|
| 423 |
+
"step": 236
|
| 424 |
+
},
|
| 425 |
+
{
|
| 426 |
+
"epoch": 60.0,
|
| 427 |
+
"grad_norm": 4.669170379638672,
|
| 428 |
+
"learning_rate": 8.050000000000001e-06,
|
| 429 |
+
"loss": 0.5232,
|
| 430 |
+
"step": 240
|
| 431 |
+
},
|
| 432 |
+
{
|
| 433 |
+
"epoch": 61.0,
|
| 434 |
+
"grad_norm": 3.661278247833252,
|
| 435 |
+
"learning_rate": 7.850000000000001e-06,
|
| 436 |
+
"loss": 0.4184,
|
| 437 |
+
"step": 244
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"epoch": 62.0,
|
| 441 |
+
"grad_norm": 6.294672012329102,
|
| 442 |
+
"learning_rate": 7.650000000000001e-06,
|
| 443 |
+
"loss": 0.4472,
|
| 444 |
+
"step": 248
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"epoch": 63.0,
|
| 448 |
+
"grad_norm": 3.544436454772949,
|
| 449 |
+
"learning_rate": 7.450000000000001e-06,
|
| 450 |
+
"loss": 0.449,
|
| 451 |
+
"step": 252
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"epoch": 64.0,
|
| 455 |
+
"grad_norm": 5.267669200897217,
|
| 456 |
+
"learning_rate": 7.25e-06,
|
| 457 |
+
"loss": 0.4789,
|
| 458 |
+
"step": 256
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"epoch": 65.0,
|
| 462 |
+
"grad_norm": 7.3072333335876465,
|
| 463 |
+
"learning_rate": 7.05e-06,
|
| 464 |
+
"loss": 0.4661,
|
| 465 |
+
"step": 260
|
| 466 |
+
},
|
| 467 |
+
{
|
| 468 |
+
"epoch": 66.0,
|
| 469 |
+
"grad_norm": 2.6512272357940674,
|
| 470 |
+
"learning_rate": 6.850000000000001e-06,
|
| 471 |
+
"loss": 0.3777,
|
| 472 |
+
"step": 264
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"epoch": 67.0,
|
| 476 |
+
"grad_norm": 4.13808536529541,
|
| 477 |
+
"learning_rate": 6.650000000000001e-06,
|
| 478 |
+
"loss": 0.4238,
|
| 479 |
+
"step": 268
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"epoch": 68.0,
|
| 483 |
+
"grad_norm": 3.1775310039520264,
|
| 484 |
+
"learning_rate": 6.450000000000001e-06,
|
| 485 |
+
"loss": 0.3932,
|
| 486 |
+
"step": 272
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"epoch": 69.0,
|
| 490 |
+
"grad_norm": 3.4776253700256348,
|
| 491 |
+
"learning_rate": 6.25e-06,
|
| 492 |
+
"loss": 0.3601,
|
| 493 |
+
"step": 276
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"epoch": 70.0,
|
| 497 |
+
"grad_norm": 4.582927227020264,
|
| 498 |
+
"learning_rate": 6.0500000000000005e-06,
|
| 499 |
+
"loss": 0.4413,
|
| 500 |
+
"step": 280
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"epoch": 71.0,
|
| 504 |
+
"grad_norm": 2.587031364440918,
|
| 505 |
+
"learning_rate": 5.85e-06,
|
| 506 |
+
"loss": 0.3916,
|
| 507 |
+
"step": 284
|
| 508 |
+
},
|
| 509 |
+
{
|
| 510 |
+
"epoch": 72.0,
|
| 511 |
+
"grad_norm": 3.7085821628570557,
|
| 512 |
+
"learning_rate": 5.65e-06,
|
| 513 |
+
"loss": 0.4055,
|
| 514 |
+
"step": 288
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"epoch": 73.0,
|
| 518 |
+
"grad_norm": 5.436678886413574,
|
| 519 |
+
"learning_rate": 5.450000000000001e-06,
|
| 520 |
+
"loss": 0.3487,
|
| 521 |
+
"step": 292
|
| 522 |
+
},
|
| 523 |
+
{
|
| 524 |
+
"epoch": 74.0,
|
| 525 |
+
"grad_norm": 5.039726734161377,
|
| 526 |
+
"learning_rate": 5.2500000000000006e-06,
|
| 527 |
+
"loss": 0.3582,
|
| 528 |
+
"step": 296
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"epoch": 75.0,
|
| 532 |
+
"grad_norm": 4.922318935394287,
|
| 533 |
+
"learning_rate": 5.050000000000001e-06,
|
| 534 |
+
"loss": 0.3563,
|
| 535 |
+
"step": 300
|
| 536 |
+
},
|
| 537 |
+
{
|
| 538 |
+
"epoch": 76.0,
|
| 539 |
+
"grad_norm": 4.511425971984863,
|
| 540 |
+
"learning_rate": 4.85e-06,
|
| 541 |
+
"loss": 0.3747,
|
| 542 |
+
"step": 304
|
| 543 |
+
},
|
| 544 |
+
{
|
| 545 |
+
"epoch": 77.0,
|
| 546 |
+
"grad_norm": 2.0960898399353027,
|
| 547 |
+
"learning_rate": 4.65e-06,
|
| 548 |
+
"loss": 0.3277,
|
| 549 |
+
"step": 308
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"epoch": 78.0,
|
| 553 |
+
"grad_norm": 1.7806938886642456,
|
| 554 |
+
"learning_rate": 4.450000000000001e-06,
|
| 555 |
+
"loss": 0.317,
|
| 556 |
+
"step": 312
|
| 557 |
+
},
|
| 558 |
+
{
|
| 559 |
+
"epoch": 79.0,
|
| 560 |
+
"grad_norm": 3.6240742206573486,
|
| 561 |
+
"learning_rate": 4.25e-06,
|
| 562 |
+
"loss": 0.3506,
|
| 563 |
+
"step": 316
|
| 564 |
+
},
|
| 565 |
+
{
|
| 566 |
+
"epoch": 80.0,
|
| 567 |
+
"grad_norm": 3.0891218185424805,
|
| 568 |
+
"learning_rate": 4.05e-06,
|
| 569 |
+
"loss": 0.3342,
|
| 570 |
+
"step": 320
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"epoch": 81.0,
|
| 574 |
+
"grad_norm": 3.1899912357330322,
|
| 575 |
+
"learning_rate": 3.85e-06,
|
| 576 |
+
"loss": 0.3692,
|
| 577 |
+
"step": 324
|
| 578 |
+
},
|
| 579 |
+
{
|
| 580 |
+
"epoch": 82.0,
|
| 581 |
+
"grad_norm": 1.9796233177185059,
|
| 582 |
+
"learning_rate": 3.65e-06,
|
| 583 |
+
"loss": 0.3087,
|
| 584 |
+
"step": 328
|
| 585 |
+
},
|
| 586 |
+
{
|
| 587 |
+
"epoch": 83.0,
|
| 588 |
+
"grad_norm": 4.603359222412109,
|
| 589 |
+
"learning_rate": 3.45e-06,
|
| 590 |
+
"loss": 0.3533,
|
| 591 |
+
"step": 332
|
| 592 |
+
},
|
| 593 |
+
{
|
| 594 |
+
"epoch": 84.0,
|
| 595 |
+
"grad_norm": 5.730408668518066,
|
| 596 |
+
"learning_rate": 3.2500000000000002e-06,
|
| 597 |
+
"loss": 0.3498,
|
| 598 |
+
"step": 336
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
"epoch": 85.0,
|
| 602 |
+
"grad_norm": 6.595205783843994,
|
| 603 |
+
"learning_rate": 3.05e-06,
|
| 604 |
+
"loss": 0.3618,
|
| 605 |
+
"step": 340
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"epoch": 86.0,
|
| 609 |
+
"grad_norm": 11.516875267028809,
|
| 610 |
+
"learning_rate": 2.85e-06,
|
| 611 |
+
"loss": 0.3932,
|
| 612 |
+
"step": 344
|
| 613 |
+
},
|
| 614 |
+
{
|
| 615 |
+
"epoch": 87.0,
|
| 616 |
+
"grad_norm": 3.7310776710510254,
|
| 617 |
+
"learning_rate": 2.6500000000000005e-06,
|
| 618 |
+
"loss": 0.329,
|
| 619 |
+
"step": 348
|
| 620 |
+
},
|
| 621 |
+
{
|
| 622 |
+
"epoch": 88.0,
|
| 623 |
+
"grad_norm": 2.2054193019866943,
|
| 624 |
+
"learning_rate": 2.4500000000000003e-06,
|
| 625 |
+
"loss": 0.3097,
|
| 626 |
+
"step": 352
|
| 627 |
+
},
|
| 628 |
+
{
|
| 629 |
+
"epoch": 89.0,
|
| 630 |
+
"grad_norm": 2.450695037841797,
|
| 631 |
+
"learning_rate": 2.25e-06,
|
| 632 |
+
"loss": 0.3052,
|
| 633 |
+
"step": 356
|
| 634 |
+
},
|
| 635 |
+
{
|
| 636 |
+
"epoch": 90.0,
|
| 637 |
+
"grad_norm": 2.2963459491729736,
|
| 638 |
+
"learning_rate": 2.05e-06,
|
| 639 |
+
"loss": 0.3126,
|
| 640 |
+
"step": 360
|
| 641 |
+
},
|
| 642 |
+
{
|
| 643 |
+
"epoch": 91.0,
|
| 644 |
+
"grad_norm": 3.7548775672912598,
|
| 645 |
+
"learning_rate": 1.85e-06,
|
| 646 |
+
"loss": 0.3343,
|
| 647 |
+
"step": 364
|
| 648 |
+
},
|
| 649 |
+
{
|
| 650 |
+
"epoch": 92.0,
|
| 651 |
+
"grad_norm": 1.9919285774230957,
|
| 652 |
+
"learning_rate": 1.6500000000000003e-06,
|
| 653 |
+
"loss": 0.2815,
|
| 654 |
+
"step": 368
|
| 655 |
+
},
|
| 656 |
+
{
|
| 657 |
+
"epoch": 93.0,
|
| 658 |
+
"grad_norm": 3.4772584438323975,
|
| 659 |
+
"learning_rate": 1.45e-06,
|
| 660 |
+
"loss": 0.3316,
|
| 661 |
+
"step": 372
|
| 662 |
+
},
|
| 663 |
+
{
|
| 664 |
+
"epoch": 94.0,
|
| 665 |
+
"grad_norm": 2.701188564300537,
|
| 666 |
+
"learning_rate": 1.25e-06,
|
| 667 |
+
"loss": 0.3175,
|
| 668 |
+
"step": 376
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"epoch": 95.0,
|
| 672 |
+
"grad_norm": 2.582921266555786,
|
| 673 |
+
"learning_rate": 1.0500000000000001e-06,
|
| 674 |
+
"loss": 0.2869,
|
| 675 |
+
"step": 380
|
| 676 |
+
},
|
| 677 |
+
{
|
| 678 |
+
"epoch": 96.0,
|
| 679 |
+
"grad_norm": 3.1191177368164062,
|
| 680 |
+
"learning_rate": 8.500000000000001e-07,
|
| 681 |
+
"loss": 0.3155,
|
| 682 |
+
"step": 384
|
| 683 |
+
},
|
| 684 |
+
{
|
| 685 |
+
"epoch": 97.0,
|
| 686 |
+
"grad_norm": 4.482059478759766,
|
| 687 |
+
"learning_rate": 6.5e-07,
|
| 688 |
+
"loss": 0.3525,
|
| 689 |
+
"step": 388
|
| 690 |
+
},
|
| 691 |
+
{
|
| 692 |
+
"epoch": 98.0,
|
| 693 |
+
"grad_norm": 3.904967784881592,
|
| 694 |
+
"learning_rate": 4.5000000000000003e-07,
|
| 695 |
+
"loss": 0.3339,
|
| 696 |
+
"step": 392
|
| 697 |
+
},
|
| 698 |
+
{
|
| 699 |
+
"epoch": 99.0,
|
| 700 |
+
"grad_norm": 5.207050323486328,
|
| 701 |
+
"learning_rate": 2.5000000000000004e-07,
|
| 702 |
+
"loss": 0.3153,
|
| 703 |
+
"step": 396
|
| 704 |
+
},
|
| 705 |
+
{
|
| 706 |
+
"epoch": 100.0,
|
| 707 |
+
"grad_norm": 4.7919745445251465,
|
| 708 |
+
"learning_rate": 5.0000000000000004e-08,
|
| 709 |
+
"loss": 0.3194,
|
| 710 |
+
"step": 400
|
| 711 |
+
}
|
| 712 |
+
],
|
| 713 |
+
"logging_steps": 500,
|
| 714 |
+
"max_steps": 400,
|
| 715 |
+
"num_input_tokens_seen": 0,
|
| 716 |
+
"num_train_epochs": 100,
|
| 717 |
+
"save_steps": 500,
|
| 718 |
+
"stateful_callbacks": {
|
| 719 |
+
"TrainerControl": {
|
| 720 |
+
"args": {
|
| 721 |
+
"should_epoch_stop": false,
|
| 722 |
+
"should_evaluate": false,
|
| 723 |
+
"should_log": false,
|
| 724 |
+
"should_save": true,
|
| 725 |
+
"should_training_stop": true
|
| 726 |
+
},
|
| 727 |
+
"attributes": {}
|
| 728 |
+
}
|
| 729 |
+
},
|
| 730 |
+
"total_flos": 18809131622400.0,
|
| 731 |
+
"train_batch_size": 16,
|
| 732 |
+
"trial_name": null,
|
| 733 |
+
"trial_params": null
|
| 734 |
+
}
|
src/category_model/checkpoint-400/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1c578e2b795ff2107b64fd5f8c207b5966e7b63550fcf63db2897f2e5d55fcc
|
| 3 |
+
size 5841
|
src/category_model/runs/Nov19_16-34-43_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763559283.vadim-HP-Laptop-15s-eq1xxx.119293.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5883fd5d4a434aa097bd4f10219658cde4fb327f41063da35f097c27473aede7
|
| 3 |
+
size 12093
|
src/category_model/runs/Nov19_16-46-52_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763560013.vadim-HP-Laptop-15s-eq1xxx.120060.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b68fcd71c4edc7c64c33b783e6abaf286da860db632393842f9229d750f5079
|
| 3 |
+
size 26866
|
src/ml_binary.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbda2dff7c7d0f5b0e98a62c7ba63df2000e571d84347a232bbc98d2233c6fa4
|
| 3 |
+
size 2651188
|
src/ml_category.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5dda5f9be7a82fcce0467fb6cd222ddf0dc69edf4bfaa25910ec7285463fb3a
|
| 3 |
+
size 3561992
|
src/ml_categorys.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7cfee13464c870a68c591c2b5005b24530d4fb5ee538671e91d05ab73d5f61d1
|
| 3 |
+
size 3954215
|
src/multilabel_model/checkpoint-700/config.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"emb_size": 312,
|
| 9 |
+
"gradient_checkpointing": false,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"hidden_size": 312,
|
| 13 |
+
"id2label": {
|
| 14 |
+
"0": "LABEL_0",
|
| 15 |
+
"1": "LABEL_1",
|
| 16 |
+
"2": "LABEL_2",
|
| 17 |
+
"3": "LABEL_3"
|
| 18 |
+
},
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"intermediate_size": 600,
|
| 21 |
+
"label2id": {
|
| 22 |
+
"LABEL_0": 0,
|
| 23 |
+
"LABEL_1": 1,
|
| 24 |
+
"LABEL_2": 2,
|
| 25 |
+
"LABEL_3": 3
|
| 26 |
+
},
|
| 27 |
+
"layer_norm_eps": 1e-12,
|
| 28 |
+
"max_position_embeddings": 512,
|
| 29 |
+
"model_type": "bert",
|
| 30 |
+
"num_attention_heads": 12,
|
| 31 |
+
"num_hidden_layers": 3,
|
| 32 |
+
"pad_token_id": 0,
|
| 33 |
+
"position_embedding_type": "absolute",
|
| 34 |
+
"problem_type": "multi_label_classification",
|
| 35 |
+
"transformers_version": "4.57.0",
|
| 36 |
+
"type_vocab_size": 2,
|
| 37 |
+
"use_cache": true,
|
| 38 |
+
"vocab_size": 29564
|
| 39 |
+
}
|
src/multilabel_model/checkpoint-700/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7cc2d851cd93e4c007e0b237673fb86cd555754dca253a379ec239b7acaea1df
|
| 3 |
+
size 47148128
|
src/multilabel_model/checkpoint-700/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e70436bc9452a2ffdf997eb531c55061f9aab4883a53385ee85478431dfeacb
|
| 3 |
+
size 94328139
|
src/multilabel_model/checkpoint-700/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22564cefda826550225bd0fd67a748fd7cb1eeafc2de68c4b5a5600037865e5c
|
| 3 |
+
size 14455
|
src/multilabel_model/checkpoint-700/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a7430e3783a2d51c132011d61c22c31c8fec694606fd4e0d6f02426d5c36b37
|
| 3 |
+
size 1465
|
src/multilabel_model/checkpoint-700/trainer_state.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 100.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 700,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 71.42857142857143,
|
| 14 |
+
"grad_norm": 0.5301845669746399,
|
| 15 |
+
"learning_rate": 0.0008614285714285715,
|
| 16 |
+
"loss": 0.6088,
|
| 17 |
+
"step": 500
|
| 18 |
+
}
|
| 19 |
+
],
|
| 20 |
+
"logging_steps": 500,
|
| 21 |
+
"max_steps": 700,
|
| 22 |
+
"num_input_tokens_seen": 0,
|
| 23 |
+
"num_train_epochs": 100,
|
| 24 |
+
"save_steps": 500,
|
| 25 |
+
"stateful_callbacks": {
|
| 26 |
+
"TrainerControl": {
|
| 27 |
+
"args": {
|
| 28 |
+
"should_epoch_stop": false,
|
| 29 |
+
"should_evaluate": false,
|
| 30 |
+
"should_log": false,
|
| 31 |
+
"should_save": true,
|
| 32 |
+
"should_training_stop": true
|
| 33 |
+
},
|
| 34 |
+
"attributes": {}
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
"total_flos": 18440325120000.0,
|
| 38 |
+
"train_batch_size": 8,
|
| 39 |
+
"trial_name": null,
|
| 40 |
+
"trial_params": null
|
| 41 |
+
}
|
src/multilabel_model/checkpoint-700/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c3a9d8bd2f33a889f206bdf109d0a17adc92a638dcd4560c2c2ee7c25e5b5dd
|
| 3 |
+
size 5841
|
src/multilabel_model/runs/Nov19_18-32-50_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763566371.vadim-HP-Laptop-15s-eq1xxx.124852.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d2783d509460017f1b7d648662ea982a1f5396f2e7c00f4c11dbb544a735cc2
|
| 3 |
+
size 5282
|
src/multilabel_model/runs/Nov19_18-43-43_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763567023.vadim-HP-Laptop-15s-eq1xxx.125134.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dafabd32f21f5da3b60a5c0761ef6534d2056e706f11a059954470db19eafe02
|
| 3 |
+
size 5282
|
src/multilabel_model/runs/Nov19_18-50-20_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763567421.vadim-HP-Laptop-15s-eq1xxx.125341.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:957d4a549ffd58187f71c12287765ee092001e9ab79184c982bed40b6f04684d
|
| 3 |
+
size 5282
|
src/multilabel_model/runs/Nov19_18-56-03_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763567764.vadim-HP-Laptop-15s-eq1xxx.125471.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d994adcc370d91dba8a168cd68b7187a71c43389d7a23a0374fe853cfb2fcae7
|
| 3 |
+
size 5282
|
src/multilabel_model/runs/Nov19_19-12-54_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763568775.vadim-HP-Laptop-15s-eq1xxx.125830.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ef3f43e0de37cd4c14f85dbec9e4c9e366fb8843d1c3dafb4f10eb0b40bd512
|
| 3 |
+
size 5282
|
src/multilabel_model/runs/Nov19_19-19-23_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763569164.vadim-HP-Laptop-15s-eq1xxx.126011.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a2c9137dc4ed0b364ae44e5d8b5fd57e8e6b253a6b90b8ce64bddc29d7c4ee5
|
| 3 |
+
size 4184
|
src/multilabel_model/runs/Nov19_19-20-09_vadim-HP-Laptop-15s-eq1xxx/events.out.tfevents.1763569210.vadim-HP-Laptop-15s-eq1xxx.126088.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88ccd53331d477ce2088f7a0647754563f0b733b0ae7546ad04cfc3971b870f8
|
| 3 |
+
size 5847
|
src/nn_binary.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:56eaba127bef8116f6d3d52e7243dd644af3fb91e9fab9c492d0474b980d1452
|
| 3 |
+
size 78129542
|
src/nn_category.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:994e7f52d66de06e363fde235a31688b8eccbd106485d80b98683479fa2e6124
|
| 3 |
+
size 38731315
|
src/nn_categorys.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b1c11cc21b869caba3002bcdeaad3763b51491aa320fbffb24cfb9505562c13e
|
| 3 |
+
size 78239148
|
src/nn_vectorizer_binary.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6632ffa21bd1c3acd6835ff4b1c09e8d87db21046e7fe548424f06ab972b37d3
|
| 3 |
+
size 367823
|
src/nn_vectorizer_category.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b82e82c137dd417b3d36b936218147bf1ac1b100ad74cb32810a86ec86ab960b
|
| 3 |
+
size 391149
|
src/nn_vectorizer_categorys.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d6a999f2756bed0a9b7febd2e41d564578a2417ee07fbf30f9c24521dcf8941
|
| 3 |
+
size 402881
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,555 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve, precision_score, recall_score, f1_score
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
# Импортируем ваши модули
|
| 10 |
+
try:
|
| 11 |
+
from use_ml import predict_sentiment, predict_category, predict_categorys
|
| 12 |
+
except ImportError:
|
| 13 |
+
st.error("Модуль use_ml не найден")
|
| 14 |
+
try:
|
| 15 |
+
from use_nn import predict_sentiment as nn_predict_sentiment
|
| 16 |
+
from use_nn import predict_category as nn_predict_category
|
| 17 |
+
from use_nn import predict_categorys as nn_predict_categorys
|
| 18 |
+
except ImportError:
|
| 19 |
+
st.error("Модуль use_nn не найден")
|
| 20 |
+
try:
|
| 21 |
+
from use_transformer import predict_sentiment as tf_predict_sentiment
|
| 22 |
+
from use_transformer import predict_category as tf_predict_category
|
| 23 |
+
from use_transformer import predict_categorys as tf_predict_categorys
|
| 24 |
+
except ImportError:
|
| 25 |
+
st.error("Модуль use_transformer не найден")
|
| 26 |
+
|
| 27 |
+
# Настройка страницы
|
| 28 |
+
st.set_page_config(
|
| 29 |
+
page_title="Анализ классификаторов текста",
|
| 30 |
+
page_icon="📊",
|
| 31 |
+
layout="wide"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_models(task_type):
|
| 36 |
+
"""Загрузка моделей в зависимости от типа задачи"""
|
| 37 |
+
models = {}
|
| 38 |
+
|
| 39 |
+
if task_type == "Бинарная":
|
| 40 |
+
try:
|
| 41 |
+
models["Классическая ML"] = predict_sentiment()
|
| 42 |
+
except:
|
| 43 |
+
pass
|
| 44 |
+
try:
|
| 45 |
+
models["Нейросеть"] = nn_predict_sentiment()
|
| 46 |
+
except:
|
| 47 |
+
pass
|
| 48 |
+
try:
|
| 49 |
+
models["Трансформер"] = tf_predict_sentiment()
|
| 50 |
+
except:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
elif task_type == "Многоклассовая":
|
| 54 |
+
try:
|
| 55 |
+
models["Классическая ML"] = predict_category()
|
| 56 |
+
except:
|
| 57 |
+
pass
|
| 58 |
+
try:
|
| 59 |
+
models["Нейросеть"] = nn_predict_category()
|
| 60 |
+
except:
|
| 61 |
+
pass
|
| 62 |
+
try:
|
| 63 |
+
models["Трансформер"] = tf_predict_category()
|
| 64 |
+
except:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
elif task_type == "Многометочная":
|
| 68 |
+
try:
|
| 69 |
+
models["Классическая ML"] = predict_categorys()
|
| 70 |
+
except:
|
| 71 |
+
pass
|
| 72 |
+
try:
|
| 73 |
+
models["Нейросеть"] = nn_predict_categorys()
|
| 74 |
+
except:
|
| 75 |
+
pass
|
| 76 |
+
try:
|
| 77 |
+
models["Трансформер"] = tf_predict_categorys()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
return models
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def plot_probabilities(probs, labels, model_name):
|
| 85 |
+
"""Визуализация вероятностей"""
|
| 86 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 87 |
+
y_pos = np.arange(len(labels))
|
| 88 |
+
|
| 89 |
+
if isinstance(probs, (np.ndarray, list)) and len(probs) > 1:
|
| 90 |
+
# Многоклассовая или многометочная
|
| 91 |
+
ax.barh(y_pos, probs, align='center')
|
| 92 |
+
ax.set_yticks(y_pos)
|
| 93 |
+
ax.set_yticklabels(labels)
|
| 94 |
+
ax.set_xlabel('Вероятность')
|
| 95 |
+
ax.set_title(f'Вероятности классов - {model_name}')
|
| 96 |
+
else:
|
| 97 |
+
# Бинарная
|
| 98 |
+
binary_probs = [1 - probs, probs] if isinstance(probs, (int, float)) else [1 - probs[0], probs[0]]
|
| 99 |
+
binary_labels = ['Negative', 'Positive']
|
| 100 |
+
ax.barh([0, 1], binary_probs, align='center')
|
| 101 |
+
ax.set_yticks([0, 1])
|
| 102 |
+
ax.set_yticklabels(binary_labels)
|
| 103 |
+
ax.set_xlabel('Вероятность')
|
| 104 |
+
ax.set_title(f'Вероятности классов - {model_name}')
|
| 105 |
+
|
| 106 |
+
plt.tight_layout()
|
| 107 |
+
return fig
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def calculate_and_display_binary_metrics(true_labels, predictions):
|
| 111 |
+
"""Расчет и отображение метрик для бинарной классификации"""
|
| 112 |
+
# Преобразуем true_labels в числовой формат
|
| 113 |
+
y_true = [1 if label == 'positive' else 0 for label in true_labels]
|
| 114 |
+
y_pred = [1 if pred['probs'] >= 0.5 else 0 for pred in predictions]
|
| 115 |
+
y_scores = [pred['probs'] for pred in predictions]
|
| 116 |
+
|
| 117 |
+
# ROC curve
|
| 118 |
+
fpr, tpr, _ = roc_curve(y_true, y_scores)
|
| 119 |
+
roc_auc = auc(fpr, tpr)
|
| 120 |
+
|
| 121 |
+
# Precision-Recall curve
|
| 122 |
+
precision, recall, _ = precision_recall_curve(y_true, y_scores)
|
| 123 |
+
|
| 124 |
+
# Визуализация
|
| 125 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
| 126 |
+
|
| 127 |
+
# ROC curve
|
| 128 |
+
ax1.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
|
| 129 |
+
ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
| 130 |
+
ax1.set_xlim([0.0, 1.0])
|
| 131 |
+
ax1.set_ylim([0.0, 1.05])
|
| 132 |
+
ax1.set_xlabel('False Positive Rate')
|
| 133 |
+
ax1.set_ylabel('True Positive Rate')
|
| 134 |
+
ax1.set_title('ROC Curve')
|
| 135 |
+
ax1.legend(loc="lower right")
|
| 136 |
+
|
| 137 |
+
# Precision-Recall curve
|
| 138 |
+
ax2.plot(recall, precision, color='blue', lw=2)
|
| 139 |
+
ax2.set_xlim([0.0, 1.0])
|
| 140 |
+
ax2.set_ylim([0.0, 1.05])
|
| 141 |
+
ax2.set_xlabel('Recall')
|
| 142 |
+
ax2.set_ylabel('Precision')
|
| 143 |
+
ax2.set_title('Precision-Recall Curve')
|
| 144 |
+
|
| 145 |
+
st.pyplot(fig)
|
| 146 |
+
|
| 147 |
+
# Матрица ошибок
|
| 148 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 149 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 150 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
|
| 151 |
+
xticklabels=['Negative', 'Positive'],
|
| 152 |
+
yticklabels=['Negative', 'Positive'])
|
| 153 |
+
ax.set_title('Confusion Matrix')
|
| 154 |
+
ax.set_xlabel('Predicted')
|
| 155 |
+
ax.set_ylabel('Actual')
|
| 156 |
+
st.pyplot(fig)
|
| 157 |
+
|
| 158 |
+
# Отчет классификации
|
| 159 |
+
st.subheader("Отчет классификации")
|
| 160 |
+
report = classification_report(y_true, y_pred, output_dict=True)
|
| 161 |
+
report_df = pd.DataFrame(report).transpose()
|
| 162 |
+
st.dataframe(report_df, use_container_width=True)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def calculate_and_display_multiclass_metrics(true_labels, predictions):
|
| 166 |
+
"""Расчет и отображение метрик для многоклассовой классификации"""
|
| 167 |
+
# Получаем все уникальные классы
|
| 168 |
+
all_classes = list(set(true_labels))
|
| 169 |
+
|
| 170 |
+
# Предсказанные классы (класс с максимальной вероятностью)
|
| 171 |
+
y_pred = [pred['labels'][np.argmax(pred['probs'])] for pred in predictions]
|
| 172 |
+
y_true = true_labels
|
| 173 |
+
|
| 174 |
+
# Матрица ошибок
|
| 175 |
+
cm = confusion_matrix(y_true, y_pred, labels=all_classes)
|
| 176 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 177 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
|
| 178 |
+
xticklabels=all_classes, yticklabels=all_classes)
|
| 179 |
+
ax.set_title('Confusion Matrix')
|
| 180 |
+
ax.set_xlabel('Predicted')
|
| 181 |
+
ax.set_ylabel('Actual')
|
| 182 |
+
plt.xticks(rotation=45)
|
| 183 |
+
plt.yticks(rotation=0)
|
| 184 |
+
st.pyplot(fig)
|
| 185 |
+
|
| 186 |
+
# Отчет классификации
|
| 187 |
+
st.subheader("Отчет классификации")
|
| 188 |
+
report = classification_report(y_true, y_pred, output_dict=True)
|
| 189 |
+
report_df = pd.DataFrame(report).transpose()
|
| 190 |
+
st.dataframe(report_df, use_container_width=True)
|
| 191 |
+
|
| 192 |
+
# Визуализация точности по классам
|
| 193 |
+
class_report = classification_report(y_true, y_pred, output_dict=True)
|
| 194 |
+
classes_metrics = {}
|
| 195 |
+
for class_name in all_classes:
|
| 196 |
+
if class_name in class_report:
|
| 197 |
+
classes_metrics[class_name] = {
|
| 198 |
+
'Precision': class_report[class_name]['precision'],
|
| 199 |
+
'Recall': class_report[class_name]['recall'],
|
| 200 |
+
'F1-Score': class_report[class_name]['f1-score']
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
metrics_df = pd.DataFrame(classes_metrics).T
|
| 204 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 205 |
+
metrics_df.plot(kind='bar', ax=ax)
|
| 206 |
+
ax.set_title('Метрики по классам')
|
| 207 |
+
ax.set_ylabel('Score')
|
| 208 |
+
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
| 209 |
+
plt.xticks(rotation=45)
|
| 210 |
+
st.pyplot(fig)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def calculate_and_display_multilabel_metrics(true_labels, predictions):
|
| 214 |
+
"""Расчет и отображение метрик для многометочной классификации"""
|
| 215 |
+
# Получаем все возможные метки из предсказаний
|
| 216 |
+
all_labels = predictions[0]['labels']
|
| 217 |
+
|
| 218 |
+
# Создаем бинарные матрицы для истинных и предсказанных меток
|
| 219 |
+
y_true_binary = np.zeros((len(true_labels), len(all_labels)))
|
| 220 |
+
y_pred_binary = np.zeros((len(predictions), len(all_labels)))
|
| 221 |
+
|
| 222 |
+
for i, (true_label_list, pred) in enumerate(zip(true_labels, predictions)):
|
| 223 |
+
for j, label in enumerate(all_labels):
|
| 224 |
+
# Истинные метки
|
| 225 |
+
if label in true_label_list:
|
| 226 |
+
y_true_binary[i, j] = 1
|
| 227 |
+
|
| 228 |
+
# Предсказанные метки (порог 0.5)
|
| 229 |
+
if pred['probs'][j] >= 0.5:
|
| 230 |
+
y_pred_binary[i, j] = 1
|
| 231 |
+
|
| 232 |
+
# Вычисляем метрики для каждой метки
|
| 233 |
+
metrics_per_label = {}
|
| 234 |
+
for j, label in enumerate(all_labels):
|
| 235 |
+
metrics_per_label[label] = {
|
| 236 |
+
'Precision': precision_score(y_true_binary[:, j], y_pred_binary[:, j]),
|
| 237 |
+
'Recall': recall_score(y_true_binary[:, j], y_pred_binary[:, j]),
|
| 238 |
+
'F1-Score': f1_score(y_true_binary[:, j], y_pred_binary[:, j]),
|
| 239 |
+
'Support': np.sum(y_true_binary[:, j])
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
# Сводная таблица метрик
|
| 243 |
+
st.subheader("Метрики по меткам")
|
| 244 |
+
metrics_df = pd.DataFrame(metrics_per_label).T
|
| 245 |
+
st.dataframe(metrics_df, use_container_width=True)
|
| 246 |
+
|
| 247 |
+
# Визуализация метрик
|
| 248 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
| 249 |
+
|
| 250 |
+
# Precision по меткам
|
| 251 |
+
axes[0, 0].barh(range(len(all_labels)), [metrics_per_label[label]['Precision'] for label in all_labels])
|
| 252 |
+
axes[0, 0].set_yticks(range(len(all_labels)))
|
| 253 |
+
axes[0, 0].set_yticklabels(all_labels)
|
| 254 |
+
axes[0, 0].set_title('Precision по меткам')
|
| 255 |
+
axes[0, 0].set_xlim(0, 1)
|
| 256 |
+
|
| 257 |
+
# Recall по меткам
|
| 258 |
+
axes[0, 1].barh(range(len(all_labels)), [metrics_per_label[label]['Recall'] for label in all_labels])
|
| 259 |
+
axes[0, 1].set_yticks(range(len(all_labels)))
|
| 260 |
+
axes[0, 1].set_yticklabels(all_labels)
|
| 261 |
+
axes[0, 1].set_title('Recall по меткам')
|
| 262 |
+
axes[0, 1].set_xlim(0, 1)
|
| 263 |
+
|
| 264 |
+
# F1-Score по меткам
|
| 265 |
+
axes[1, 0].barh(range(len(all_labels)), [metrics_per_label[label]['F1-Score'] for label in all_labels])
|
| 266 |
+
axes[1, 0].set_yticks(range(len(all_labels)))
|
| 267 |
+
axes[1, 0].set_yticklabels(all_labels)
|
| 268 |
+
axes[1, 0].set_title('F1-Score по меткам')
|
| 269 |
+
axes[1, 0].set_xlim(0, 1)
|
| 270 |
+
|
| 271 |
+
# Support по меткам
|
| 272 |
+
axes[1, 1].barh(range(len(all_labels)), [metrics_per_label[label]['Support'] for label in all_labels])
|
| 273 |
+
axes[1, 1].set_yticks(range(len(all_labels)))
|
| 274 |
+
axes[1, 1].set_yticklabels(all_labels)
|
| 275 |
+
axes[1, 1].set_title('Support (количество примеров) по меткам')
|
| 276 |
+
|
| 277 |
+
plt.tight_layout()
|
| 278 |
+
st.pyplot(fig)
|
| 279 |
+
|
| 280 |
+
# Примеры предсказаний
|
| 281 |
+
st.subheader("Примеры предсказаний")
|
| 282 |
+
sample_indices = np.random.choice(len(predictions), min(5, len(predictions)), replace=False)
|
| 283 |
+
|
| 284 |
+
for idx in sample_indices:
|
| 285 |
+
with st.expander(f"Пример {idx + 1}"):
|
| 286 |
+
col1, col2 = st.columns(2)
|
| 287 |
+
|
| 288 |
+
with col1:
|
| 289 |
+
st.write("**Истинные метки:**")
|
| 290 |
+
st.write(true_labels[idx])
|
| 291 |
+
|
| 292 |
+
with col2:
|
| 293 |
+
st.write("**Предсказанные метки:**")
|
| 294 |
+
predicted_labels = [all_labels[i] for i, prob in enumerate(predictions[idx]['probs']) if prob >= 0.5]
|
| 295 |
+
st.write(predicted_labels)
|
| 296 |
+
|
| 297 |
+
st.write("**Вероятности:**")
|
| 298 |
+
prob_df = pd.DataFrame({
|
| 299 |
+
'Метка': all_labels,
|
| 300 |
+
'Вероятность': predictions[idx]['probs']
|
| 301 |
+
}).sort_values('Вероятность', ascending=False)
|
| 302 |
+
st.dataframe(prob_df, use_container_width=True)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def process_test_file(uploaded_file, task_type):
|
| 306 |
+
"""Обработка загруженного JSONL файла"""
|
| 307 |
+
data = []
|
| 308 |
+
for line in uploaded_file:
|
| 309 |
+
data.append(json.loads(line.decode('utf-8')))
|
| 310 |
+
|
| 311 |
+
df = pd.DataFrame(data)
|
| 312 |
+
return df
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def calculate_metrics(df, predictions, task_type):
|
| 316 |
+
"""Расчет метрик качества"""
|
| 317 |
+
if task_type == "Бинарная":
|
| 318 |
+
y_true = df['label'].apply(lambda x: 1 if x == 'positive' else 0)
|
| 319 |
+
y_pred = [1 if pred['probs'] >= 0.5 else 0 for pred in predictions]
|
| 320 |
+
y_scores = [pred['probs'] for pred in predictions]
|
| 321 |
+
|
| 322 |
+
# ROC curve
|
| 323 |
+
fpr, tpr, _ = roc_curve(y_true, y_scores)
|
| 324 |
+
roc_auc = auc(fpr, tpr)
|
| 325 |
+
|
| 326 |
+
# Precision-Recall curve
|
| 327 |
+
precision, recall, _ = precision_recall_curve(y_true, y_scores)
|
| 328 |
+
|
| 329 |
+
return {
|
| 330 |
+
'fpr': fpr,
|
| 331 |
+
'tpr': tpr,
|
| 332 |
+
'roc_auc': roc_auc,
|
| 333 |
+
'precision': precision,
|
| 334 |
+
'recall': recall,
|
| 335 |
+
'y_true': y_true,
|
| 336 |
+
'y_pred': y_pred,
|
| 337 |
+
'y_scores': y_scores
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
elif task_type == "Многоклассовая":
|
| 341 |
+
# Для многоклассовой нужна более сложная обработка
|
| 342 |
+
return {"message": "Многоклассовые метрики требуют дополнительной реализации"}
|
| 343 |
+
|
| 344 |
+
else:
|
| 345 |
+
return {"message": "Многометочные метрики требуют дополнительной реализации"}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# Основной интерфейс
|
| 349 |
+
st.title("📊 Анализ классификаторов текста")
|
| 350 |
+
|
| 351 |
+
# Сайдбар для навигации
|
| 352 |
+
st.sidebar.title("Навигация")
|
| 353 |
+
app_mode = st.sidebar.selectbox(
|
| 354 |
+
"Выберите режим",
|
| 355 |
+
["Интерактивная классификация", "Анализ тестовой выборки"],
|
| 356 |
+
key="main_navigation" # Уникальный ключ
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# Интерактивная классификация
|
| 360 |
+
if app_mode == "Интерактивная классификация":
|
| 361 |
+
st.header("🔍 Интерактивная классификация")
|
| 362 |
+
|
| 363 |
+
col1, col2 = st.columns([1, 1])
|
| 364 |
+
|
| 365 |
+
with col1:
|
| 366 |
+
task_type = st.selectbox(
|
| 367 |
+
"Тип задачи",
|
| 368 |
+
["Бинарная", "Многоклассовая", "Многометочная"],
|
| 369 |
+
key="interactive_task_type" # Уникальный ключ
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
available_models = list(load_models(task_type).keys())
|
| 373 |
+
if not available_models:
|
| 374 |
+
st.error("Нет доступных моделей для выбранного типа задачи")
|
| 375 |
+
st.stop()
|
| 376 |
+
|
| 377 |
+
selected_models = st.multiselect(
|
| 378 |
+
"Выберите модели для сравнения",
|
| 379 |
+
available_models,
|
| 380 |
+
default=available_models[0] if available_models else None,
|
| 381 |
+
key="interactive_models" # Уникальный ключ
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
text_input = st.text_area(
|
| 385 |
+
"Введите текст для классификации",
|
| 386 |
+
height=150,
|
| 387 |
+
placeholder="Введите текст здесь...",
|
| 388 |
+
key="interactive_text_input" # Уникальный ключ
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
with col2:
|
| 392 |
+
if text_input and selected_models:
|
| 393 |
+
models = load_models(task_type)
|
| 394 |
+
|
| 395 |
+
for model_name in selected_models:
|
| 396 |
+
st.subheader(f"Модель: {model_name}")
|
| 397 |
+
|
| 398 |
+
try:
|
| 399 |
+
result = models[model_name](text_input)
|
| 400 |
+
|
| 401 |
+
# Отображение результатов
|
| 402 |
+
if task_type == "Бинарная":
|
| 403 |
+
sentiment = "Positive" if result['probs'] >= 0.5 else "Negative"
|
| 404 |
+
confidence = result['probs'] if result['probs'] >= 0.5 else 1 - result['probs']
|
| 405 |
+
|
| 406 |
+
st.write(f"**Результат**: {sentiment}")
|
| 407 |
+
st.write(f"**Уверенность**: {confidence:.3f}")
|
| 408 |
+
|
| 409 |
+
# Визуализация вероятностей
|
| 410 |
+
fig = plot_probabilities(result['probs'], result.get('labels', ['Negative', 'Positive']),
|
| 411 |
+
model_name)
|
| 412 |
+
st.pyplot(fig)
|
| 413 |
+
|
| 414 |
+
else:
|
| 415 |
+
if task_type == "Многоклассовая":
|
| 416 |
+
predicted_idx = np.argmax(result['probs'])
|
| 417 |
+
predicted_label = result['labels'][predicted_idx]
|
| 418 |
+
confidence = result['probs'][predicted_idx]
|
| 419 |
+
|
| 420 |
+
st.write(f"**Предсказанный класс**: {predicted_label}")
|
| 421 |
+
st.write(f"**Уверенность**: {confidence:.3f}")
|
| 422 |
+
|
| 423 |
+
else: # Многометочная
|
| 424 |
+
predicted_labels = [result['labels'][i] for i, prob in enumerate(result['probs']) if
|
| 425 |
+
prob >= 0.5]
|
| 426 |
+
st.write(f"**Предсказанные классы**: {', '.join(predicted_labels)}")
|
| 427 |
+
|
| 428 |
+
# Визуализация вероятностей
|
| 429 |
+
fig = plot_probabilities(result['probs'], result['labels'], model_name)
|
| 430 |
+
st.pyplot(fig)
|
| 431 |
+
|
| 432 |
+
# Таблица вероятностей
|
| 433 |
+
prob_df = pd.DataFrame({
|
| 434 |
+
'Класс': result['labels'],
|
| 435 |
+
'Вероятность': result['probs']
|
| 436 |
+
}).sort_values('Вероятность', ascending=False)
|
| 437 |
+
|
| 438 |
+
st.dataframe(prob_df, use_container_width=True)
|
| 439 |
+
|
| 440 |
+
except Exception as e:
|
| 441 |
+
st.write("График не поддерживается у данной модели")
|
| 442 |
+
|
| 443 |
+
# Анализ тестовой выборки
|
| 444 |
+
elif app_mode == "Анализ тестовой выборки":
|
| 445 |
+
st.header("📈 Анализ тестовой выборки")
|
| 446 |
+
|
| 447 |
+
uploaded_file = st.file_uploader(
|
| 448 |
+
"Загрузите JSONL файл с тестовой выборкой",
|
| 449 |
+
type=['jsonl'],
|
| 450 |
+
help="Файл должен содержать поля 'text' и 'label' (для бинарной/многоклассовой) или 'labels' (для многометочной)",
|
| 451 |
+
key="file_uploader"
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if uploaded_file:
|
| 455 |
+
task_type = st.selectbox(
|
| 456 |
+
"Тип задачи для анализа",
|
| 457 |
+
["Бинарная", "Многоклассовая", "Многометочная"],
|
| 458 |
+
key="analysis_task_type"
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
available_models = list(load_models(task_type).keys())
|
| 462 |
+
if not available_models:
|
| 463 |
+
st.error("Нет доступных моделей для выбранного типа задачи")
|
| 464 |
+
st.stop()
|
| 465 |
+
|
| 466 |
+
selected_model = st.selectbox(
|
| 467 |
+
"Выберите модель для анализа",
|
| 468 |
+
available_models,
|
| 469 |
+
key="analysis_model"
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
if st.button("Запустить анализ", key="analyze_button"):
|
| 473 |
+
with st.spinner("Обработка данных..."):
|
| 474 |
+
# Загрузка и обработка данных
|
| 475 |
+
df = process_test_file(uploaded_file, task_type)
|
| 476 |
+
st.write(f"Загружено {len(df)} примеров")
|
| 477 |
+
|
| 478 |
+
# Проверка структуры данных
|
| 479 |
+
st.subheader("Структура данных")
|
| 480 |
+
st.dataframe(df.head(), use_container_width=True)
|
| 481 |
+
|
| 482 |
+
if task_type == "Многометочная" and 'labels' not in df.columns:
|
| 483 |
+
st.error("Для многометочной классификации в файле должно быть поле 'labels'")
|
| 484 |
+
st.stop()
|
| 485 |
+
elif task_type != "Многометочная" and 'label' not in df.columns:
|
| 486 |
+
st.error("Для бинарной и многоклассовой классификации в файле должно быть поле 'label'")
|
| 487 |
+
st.stop()
|
| 488 |
+
|
| 489 |
+
# Предсказания
|
| 490 |
+
model = load_models(task_type)[selected_model]
|
| 491 |
+
predictions = []
|
| 492 |
+
true_labels = []
|
| 493 |
+
|
| 494 |
+
progress_bar = st.progress(0)
|
| 495 |
+
for i, row in df.iterrows():
|
| 496 |
+
try:
|
| 497 |
+
result = model(row['text'])
|
| 498 |
+
predictions.append(result)
|
| 499 |
+
|
| 500 |
+
# Сохраняем истинные метки в нужном формате
|
| 501 |
+
if task_type == "Многометочная":
|
| 502 |
+
true_labels.append(row['labels'])
|
| 503 |
+
else:
|
| 504 |
+
true_labels.append(row['label'])
|
| 505 |
+
|
| 506 |
+
except Exception as e:
|
| 507 |
+
st.error(f"Ошибка при обработке примера {i}: {str(e)}")
|
| 508 |
+
predictions.append(None)
|
| 509 |
+
true_labels.append(None)
|
| 510 |
+
|
| 511 |
+
progress_bar.progress((i + 1) / len(df))
|
| 512 |
+
|
| 513 |
+
# Удаляем примеры с ошибками
|
| 514 |
+
valid_indices = [i for i, pred in enumerate(predictions) if pred is not None]
|
| 515 |
+
predictions = [predictions[i] for i in valid_indices]
|
| 516 |
+
true_labels = [true_labels[i] for i in valid_indices]
|
| 517 |
+
|
| 518 |
+
st.write(f"Успешно обработано {len(predictions)} из {len(df)} примеров")
|
| 519 |
+
|
| 520 |
+
# Расчет и отображение метрик для разных типов задач
|
| 521 |
+
if task_type == "Бинарная":
|
| 522 |
+
calculate_and_display_binary_metrics(true_labels, predictions)
|
| 523 |
+
|
| 524 |
+
elif task_type == "Многоклассовая":
|
| 525 |
+
calculate_and_display_multiclass_metrics(true_labels, predictions)
|
| 526 |
+
|
| 527 |
+
elif task_type == "Многометочная":
|
| 528 |
+
calculate_and_display_multilabel_metrics(true_labels, predictions)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# Информация в сайдбаре
|
| 532 |
+
st.sidebar.markdown("---")
|
| 533 |
+
st.sidebar.info("""
|
| 534 |
+
**Инструкция:**
|
| 535 |
+
1. **Интерактивная классификация**: Тестируйте модели на произвольном тексте
|
| 536 |
+
2. **Анализ тестовой выборки**: Загрузите JSONL файл для оценки качества
|
| 537 |
+
""")
|
| 538 |
|
| 539 |
+
# CSS для улучшения внешнего вида
|
| 540 |
+
st.markdown("""
|
| 541 |
+
<style>
|
| 542 |
+
.main-header {
|
| 543 |
+
font-size: 2.5rem;
|
| 544 |
+
color: #1f77b4;
|
| 545 |
+
text-align: center;
|
| 546 |
+
margin-bottom: 2rem;
|
| 547 |
+
}
|
| 548 |
+
.metric-card {
|
| 549 |
+
background-color: #f0f2f6;
|
| 550 |
+
padding: 1rem;
|
| 551 |
+
border-radius: 0.5rem;
|
| 552 |
+
margin: 0.5rem 0;
|
| 553 |
+
}
|
| 554 |
+
</style>
|
| 555 |
+
""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/use_ml.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#1 привести полученный текст к приемлемому виду
|
| 2 |
+
#2 подать текст на вход к модели и получить результат
|
| 3 |
+
import spacy
|
| 4 |
+
from joblib import load
|
| 5 |
+
|
| 6 |
+
def predict_sentiment():
|
| 7 |
+
model_binary = load("ml_binary.joblib")
|
| 8 |
+
def _inner(text: str):
|
| 9 |
+
pred = model_binary.predict([preprocess_text(text)])[0]
|
| 10 |
+
res = {
|
| 11 |
+
"labels": "positive" if pred == 1 else "negative",
|
| 12 |
+
"probs": pred
|
| 13 |
+
}
|
| 14 |
+
return res
|
| 15 |
+
return _inner
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def predict_category():
|
| 20 |
+
model_category = load("ml_category.joblib")
|
| 21 |
+
|
| 22 |
+
def _inner(text: str):
|
| 23 |
+
pred = model_category.predict([preprocess_text(text)])[0]
|
| 24 |
+
labels = [
|
| 25 |
+
"политика",
|
| 26 |
+
"экономика",
|
| 27 |
+
"спорт",
|
| 28 |
+
"культура"
|
| 29 |
+
]
|
| 30 |
+
probs = [0, 0, 0, 0]
|
| 31 |
+
probs[pred] = 1
|
| 32 |
+
res = {
|
| 33 |
+
"labels": labels,
|
| 34 |
+
"probs": probs
|
| 35 |
+
}
|
| 36 |
+
return res
|
| 37 |
+
return _inner
|
| 38 |
+
|
| 39 |
+
def predict_categorys():
|
| 40 |
+
model_categorys = load("ml_categorys.joblib")
|
| 41 |
+
|
| 42 |
+
def _inner(text: str):
|
| 43 |
+
pred = model_categorys.predict([preprocess_text(text)])[0]
|
| 44 |
+
labels = [
|
| 45 |
+
"политика",
|
| 46 |
+
"экономика",
|
| 47 |
+
"спорт",
|
| 48 |
+
"культура"
|
| 49 |
+
]
|
| 50 |
+
res = {
|
| 51 |
+
"labels": labels,
|
| 52 |
+
"probs": pred
|
| 53 |
+
}
|
| 54 |
+
return res
|
| 55 |
+
return _inner
|
| 56 |
+
|
| 57 |
+
def preprocess_text(text: str) -> str:
|
| 58 |
+
if text is None:
|
| 59 |
+
return ""
|
| 60 |
+
|
| 61 |
+
nlp = spacy.load("ru_core_news_md", disable=["ner"])
|
| 62 |
+
|
| 63 |
+
text = " ".join(text.split()).lower()
|
| 64 |
+
|
| 65 |
+
doc = nlp(text)
|
| 66 |
+
tokens = []
|
| 67 |
+
|
| 68 |
+
for t in doc:
|
| 69 |
+
if t.is_stop or t.is_punct or t.is_space:
|
| 70 |
+
continue
|
| 71 |
+
lemma = t.lemma_.strip()
|
| 72 |
+
if len(lemma) <= 1:
|
| 73 |
+
continue
|
| 74 |
+
tokens.append(lemma)
|
| 75 |
+
|
| 76 |
+
return " ".join(tokens)
|
src/use_nn.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from use_ml import preprocess_text
|
| 3 |
+
|
| 4 |
+
def predict_sentiment():
|
| 5 |
+
model = tf.keras.models.load_model("nn_binary.keras")
|
| 6 |
+
vectorizer = tf.keras.models.load_model("nn_vectorizer_binary.keras")
|
| 7 |
+
def _inner(text: str) -> str:
|
| 8 |
+
p_text = preprocess_text(text)
|
| 9 |
+
vec = vectorizer([p_text])
|
| 10 |
+
pred = model.predict(vec)[0][0]
|
| 11 |
+
res = {
|
| 12 |
+
"labels": "positive" if pred >= 0.5 else "negative",
|
| 13 |
+
"probs": pred
|
| 14 |
+
}
|
| 15 |
+
return res
|
| 16 |
+
return _inner
|
| 17 |
+
|
| 18 |
+
def predict_category():
|
| 19 |
+
model = tf.keras.models.load_model("nn_category.keras")
|
| 20 |
+
vectorizer = tf.keras.models.load_model("nn_vectorizer_category.keras")
|
| 21 |
+
def _inner(text: str) -> str:
|
| 22 |
+
p_text = preprocess_text(text)
|
| 23 |
+
vec = vectorizer([p_text])
|
| 24 |
+
pred = model.predict(vec)[0]
|
| 25 |
+
labels = [
|
| 26 |
+
"политика",
|
| 27 |
+
"экономика",
|
| 28 |
+
"спорт",
|
| 29 |
+
"культура"
|
| 30 |
+
]
|
| 31 |
+
res = {
|
| 32 |
+
"labels": labels,
|
| 33 |
+
"probs": pred
|
| 34 |
+
}
|
| 35 |
+
return res
|
| 36 |
+
return _inner
|
| 37 |
+
|
| 38 |
+
def predict_categorys():
|
| 39 |
+
model = tf.keras.models.load_model("nn_categorys.keras")
|
| 40 |
+
vectorizer = tf.keras.models.load_model("nn_vectorizer_categorys.keras")
|
| 41 |
+
def _inner(text: str):
|
| 42 |
+
p_text = preprocess_text(text)
|
| 43 |
+
vec = vectorizer([p_text])
|
| 44 |
+
labels = [
|
| 45 |
+
"политика",
|
| 46 |
+
"экономика",
|
| 47 |
+
"спорт",
|
| 48 |
+
"культура"
|
| 49 |
+
]
|
| 50 |
+
pred = model.predict(vec)[0]
|
| 51 |
+
res = {
|
| 52 |
+
"labels": labels,
|
| 53 |
+
"probs": pred
|
| 54 |
+
}
|
| 55 |
+
return res
|
| 56 |
+
return _inner
|
src/use_transformer.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def predict_sentiment():
|
| 6 |
+
model_path = "./binary_model/checkpoint-400"
|
| 7 |
+
model_name = "cointegrated/rubert-tiny"
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 9 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 10 |
+
clf = pipeline(
|
| 11 |
+
"text-classification",
|
| 12 |
+
model=model,
|
| 13 |
+
tokenizer=tokenizer,
|
| 14 |
+
return_all_scores=False
|
| 15 |
+
)
|
| 16 |
+
def _inner(text: str):
|
| 17 |
+
pred = clf(text)
|
| 18 |
+
res = {
|
| 19 |
+
"labels": pred[0]["label"],
|
| 20 |
+
"probs": pred[0]["score"]
|
| 21 |
+
}
|
| 22 |
+
return res
|
| 23 |
+
return _inner
|
| 24 |
+
|
| 25 |
+
def predict_category():
|
| 26 |
+
model_path = "./category_model/checkpoint-400"
|
| 27 |
+
model_name = "cointegrated/rubert-tiny"
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 29 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 30 |
+
clf = pipeline(
|
| 31 |
+
"text-classification",
|
| 32 |
+
model=model,
|
| 33 |
+
tokenizer=tokenizer,
|
| 34 |
+
return_all_scores=False
|
| 35 |
+
)
|
| 36 |
+
def _inner(text: str):
|
| 37 |
+
pred = clf(text)
|
| 38 |
+
labels = {"политика": 0, "экономика": 0, "спорт": 0, "культура": 0, pred[0]["label"]: pred[0]["score"]}
|
| 39 |
+
classes = [
|
| 40 |
+
"политика",
|
| 41 |
+
"экономика",
|
| 42 |
+
"спорт",
|
| 43 |
+
"культура"
|
| 44 |
+
]
|
| 45 |
+
new_labels = []
|
| 46 |
+
for cl in classes:
|
| 47 |
+
new_labels.append(labels[cl])
|
| 48 |
+
res = {
|
| 49 |
+
"labels": classes,
|
| 50 |
+
"probs": new_labels
|
| 51 |
+
}
|
| 52 |
+
return res
|
| 53 |
+
|
| 54 |
+
return _inner
|
| 55 |
+
|
| 56 |
+
def predict_categorys():
|
| 57 |
+
model_path = "./multilabel_model/checkpoint-700"
|
| 58 |
+
model_name = "cointegrated/rubert-tiny"
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 60 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 61 |
+
model.eval()
|
| 62 |
+
classes = [
|
| 63 |
+
"политика",
|
| 64 |
+
"экономика",
|
| 65 |
+
"спорт",
|
| 66 |
+
"культура"
|
| 67 |
+
]
|
| 68 |
+
def _inner(text: str):
|
| 69 |
+
input = tokenizer(
|
| 70 |
+
text,
|
| 71 |
+
return_tensors="pt",
|
| 72 |
+
truncation=True,
|
| 73 |
+
padding="max_length",
|
| 74 |
+
max_length=256
|
| 75 |
+
)
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
logits = model(**input).logits
|
| 78 |
+
probs = torch.sigmoid(logits).squeeze().tolist()
|
| 79 |
+
res = {
|
| 80 |
+
"labels": classes,
|
| 81 |
+
"probs": probs
|
| 82 |
+
}
|
| 83 |
+
return res
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
return _inner
|