Spaces:
Runtime error
Runtime error
| import dash | |
| import dash_bootstrap_components as dbc | |
| from dash import dcc | |
| from dash import html | |
| from dash.dependencies import Input, Output, State | |
| from typing import List, Tuple | |
| from scipy.spatial.distance import cdist | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| df = pd.read_pickle('all_embeddings_with_splits.p') | |
| app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP]) | |
| app.layout = dbc.Container( | |
| [ | |
| html.H1("Embedding Plots"), | |
| html.Hr(), | |
| html.Div( | |
| [ | |
| dbc.Row( | |
| [ | |
| dbc.Col( | |
| [ | |
| html.Label('Algorithm:'), | |
| dcc.Dropdown( | |
| id="algorithm-dropdown", | |
| options=[ | |
| {"label": "PCA", "value": "pca"}, | |
| {"label": "UMAP", "value": "umap"}, | |
| {"label": "tSNE", "value": "tsne"}, | |
| {"label": "PaCMAP", "value": "pacmap"}, | |
| ], | |
| value="pacmap", | |
| clearable=False, | |
| searchable=False, | |
| style={"margin-bottom": "10px"} | |
| ), | |
| html.Label('Number of dimensions:'), | |
| dcc.Dropdown( | |
| id="num-components-dropdown", | |
| options=[ | |
| {"label": "2", "value": 2}, | |
| {"label": "3", "value": 3} | |
| ], | |
| value=3, | |
| clearable=False, | |
| searchable=False, | |
| style={"margin-bottom": "10px"} | |
| ), | |
| html.Label('Color by:'), | |
| dcc.Dropdown( | |
| id="color-by", | |
| options=[ | |
| { | |
| "label": "Protein Classification", | |
| "value": "classification" | |
| }, | |
| { | |
| "label": "Split (train/test/val/gpcr)", | |
| "value": "split" | |
| } | |
| ], | |
| value="classification", | |
| clearable=False, | |
| searchable=False, | |
| style={"margin-bottom": "10px"} | |
| ), | |
| html.Span( | |
| [ | |
| "Keep the top ", | |
| dcc.Input( | |
| id="top-n-classes", | |
| type="number", | |
| value=10, | |
| min=1, | |
| max=len(df["classification"].unique()), | |
| step=1, | |
| style={"width": "50px"} | |
| ), | |
| " classes." | |
| ], | |
| style={"margin-bottom": "20px"} | |
| ), | |
| html.Br(), | |
| dbc.Button( | |
| "Update", | |
| id="update-button", | |
| color="primary", | |
| n_clicks=0, | |
| style={"width": "100%", "margin": "10px 0px"} | |
| ), | |
| dbc.Container( | |
| id="closest-points", | |
| style={"max-height": "65vh", "overflow-y": "auto"} | |
| ), | |
| ], | |
| width={"size": 2, "order": 1}, | |
| ), | |
| dbc.Col( | |
| dcc.Graph( | |
| id="embedding-graph", | |
| style={"height": "100%", "width": "100%"}, | |
| ), | |
| width={"size": 10, "order": 2}, | |
| ), | |
| ], | |
| style={"height":"95vh"} | |
| ) | |
| ], | |
| style={"height":"100hv"} | |
| ), | |
| html.Hr(), | |
| ], | |
| fluid=True, | |
| ) | |
| def load_embedding(algorithm: str, num_components: int) -> np.array: | |
| """Loads the embeddings given an algorithm and number of dimensions. | |
| Parameters | |
| ---------- | |
| algorithm : str | |
| Algorithm used | |
| num_components : int | |
| see param name | |
| Returns | |
| ------- | |
| np.array | |
| A Ax1280 numpy matrix with the embeddings. | |
| """ | |
| if algorithm == "pca": | |
| embedding = np.load("pca.npy") | |
| else: | |
| embedding = np.load(f"{algorithm}{str(num_components)}d.npy") | |
| return embedding | |
| def get_top_n_classifications(df: pd.DataFrame, n: int) -> List[str]: | |
| return df["classification"].value_counts().nlargest(n).index.tolist() | |
| def update_embedding_graph(n_clicks: int, | |
| algorithm: str, | |
| num_components: int, | |
| top_n_classes: int, | |
| color_by: str) -> go.Figure: | |
| if n_clicks > 0: | |
| embedding = load_embedding(algorithm, num_components) | |
| if color_by == "split": | |
| color_map = { | |
| "gpcr": "red", | |
| "train": "blue", | |
| "val": "green", | |
| "test": "orange", | |
| "unknown": "grey", | |
| } | |
| color_series = df["splits"].copy() | |
| df["color_series"] = color_series | |
| else: | |
| top_classes = get_top_n_classifications(df, n=top_n_classes) | |
| is_top_n = df["classification"].isin(top_classes) | |
| color_series = df["classification"].copy() | |
| color_series[~is_top_n] = "other" | |
| df["color_series"] = color_series | |
| top_n_colors = px.colors.qualitative.Plotly[:top_n_classes] | |
| color_map_top = {c: top_n_colors[i] for i, c in enumerate(top_classes)} | |
| color_map = {c: color_map_top[c] if c in top_classes else 'grey' for i, c in enumerate(set(df['color_series']))} | |
| if num_components == 3: | |
| fig = go.Figure() | |
| for c in df["color_series"].unique(): | |
| class_indices = np.where(df["color_series"] == c)[0] | |
| data = embedding[class_indices] | |
| fig.add_trace( | |
| go.Scatter3d( | |
| x=data[:,0], | |
| y=data[:,1], | |
| z=data[:,2], | |
| mode='markers', | |
| name=c, | |
| marker=dict( | |
| size=2.5, | |
| color=color_map[c], | |
| opacity=1 if color_map[c] != 'grey' else 0.3, | |
| ), | |
| hovertemplate= | |
| "<b>PDB ID</b>: %{customdata[0]}<br>" + | |
| "<b>Classification</b>: %{customdata[1]}<br>" + | |
| "<extra></extra>", | |
| customdata=df.iloc[class_indices][['pdb_id', 'classification']] | |
| ) | |
| ) | |
| fig.update_layout( | |
| scene=dict( | |
| xaxis=dict(showgrid=False, showticklabels=False, title=""), | |
| yaxis=dict(showgrid=False, showticklabels=False, title=""), | |
| zaxis=dict(showgrid=False, showticklabels=False, title=""), | |
| ), | |
| ) | |
| fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False ) | |
| elif num_components == 2: | |
| fig = go.Figure() | |
| for c in df["color_series"].unique(): | |
| class_indices = np.where(df["color_series"] == c)[0] | |
| data = embedding[class_indices] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=data[:,0], | |
| y=data[:,1], | |
| mode='markers', | |
| name=c, | |
| marker=dict( | |
| size=2.5, | |
| color=color_map[c], | |
| opacity=1 if color_map[c] != 'grey' else 0.3, | |
| ), | |
| hovertemplate= | |
| "<b>PDB ID</b>: %{customdata[0]}<br>" + | |
| "<b>Classification</b>: %{customdata[1]}<br>" | |
| "<extra></extra>", | |
| customdata=df.iloc[class_indices][['pdb_id', 'classification']] | |
| ) | |
| ) | |
| fig.update_traces(marker=dict(size=7.5), selector=dict(mode='markers')) | |
| fig.update_scenes(xaxis_visible=False, yaxis_visible=False) | |
| fig.update_layout( | |
| legend=dict( | |
| x=0, | |
| y=1, | |
| itemsizing='constant', | |
| itemclick='toggle', | |
| itemdoubleclick='toggleothers', | |
| traceorder='reversed', | |
| itemwidth=30, | |
| ), | |
| margin=dict(l=0, r=0, b=0, t=0), | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| ) | |
| return fig | |
| else: | |
| raise dash.exceptions.PreventUpdate | |
| #### GET CLOSEST POINTS | |
| def extract_info_from_clickData(clickData: dict) -> Tuple[str, str]: | |
| """Extracts information from a clickData dictionary coming from clicking | |
| a point in a scatter plot. | |
| Speficially, it retrieves the pdb_id and the classification. | |
| Shape of clickData: | |
| { | |
| "points": [ | |
| { | |
| "x": 11.330583, | |
| "y": 15.741333, | |
| "z": -5.3435574, | |
| "curveNumber": 2, | |
| "pointNumber": 982, | |
| "bbox": { | |
| "x0": 704.3911532022826, | |
| "x1": 704.3911532022826, | |
| "y0": 393.5066681413661, | |
| "y1": 393.5066681413661 | |
| }, | |
| "customdata": [ | |
| "1zfp", | |
| "complex (signal transduction/peptide)" | |
| ] | |
| } | |
| ] | |
| } | |
| Parameters | |
| ---------- | |
| clickData : dict | |
| Contains the information of a point on a go.Figure graph. | |
| Returns | |
| ------- | |
| Tuple[] | |
| _description_ | |
| """ | |
| pdb_id = clickData["points"][0]["customdata"][0] | |
| classification = clickData["points"][0]["customdata"][1] | |
| return pdb_id, classification | |
| def find_closest_n_points(df: pd.DataFrame, | |
| embedding: np.array, | |
| index: int = None, | |
| pdb_id: str = None, | |
| n: int = 20) -> Tuple[list, list]: | |
| """ | |
| Given an embedding array and a point index or pdb_id, finds the n closest | |
| points to the given point. | |
| Parameters: | |
| ----------- | |
| embedding: np.ndarray | |
| A 2D numpy array with the embedding coordinates. | |
| point_index: int | |
| The index of the point to which we want to find the closest points. | |
| n: int | |
| The number of closest points to retrieve. | |
| Returns: | |
| -------- | |
| closest_indices: list | |
| A list with the indices of the n closest points to the given point. | |
| """ | |
| if pdb_id: | |
| index = df.index[df["pdb_id"] == pdb_id].item() | |
| distances = cdist(embedding[index, np.newaxis], embedding) | |
| closest_indices = np.argsort(distances)[0][:n] | |
| closest_ids = df.iloc[closest_indices]["pdb_id"].tolist() | |
| closest_ids_classifications = df.iloc[closest_indices]["classification"].tolist() | |
| return closest_ids, closest_ids_classifications | |
| def update_closest_points_div( | |
| clickData: dict, | |
| algorithm: str, | |
| num_components: int) -> html.Table: | |
| embedding = load_embedding(algorithm, num_components) | |
| if clickData is not None: | |
| pdb_id, _ = extract_info_from_clickData(clickData) | |
| index = df.index[df["pdb_id"] == pdb_id].item() | |
| closest_ids, closest_ids_classifications = find_closest_n_points( | |
| df, embedding, index) | |
| cards = [] | |
| for i in range(len(closest_ids)): | |
| card = dbc.Card( | |
| dbc.CardBody( | |
| [ | |
| html.P(closest_ids[i], className="card-title"), | |
| html.P(closest_ids_classifications[i], className="card-text"), | |
| ] | |
| ), | |
| className="mb-3", | |
| ) | |
| cards.append(card) | |
| return cards | |
| return html.Div(id="closest-points", children=[html.Div("Click on a data point to see the closest points.")]) | |
| if __name__ == "__main__": | |
| app.run_server(debug=False, host='0.0.0.0', port=7680) |