import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from experiments.multivariate_normal.accurate_posterior import (
    posterior_bias,
    posterior_mmd,
)
from experiments.multivariate_normal.generative_model import get_generative_model
from experiments.multivariate_normal.network_config import (
    get_trainer,
    get_trainer_no_sc,
)
from experiments.multivariate_normal.plots import comparison_lines
from src.self_consistency_real.utils import delete_checkpoints


def train_model(dataset_size: int, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / f"dataset_size_{dataset_size}"
        / f"npe_with_sc_{replication}"
    )

    generative_model = get_generative_model(dimension=10, rng=rng)
    train_data = generative_model(1024)

    trainer = get_trainer(
        dimension=10, dataset_size=dataset_size, rng=rng, checkpoint_path=checkpoint_dir
    )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def train_model_no_sc(dimension: int, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / "dataset_size_0"
        / f"npe_only_{replication}"
    )

    generative_model = get_generative_model(dimension=dimension, rng=rng)
    train_data = generative_model(1024)

    trainer = get_trainer_no_sc(
        dimension=dimension, rng=rng, checkpoint_path=checkpoint_dir
    )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def posterior_metrics():
    # shape:
    # metric (mean_bias, std_bias, mmd),
    # dataset_size (1, 4, 32),
    # location (range(12))
    # replication (range(10)).
    file_path = Path(__file__).parents[0] / "data" / "posterior_metrics.pkl"

    if not file_path.exists():
        metrics = np.zeros((3, 3, 12, 10))

        for dataset_size_idx, dataset_size in enumerate([1, 4, 32]):
            for replication in range(10):
                trainer = train_model(dataset_size, replication, rerun=False)
                for location in range(12):
                    location_ = np.random.normal(loc=[location] * 10, scale=0.1)
                    mean, std = posterior_bias(trainer, location_)
                    mmd = posterior_mmd(trainer, location_)

                    metrics[0, dataset_size_idx, location, replication] = mean[0]
                    metrics[1, dataset_size_idx, location, replication] = std[0]
                    metrics[2, dataset_size_idx, location, replication] = mmd

        with open(file_path, "wb") as file:
            pickle.dump(metrics, file)

    with open(file_path, "rb") as file:
        metrics = pickle.load(file)

    return metrics


def posterior_metrics_no_sc():
    # shape:
    # metric (mean_bias, std_bias, mmd),
    # location (range(12))
    # replication (range(10)).
    file_path = Path(__file__).parents[0] / "data" / "posterior_metrics_no_sc.pkl"

    if not file_path.exists():
        metrics = np.zeros((3, 12, 10))

        for replication in range(10):
            trainer = train_model_no_sc(10, replication, rerun=False)
            for location in range(12):
                location_ = np.random.normal(loc=[location] * 10, scale=0.1)
                mean, std = posterior_bias(trainer, location_)
                mmd = posterior_mmd(trainer, location_)

                metrics[0, location, replication] = mean[0]
                metrics[1, location, replication] = std[0]
                metrics[2, location, replication] = mmd

        with open(file_path, "wb") as file:
            pickle.dump(metrics, file)

    with open(file_path, "rb") as file:
        metrics = pickle.load(file)

    return metrics


def posterior_metric_plot():
    metrics = posterior_metrics()
    metrics_no_sc = posterior_metrics_no_sc()

    fig, axes = plt.subplots(1, 3, figsize=(14, 5))
    axes = axes.flatten()

    titles = ["posterior mean bias", "posterior SD bias", r"posterior distance"]
    ylabels = [
        r"$\mu_{\text{est.}} - \mu_{\text{true}}$",
        r"$\sigma_{\text{est.}} - \sigma_{\text{true}}$",
        "MMD",
    ]
    labels = ["0 (no SC)", "1", "4", "32"]
    offsets = [0, -0.2, 0, 0.2]

    for i, ax in enumerate(axes):
        means = [np.mean(metrics_no_sc[i, :, :], axis=-1)] + [
            np.mean(metrics[i, dim_idx, :, :], axis=-1) for dim_idx in range(3)
        ]
        stds = [np.std(metrics_no_sc[i, :, :], axis=-1)] + [
            np.std(metrics[i, dim_idx, :, :], axis=-1) for dim_idx in range(3)
        ]

        axes[i] = comparison_lines(
            ax,
            means,
            stds,
            linestyles=["solid", ":", "-.", "--"],
            colors=["#7b001b", "#96b7d6", "#2a63ab", "#001842"],
            offsets=offsets,
            labels=labels,
            ylabel=ylabels[i],
            title=titles[i],
        )

    axes[1].axhline(-1 / np.sqrt(2), linestyle="dotted", color="black")
    axes[1].text(-0.3, -1 / np.sqrt(2) + 0.05, r"$\sigma_{\text{est}}=0$", fontsize=14)

    fig.supxlabel("\n\n\n", fontsize=16)

    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc="lower right",
        ncol=4,
        fontsize=16,
        title="dataset size",
        title_fontsize=16,
        bbox_to_anchor=(1, 0),
        frameon=False,
    )
    legend.get_title().set_position((-270, -23))
    fig.suptitle("Size of the unlabeled training data", fontsize=22, x=0, ha="left")
    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "posterior_metric_plot.pdf"
    fig.savefig(path)

    return fig


def mmd_plot():
    metrics = posterior_metrics()
    metrics_no_sc = posterior_metrics_no_sc()

    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7.5 * 1.2, 4.5 * 1.2))

    ylabel = "MMD"
    labels = ["0 (no SC)", "1", "4", "32"]
    offsets = [0, -0.2, 0, 0.2]

    means = [np.mean(metrics_no_sc[2, :, :], axis=-1)] + [
        np.mean(metrics[2, dim_idx, :, :], axis=-1) for dim_idx in range(3)
    ]
    stds = [np.std(metrics_no_sc[2, :, :], axis=-1)] + [
        np.std(metrics[2, dim_idx, :, :], axis=-1) for dim_idx in range(3)
    ]

    ax = comparison_lines(
        ax,
        means,
        stds,
        linestyles=["solid", ":", "-.", "--"],
        colors=["#7b001b", "#96b7d6", "#2a63ab", "#001842"],
        labels=labels,
        offsets=offsets,
        ylabel=ylabel,
        linewidth=3,
        title_fontsize=26,
        xlabel_fontsize=25,
        ylabel_fontsize=25,
        tick_labelsize=20,
        title="size of unlabeled training data",
    )

    fig.supxlabel("\n\n\n", fontsize=16)

    handles, labels = ax.get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc="lower right",
        ncol=4,
        fontsize=22,
        bbox_to_anchor=(0.97, 0),
        frameon=False,
    )
    # legend.get_title().set_position((-350, -20))
    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "mmd_plot.pdf"
    fig.savefig(path)

    return fig


if __name__ == "__main__":
    # for dataset_size in [1, 4, 32]:
    #     for replication in range(10):
    #         train_model(dataset_size, replication, rerun=False)
    #
    # for replication in range(10):
    #     train_model_no_sc(10, replication, rerun=False)

    # posterior_metric_plot()
    mmd_plot()
