from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from experiments.multivariate_normal.accurate_posterior import (
    posterior_bias,
    posterior_mmd,
)
from experiments.multivariate_normal.plots import comparison_lines
from experiments.multivariate_normal.summary_network.generative_model import (
    get_generative_model,
)
from experiments.multivariate_normal.summary_network.network_config import (
    get_trainer,
    get_trainer_no_sc,
)
from src.self_consistency_real.utils import delete_checkpoints


def train_model(with_summary_network: bool, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    subfolder = "with_summary_net" if with_summary_network else "no_summary_net"
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / subfolder
        / f"npe_with_sc_{replication}"
    )

    generative_model = get_generative_model(
        with_summary_network=with_summary_network, dimension=10, rng=rng
    )
    train_data = generative_model(1024)

    trainer = get_trainer(
        with_summary_network=with_summary_network,
        dimension=10,
        rng=rng,
        checkpoint_path=checkpoint_dir,
    )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def train_model_no_sc(with_summary_network: bool, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    subfolder = "with_summary_net" if with_summary_network else "no_summary_net"
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / subfolder
        / f"npe_only_{replication}"
    )

    generative_model = get_generative_model(
        with_summary_network=with_summary_network, dimension=10, rng=rng
    )
    train_data = generative_model(1024)

    trainer = get_trainer_no_sc(
        with_summary_network=with_summary_network,
        rng=rng,
        checkpoint_path=checkpoint_dir,
    )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def posterior_metrics():
    # shape:
    # metric (mean_bias, std_bias, mmd),
    # with_summary_network (True, False),
    # location (range(12))
    # replication (range(10)).
    metrics = np.zeros((3, 2, 12, 10))

    for summary_idx, with_summary_net in enumerate([True, False]):
        for replication in range(10):
            trainer = train_model(with_summary_net, replication, rerun=False)
            for location in range(12):
                if with_summary_net:
                    location_ = [np.random.normal(loc=[location] * 10, scale=0.1)]
                else:
                    location_ = np.random.normal(loc=[location] * 10)
                mean, std = posterior_bias(trainer, location_)
                mmd = posterior_mmd(trainer, location_)

                metrics[0, summary_idx, location, replication] = mean[0]
                metrics[1, summary_idx, location, replication] = std[0]
                metrics[2, summary_idx, location, replication] = mmd

    return metrics


def posterior_metrics_no_sc():
    # shape:
    # metric (mean_bias, std_bias, mmd),
    # with_summary_network (True, False),
    # location (range(12))
    # replication (range(10)).
    metrics = np.zeros((3, 2, 12, 10))

    for summary_idx, with_summary_net in enumerate([True, False]):
        for replication in range(10):
            trainer = train_model_no_sc(with_summary_net, replication, rerun=False)
            for location in range(12):
                if with_summary_net:
                    location_ = [np.random.normal(loc=[location] * 10, scale=0.1)]
                else:
                    location_ = np.random.normal(loc=[location] * 10)
                mean, std = posterior_bias(trainer, location_)
                mmd = posterior_mmd(trainer, location_)

                metrics[0, summary_idx, location, replication] = mean[0]
                metrics[1, summary_idx, location, replication] = std[0]
                metrics[2, summary_idx, location, replication] = mmd

    return metrics


def posterior_metric_plot():
    metrics = posterior_metrics()
    metrics_no_sc = posterior_metrics_no_sc()

    fig, axes = plt.subplots(1, 3, figsize=(14, 5))
    axes = axes.flatten()

    titles = ["posterior mean bias", "posterior SD bias", r"posterior distance"]
    ylabels = [
        r"$\mu_{\text{est.}} - \mu_{\text{true}}$",
        r"$\sigma_{\text{est.}} - \sigma_{\text{true}}$",
        "MMD",
    ]
    labels = ["with", "with (no SC)", "without", "without (no SC)"]
    offsets = [-0.1, -0.1, 0.1, 0.1]

    for i, ax in enumerate(axes):
        means_sc = [np.mean(metrics[i, idx, :, :], axis=-1) for idx in range(2)]
        means_no_sc = [
            np.mean(metrics_no_sc[i, idx, :, :], axis=-1) for idx in range(2)
        ]
        means = [val for pair in zip(means_sc, means_no_sc) for val in pair]

        stds_sc = [np.std(metrics[i, idx, :, :], axis=-1) for idx in range(2)]
        stds_no_sc = [np.std(metrics_no_sc[i, idx, :, :], axis=-1) for idx in range(2)]
        stds = [val for pair in zip(stds_sc, stds_no_sc) for val in pair]

        axes[i] = comparison_lines(
            ax,
            means,
            stds,
            linestyles=[":", ":", "-.", "-."],
            colors=["#96b7d6", "#ffacac", "#052744", "#7b001b"],
            offsets=offsets,
            labels=labels,
            ylabel=ylabels[i],
            title=titles[i],
        )

    axes[1].axhline(-1 / np.sqrt(2), linestyle="dotted", color="black")
    axes[1].text(-0.3, -1 / np.sqrt(2) + 0.05, r"$\sigma_{\text{est}}=0$", fontsize=14)

    fig.supxlabel("\n\n\n", fontsize=16)

    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc="lower right",
        ncol=2,
        fontsize=16,
        title="summary network",
        title_fontsize=16,
        bbox_to_anchor=(1, 0),
        frameon=False,
    )
    legend.get_title().set_position((-270, -34))
    fig.suptitle("Summary network", fontsize=22, x=0, ha="left")
    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "posterior_metric_plot.pdf"
    fig.savefig(path)

    return fig


if __name__ == "__main__":
    for with_summary_network in [True, False]:
        for replication in range(10):
            train_model(with_summary_network, replication, rerun=False)
            train_model_no_sc(with_summary_network, replication, rerun=False)

    posterior_metric_plot()
