Supplementary Figure 12D#

Import packages#

[1]:
import scdiffeq as sdq
import scdiffeq_analyses as sdq_an
import pandas as pd
import cellplots as cp
import matplotlib.pyplot as plt

Load project#

[2]:
project = sdq.io.Project("./LightningSDE-FixedPotential-RegularizedVelocityRatio/")

Organize loss values#

[3]:
SummarizedLoss = {}
for vname, version_path in project._VERSION_PATHS.items():
    version = getattr(project, vname)
    SummarizedLoss[vname] = sdq_an.parsers.SummarizedLoss(version=version)

Plot loss curves#

[4]:
fig, axes = cp.plot(
    5,
    5,
    height=0.5,
    width=0.6,
    hspace=0.4,
    wspace=0.4,
    x_label=["Epoch"] * 5,
    y_label=["Sinkhorn Divergence"] * 5,
)
for en, (vname, summarized_loss) in enumerate(SummarizedLoss.items()):
    sdq_an.pl.fit_loss(summarized_loss=summarized_loss, ax=axes[en])
plt.savefig("pancreas_scdiffeq_loss_curves.svg", dpi=500)
plt.savefig("pancreas_scdiffeq_loss_curves.png", dpi=500)
../_images/_analyses_FigureS12D_7_0.png