Ludvig commited on
Commit
148c712
·
1 Parent(s): 448553b

Minor refactor and adds lots of design choices (not "connected to api")

Browse files
Files changed (2) hide show
  1. app.py +18 -117
  2. design.py +223 -0
app.py CHANGED
@@ -23,23 +23,26 @@ from collections import OrderedDict
23
 
24
  from utils import call_subprocess, clean_string_for_non_alphanumerics
25
  from data import read_data, read_data_cached, DownloadHeader, generate_data
 
26
  from text_sections import (
27
  intro_text,
28
  columns_text,
29
  upload_predictions_text,
30
  upload_counts_text,
31
  generate_data_text,
32
- design_text,
33
  enter_count_data_text,
34
  )
35
 
36
- st.markdown("""
 
37
  <style>
38
  .small-font {
39
  font-size:0.85em !important;
40
  }
41
  </style>
42
- """, unsafe_allow_html=True)
 
 
43
 
44
 
45
  # Create temporary directory
@@ -308,113 +311,11 @@ if st.session_state["step"] >= 2:
308
  f"Got {num_classes} target classes."
309
  )
310
 
311
- with st.form(key="settings_form"):
312
- design_text()
313
- col1, col2 = st.columns(2)
314
- with col1:
315
- selected_classes = st.multiselect(
316
- "Select classes (min=2, order is respected)",
317
- options=st.session_state["classes"],
318
- default=st.session_state["classes"],
319
- help="Select the classes to create the confusion matrix for. "
320
- "Any observation with either a target or prediction "
321
- "of another class is excluded.",
322
- )
323
- with col2:
324
- if (
325
- st.session_state["input_type"] == "data"
326
- and predictions_are_probabilities
327
- ):
328
- prob_of_class = st.selectbox(
329
- "Probabilities are of (not working)",
330
- options=st.session_state["classes"],
331
- index=1,
332
- )
333
- else:
334
- prob_of_class = None
335
-
336
- with st.expander("Advanced"):
337
-
338
- default_elements = [
339
- "Counts",
340
- "Normalized Counts (%)",
341
- "Zero Shading",
342
- "Arrows",
343
- ]
344
- if num_classes < 6:
345
- # Percentages clutter too much with many classes
346
- default_elements += [
347
- "Row Percentages",
348
- "Column Percentages",
349
- ]
350
- elements_to_add = st.multiselect(
351
- "Add the following elements",
352
- options=[
353
- "Sum Tiles",
354
- "Counts",
355
- "Normalized Counts (%)",
356
- "Row Percentages",
357
- "Column Percentages",
358
- "Zero Shading",
359
- "Zero Percentages",
360
- "Zero Text",
361
- "Arrows",
362
- ],
363
- default=default_elements,
364
- )
365
-
366
- col1, col2, col3 = st.columns(3)
367
- with col1:
368
- counts_on_top = st.checkbox(
369
- "Counts on top (not working)",
370
- help="Whether to switch the positions of the counts and normalized counts (%). "
371
- "That is, the counts become the big centralized numbers and the "
372
- "normalized counts go below with a smaller font size.",
373
- )
374
- with col2:
375
- diag_percentages_only = st.checkbox("Diagonal row/column percentages only")
376
- with col3:
377
- num_digits = st.number_input(
378
- "Digits", value=2, help="Number of digits to round percentages to."
379
- )
380
-
381
- element_flags = [
382
- key
383
- for key, val in {
384
- "--add_sums": "Sum Tiles" in elements_to_add,
385
- "--add_counts": "Counts" in elements_to_add,
386
- "--add_normalized": "Normalized Counts (%)" in elements_to_add,
387
- "--add_row_percentages": "Row Percentages" in elements_to_add,
388
- "--add_col_percentages": "Column Percentages" in elements_to_add,
389
- "--add_zero_percentages": "Zero Percentages" in elements_to_add,
390
- "--add_zero_text": "Zero Text" in elements_to_add,
391
- "--add_zero_shading": "Zero Shading" in elements_to_add,
392
- "--add_arrows": "Arrows" in elements_to_add,
393
- "--counts_on_top": counts_on_top,
394
- "--diag_percentages_only": diag_percentages_only,
395
- }.items()
396
- if val
397
- ]
398
-
399
- palette = st.selectbox(
400
- "Color Palette",
401
- options=["Blues", "Greens", "Oranges", "Greys", "Purples", "Reds"],
402
- )
403
-
404
- # Ask for output parameters
405
- # TODO: Set default based on number of classes and sum tiles
406
- col1, col2, col3 = st.columns(3)
407
- with col1:
408
- width = st.number_input("Width (px)", value=1200 + 100 * (num_classes - 2))
409
- with col2:
410
- height = st.number_input(
411
- "Height (px)", value=1200 + 100 * (num_classes - 2)
412
- )
413
- with col3:
414
- dpi = st.number_input("DPI (not working)", value=320)
415
-
416
- if st.form_submit_button(label="Generate plot"):
417
- st.session_state["step"] = 3
418
 
