Ludvig commited on
Commit
621a5a4
·
1 Parent(s): b995dc9

Fixes, styling, improvements

Browse files
Files changed (6) hide show
  1. README.md +5 -12
  2. app.py +109 -92
  3. design.py +8 -16
  4. plot.R +25 -11
  5. text_sections.py +78 -24
  6. utils.py +13 -2
README.md CHANGED
@@ -3,26 +3,19 @@ title: plot_confusion_matrix
3
  sdk: docker
4
  app_file: app.py
5
  pinned: true
 
 
 
6
  ---
7
 
8
- # cvms_plot_app
9
 
10
  Streamlit application for plotting a confusion matrix.
11
 
12
- emoji: {{emoji}}
13
- colorFrom: {{colorFrom}}
14
- colorTo: {{colorTo}}
15
-
16
 
17
  ## TODOs
18
-
19
- - IMPORTANT! Allow specifying which class probabilities are of! (See plot prob_of_class)
20
- - Allow setting threshold - manual, max J, spec/sens
21
- - Add bg box around confusion matrix plot as text dissappears on dark mode!
22
- - ggsave does not use dpi??
23
  - Allow svg, pdf?
24
  - Add full reset button (empty cache on different files) - callback?
25
- - Handle <2 classes in design box (add st.error)
26
- - Handle classes with spaces in them?
27
  - Add option to change zero-tile background (e.g. to black for black backgrounds)
28
  - Add option to format total-count tile in sum tiles
 
3
  sdk: docker
4
  app_file: app.py
5
  pinned: true
6
+ emoji: 🍀
7
+ colorFrom: fe7120
8
+ colorTo: 8511a5
9
  ---
10
 
11
+ # Plot Confusion Matrix Streamlit Application
12
 
13
  Streamlit application for plotting a confusion matrix.
14
 
 
 
 
 
15
 
16
  ## TODOs
17
+ - ggsave only uses DPI for scaling? We would expect output files to have the given DPI?
 
 
 
 
18
  - Allow svg, pdf?
19
  - Add full reset button (empty cache on different files) - callback?
 
 
20
  - Add option to change zero-tile background (e.g. to black for black backgrounds)
21
  - Add option to format total-count tile in sum tiles
app.py CHANGED
@@ -12,7 +12,7 @@ from pandas.api.types import is_float_dtype
12
  from itertools import combinations
13
  from collections import OrderedDict
14
 
15
- from utils import call_subprocess, clean_string_for_non_alphanumerics
16
  from data import read_data, read_data_cached, DownloadHeader, generate_data
17
  from design import design_section
