diff --git a/panel/simdec_app.py b/panel/simdec_app.py index 9403d97..5939dc6 100644 --- a/panel/simdec_app.py +++ b/panel/simdec_app.py @@ -1,5 +1,6 @@ import bisect import io +import re from bokeh.models import PrintfTickFormatter from bokeh.models.widgets.tables import NumberFormatter @@ -15,10 +16,8 @@ from simdec.sensitivity_indices import SensitivityAnalysisResult from simdec.visualization import sequential_cmaps, single_color_to_colormap - # panel app -pn.extension("tabulator") -pn.extension("floatpanel") +pn.extension("tabulator", "floatpanel", notifications=True) pn.config.sizing_mode = "stretch_width" pn.config.throttled = True @@ -42,35 +41,58 @@ # save_layout=True, ) +VALID_CHARACTERS = re.compile(r"[A-Za-z0-9_ \-.]") +GENERIC_ERROR_MSG = ( + "Could not parse the CSV file. " + "Please check that it uses commas ',' as the delimiter " + "and that column names contain no special characters." +) + @pn.cache def load_data(text_fname): if text_fname is None: - text_fname = "tests/data/stress.csv" - else: - text_fname = io.BytesIO(text_fname) - - data = pd.read_csv(text_fname) - return data + return pd.read_csv("tests/data/stress.csv") + try: + raw = bytes(text_fname) + first_line = raw.decode("utf-8").split("\n")[0].strip() + if "," not in first_line: + raise ValueError("No comma delimiter") + col_names = [c.strip().strip('"').strip("'") for c in first_line.split(",")] + if any(VALID_CHARACTERS.search(c) for c in col_names): + raise ValueError("Bad column names") + return pd.read_csv(io.BytesIO(raw)) + except Exception: + pn.state.notifications.error(GENERIC_ERROR_MSG, duration=0) + return pd.read_csv("tests/data/stress.csv") @pn.cache def column_inputs(data, output): + if data is None: + return [] inputs = list(data.columns) - inputs.remove(output) + if output in inputs: + inputs.remove(output) return inputs @pn.cache def column_output(data): + if data is None: + return [] return list(data.columns) @pn.cache def filtered_data(data, output_name): + if data is None or not output_name: + return pd.Series(dtype=float) try: return data[output_name] except KeyError: + if isinstance(output_name, list): + return data.iloc[:, [0]] return data.iloc[:, 0] @@ -350,7 +372,7 @@ def csv_data( interactive_column_output = pn.bind(column_output, interactive_file) # hack to make the default selection faster -interactive_output_ = pn.bind(lambda x: x[0], interactive_column_output) +interactive_output_ = pn.bind(lambda x: x[0] if x else None, interactive_column_output) selector_output = pn.widgets.Select( name="Output", value=interactive_output_, options=interactive_column_output ) diff --git a/src/simdec/__init__.py b/src/simdec/__init__.py index d0b238b..9394a8f 100644 --- a/src/simdec/__init__.py +++ b/src/simdec/__init__.py @@ -8,6 +8,7 @@ "states_expansion", "decomposition", "visualization", + "two_output_visualization", "tableau", "palette", ] diff --git a/src/simdec/visualization.py b/src/simdec/visualization.py index 6576ba0..ae3f7c0 100644 --- a/src/simdec/visualization.py +++ b/src/simdec/visualization.py @@ -11,7 +11,7 @@ import pandas as pd from pandas.io.formats.style import Styler -__all__ = ["visualization", "tableau", "palette"] +__all__ = ["visualization", "two_output_visualization", "tableau", "palette"] SEQUENTIAL_PALETTES = [ @@ -189,6 +189,105 @@ def visualization( return ax +def two_output_visualization( + *, + bins: pd.DataFrame, + bins2: pd.DataFrame, + palette: list[list[float]], + n_bins: str | int = "auto", + output_name: str = "Output 1", + output_name2: str = "Output 2", + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + r_scatter: float = 1.0, +) -> tuple[plt.Figure, np.ndarray]: + """Two-output visualization. + Produces a 2x2 figure + * top-left : stacked histogram for *output 1* (axes hidden) + * bottom-left : scatter of output 1 vs output 2, coloured by scenario + * bottom-right: rotated stacked histogram for *output 2* (axes hidden) + * top-right : empty + + Parameters + ---------- + bins : DataFrame + Multidimensional bins for the primary output. + bins2 : DataFrame + Multidimensional bins for the secondary output. + palette : list of int of size (n, 4) + List of colours corresponding to scenarios. + n_bins : str or int + Number of bins for the histograms. + output_name : str, default "Output 1" + Axis label for the primary output. + output_name2 : str, default "Output 2" + Axis label for the secondary output. + xlim : tuple of float, optional + Limits for the primary output axis (scatter x / top histogram). + ylim : tuple of float, optional + Limits for the secondary output axis (scatter y / right histogram). + r_scatter : float, default 1.0 + Fraction of data points shown in the scatter plot. + + Returns + ------- + fig : Figure + axs : ndarray of shape (2, 2) + + """ + fig, axs = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(8, 8)) + + axs[0, 1].axis("off") + + visualization(bins=bins.copy(), palette=palette, n_bins=n_bins, ax=axs[0, 0]) + if xlim is not None: + axs[0, 0].set_xlim(xlim) + axs[0, 0].set_box_aspect(1) + axs[0, 0].axis("off") + + data = pd.concat([pd.melt(bins), pd.melt(bins2)["value"]], axis=1) + data.columns = ["c", "x", "y"] + if r_scatter < 1.0: + data = data.sample(frac=r_scatter) + + sns.scatterplot( + data=data, + x="x", + y="y", + hue="c", + palette=palette, + ax=axs[1, 0], + legend=False, + ) + axs[1, 0].set(xlabel=output_name, ylabel=output_name2) + if xlim is not None: + axs[1, 0].set_xlim(xlim) + if ylim is not None: + axs[1, 0].set_ylim(ylim) + axs[1, 0].set_box_aspect(1) + + sns.histplot( + data, + y="y", + hue="c", + multiple="stack", + stat="probability", + palette=palette, + common_bins=True, + common_norm=True, + bins=40, + legend=False, + ax=axs[1, 1], + ) + if ylim is not None: + axs[1, 1].set_ylim(ylim) + axs[1, 1].set_box_aspect(1) + axs[1, 1].axis("off") + + fig.subplots_adjust(wspace=-0.015, hspace=0) + return fig, axs + + def tableau( *, var_names: list[str], diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..a974ae0 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,64 @@ +import pytest +import pandas as pd +import matplotlib.pyplot as plt +import simdec as sd + + +@pytest.fixture(autouse=True) +def close_plots(): + yield + plt.close("all") + + +def test_visualization_histogram(): + bins = pd.DataFrame({"s1": [1, 2], "s2": [3, 4]}) + palette = [[1, 0, 0, 1], [0, 1, 0, 1]] + ax = sd.visualization(bins=bins, palette=palette, kind="histogram") + assert isinstance(ax, plt.Axes) + + +def test_visualization_boxplot(): + bins = pd.DataFrame({"s1": [1, 2], "s2": [3, 4]}) + palette = [[1, 0, 0, 1], [0, 1, 0, 1]] + ax = sd.visualization(bins=bins, palette=palette, kind="boxplot") + assert isinstance(ax, plt.Axes) + + +def test_visualization_invalid_kind(): + bins = pd.DataFrame({"s1": [1]}) + with pytest.raises(ValueError, match="'kind' can only be 'histogram' or 'boxplot'"): + sd.visualization(bins=bins, palette=[[1, 0, 0, 1]], kind="invalid") + + +def test_two_output_visualization_returns_correct_types(): + bins = pd.DataFrame({"s1": [1, 2]}) + bins2 = pd.DataFrame({"s1": [5, 6]}) + palette = [[1, 0, 0, 1]] + fig, axs = sd.two_output_visualization(bins=bins, bins2=bins2, palette=palette) + assert isinstance(fig, plt.Figure) + assert axs.shape == (2, 2) + + +def test_two_output_visualization_axis_labels(): + bins = pd.DataFrame({"s1": [1, 2]}) + bins2 = pd.DataFrame({"s1": [5, 6]}) + palette = [[1, 0, 0, 1]] + _, axs = sd.two_output_visualization( + bins=bins, + bins2=bins2, + palette=palette, + output_name="Stress", + output_name2="Displacement", + ) + assert axs[1, 0].get_xlabel() == "Stress" + assert axs[1, 0].get_ylabel() == "Displacement" + + +def test_two_output_visualization_r_scatter(): + bins = pd.DataFrame({"s1": list(range(100))}) + bins2 = pd.DataFrame({"s1": list(range(100))}) + palette = [[1, 0, 0, 1]] + fig, axs = sd.two_output_visualization( + bins=bins, bins2=bins2, palette=palette, r_scatter=0.5 + ) + assert isinstance(fig, plt.Figure)