Source code for swiftemulator.comparison.visualisation

"""
Visualisation functions for comparison datasets.

Allows you to project a plausibility region for each
parameter cross-correlation.
"""

from swiftemulator.backend.model_parameters import ModelParameters
from swiftemulator.backend.model_specification import ModelSpecification
from swiftemulator.backend.model_values import ModelValues

from typing import Dict, Hashable, Tuple, Iterable, Optional
from matplotlib.colors import Normalize
from scipy.stats import binned_statistic_2d

try:
    from swiftsimio.visualisation.projection import scatter

    swiftsimio_available = True
except (ImportError, ModuleNotFoundError):
    swiftsimio_available = False

import matplotlib.pyplot as plt
import numpy as np


[docs]def visualise_penalties_mean( model_specification: ModelSpecification, model_parameters: ModelParameters, penalties: Dict[Hashable, float], norm: Normalize = Normalize(vmin=0.2, vmax=0.7, clip=True), remove_ticks: bool = True, figsize: Optional[Tuple[float]] = None, use_parameters: Optional[Iterable[str]] = None, use_colorbar: Optional[bool] = False, highlight_model: Optional[Hashable] = None, ) -> Tuple[plt.Figure, Iterable[plt.Axes]]: """ Visualises the penalties using SPH smoothing for each individual model point. Parameters ---------- model_specification: ModelSpecification The appropriate model specification. Used for the limits of the figure. model_parameters: ModelParameters Parameters of the model, with the appropriate unique IDs. penalties: Dict[Hashable, float] Penalties for all parameters in ``model_parameters``, with the key in this dictionary being the unique IDs. norm: Normalize, optional A ``matplotlib`` normalisation object. By default this uses ``vmin=0.2`` and ``vmax=0.7``. remove_ticks: bool, optional Remove the axes ticks? This is recommended, as the plot can become very cluttered if you don't do this. Default: ``True``. figsize: Tuple[float], optional The figure size to use. Defaults to 7 inches by 7 inches, the size for a ``figure*`` in the MNRAS template. use_parameters: Iterable[str], optional The parameters to include in the figure. If not provided, all parameters in the ``model_specification`` are used. use_colorbar: Bool, optional Include a colorbar? Default: False highlight_model: Hashable, optional The model unique ID to highlight. If not provided, no model is highlighted. Returns ------- fig: Figure The figure object. axes: np.ndarray[Axes] The individual axes. Notes ----- You can either change how the figure looks by using the figure and axes objects that are returned, or by modifying the ``matplotlib`` stylesheet you are currently using. """ if use_parameters is None: use_parameters = model_specification.parameter_names if figsize is None: if use_colorbar: figsize = (7.0, 7.4) else: figsize = (7.0, 7.0) parameter_indices = [ model_specification.parameter_names.index(x) for x in use_parameters ] number_of_parameters = len(use_parameters) grid_size = number_of_parameters fig, axes_grid = plt.subplots( grid_size, grid_size, figsize=figsize, squeeze=True, sharex="col", sharey="row", ) visualisation_size = 2.0 / np.sqrt(len(model_parameters)) simulation_ordering = list(model_parameters.keys()) if highlight_model is not None: highlight_index = simulation_ordering.index(highlight_model) # Build temporary 1D arrays of parameters/offsets in correct ordering ordered_penalties = np.array([penalties[x] for x in simulation_ordering]) limits = model_specification.parameter_limits # Parameters must be re-scaled to the range [0,1] for smoothed projection. ordered_parameters = [ ( np.array( [ model_parameters.model_parameters[x][parameter] for x in simulation_ordering ] ) - limits[index][0] ) / (limits[index][1] - limits[index][0]) for index, parameter in zip(parameter_indices, use_parameters) ] smoothing_lengths = np.ones_like(ordered_penalties) * visualisation_size for parameter_x, axes_column in zip(parameter_indices, axes_grid): for parameter_y, ax in zip(parameter_indices, axes_column): limits_x = model_specification.parameter_limits[parameter_x] limits_y = model_specification.parameter_limits[parameter_y] name_x = model_specification.parameter_printable_names[parameter_x] name_y = model_specification.parameter_printable_names[parameter_y] is_center_line = parameter_x == parameter_y do_not_plot = is_center_line and remove_ticks if not do_not_plot: if swiftsimio_available: norm_grid = scatter( x=ordered_parameters[parameter_x], y=ordered_parameters[parameter_y], m=np.ones_like(ordered_penalties), h=smoothing_lengths, res=512, ) weighted_grid = scatter( x=ordered_parameters[parameter_x], y=ordered_parameters[parameter_y], m=ordered_penalties, h=smoothing_lengths, res=512, ) norm_grid[norm_grid == 0.0] = 1.0 ratio_grid = weighted_grid / norm_grid else: norm_grid, _, _, _ = binned_statistic_2d( x=ordered_parameters[parameter_x], y=ordered_parameters[parameter_y], values=ordered_penalties, statistic="mean", bins=16, ) im = ax.imshow( ratio_grid.T, extent=limits_x + limits_y, origin="lower", norm=norm, rasterized=True, ) if highlight_model is not None: highlight_x = ordered_parameters[parameter_x][highlight_index] highlight_y = ordered_parameters[parameter_y][highlight_index] # Need to re-scale from 0->1 to 'real' space highlight_x *= limits_x[1] - limits_x[0] highlight_x += limits_x[0] highlight_y *= limits_y[1] - limits_y[0] highlight_y += limits_y[0] ax.scatter( highlight_x, highlight_y, color="white", edgecolor="black", ) ax.set_xlim(*limits_x) ax.set_ylim(*limits_y) if remove_ticks: ax.tick_params( axis="both", which="both", bottom=False, left=False, right=False, top=False, labelbottom=False, labelleft=False, labelright=False, labeltop=False, ) if is_center_line: ax.text( 0.5, 0.5, f"{limits_x[0]:3.3f} <\n{name_x}\n< {limits_x[1]:3.3f}", transform=ax.transAxes, ha="center", va="center", ) else: ax.set_xlabel(name_x) ax.set_ylabel(name_y) # Set square in data reference frame ax.set_aspect(1.0 / ax.get_data_ratio()) if use_colorbar: fig.colorbar( im, ax=axes_grid.ravel().tolist(), orientation="horizontal", label="Mean penalty along line of sight", ) for a in axes_grid[:-1, :].flat: a.set_xlabel(None) for a in axes_grid[:, 1:].flat: a.set_ylabel(None) # As of matplotlib 3.3.4, with a large number of sub-plots this hangs... if grid_size > 4: fig.constrained_layout = False fig.subplots_adjust(0, 0, 1, 1, 0.005, 0.005) return fig, ax
[docs]def visualise_penalties_generic_statistic( model_specification: ModelSpecification, model_parameters: ModelParameters, penalties: Dict[Hashable, float], statistic: Optional[str] = None, norm: Normalize = Normalize(vmin=0.2, vmax=0.7, clip=True), remove_ticks: bool = True, figsize: Optional[Tuple[float]] = None, use_parameters: Optional[Iterable[str]] = None, use_colorbar: Optional[bool] = False, highlight_model: Optional[Hashable] = None, ) -> Tuple[plt.Figure, Iterable[plt.Axes]]: """ Visualises the penalties using basic binning. Parameters ---------- model_specification: ModelSpecification The appropriate model specification. Used for the limits of the figure. model_parameters: ModelParameters Parameters of the model, with the appropriate unique IDs. penalties: Dict[Hashable, float] Penalties for all parameters in ``model_parameters``, with the key in this dictionary being the unique IDs. statistic: str, optional The statistic that you would like to compute. Allowed values are the same as for ``scipy.stats.binned_statistic_2d``. Defaults to ``mean``. norm: Normalize, optional A ``matplotlib`` normalisation object. By default this uses ``vmin=0.2`` and ``vmax=0.7``. remove_ticks: bool, optional Remove the axes ticks? This is recommended, as the plot can become very cluttered if you don't do this. Default: ``True``. figsize: Tuple[float], optional The figure size to use. Defaults to 7 inches by 7 inches, the size for a ``figure*`` in the MNRAS template. use_parameters: Iterable[str], optional The parameters to include in the figure. If not provided, all parameters in the ``model_specification`` are used. use_colorbar: Bool, optional Include a colorbar? Default: False. highlight_model: Hashable, optional The model unique ID to highlight. If not provided, no model is highlighted. Returns ------- fig: Figure The figure object. axes: np.ndarray[Axes] The individual axes. Notes ----- You can either change how the figure looks by using the figure and axes objects that are returned, or by modifying the ``matplotlib`` stylesheet you are currently using. """ if use_parameters is None: use_parameters = model_specification.parameter_names if figsize is None: if use_colorbar: figsize = (7.0, 7.4) else: figsize = (7.0, 7.0) parameter_indices = [ model_specification.parameter_names.index(x) for x in use_parameters ] number_of_parameters = len(use_parameters) grid_size = number_of_parameters fig, axes_grid = plt.subplots( grid_size, grid_size, figsize=figsize, squeeze=True, sharex="col", sharey="row", ) visualisation_size = 4.0 / np.sqrt(len(model_parameters)) simulation_ordering = list(model_parameters.keys()) if highlight_model is not None: highlight_index = simulation_ordering.index(highlight_model) # Build temporary 1D arrays of parameters/offsets in correct ordering ordered_penalties = np.array([penalties[x] for x in simulation_ordering]) limits = model_specification.parameter_limits ordered_parameters = [ np.array( [ model_parameters.model_parameters[x][parameter] for x in simulation_ordering ] ) for index, parameter in zip(parameter_indices, use_parameters) ] bins = int(round(1.0 / visualisation_size)) statistic = statistic if statistic is not None else "mean" # JB: I am 100% confident in this loop and that we are looping # over the correct axes. Do not change this loop. for parameter_y, axes_column in zip(parameter_indices, axes_grid): for parameter_x, ax in zip(parameter_indices, axes_column): limits_x = model_specification.parameter_limits[parameter_x] limits_y = model_specification.parameter_limits[parameter_y] name_x = model_specification.parameter_printable_names[parameter_x] name_y = model_specification.parameter_printable_names[parameter_y] is_center_line = parameter_x == parameter_y do_not_plot = is_center_line and remove_ticks if not do_not_plot: grid, xs, ys, _ = binned_statistic_2d( x=ordered_parameters[parameter_x], y=ordered_parameters[parameter_y], values=ordered_penalties, statistic=statistic, bins=bins, ) im = ax.pcolormesh( xs, ys, grid.T, norm=norm, rasterized=True, ) # Uncomment me if you don't believe the comment above # ax.text(0.5, 0.5, f"x={name_x}\ny={name_y}", transform=ax.transAxes, ha="center", va="center", color="white") if highlight_model is not None: highlight_x = ordered_parameters[parameter_x][highlight_index] highlight_y = ordered_parameters[parameter_y][highlight_index] ax.scatter( highlight_x, highlight_y, color="white", edgecolor="black", ) ax.set_xlim(*limits_x) ax.set_ylim(*limits_y) if remove_ticks: ax.tick_params( axis="both", which="both", bottom=False, left=False, right=False, top=False, labelbottom=False, labelleft=False, labelright=False, labeltop=False, ) if is_center_line: ax.text( 0.5, 0.5, f"{limits_x[0]:3.3f} <\n{name_x}\n< {limits_x[1]:3.3f}", transform=ax.transAxes, ha="center", va="center", ) else: ax.set_xlabel(name_x) ax.set_ylabel(name_y) # Set square in data reference frame ax.set_aspect(1.0 / ax.get_data_ratio()) if use_colorbar: fig.colorbar( im, ax=axes_grid.ravel().tolist(), orientation="horizontal", label=f"{statistic.capitalize()} penalty along line of sight", ) for a in axes_grid[:-1, :].flat: a.set_xlabel(None) for a in axes_grid[:, 1:].flat: a.set_ylabel(None) # As of matplotlib 3.3.4, with a large number of sub-plots this hangs... if grid_size > 4: fig.constrained_layout = False fig.subplots_adjust(0, 0, 1, 1, 0.005, 0.005) return fig, ax