Figure 2e#

Plot interpolation results#

[286]:
%load_ext nb_black

import scdiffeq_analyses as sdq_an
import glob, os
import pandas as pd
import seaborn as sns
import numpy as np

metrics_base_path = "./scDiffEq.task_one.timepoint_recovery.seed_{}/version_0/fit_logs/version_0/metrics.csv"
Metrics = {seed: pd.read_csv(metrics_base_path.format(seed)) for seed in range(5)}
BestTrainTest = {}
for seed, metrics_df in Metrics.items():
    BestTrainTest[seed] = {}
    task_one = sdq_an.TaskOneSummary(metrics_df)
    BestTrainTest[seed]["train"] = task_one.best_train_loss
    BestTrainTest[seed]["test"] = task_one.best_test_loss

best_sdq = pd.DataFrame(BestTrainTest).T
print(best_sdq.mean())
best_sdq
The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black
train    13.748313
test     24.556157
dtype: float64
[286]:
train test
0 13.113493 24.334347
1 13.892048 24.157946
2 13.807034 25.201881
3 13.941353 24.506296
4 13.987637 24.580315
[287]:
best_prescient = pd.read_csv(
    "../../notebooks/task_2_nbs/folder_09_prescient/PRESCIENT.interpolation.no_weights.results.csv",
    index_col=0,
    names=["test", "train"],
    skiprows=1,
)[["train", "test"]]
print(best_prescient.mean())
best_prescient
train    14.944709
test     25.845095
dtype: float64
[287]:
train test
0 15.265286 26.038500
1 14.392807 25.847347
2 14.809107 25.358963
3 15.089598 26.003586
4 15.166747 25.977081
[288]:
train_df = pd.DataFrame(
    {"PRESCIENT": best_prescient["train"], "scDiffEq": best_sdq["train"]}
)
test_df = pd.DataFrame(
    {"PRESCIENT": best_prescient["test"], "scDiffEq": best_sdq["test"]}
)
[321]:
import matplotlib.pyplot as plt
from scdiffeq.core.utils import AutoParseBase


class StyledBoxPlot(AutoParseBase):
    def __init__(
        self,
        figsize=None,
        figsize_height=1,
        figsize_width=1,
        box_face_zorder=7,
        box_edge_zorder=8,
        scatter_zorder=10,
        lw=0.5,
        widths=0.5,
    ):
        self.__parse__(locals())
        self.build_plot()

    def __configure__(self, data):

        self.columns = data.columns.tolist()

    def build_plot(self):

        self.fig, self.axes = vinplots.quick_plot(
            nplots=2,
            ncols=2,
            figsize=self.figsize,
            figsize_height=self.figsize_height,
            figsize_width=self.figsize_width,
            wspace=0.4,
        )

    def scatter(
        self, ax, x: float, y: pd.Series, w_scale: float = 0.02, **scatter_kwargs
    ):

        n_pts = y.shape[0]
        x = np.random.normal(loc=x, scale=w_scale, size=n_pts)

        ax.scatter(x, y, edgecolor="None", **scatter_kwargs)

    def run_scatter(self, ax, alpha=0.8):

        for i in range(len(self.columns)):
            self.scatter(
                ax,
                x=i + 1,
                y=self.data[self.columns[i]],
                c=self.colors[i],
                zorder=self.scatter_zorder,
                alpha=alpha,
                s=35,
            )

    def _whisker_cap_colors(self):
        for color in self.colors:
            for c in [color] * 2:
                yield c

    @property
    def whisker_cap_colors(self):
        return list(self._whisker_cap_colors())

    def _edge_box(self, ax):
        box_edges = ax.boxplot(
            self.data.values,
            patch_artist=True,
            showmeans=True,
            meanline=True,
            showfliers=False,
            widths=self.widths,
            zorder=self.box_edge_zorder,
        )

        for n, patch in enumerate(box_edges["whiskers"]):
            patch.set_c(self.whisker_cap_colors[n])
            patch.set_linewidth(self.lw)
        for n, patch in enumerate(box_edges["caps"]):
            patch.set_c(self.whisker_cap_colors[n])
            patch.set_linewidth(self.lw)

        for n, box_mean in enumerate(box_edges["means"]):
            box_mean.set_color(self.colors[n])
            box_mean.set_linestyle("--")
            box_mean.set_linewidth(self.lw * 2)
        for box_median in box_edges["medians"]:
            box_median.set_visible(False)
        for n, box_edge in enumerate(box_edges["boxes"]):
            box_edge.set_facecolor([0, 0, 0, 0])
            box_edge.set_edgecolor(self.colors[n])  #
            box_edge.set_linewidth(self.lw)

    def run_boxplot(self, ax, alpha=0.3):

        box = ax.boxplot(
            self.data.values,
            showmeans=False,
            meanline=False,
            widths=self.widths,
            labels=self.columns,
            patch_artist=True,
            showfliers=False,
            zorder=self.box_face_zorder,
        )
        ax.set_xticks([])

        for patch in box["whiskers"]:
            patch.set_visible(False)
        for patch in box["caps"]:
            patch.set_visible(False)

        for patch in box["medians"]:
            patch.set_visible(False)

        for n, patch in enumerate(box["boxes"]):
            patch.set_facecolor(self.colors[n])
            patch.set_alpha(alpha)

        self._edge_box(ax)

    def __call__(self, ax, data):

        self.data = data

        #         self.colors = ["olivedrab", "dodgerblue", "salmon", "goldenrod"]
        self.colors = ["#fb8500", "#023047"]
        self.__configure__(data)
        self.run_scatter(ax=ax)
        self.run_boxplot(ax=ax)

Plot#

[324]:
sbp = StyledBoxPlot(figsize_width=0.2, figsize_height=0.5)
sbp(ax=sbp.axes[0], data=train_df)
sbp(ax=sbp.axes[1], data=test_df)
sbp.axes[0].set_ylim(13, 16)
sbp.axes[1].set_ylim(24, 27)

xt = sbp.axes[0].set_xticks(range(1, 3), ["PRESCIENT", "scDiffEq"], fontsize=6)
xt = sbp.axes[1].set_xticks(range(1, 3), ["PRESCIENT", "scDiffEq"], fontsize=6)
yt = sbp.axes[0].set_yticks(range(13, 17), range(13, 17), fontsize=6)
yt = sbp.axes[1].set_yticks(range(24, 28), range(24, 28), fontsize=6)

ti = sbp.axes[0].set_title("Training (d6)", fontsize=10)
ti = sbp.axes[1].set_title("Test (d4)", fontsize=10)
yl = sbp.axes[0].set_ylabel("$W_{2}$ Distance")
plt.savefig("task_one.interpolation.results.svg")
../_images/_analyses_Figure2E_6_0.png