from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from experiments.multivariate_normal import network_config as analytic_likelihood
from experiments.multivariate_normal.accurate_posterior import (
    posterior_bias,
    posterior_mmd,
)
from experiments.multivariate_normal.generative_model import get_generative_model

from experiments.multivariate_normal.estimated_likelihood.network_config import (
    get_trainer,
    get_trainer_no_sc,
)
from experiments.multivariate_normal.plots import comparison_lines
from src.self_consistency_real.utils import delete_checkpoints


def train_model(estimated_likelihood: bool, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    subfolder = (
        "estimated_likelihood" if estimated_likelihood else "analytic_likelihood"
    )
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / subfolder
        / f"npe_with_sc_{replication}"
    )

    generative_model = get_generative_model(rng=rng)
    train_data = generative_model(1024)

    if estimated_likelihood:
        trainer = get_trainer(rng=rng, checkpoint_path=checkpoint_dir)
    else:
        trainer = analytic_likelihood.get_trainer(
            rng=rng, checkpoint_path=checkpoint_dir
        )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def train_model_no_sc(estimated_likelihood: bool, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    subfolder = (
        "estimated_likelihood" if estimated_likelihood else "analytic_likelihood"
    )
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / subfolder
        / f"npe_only_{replication}"
    )

    generative_model = get_generative_model(rng=rng)
    train_data = generative_model(1024)

    if estimated_likelihood:
        trainer = get_trainer_no_sc(rng=rng, checkpoint_path=checkpoint_dir)
    else:
        trainer = analytic_likelihood.get_trainer_no_sc(
            rng=rng, checkpoint_path=checkpoint_dir
        )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def posterior_metrics():
    # shape:
    # metric (mean_bias, std_bias, mmd),
    # estimated likelihood (True, False),
    # location (range(12))
    # replication (range(10)).
    metrics = np.zeros((3, 2, 12, 10))

    for likelihood_idx, estimated_likelihood in enumerate([True, False]):
        for replication in range(10):
            trainer = train_model(estimated_likelihood, replication, rerun=False)
            for location in range(12):
                location_ = np.random.normal(loc=[location] * 10, scale=0.1)
                mean, std = posterior_bias(trainer, location_)
                mmd = posterior_mmd(trainer, location_)

                metrics[0, likelihood_idx, location, replication] = mean[0]
                metrics[1, likelihood_idx, location, replication] = std[0]
                metrics[2, likelihood_idx, location, replication] = mmd

    return metrics


def posterior_metrics_no_sc():
    # shape:
    # metric (mean_bias, std_bias, mmd),
    # estimated likelihood (True, False),
    # location (range(12))
    # replication (range(10)).
    metrics = np.zeros((3, 2, 12, 10))

    for likelihood_idx, estimated_likelihood in enumerate([True, False]):
        for replication in range(10):
            trainer = train_model_no_sc(estimated_likelihood, replication, rerun=False)
            for location in range(12):
                location_ = np.random.normal(loc=[location] * 10, scale=0.1)
                mean, std = posterior_bias(trainer, location_)
                mmd = posterior_mmd(trainer, location_)

                metrics[0, likelihood_idx, location, replication] = mean[0]
                metrics[1, likelihood_idx, location, replication] = std[0]
                metrics[2, likelihood_idx, location, replication] = mmd

    return metrics


def posterior_metric_plot():
    metrics = posterior_metrics()
    metrics_no_sc = posterior_metrics_no_sc()

    fig, axes = plt.subplots(1, 3, figsize=(14, 5))
    axes = axes.flatten()

    titles = ["posterior mean bias", "posterior SD bias", r"posterior distance"]
    ylabels = [
        r"$\mu_{\text{est.}} - \mu_{\text{true}}$",
        r"$\sigma_{\text{est.}} - \sigma_{\text{true}}$",
        "MMD",
    ]
    labels = ["estimated", "estimated (no SC)", "analytic", "analytic (no SC)"]
    offsets = [-0.1, -0.1, 0.1, 0.1]

    for i, ax in enumerate(axes):
        means_sc = [np.mean(metrics[i, idx, :, :], axis=-1) for idx in range(2)]
        means_no_sc = [
            np.mean(metrics_no_sc[i, idx, :, :], axis=-1) for idx in range(2)
        ]
        means = [val for pair in zip(means_sc, means_no_sc) for val in pair]

        stds_sc = [np.std(metrics[i, idx, :, :], axis=-1) for idx in range(2)]
        stds_no_sc = [np.std(metrics_no_sc[i, idx, :, :], axis=-1) for idx in range(2)]
        stds = [val for pair in zip(stds_sc, stds_no_sc) for val in pair]

        axes[i] = comparison_lines(
            ax,
            means,
            stds,
            linestyles=[":", ":", "solid", "solid"],
            colors=["#96b7d6", "#ffacac", "#001842", "#7b001b"],
            offsets=offsets,
            labels=labels,
            ylabel=ylabels[i],
            title=titles[i],
        )

    axes[1].axhline(-1 / np.sqrt(2), linestyle="dotted", color="black")
    axes[1].text(-0.3, -1 / np.sqrt(2) + 0.05, r"$\sigma_{\text{est}}=0$", fontsize=14)

    fig.supxlabel("\n\n\n", fontsize=16)

    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc="lower right",
        ncol=2,
        fontsize=16,
        title="likelihood",
        title_fontsize=16,
        bbox_to_anchor=(1, 0),
        frameon=False,
    )
    legend.get_title().set_position((-280, -34))
    fig.suptitle("Likelihood function", fontsize=22, x=0, ha="left")
    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "posterior_metric_plot.pdf"
    fig.savefig(path)

    return fig


if __name__ == "__main__":
    for estimated_likelihood in [True, False]:
        for replication in range(10):
            train_model(estimated_likelihood, replication, rerun=False)
            train_model_no_sc(estimated_likelihood, replication, rerun=False)

    posterior_metric_plot()
