Skip to content

heatmap

plot_heatmap

Draw a colormap plot based on search results.

Requires exactly 2 params and 1 result (after accounting for the ignore_keys).

Parameters:

Name Type Description Default
search Union[Search, str]

The search results (in memory or path to disk file) to be visualized.

required
title Optional[str]

The plot title to use.

None
ignore_keys Union[None, str, Sequence[str]]

Which keys in the params/results should be ignored.

None

Returns:

Type Description
FigureFE

A plotly figure instance.

Source code in fastestimator/fastestimator/search/visualize/heatmap.py
def plot_heatmap(search: Union[Search, str],
                 title: Optional[str] = None,
                 ignore_keys: Union[None, str, Sequence[str]] = None) -> FigureFE:
    """Draw a colormap plot based on search results.

    Requires exactly 2 params and 1 result (after accounting for the ignore_keys).

    Args:
        search: The search results (in memory or path to disk file) to be visualized.
        title: The plot title to use.
        ignore_keys: Which keys in the params/results should be ignored.

    Returns:
        A plotly figure instance.
    """
    if isinstance(search, str):
        search = _load_search_file(search)
    if title is None:
        title = search.name
    reverse_colors = search.best_mode == 'min'
    search = SearchData(search=search, ignore_keys=ignore_keys)
    _heatmap_supports_data(search)

    # Convert all params to be categorical
    x = [search.to_category(key=search.params[0], val=e) for e in search.data[search.params[0]]]
    x_labels = humansorted(set(x))

    y = [search.to_category(key=search.params[1], val=e) for e in search.data[search.params[1]]]
    y_labels = humansorted(set(y))

    # Map the metrics into an n x n grid, then remove any extra columns. Final grid will be n x m with n <= m
    n_plots = len(search.results)
    n_cols = math.ceil(math.sqrt(n_plots))
    n_rows = math.ceil(n_plots / n_cols)

    vertical_gap = 0.15 / n_rows
    horizontal_gap = 0.2 / n_cols

    # Get basic plot layout
    fig = make_subplots(rows=n_rows,
                        cols=n_cols,
                        subplot_titles=search.results,
                        shared_xaxes='all',
                        shared_yaxes='all',
                        vertical_spacing=vertical_gap,
                        horizontal_spacing=horizontal_gap,
                        x_title=search.params[0],
                        y_title=search.params[1])
    fig.update_layout({'title': title,
                       'title_x': 0.5,
                       })

    # Fill in the penultimate row x-labels when the last row has empty columns
    for idx in range((n_plots % n_cols) or n_cols, n_cols):
        plotly_idx = max((n_rows - 2) * n_cols, 0) + idx + 1
        x_axis_name = f'xaxis{plotly_idx}'
        fig['layout'][x_axis_name]['showticklabels'] = True

    # Ensure the categories are in the right order
    fig['layout']['xaxis']['categoryarray'] = x_labels
    fig['layout']['yaxis']['categoryarray'] = y_labels

    plot_height = (1 - (n_rows - 1) * vertical_gap) / n_rows
    plot_width = (1 - (n_cols - 1) * horizontal_gap) / n_cols

    # Plot the groups
    for idx, plot in enumerate(search.results):
        row = idx // n_cols
        col = idx % n_cols
        fig.add_trace(Heatmap(x=x,
                              y=y,
                              z=search.data[plot],
                              colorscale="Viridis",
                              reversescale=reverse_colors,
                              colorbar={'len': plot_height,
                                        'lenmode': 'fraction',
                                        'yanchor': 'top',
                                        'y': 1 - row * (plot_height + vertical_gap),
                                        'xanchor': 'left',
                                        'x': col * (plot_width + horizontal_gap) + plot_width},
                              name="",
                              hovertemplate=search.params[0] + ": %{x}<br>" + search.params[1] + ": %{y}<br>" +
                                            plot + ": %{z}",
                              hoverongaps=False),
                      row=row + 1,
                      col=col + 1)

        # Make sure that the image aspect ratio doesn't get messed up
        x_axis_name = fig.get_subplot(row=row + 1, col=col + 1).xaxis.plotly_name
        y_axis_name = fig.get_subplot(row=row + 1, col=col + 1).yaxis.plotly_name
        fig['layout'][x_axis_name]['scaleanchor'] = 'x'
        fig['layout'][x_axis_name]['scaleratio'] = 1
        fig['layout'][x_axis_name]['constrain'] = 'domain'
        fig['layout'][y_axis_name]['scaleanchor'] = 'x'
        fig['layout'][y_axis_name]['constrain'] = 'domain'

    # If inside a jupyter notebook then force the height based on number of rows
    if in_notebook():
        fig.update_layout(height=500 * max(1.0, len(y_labels)/5.0) * n_rows)
        fig.update_layout(width=500 * max(1.0, len(x_labels)/5.0) * n_cols)

    return FigureFE.from_figure(fig)

visualize_heatmap

Display or save a parallel coordinate plot based on search results.

Parameters:

Name Type Description Default
search Union[Search, str]

The search results (in memory or path to disk file) to be visualized.

required
title Optional[str]

The plot title to use.

None
ignore_keys Union[None, str, Sequence[str]]

Which keys in the params/results should be ignored.

None
save_path Optional[str]

The path where the figure should be saved, or None to display the figure to the screen.

None
verbose bool

Whether to print out the save location.

True
Source code in fastestimator/fastestimator/search/visualize/heatmap.py
def visualize_heatmap(search: Union[Search, str],
                      title: Optional[str] = None,
                      ignore_keys: Union[None, str, Sequence[str]] = None,
                      save_path: Optional[str] = None,
                      verbose: bool = True) -> None:
    """Display or save a parallel coordinate plot based on search results.

    Args:
        search: The search results (in memory or path to disk file) to be visualized.
        title: The plot title to use.
        ignore_keys: Which keys in the params/results should be ignored.
        save_path: The path where the figure should be saved, or None to display the figure to the screen.
        verbose: Whether to print out the save location.
    """
    fig = plot_heatmap(search=search, title=title, ignore_keys=ignore_keys)
    fig.show(save_path=save_path, verbose=verbose, scale=3)