import pickle

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tqdm import tqdm

import experiments.hodgkin_huxley.metrics
from experiments.hodgkin_huxley.network_config import get_trainer, get_trainer_no_sc
from experiments.hodgkin_huxley.paths import CHECKPOINT_DIR, DATA_DIR, PLOT_DIR

plt.rcParams["text.usetex"] = True


def mean_absolute_bias_plot(trainer_no_sc, trainer_sc, n_samples=50):
    data_dict = mean_absolute_bias(trainer_no_sc, trainer_sc, n_samples=n_samples)

    fig, ax = plt.subplots(figsize=(3.5 * 1.2, 1.5 * 1.2))

    ax.hist(
        [
            data_dict["bias_no_sc"][i] - data_dict["bias_sc"][i]
            for i in range(n_samples)
        ],
        bins=int(np.sqrt(1000) * 1.5),
        color="#012F47",
        linewidth=2,
    )
    ax.axvline(x=0, linewidth=4, color="#AAAAAA", linestyle="--")

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    ax.set_xlabel(
        r"$\mathrm{MAB}_{i}(\mathrm{NPE}) - \mathrm{MAB}_{i}(\mathrm{NPE+SC})$"
    )
    ax.set_ylabel(r"$\mathrm{density}$")

    fig.tight_layout()

    return fig


def mean_absolute_bias(trainer_no_sc, trainer_sc, n_samples=1000):
    file_path = DATA_DIR / f"mean_absolute_bias_{n_samples}.pkl"

    if not file_path.exists():
        model = trainer_no_sc.generative_model
        prior = tf.random.normal((n_samples, 7), -2, 1)
        y = model.simulator(prior)["sim_data"]

        bias_no_sc = [
            experiments.hodgkin_huxley.metrics.mean_absolute_bias(
                trainer_no_sc, tf.expand_dims(y[i], 0)
            )[0]
            for i in tqdm(range(n_samples))
        ]
        bias_sc = [
            experiments.hodgkin_huxley.metrics.mean_absolute_bias(
                trainer_sc, tf.expand_dims(y[i], 0)
            )[0]
            for i in tqdm(range(n_samples))
        ]

        data_dict = {
            "prior": prior,
            "y": y,
            "bias_no_sc": bias_no_sc,
            "bias_sc": bias_sc,
        }

        with open(file_path, "wb") as file:
            pickle.dump(data_dict, file)

    with open(file_path, "rb") as file:
        data_dict = pickle.load(file)

    return data_dict


if __name__ == "__main__":
    trainer_no_sc = get_trainer_no_sc(checkpoint_path=CHECKPOINT_DIR / "no_sc")
    trainer_sc = get_trainer(checkpoint_path=CHECKPOINT_DIR / "sc")

    fig = mean_absolute_bias_plot(trainer_no_sc, trainer_sc)
    fig.savefig(PLOT_DIR / "02_mean_absolute_bias.pdf")