419
  if st.session_state["step"] >= 3:
420
  plotting_args = [
@@ -427,24 +328,24 @@ if st.session_state["step"] >= 2:
427
  "--prediction_col",
428
  f"'{prediction_col}'",
429
  "--width",
430
- f"{width}",
431
  "--height",
432
- f"{height}",
433
  "--dpi",
434
- f"{dpi}",
435
  "--classes",
436
- f"{','.join(selected_classes)}",
437
  "--digits",
438
- f"{num_digits}",
439
  "--palette",
440
- f"{palette}",
441
  ]
442
 
443
  if st.session_state["input_type"] == "counts":
444
  # The input data are counts
445
  plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
446
 
447
- plotting_args += element_flags
448
 
449
  plotting_args = " ".join(plotting_args)
450
 
 
23
 
24
  from utils import call_subprocess, clean_string_for_non_alphanumerics
25
  from data import read_data, read_data_cached, DownloadHeader, generate_data
26
+ from design import design_section
27
  from text_sections import (
28
  intro_text,
29
  columns_text,
30
  upload_predictions_text,
31
  upload_counts_text,
32
  generate_data_text,
 
33
  enter_count_data_text,
34
  )
35
 
36
+ st.markdown(
37
+ """
38
  <style>
39
  .small-font {
40
  font-size:0.85em !important;
41
  }
42
  </style>
43
+ """,
44
+ unsafe_allow_html=True,
45
+ )
46
 
47
 
48
  # Create temporary directory
 
311
  f"Got {num_classes} target classes."
312
  )
313
 
