import pickle

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tqdm import tqdm

import experiments.hodgkin_huxley.metrics as metrics
from experiments.hodgkin_huxley.generative_model import (
    SimulatorWithLogProb,
)
from experiments.hodgkin_huxley.network_config import get_trainer, get_trainer_no_sc
from experiments.hodgkin_huxley.paths import CHECKPOINT_DIR, DATA_DIR, PLOT_DIR

plt.rcParams["text.usetex"] = False


def ppc_ood_data(trainer_no_sc, trainer_sc, mean=-2):
    simulator = SimulatorWithLogProb()

    fig, axarray = plt.subplots(2, 1, figsize=(3.5 * 1.2, 2.3 * 1.2))

    axarray[1].set_title("NPE+SC", loc="left")
    axarray[0].set_title("NPE only", loc="left")

    prior_draw = tf.random.normal((1, 7), mean, 1).numpy()

    y = simulator(prior_draw)
    x = np.linspace(0, 60, len(y[0]))

    # no SC
    posterior = trainer_no_sc.amortizer.sample({"summary_conditions": y}, n_samples=200)
    pred = simulator(posterior).numpy()

    axarray[0].plot(x, y[0], color="#012F47", zorder=1)
    axarray[0].plot(x, pred.T, color="#777777", alpha=0.05, zorder=0)

    axarray[0].spines["top"].set_visible(False)
    axarray[0].spines["right"].set_visible(False)
    axarray[0].set_xticks([])

    # with SC
    posterior = trainer_sc.amortizer.sample({"summary_conditions": y}, n_samples=200)
    pred = simulator(posterior).numpy()

    axarray[1].plot(x, y[0], color="#012F47", zorder=1)
    axarray[1].plot(x, pred.T, color="#777777", zorder=0, alpha=0.05)
    axarray[1].spines["top"].set_visible(False)
    axarray[1].spines["right"].set_visible(False)

    fig.supxlabel("time [ms]", y=0.06, fontsize=12)
    fig.supylabel("membrane potential [mV]", x=0.04, fontsize=12)

    fig.tight_layout()

    return fig


def mean_absolute_bias_plot(trainer_no_sc, trainer_sc, mean=-2, n_samples=1000):
    data_dict = mean_absolute_bias(
        trainer_no_sc, trainer_sc, mean=mean, n_samples=n_samples
    )

    fig, ax = plt.subplots(figsize=(3.5 * 1.2, 2.3 * 1.2))

    ax.hist(
        [
            data_dict["bias_no_sc"][i] - data_dict["bias_sc"][i]
            for i in range(n_samples)
        ],
        bins=int(np.sqrt(1000) * 1.5),
        color="#012F47",
        linewidth=2,
    )
    ax.axvline(x=0, linewidth=4, color="#777777", linestyle="--")

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    ax.set_xlabel("MAB(NPE) - MAB(NPE+SC)", fontsize=12)
    ax.set_ylabel("count", fontsize=12)
    fig.tight_layout()

    return fig


def mean_absolute_bias(trainer_no_sc, trainer_sc, mean=-2, n_samples=1000):
    file_path = DATA_DIR / f"mean_absolute_bias_{mean}_{n_samples}.pkl"

    if not file_path.exists():
        model = trainer_no_sc.generative_model
        prior = tf.random.normal((n_samples, 7), mean, 1)
        y = model.simulator(prior)["sim_data"]

        bias_no_sc = [
            metrics.mean_absolute_bias(trainer_no_sc, tf.expand_dims(y[i], 0))[0]
            for i in tqdm(range(n_samples))
        ]
        bias_sc = [
            metrics.mean_absolute_bias(trainer_sc, tf.expand_dims(y[i], 0))[0]
            for i in tqdm(range(n_samples))
        ]

        data_dict = {
            "prior": prior,
            "y": y,
            "bias_no_sc": bias_no_sc,
            "bias_sc": bias_sc,
        }

        with open(file_path, "wb") as file:
            pickle.dump(data_dict, file)

    with open(file_path, "rb") as file:
        data_dict = pickle.load(file)

    return data_dict


if __name__ == "__main__":
    trainer_no_sc = get_trainer_no_sc(checkpoint_path=CHECKPOINT_DIR / "no_sc")
    trainer_sc = get_trainer(checkpoint_path=CHECKPOINT_DIR / "sc")

    fig = ppc_ood_data(trainer_no_sc, trainer_sc, mean=-2)
    fig.savefig(PLOT_DIR / "03_ppc_a.pdf")

    fig = mean_absolute_bias_plot(trainer_no_sc, trainer_sc, mean=-2)
    fig.savefig(PLOT_DIR / "03_ppc_b.pdf")

    fig = ppc_ood_data(trainer_no_sc, trainer_sc, mean=0)
    fig.savefig(PLOT_DIR / "03_ppc_appendix_a.pdf")

    fig = mean_absolute_bias_plot(trainer_no_sc, trainer_sc, mean=0)
    fig.savefig(PLOT_DIR / "03_ppc_appendix_b.pdf")
