"""Serialization of Figure object to matplotlib."""

from typing import Any, Dict, List, Optional

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axis import Axis as AxisMPL
from matplotlib.figure import Figure as FigureMPL
from matplotlib.gridspec import GridSpec
from sbmlutils import log

from sbmlsim.plot import Axis, Curve, Figure, SubPlot
from sbmlsim.plot.plotting import (

logger = log.get_logger(__name__)
def interp(x, xp, fp):
    """Interpolation for speedup of plots.

    :param x:
    :param xp:
    :param fp:
    :return:
    """
    y = np.interp(x=x, xp=xp, fp=fp)
    # better spline interpolation, but NaN issues with zero values
    # tck, fp, ier, msg = interpolate.splrep(xp, fp, full_output=True)
    # if ier > 0:
    #     logger.error(f"Spline fitting failed: '{msg}'")
    #
    # y = interpolate.splev(x, tck, der=0)
    if not np.all(np.isfinite(y)):
        logger.error(f"NaN or Inf values in interpolation: {fp} -> {y}")
    return y
class MatplotlibFigureSerializer:
    """Serializer for figures to matplotlib."""

    @classmethod
def _get_scale(cls, axis: Axis) -> str:
        """Get string representation of the scale."""
        if axis.scale == AxisScale.LINEAR:
            return "linear"
        elif axis.scale == AxisScale.LOG10:
            return "log"
        else:
            raise ValueError(f"Unsupported axis scale: '{axis.scale}'")
[docs] def to_figure( cls, experiment: "SimulationExperiment", figure: Figure # noqa: F821 ) -> FigureMPL: """Convert sbmlsim.Figure to matplotlib figure.""" # create new figure fig: plt.Figure = plt.figure( figsize=(figure.width, figure.height), dpi=Figure.fig_dpi, facecolor=Figure.fig_facecolor, ) if fig.suptitle(, fontsize=Figure.fig_titlesize, fontweight=Figure.fig_titleweight, ) # create grid for figure gs = GridSpec( nrows=figure.num_rows, ncols=figure.num_cols, figure=fig, # done via subplots adjust below # hspace=figure.fig_subplots_hspace, # wspace=figure.fig_subplots_wspace, ) subplot: SubPlot for subplot in figure.subplots: plot = subplot.plot xax: Axis = plot.xaxis if plot.xaxis else Axis() yax: Axis = plot.yaxis if plot.yaxis else Axis() yax_right = plot.yaxis_right ridx = subplot.row - 1 cidx = subplot.col - 1 ax1: plt.Axes = fig.add_subplot( gs[ridx : ridx + subplot.row_span, cidx : cidx + subplot.col_span] ) # secondary axis ax2: Optional[plt.Axes] = None axes: List[plt.Axes] = [ax1] if yax_right: for curve in plot.curves: if ( curve.yaxis_position and curve.yaxis_position == YAxisPosition.RIGHT ): ax2 = ax1.twinx() axes.append(ax2) break else: logger.error("Position right defined by no yAxis right.") # units if xax is None: logger.warning(f"No xaxis in plot: {subplot}") ax1.spines["bottom"].set_color(Figure.fig_facecolor) ax1.spines["top"].set_color(Figure.fig_facecolor) if yax is None: logger.warning(f"No yaxis in plot: {subplot}") ax1.spines["right"].set_color(Figure.fig_facecolor) ax1.spines["left"].set_color(Figure.fig_facecolor) if (not xax) or (not yax): if len(plot.curves) > 0: raise ValueError( f"xaxis and yaxis are required for plotting curves, but " f"'xaxis={xax}' and 'yaxis={yax}'." ) xunit = xax.unit if xax else None yunit_left = yax.unit if yax else None yunit_right = yax_right.unit if yax_right else None # memory for stacked bars barstack_x = None barstack_y = None barhstack_x = None barhstack_y = None # plot ordered curves abstract_curves: List[AbstractCurve] = sorted( plot.curves + plot.areas, key=lambda x: x.order ) for abstract_curve in abstract_curves: if ( abstract_curve.yaxis_position and abstract_curve.yaxis_position == YAxisPosition.RIGHT ): # right axis yunit = yunit_right ax = ax2 else: # left axis yunit = yunit_left ax = ax1 if isinstance(abstract_curve, Curve): # --- Curve --- curve: Curve = abstract_curve x = curve.x.get_data(experiment=experiment, to_units=xunit) y = curve.y.get_data(experiment=experiment, to_units=yunit) xerr = None if curve.xerr is not None: xerr = curve.xerr.get_data( experiment=experiment, to_units=xunit ) yerr = None if curve.yerr is not None: yerr = curve.yerr.get_data( experiment=experiment, to_units=yunit ) label = if else "__nolabel__" # FIXME: necessary to get the individual curves out of the data cube # TODO: iterate over all repeats in the data if x is None: x_data = None else: x_data = x.magnitude[:, 0] if len(x.shape) == 2 else x.magnitude if y is None: y_data = None else: y_data = y.magnitude[:, 0] if len(y.shape) == 2 else y.magnitude if xerr is None: xerr_data = None else: xerr_data = ( xerr.magnitude[:, 0] if len(xerr.shape) == 2 else xerr.magnitude ) if yerr is None: yerr_data = None else: yerr_data = ( yerr.magnitude[:, 0] if len(yerr.shape) == 2 else yerr.magnitude ) kwargs: Dict[str, Any] = {} if style: Style = if curve.type == CurveType.POINTS: kwargs = style.to_mpl_points_kwargs() else: # bar plot kwargs = style.to_mpl_bar_kwargs() if curve.type == CurveType.POINTS: ax.errorbar( x=x_data, y=y_data, xerr=xerr_data, yerr=yerr_data, label=label, **kwargs, ) elif curve.type == CurveType.BAR: x=x_data, height=y_data, xerr=xerr_data, yerr=yerr_data, label=label, **kwargs, ) elif curve.type == CurveType.HORIZONTALBAR: ax.barh( y=x_data, width=y_data, xerr=yerr_data, yerr=xerr_data, label=label, **kwargs, ) elif curve.type == CurveType.BARSTACKED: if barstack_x is None: barstack_x = x_data barstack_y = np.zeros_like(y_data) if not np.all(np.isclose(barstack_x, x_data)): raise ValueError("x data must match for stacked bars.") x=x_data, height=y_data, bottom=barstack_y, xerr=xerr_data, yerr=yerr_data, label=label, **kwargs, ) barstack_y = barstack_y + y_data elif curve.type == CurveType.HORIZONTALBARSTACKED: if barhstack_x is None: barhstack_x = x_data barhstack_y = np.zeros_like(y_data) if not np.all(np.isclose(barhstack_x, x_data)): raise ValueError("x data must match for stacked bars.") ax.barh( y=x_data, width=y_data, left=barhstack_y, xerr=yerr_data, yerr=xerr_data, label=label, **kwargs, ) barhstack_y = barhstack_y + y_data elif isinstance(abstract_curve, ShadedArea): # --- ShadedArea --- area: ShadedArea = abstract_curve x = area.x.get_data(experiment=experiment, to_units=xunit) yfrom = area.yfrom.get_data(experiment=experiment, to_units=yunit) yto = area.yto.get_data(experiment=experiment, to_units=yunit) # FIXME: support multidimensional results x_data = x.magnitude[:, 0] if x is not None else None yfrom_data = yfrom.magnitude[:, 0] if yfrom is not None else None yto_data = yto.magnitude[:, 0] if yto is not None else None label = if else "__nolabel__" kwargs: Dict[str, Any] = {} if style: Style = kwargs = style.to_mpl_area_kwargs() ax.fill_between( x=x_data, y1=yfrom_data, y2=yto_data, label=label, **kwargs ) # plot settings if and plot.title_visible: ax1.set_title( def apply_axis_settings(sax: Axis, ax: plt.Axes, axis_type: str): """Apply settings to all axis.""" if axis_type not in ["x", "y"]: raise ValueError # handle the reverse flag if sax.reverse: ax_min, ax_max = sax.max, sax.min else: ax_min, ax_max = sax.min, sax.max if sax.min is not None: if axis_type == "x": ax.set_xlim(xmin=ax_min) elif axis_type == "y": ax.set_ylim(ymin=ax_min) if sax.max is not None: if axis_type == "x": ax.set_xlim(xmax=ax_max) elif axis_type == "y": ax.set_ylim(ymax=ax_max) if axis_type == "x": ax.set_xscale(cls._get_scale(sax)) elif axis_type == "y": ax.set_yscale(cls._get_scale(sax)) if sax.label_visible and if axis_type == "x": ax.set_xlabel( elif axis_type == "y": ax.set_ylabel( if not sax.ticks_visible: if axis_type == "x": ax.set_xticklabels([]) # hide ticks elif axis_type == "y": ax.set_yticklabels([]) # hide ticks # style # # if and if axis_type == "x": directions = ["bottom", "top"] elif axis_type == "y": directions = ["left", "right"] style: Style = if style.line: if style.line.thickness: linewidth = style.line.thickness for axis in directions: ax.tick_params(width=linewidth) if np.isclose(linewidth, 0.0): ax.spines[axis].set_color(Figure.fig_facecolor) else: ax.spines[axis].set_linewidth(linewidth) ax.tick_params(width=linewidth) if style.line.color: color = style.line.color for axis in directions: ax.spines[axis].set_color(str(color)) if style.line.type and style.line.type == LineType.NONE: for axis in directions: ax.spines[axis].set_color(Figure.fig_facecolor) if xax: apply_axis_settings(xax, ax1, axis_type="x") if xax: apply_axis_settings(yax, ax1, axis_type="y") if yax_right: apply_axis_settings(yax_right, ax2, axis_type="y") # recompute the ax.dataLim # ax.relim() # update ax.viewLim using the new dataLim # ax.autoscale_view() # figure styling for ax in axes: ax.title.set_fontsize(Figure.axes_titlesize) ax.title.set_fontweight(Figure.axes_titleweight) ax.xaxis.label.set_fontsize(Figure.axes_labelsize) ax.xaxis.label.set_fontweight(Figure.axes_labelweight) ax.yaxis.label.set_fontsize(Figure.axes_labelsize) ax.yaxis.label.set_fontweight(Figure.axes_labelweight) ax.tick_params(axis="x", labelsize=Figure.xtick_labelsize) ax.tick_params(axis="y", labelsize=Figure.ytick_labelsize) # hide none-existing axes if xax is None: ax1.tick_params(axis="x", colors=Figure.fig_facecolor) ax1.xaxis.label.set_color(Figure.fig_facecolor) if yax is None: ax1.tick_params(axis="y", colors=Figure.fig_facecolor) ax1.yaxis.label.set_color(Figure.fig_facecolor) xgrid = xax.grid if xax else None ygrid = yax.grid if yax else None if xgrid and ygrid: ax1.grid(True, axis="both") elif xgrid: ax1.grid(True, axis="x") elif ygrid: ax1.grid(True, axis="y") else: ax1.grid(False) if plot.legend: if len(axes) == 1: handles1, _ = ax1.get_legend_handles_labels() if handles1: if figure.legend_position == "inside": ax1.legend( fontsize=Figure.legend_fontsize, loc=Figure.legend_loc ) elif figure.legend_position == "outside": ax1.legend( fontsize=Figure.legend_fontsize, loc="upper left", bbox_to_anchor=(1.04, 1), ) elif len(axes) == 2: handles1, _ = ax1.get_legend_handles_labels() if handles1: ax1.legend(fontsize=Figure.legend_fontsize, loc="upper left") handles2, _ = ax2.get_legend_handles_labels() if handles2: ax2.legend(fontsize=Figure.legend_fontsize, loc="upper right") wspace = figure.fig_subplots_wspace hspace = figure.fig_subplots_hspace if figure.legend_position == "outside": wspace += 1.0 fig.subplots_adjust(wspace=wspace, hspace=hspace) return fig