import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from experiments.hodgkin_huxley.generative_model import (
    SimulatorWithLogProb,
)
from experiments.hodgkin_huxley.network_config import get_trainer, get_trainer_no_sc
from experiments.hodgkin_huxley.paths import CHECKPOINT_DIR, PLOT_DIR


def ppc_ood_data(trainer_no_sc, trainer_sc, nrows=5, mean=-2, suptitle=""):
    simulator = SimulatorWithLogProb()

    fig, axarray = plt.subplots(nrows, 2, figsize=(7 * 1.2, 1 * 1.2 * nrows))

    axarray[0][0].set_title("NPE only", fontsize=22)
    axarray[0][1].set_title("NPE+SC", fontsize=22)

    for row_idx in range(nrows):
        prior_draw = tf.random.normal((1, 7), mean, 1).numpy()

        y = simulator(prior_draw)
        x = np.linspace(0, 60, len(y[0]))

        # no SC
        posterior = trainer_no_sc.amortizer.sample(
            {"summary_conditions": y}, n_samples=200
        )
        pred = simulator(posterior).numpy()

        axarray[row_idx][0].plot(x, y[0], color="#012F47", zorder=1)
        axarray[row_idx][0].plot(x, pred.T, color="#AAAAAA", alpha=0.05, zorder=0)

        axarray[row_idx][0].spines["top"].set_visible(False)
        axarray[row_idx][0].spines["right"].set_visible(False)

        # with SC
        posterior = trainer_sc.amortizer.sample(
            {"summary_conditions": y}, n_samples=200
        )
        pred = simulator(posterior).numpy()

        axarray[row_idx][1].plot(x, y[0], color="#012F47", zorder=1)
        axarray[row_idx][1].plot(x, pred.T, color="#AAAAAA", zorder=0, alpha=0.05)
        axarray[row_idx][1].spines["top"].set_visible(False)
        axarray[row_idx][1].spines["right"].set_visible(False)

    fig.supxlabel("time [ms]", fontsize=20)
    fig.supylabel("membrane potential [mV]", fontsize=20)
    fig.suptitle(suptitle, fontsize=25)

    fig.tight_layout()

    return fig


if __name__ == "__main__":
    trainer_no_sc = get_trainer_no_sc(checkpoint_path=CHECKPOINT_DIR / "no_sc")
    trainer_sc = get_trainer(checkpoint_path=CHECKPOINT_DIR / "sc")

    fig = ppc_ood_data(
        trainer_no_sc,
        trainer_sc,
        mean=0,
        suptitle="In-simulation data (θ ~ Normal(0, 1))",
    )
    fig.savefig(PLOT_DIR / "id_data.pdf")

    fig = ppc_ood_data(
        trainer_no_sc,
        trainer_sc,
        mean=-2,
        suptitle="Out-of-simulation data (θ ~ Normal(-2, 1))",
    )
    fig.savefig(PLOT_DIR / "ood_data.pdf")