18
  from text_sections import (
@@ -37,8 +37,6 @@ st.markdown(
37
 
38
 
39
  # Create temporary directory
40
-
41
-
42
  @st.cache_resource
43
  def set_tmp_dir():
44
  """
@@ -157,12 +155,19 @@ elif input_choice == "Upload counts":
157
 
158
  if st.form_submit_button(label="Set columns"):
159
  st.session_state["step"] = 2
160
- st.session_state["classes"] = sorted(
161
- [
162
- str(c)
163
- for c in st.session_state["count_data"][target_col].unique()
164
- ]
165
- )
 
 
 
 
 
 
 
166
 
167
  # Generate data
168
  elif input_choice == "Generate":
@@ -283,108 +288,120 @@ elif input_choice == "Enter counts":
283
  n_col = "N"
284
 
285
  if st.session_state["step"] >= 2:
 
286
  if st.session_state["input_type"] == "data":
287
  # Remove unused columns
288
  df = df.loc[:, [target_col, prediction_col]]
289
 
290
- # Ensure targets are strings
291
- df[target_col] = df[target_col].astype(str)
292
- df[target_col] = df[target_col].apply(lambda x: x.replace(" ", "_"))
 
 
 
 
 
293
 
294
- # Save to tmp directory to allow reading in R script
295
- df.to_csv(data_store_path)
 
 
296
 
297
- # Extract unique classes
298
- st.session_state["classes"] = sorted([str(c) for c in df[target_col].unique()])
299
 
300
- predictions_are_probabilities = is_float_dtype(df[prediction_col])
301
- if predictions_are_probabilities and len(st.session_state["classes"]) != 2:
302
- st.error(
303
- "Predictions can only be probabilities in binary classification. "
304
- f"Got {len(st.session_state['classes'])} classes."
305
  )
306
 
307
- st.subheader("The Data")
308
- col1, col2, col3 = st.columns([2, 2, 2])
309
- with col2:
310
- st.write(df.head(5))
311
- st.write(f"{df.shape} (Showing first 5 rows)")
312
 
313
  else:
314
- predictions_are_probabilities = False
315
  st.session_state["count_data"].to_csv(data_store_path)
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- # Check the number of classes
318
- num_classes = len(st.session_state["classes"])
319
- if num_classes < 2:
320
- # TODO Handle better than throwing error?
321
- raise ValueError(
322
- "Uploaded data must contain 2 or more classes in `Targets column`. "
323
- f"Got {num_classes} target classes."
324
- )
325
 
326
- # Section for specifying design settings
327
-
328
- design_settings, design_ready, selected_classes, prob_of_class = design_section(
329
- num_classes=num_classes,
330
- predictions_are_probabilities=predictions_are_probabilities,
331
- design_settings_store_path=design_settings_store_path,
332
- )
333
-
334
- # design_ready tells us whether to proceed or wait
335
- # for user to fix issues
336
- if st.session_state["step"] >= 3 and design_ready:
337
- DownloadHeader.centered_json_download(
338
- data=design_settings,
339
- file_name="design_settings.json",
340
- label="Download design settings",
341
- help="Download the design settings to allow reusing settings in future plots.",
342
  )
343
 
344
- st.markdown("---")
345
-
346
- plotting_args = [
347
- "--data_path",
348
- f"'{data_store_path}'",
349
- "--out_path",
350
- f"'{conf_mat_path}'",
351
- "--settings_path",
352
- f"'{design_settings_store_path}'",
353
- "--target_col",
354
- f"'{target_col}'",
355
- "--prediction_col",
356
- f"'{prediction_col}'",
357
- "--classes",
358
- f"{','.join(selected_classes)}",
359
- ]
360
-
361
- if st.session_state["input_type"] == "counts":
362
- # The input data are counts
363
- plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
364
-
365
- plotting_args = " ".join(plotting_args)
366
-
367
- call_subprocess(
368
- f"Rscript plot.R {plotting_args}",
369
- message="Plotting script",
370
- return_output=True,
371
- encoding="UTF-8",
372
- )
373
 
374
- DownloadHeader.header_and_image_download(
375
- "", filepath=conf_mat_path, label="Download Plot"
376
- )
377
- col1, col2, col3 = st.columns([2, 8, 2])
378
- with col2:
379
- image = Image.open(conf_mat_path)
380
- st.image(
381
- image,
382
- caption="Confusion Matrix",
383
- clamp=False,
384
- channels="RGB",
385
- output_format="auto",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  )
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  else:
389
  st.write("Please upload data.")
390
 
 
12
  from itertools import combinations
13
  from collections import OrderedDict
14
 
15
+ from utils import call_subprocess, clean_string_for_non_alphanumerics, clean_str_column
16
  from data import read_data, read_data_cached, DownloadHeader, generate_data
17
  from design import design_section
18
  from text_sections import (
 
37
 
38
 
39
  # Create temporary directory
 
 
40
  @st.cache_resource
41
  def set_tmp_dir():
42
  """
 
155
 
156
  if st.form_submit_button(label="Set columns"):
157
  st.session_state["step"] = 2
158
+
159
+ if st.session_state["step"] >= 2:
160
+ print(st.session_state["count_data"])
161
+ # Ensure targets and predictions are clean strings
162
+ st.session_state["count_data"][target_col] = clean_str_column(
163
+ st.session_state["count_data"][target_col]
164
+ )
165
+ st.session_state["count_data"][prediction_col] = clean_str_column(
166
+ st.session_state["count_data"][prediction_col]
167
+ )
168
+ st.session_state["classes"] = sorted(
169
+ [c for c in st.session_state["count_data"][target_col].unique()]
170
+ )
171
 
172
  # Generate data
173
  elif input_choice == "Generate":
 
288
  n_col = "N"
289
 
290
  if st.session_state["step"] >= 2:
291
+ data_is_ready = False
292
  if st.session_state["input_type"] == "data":
293
  # Remove unused columns
294
  df = df.loc[:, [target_col, prediction_col]]
295
 
296
+ predictions_are_probabilities = is_float_dtype(df[prediction_col])
297
+ if predictions_are_probabilities:
298
+ st.error(
299
+ "Predictions should be the predicted classes - not probabilities. "
300
+ )
301
+ data_is_ready = False
302
+ else:
303
+ data_is_ready = True
304
 
305
+ if data_is_ready:
306
+ # Ensure targets and predictions are clean strings
307
+ df[target_col] = clean_str_column(df[target_col])
308
+ df[prediction_col] = clean_str_column(df[prediction_col])
309
 
310
+ # Save to tmp directory to allow reading in R script
311
+ df.to_csv(data_store_path)
312
 
313
+ # Extract unique classes
314
+ st.session_state["classes"] = sorted(
315
+ [str(c) for c in df[target_col].unique()]
 
 
316
  )
317
 
318
+ st.subheader("The Data")
319
+ col1, col2, col3 = st.columns([2, 2, 2])
320
+ with col2:
321
+ st.write(df.head(5))
322
+ st.write(f"{df.shape} (Showing first 5 rows)")
323
 
324
  else:
 
325
  st.session_state["count_data"].to_csv(data_store_path)
326
+ data_is_ready = True
327
+
328
+ if data_is_ready:
329
+ # Check the number of classes
330
+ num_classes = len(st.session_state["classes"])
331
+ if num_classes < 2:
332
+ # TODO Handle better than throwing error?
333
+ raise ValueError(
334
+ "Uploaded data must contain 2 or more classes in `Targets column`. "
335
+ f"Got {num_classes} target classes."
336
+ )
337
 
338
+ # Section for specifying design settings
 
 
 
 
 
 
 
339
 
340
+ design_settings, design_ready, selected_classes = design_section(
341
+ num_classes=num_classes,
342
+ design_settings_store_path=design_settings_store_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  )
344
 
345
+ # design_ready tells us whether to proceed or wait
346
+ # for user to fix issues
347
+ if st.session_state["step"] >= 3 and design_ready:
348
+ DownloadHeader.centered_json_download(
349
+ data=design_settings,
350
+ file_name="design_settings.json",
351
+ label="Download design settings",
352
+ help="Download the design settings to allow reusing settings in future plots.",
353
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
+ st.markdown("---")
356
+
357
+ selected_classes_string = ",".join([f"'{c}'" for c in selected_classes])
358
+
359
+ plotting_args = [
360
+ "--data_path",
361
+ f"'{data_store_path}'",
362
+ "--out_path",
363
+ f"'{conf_mat_path}'",
364
+ "--settings_path",
365
+ f"'{design_settings_store_path}'",
366
+ "--target_col",
367
+ f"'{target_col}'",
368
+ "--prediction_col",
369
+ f"'{prediction_col}'",
370
+ "--classes",
371
+ f"{selected_classes_string}",
372
+ ]
373
+
374
+ if st.session_state["input_type"] == "counts":
375
+ # The input data are counts
376
+ plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
377
+
378
+ plotting_args = " ".join(plotting_args)
379
+
380
+ call_subprocess(
381
+ f"Rscript plot.R {plotting_args}",
382
+ message="Plotting script",
383
+ return_output=True,
384
+ encoding="UTF-8",
385
  )
386
 
387
+ DownloadHeader.header_and_image_download(
388
+ "", filepath=conf_mat_path, label="Download plot"
389
+ )
390
+
391
+ col1, col2, col3 = st.columns([2, 8, 2])
392
+ with col2:
393
+ st.write(" ")
394
+ image = Image.open(str(conf_mat_path)[:-3] + "jpg")
395
+ st.image(
396
+ image,
397
+ caption="Confusion Matrix",
398
+ clamp=False,
399
+ channels="RGB",
400
+ output_format="auto",
401
+ )
402
+ st.write(" ")
403
+ st.write("Note: The downloadable file has a transparent background.")
404
+
405
  else:
406
  st.write("Please upload data.")
407
 
design.py CHANGED
@@ -31,7 +31,6 @@ def _add_select_box(
31
 
32
  def design_section(
33
  num_classes,
34
- predictions_are_probabilities,
35
  design_settings_store_path,
36
  ):
37
  output = {}
@@ -80,19 +79,7 @@ def design_section(
80
  "of another class is excluded.",
81
  )
82
  with col2:
83
- prob_of_class = None
84
- # Not respected, so disabled for now
85
- # if (
86
- # st.session_state["input_type"] == "data"
87
- # and predictions_are_probabilities
88
- # ):
89
- # prob_of_class = st.selectbox(
90
- # "Probabilities are of (not working)",
91
- # options=st.session_state["classes"],
92
- # index=1,
93
- # )
94
- # else:
95
- # prob_of_class = None
96
 
97
  # Color palette
98
  output["palette"] = _add_select_box(
@@ -124,9 +111,11 @@ def design_section(
124
  )
125
  with col3:
126
  output["dpi"] = st.number_input(
127
- "DPI (not working)",
128
  value=get_uploaded_setting(key="dpi", default=320, type_=int),
129
  step=10,
 
 
130
  )
131
 
132
  st.write(" ") # Slightly bigger gap between the two sections
@@ -469,8 +458,11 @@ def design_section(
469
  "the sum tiles under **Tiles** >> *Sum tile settings*."
470
  )
471
  design_ready = False
 
 
 
472
 
473
- return output, design_ready, selected_classes, prob_of_class
474
 
475
 
476
  # defaults: dict,
 
31
 
32
  def design_section(
33
  num_classes,
 
34
  design_settings_store_path,
35
  ):
36
  output = {}
 
79
  "of another class is excluded.",
80
  )
81
  with col2:
82
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Color palette
85
  output["palette"] = _add_select_box(
 
111
  )
112
  with col3:
113
  output["dpi"] = st.number_input(
114
+ "DPI (scaling)",
115
  value=get_uploaded_setting(key="dpi", default=320, type_=int),
116
  step=10,
117
+ help="While the output file *currently* won't have this DPI, "
118
+ "the DPI setting affects scaling of elements. ",
119
  )
120
 
121
  st.write(" ") # Slightly bigger gap between the two sections
 
458
  "the sum tiles under **Tiles** >> *Sum tile settings*."
459
  )
460
  design_ready = False
461
+ if len(selected_classes) < 2:
462
+ st.error("At least 2 classes must be selected.")
463
+ design_ready = False
464
 
465
+ return output, design_ready, selected_classes
466
 
467
 
468
  # defaults: dict,
plot.R CHANGED
@@ -42,10 +42,6 @@ option_list <- list(
42
  "Comma-separated class names. ",
43
  "Only these classes will be used - in the specified order."
44
  )
45
- ),
46
- make_option(c("--prob_of_class"),
47
- type = "character",
48
- help = "Name of class that probabilities are of."
49
  )
50
  )
51
 
@@ -104,10 +100,10 @@ if (isTRUE(dev_mode)) {
104
  print(df)
105
  }
106
 
107
- if (!target_col %in% colnames(df)){
108
  stop("Specified `target_col` not a column in the data.")
109
  }
110
- if (!prediction_col %in% colnames(df)){
111
  stop("Specified `target_col` not a column in the data.")
112
  }
113
 
@@ -157,10 +153,6 @@ if (!isTRUE(data_are_counts)) {
157
  "multinomial"
158
  )
159
 
160
- # TODO : use prob_of_class to ensure probabilities
161
- # are interpreted correctly!!
162
- # TODO : Set / calculate threshold
163
- # Might need to invert them to get it to work!
164
  evaluation <- tryCatch(
165
  {
166
  cvms::evaluate(
@@ -320,7 +312,29 @@ tryCatch(
320
  )
321
  },
322
  error = function(e) {
323
- print(paste0("Failed to ggsave plot to: ", opt$out_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  print(e)
325
  stop(e)
326
  }
 
42
  "Comma-separated class names. ",
43
  "Only these classes will be used - in the specified order."
44
  )
 
 
 
 
45
  )
46
  )
47
 
 
100
  print(df)
101
  }
102
 
103
+ if (!target_col %in% colnames(df)) {
104
  stop("Specified `target_col` not a column in the data.")
105
  }
106
+ if (!prediction_col %in% colnames(df)) {
107
  stop("Specified `target_col` not a column in the data.")
108
  }
109
 
 
153
  "multinomial"
154
  )