314
+ # Section for specifying design settings
315
+ design_settings = design_section(
316
+ num_classes=num_classes,
317
+ predictions_are_probabilities=predictions_are_probabilities,
318
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  if st.session_state["step"] >= 3:
321
  plotting_args = [
 
328
  "--prediction_col",
329
  f"'{prediction_col}'",
330
  "--width",
331
+ f"{design_settings['width']}",
332
  "--height",
333
+ f"{design_settings['height']}",
334
  "--dpi",
335
+ f"{design_settings['dpi']}",
336
  "--classes",
337
+ f"{','.join(design_settings['selected_classes'])}",
338
  "--digits",
339
+ f"{design_settings['num_digits']}",
340
  "--palette",
341
+ f"{design_settings['palette']}",
342
  ]
343
 
344
  if st.session_state["input_type"] == "counts":
345
  # The input data are counts
346
  plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
347
 
348
+ plotting_args += design_settings["element_flags"]
349
 
350
  plotting_args = " ".join(plotting_args)
351
 
design.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from text_sections import (
4
+ design_text,
5
+ )
6
+
7
+ # arrow_size = 0.048,
8
+ # arrow_nudge_from_text = 0.065,
9
+ # sums_settings = sum_tile_settings(),
10
+ # intensity_by
11
+
12
+ # darkness = 0.8
13
+
14
+
15
+ def design_section(
16
+ num_classes,
17
+ predictions_are_probabilities,
18
+ ):
19
+ output = {}
20
+
21
+ with st.form(key="settings_form"):
22
+ design_text()
23
+ col1, col2 = st.columns(2)
24
+ with col1:
25
+ output["selected_classes"] = st.multiselect(
26
+ "Select classes (min=2, order is respected)",
27
+ options=st.session_state["classes"],
28
+ default=st.session_state["classes"],
29
+ help="Select the classes to create the confusion matrix for. "
30
+ "Any observation with either a target or prediction "
31
+ "of another class is excluded.",
32
+ )
33
+ with col2:
34
+ if (
35
+ st.session_state["input_type"] == "data"
36
+ and predictions_are_probabilities
37
+ ):
38
+ output["prob_of_class"] = st.selectbox(
39
+ "Probabilities are of (not working)",
40
+ options=st.session_state["classes"],
41
+ index=1,
42
+ )
43
+ else:
44
+ output["prob_of_class"] = None
45
+
46
+ with st.expander("Elements"):
47
+ default_elements = [
48
+ "Counts",
49
+ "Normalized Counts (%)",
50
+ "Zero Shading",
51
+ "Arrows",
52
+ ]
53
+ if num_classes < 6:
54
+ # Percentages clutter too much with many classes
55
+ default_elements += [
56
+ "Row Percentages",
57
+ "Column Percentages",
58
+ ]
59
+ elements_to_add = st.multiselect(
60
+ "Add the following elements",
61
+ options=[
62
+ "Sum Tiles",
63
+ "Counts",
64
+ "Normalized Counts (%)",
65
+ "Row Percentages",
66
+ "Column Percentages",
67
+ "Zero Shading",
68
+ "Zero Percentages",
69
+ "Zero Text",
70
+ "Arrows",
71
+ ],
72
+ default=default_elements,
73
+ )
74
+
75
+ st.markdown("""---""")
76
+
77
+ col1, col2, col3 = st.columns(3)
78
+ with col1:
79
+ counts_on_top = st.checkbox(
80
+ "Counts on top (not working)",
81
+ help="Whether to switch the positions of the counts and normalized counts (%). "
82
+ "That is, the counts become the big centralized numbers and the "
83
+ "normalized counts go below with a smaller font size.",
84
+ )
85
+ with col2:
86
+ diag_percentages_only = st.checkbox(
87
+ "Diagonal row/column percentages only"
88
+ )
89
+
90
+ with st.expander("Text"):
91
+ col1, col2, col3 = st.columns(3)
92
+ with col1:
93
+ output["num_digits"] = st.number_input(
94
+ "Digits", value=2, help="Number of digits to round percentages to."
95
+ )
96
+ with col2:
97
+ rotate_y_text = st.checkbox("Rotate y text", value=True)
98
+ with col3:
99
+ place_x_axis_above = st.checkbox("Place X axis on top", value=True)
100
+
101
+ with st.expander("Fonts"):
102
+ font_dicts = {}
103
+ for font_type in ["Counts", "Normalized (%)", "Row Percentage", "Column Percentage"]:
104
+ st.subheader(font_type)
105
+ num_cols = 3
106
+ font_dicts[font_type] = font_inputs(key_prefix=font_type)
107
+ for i, (setting_name, setting_widget) in enumerate(
108
+ font_dicts[font_type].items()
109
+ ):
110
+ if i % num_cols == 0:
111
+ cols = st.columns(num_cols)
112
+ with cols[i % num_cols]:
113
+ setting_widget()
114
+
115
+ st.markdown("""---""")
116
+
117
+ with st.expander("Tiles"):
118
+ col1, col2, col3 = st.columns(3)
119
+ with col1:
120
+ pass
121
+ with col2:
122
+ output["intensity_by"] = st.selectbox("Intensity based on", options=["Counts", "Normalized (%)"])
123
+ with col3:
124
+ output["darkness"] = st.slider(
125
+ "Darkness",
126
+ min_value=0.0,
127
+ max_value=1.0,
128
+ value=0.8,
129
+ step=0.01,
130
+ help="How dark the darkest colors should be, between 0 and 1, where 1 is darkest.",
131
+ )
132
+
133
+ st.markdown("""---""")
134
+
135
+ col1, col2, col3, col4 = st.columns(4)
136
+ with col1:
137
+ output["tile_border_add"] = st.checkbox("Add tile borders", value=False)
138
+ with col2:
139
+ output["tile_border_color"] = st.color_picker(
140
+ "Border color", value="#000000"
141
+ )
142
+ with col3:
143
+ output["tile_border_size"] = st.slider(
144
+ "Border size",
145
+ min_value=0.0,
146
+ max_value=3.0,
147
+ value=0.1,
148
+ step=0.01,
149
+ )
150
+ with col4:
151
+ output["tile_border_linetype"] = st.selectbox(
152
+ "Border linetype",
153
+ options=[
154
+ "solid",
155
+ "dashed",
156
+ "dotted",
157
+ "dotdash",
158
+ "longdash",
159
+ "twodash",
160
+ ],
161
+ )
162
+
163
+ output["element_flags"] = [
164
+ key
165
+ for key, val in {
166
+ "--add_sums": "Sum Tiles" in elements_to_add,
167
+ "--add_counts": "Counts" in elements_to_add,
168
+ "--add_normalized": "Normalized Counts (%)" in elements_to_add,
169
+ "--add_row_percentages": "Row Percentages" in elements_to_add,
170
+ "--add_col_percentages": "Column Percentages" in elements_to_add,
171
+ "--add_zero_percentages": "Zero Percentages" in elements_to_add,
172
+ "--add_zero_text": "Zero Text" in elements_to_add,
173
+ "--add_zero_shading": "Zero Shading" in elements_to_add,
174
+ "--add_arrows": "Arrows" in elements_to_add,
175
+ "--counts_on_top": counts_on_top,
176
+ "--diag_percentages_only": diag_percentages_only,
177
+ }.items()
178
+ if val
179
+ ]
180
+
181
+ output["palette"] = st.selectbox(
182
+ "Color Palette",
183
+ options=["Blues", "Greens", "Oranges", "Greys", "Purples", "Reds"],
184
+ )
185
+
186
+ # Ask for output parameters
187
+ # TODO: Set default based on number of classes and sum tiles
188
+ col1, col2, col3 = st.columns(3)
189
+ with col1:
190
+ output["width"] = st.number_input(
191
+ "Width (px)", value=1200 + 100 * (num_classes - 2)
192
+ )
193
+ with col2:
194
+ output["height"] = st.number_input(
195
+ "Height (px)", value=1200 + 100 * (num_classes - 2)
196
+ )
197
+ with col3:
198
+ output["dpi"] = st.number_input("DPI (not working)", value=320)
199
+
200
+ if st.form_submit_button(label="Generate plot"):
201
+ st.session_state["step"] = 3
202
+
203
+ return output
204
+
205
+
206
+ def font_inputs(key_prefix: str):
207
+ return {
208
+ "color": lambda: st.color_picker("Color", key=f"{key_prefix}_color"),
209
+ "bold": lambda: st.checkbox("Bold", key=f"{key_prefix}_bold"),
210
+ "cursive": lambda: st.checkbox("Italics", key=f"{key_prefix}_italics"),
211
+ "size": lambda: st.number_input("Size", key=f"{key_prefix}_size"),
212
+ "nudge_x": lambda: st.number_input(
213
+ "Nudge on x-axis", key=f"{key_prefix}_nudge_x"
214
+ ),
215
+ "nudge_y": lambda: st.number_input(
216
+ "Nudge on y-axis", key=f"{key_prefix}_nudge_y"
217
+ ),
218
+ "alpha": lambda: st.slider(
219
+ "Transparency", min_value=0, max_value=1, value=1, key=f"{key_prefix}_alpha"
220
+ ),
221
+ "prefix": lambda: st.text_input("Prefix", key=f"{key_prefix}_prefix"),
222
+ "suffix": lambda: st.text_input("Suffix", key=f"{key_prefix}_suffix"),
223
+ }