155
 
 
 
 
 
156
  evaluation <- tryCatch(
157
  {
158
  cvms::evaluate(
 
312
  )
313
  },
314
  error = function(e) {
315
+ print(paste0("png: Failed to ggsave plot to: ", opt$out_path))
316
+ print(e)
317
+ stop(e)
318
+ }
319
+ )
320
+
321
+ # Create a jpg version as well
322
+ tryCatch(
323
+ {
324
+ ggplot2::ggsave(
325
+ paste0(substr(
326
+ opt$out_path,
327
+ start = 1,
328
+ stop = nchar(opt$out_path) - 3
329
+ ), "jpg"),
330
+ width = design_settings$width,
331
+ height = design_settings$height,
332
+ dpi = design_settings$dpi,
333
+ units = "px"
334
+ )
335
+ },
336
+ error = function(e) {
337
+ print(paste0("jpg: Failed to ggsave plot to: ", opt$out_path))
338
  print(e)
339
  stop(e)
340
  }
text_sections.py CHANGED
@@ -1,7 +1,27 @@
1
  import streamlit as st
 
2
  from utils import call_subprocess
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @st.cache_resource
6
  def get_cvms_version():
7
  return (
@@ -19,6 +39,27 @@ def get_cvms_version():
19
  )
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def intro_text():
23
  col1, col2 = st.columns([8, 2])
24
  with col1:
@@ -41,18 +82,20 @@ def intro_text():
41
  st.subheader("Have your data ready?")
42
  st.markdown( # TODO: Make A,B, etc. icons
43
  "Upload a csv file with either: \n\n"
44
- "A) **Targets** and **predictions**. \n\n"
45
- "B) Existing confusion matrix **counts**. \n\n"
46
- "--> Specify the columns to use.\n\n"
47
- "--> Press **Generate plot**.\n\n"
 
48
  )
49
  with col2:
50
  st.subheader("No data to upload?")
51
  st.markdown(
52
  "No worries! Either: \n\n"
53
- "C) **Input** your counts directly! \n\n"
54
- "D) **Generate** some data with **very** easy controls! \n\n"
55
- "--> Press **Generate plot**.\n\n"
 
56
  )
57
  st.markdown("""---""")
58
  st.markdown(
@@ -97,28 +140,38 @@ def upload_counts_text():
97
  st.subheader("Upload your counts")
98
  st.write(
99
  "Plot an existing confusion matrix (counts of target-prediction combinations). "
100
- "The application expects a `.csv` file with: \n"
101
- "1) A `target classes` column. \n\n"
102
- "2) A `predicted classes` column. \n\n"
103
- "3) A `combination count` column for the "
104
- "combination frequency of 1 and 2. \n\n"
105
- "Other columns are currently ignored. "
106
- "See example of such a .csv file [here] (TODO). "
107
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  def upload_predictions_text():
111
  st.subheader("Upload your predictions")
112
- st.markdown(
113
- "The application expects a `.csv` file with: \n"
114
- "1) A `target` column. \n"
115
- "Targets will be converted into strings. \n\n"
116
- "2) A `prediction` column. \n"
117
- "Predictions can be probabilities (binary classification only) or class predictions. \n\n"
118
- "Other columns are currently ignored. \n\n"
119
- "You will have the option to select the names of these two columns, so don't "
120
- "worry too much about the column names in the uploaded data."
121
- )
 
 
 
122
 
123
 
124
  def columns_text():
@@ -131,6 +184,7 @@ def columns_text():
131
  def design_text():
132
  st.subheader("Design your plot")
133
  st.write("This is where you customize the design of your confusion matrix plot.")
 
134
  st.markdown(
135
  "The *width* and *height* settings are usually necessary to adjust as they "
136
  "change the relative size of the elements. Try adjusting 100px at a "
 
1
  import streamlit as st
2
+ import pandas as pd
3
  from utils import call_subprocess
4
 
5
 
6
+ def insert_arrow():
7
+ return '<svg xmlns="http://www.w3.org/2000/svg" style="width:25px;" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-6 h-6"><path stroke-linecap="round" stroke-linejoin="round" d="M17.25 8.25L21 12m0 0l-3.75 3.75M21 12H3" /></svg>'
8
+
9
+
10
+ def insert_chart_icon(choice=0):
11
+ if choice == 0:
12
+ return '<svg xmlns="http://www.w3.org/2000/svg" style="width:25px;" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5"><path fill-rule="evenodd" d="M3 3.5A1.5 1.5 0 014.5 2h6.879a1.5 1.5 0 011.06.44l4.122 4.12A1.5 1.5 0 0117 7.622V16.5a1.5 1.5 0 01-1.5 1.5h-11A1.5 1.5 0 013 16.5v-13zM13.25 9a.75.75 0 01.75.75v4.5a.75.75 0 01-1.5 0v-4.5a.75.75 0 01.75-.75zm-6.5 4a.75.75 0 01.75.75v.5a.75.75 0 01-1.5 0v-.5a.75.75 0 01.75-.75zm4-1.25a.75.75 0 00-1.5 0v2.5a.75.75 0 001.5 0v-2.5z" clip-rule="evenodd" /></svg>'
13
+ else:
14
+ return '<svg xmlns="http://www.w3.org/2000/svg" style="width:25px;" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5"><path fill-rule="evenodd" d="M4.5 2A1.5 1.5 0 003 3.5v13A1.5 1.5 0 004.5 18h11a1.5 1.5 0 001.5-1.5V7.621a1.5 1.5 0 00-.44-1.06l-4.12-4.122A1.5 1.5 0 0011.378 2H4.5zm2.25 8.5a.75.75 0 000 1.5h6.5a.75.75 0 000-1.5h-6.5zm0 3a.75.75 0 000 1.5h6.5a.75.75 0 000-1.5h-6.5z" clip-rule="evenodd" /></svg>'
15
+
16
+
17
+ def insert_edit_icon():
18
+ return '<svg xmlns="http://www.w3.org/2000/svg" style="width:25px;" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5"><path d="M5.433 13.917l1.262-3.155A4 4 0 017.58 9.42l6.92-6.918a2.121 2.121 0 013 3l-6.92 6.918c-.383.383-.84.685-1.343.886l-3.154 1.262a.5.5 0 01-.65-.65z" /><path d="M3.5 5.75c0-.69.56-1.25 1.25-1.25H10A.75.75 0 0010 3H4.75A2.75 2.75 0 002 5.75v9.5A2.75 2.75 0 004.75 18h9.5A2.75 2.75 0 0017 15.25V10a.75.75 0 00-1.5 0v5.25c0 .69-.56 1.25-1.25 1.25h-9.5c-.69 0-1.25-.56-1.25-1.25v-9.5z" /></svg>'
19
+
20
+
21
+ def insert_generate_icon():
22
+ return '<svg xmlns="http://www.w3.org/2000/svg" style="width:25px;" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5"><path fill-rule="evenodd" d="M10 1a.75.75 0 01.75.75v1.5a.75.75 0 01-1.5 0v-1.5A.75.75 0 0110 1zM5.05 3.05a.75.75 0 011.06 0l1.062 1.06A.75.75 0 116.11 5.173L5.05 4.11a.75.75 0 010-1.06zm9.9 0a.75.75 0 010 1.06l-1.06 1.062a.75.75 0 01-1.062-1.061l1.061-1.06a.75.75 0 011.06 0zM3 8a.75.75 0 01.75-.75h1.5a.75.75 0 010 1.5h-1.5A.75.75 0 013 8zm11 0a.75.75 0 01.75-.75h1.5a.75.75 0 010 1.5h-1.5A.75.75 0 0114 8zm-6.828 2.828a.75.75 0 010 1.061L6.11 12.95a.75.75 0 01-1.06-1.06l1.06-1.06a.75.75 0 011.06 0zm3.594-3.317a.75.75 0 00-1.37.364l-.492 6.861a.75.75 0 001.204.65l1.043-.799.985 3.678a.75.75 0 001.45-.388l-.978-3.646 1.292.204a.75.75 0 00.74-1.16l-3.874-5.764z" clip-rule="evenodd" /></svg>'
23
+
24
+
25
  @st.cache_resource
26
  def get_cvms_version():
27
  return (
 
39
  )
40
 
41
 
42
+ @st.cache_data
43
+ def get_example_counts():
44
+ return pd.DataFrame(
45
+ {
46
+ "Target": ["cl1", "cl2", "cl1", "cl2"],
47
+ "Prediction": ["cl1", "cl2", "cl2", "cl1"],
48
+ "N": [12, 10, 3, 5],
49
+ }
50
+ )
51
+
52
+
53
+ @st.cache_data
54
+ def get_example_data():
55
+ return pd.DataFrame(
56
+ {
57
+ "Target": ["cl1", "cl1", "cl2", "cl2", "cl1", "cl1"],
58
+ "Prediction": ["cl1", "cl2", "cl2", "cl1", "cl1", "cl2"],
59
+ }
60
+ )
61
+
62
+
63
  def intro_text():
64
  col1, col2 = st.columns([8, 2])
65
  with col1:
 
82
  st.subheader("Have your data ready?")
83
  st.markdown( # TODO: Make A,B, etc. icons
84
  "Upload a csv file with either: \n\n"
85
+ f"{insert_chart_icon(1)} **Targets** and **predictions** \n\n"
86
+ f"{insert_chart_icon(0)} Existing confusion matrix **counts** \n\n"
87
+ f"{insert_arrow()} Specify the columns to use\n\n"
88
+ f"{insert_arrow()} Press **Generate plot**\n\n",
89
+ unsafe_allow_html=True,
90
  )
91
  with col2:
92
  st.subheader("No data to upload?")
93
  st.markdown(
94
  "No worries! Either: \n\n"
95
+ f"{insert_edit_icon()} **Input** your counts directly! \n\n"
96
+ f"{insert_generate_icon()} **Generate** some data with **very** easy controls! \n\n"
97
+ f"{insert_arrow()} Press **Generate plot**\n\n",
98
+ unsafe_allow_html=True,
99
  )
100
  st.markdown("""---""")
101
  st.markdown(
 
140
  st.subheader("Upload your counts")
141
  st.write(
142
  "Plot an existing confusion matrix (counts of target-prediction combinations). "
 
 
 
 
 
 
 
143
  )
144
+ col1, col2 = st.columns([5, 4])
145
+ with col1:
146
+ st.markdown(
147
+ "The application expects a `.csv` file with: \n"
148
+ "1) A `target classes` column. \n\n"
149
+ "2) A `predicted classes` column. \n\n"
150
+ "3) A `combination count` column for the "
151
+ "combination frequency of 1 and 2. \n\n"
152
+ "Other columns are currently ignored. "
153
+ "In the next step, you will be asked to select the names of these two columns. "
154
+ )
155
+ with col2:
156
+ st.write("Example of such a file:")
157
+ st.write(get_example_counts())
158
 
159
 
160
  def upload_predictions_text():
161
  st.subheader("Upload your predictions")
162
+ col1, col2 = st.columns([5, 4])
163
+ with col1:
164
+ st.markdown(
165
+ "The application expects a `.csv` file with: \n"
166
+ "1) A `target` column. \n"
167
+ "2) A `prediction` column. \n"
168
+ "Predictions should be class predictions (not probabilities). \n\n"
169
+ "Other columns are currently ignored. \n\n"
170
+ "In the next step, you will be asked to select the names of these two columns. "
171
+ )
172
+ with col2:
173
+ st.write("Example of such a file:")
174
+ st.write(get_example_data())
175
 
176
 
177
  def columns_text():
 
184
  def design_text():
185
  st.subheader("Design your plot")
186
  st.write("This is where you customize the design of your confusion matrix plot.")
187
+ st.markdown("We suggest you go directly to `Generate plot` to see the starting point. Then go back and tweak to your liking!")
188
  st.markdown(
189
  "The *width* and *height* settings are usually necessary to adjust as they "
190
  "change the relative size of the elements. Try adjusting 100px at a "
utils.py CHANGED
@@ -21,5 +21,16 @@ def call_subprocess(call_, message, return_output=False, encoding="UTF-8"):
21
 
22
 
23
  def clean_string_for_non_alphanumerics(s):
24
- pattern = re.compile("[\W'_']+")
25
- return pattern.sub("", s)
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def clean_string_for_non_alphanumerics(s):
24
+ # Remove non-alphanumerics (keep spaces)
25
+ pattern1 = re.compile("[^0-9a-zA-Z\s]+")
26
+ # Replace multiple spaces with a single space
27
+ pattern2 = re.compile("\s+")
28
+ # Apply replacements
29
+ s = pattern1.sub("", s)
30
+ s = pattern2.sub(" ", s)
31
+ # Trim whitespace in start and end
32
+ return s.strip()
33
+
34
+
35
+ def clean_str_column(x):
36
+ return x.astype(str).apply(lambda x: clean_string_for_non_alphanumerics(x))