from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import numpy as np

from experiments.multivariate_normal.accurate_posterior import (
    analytic_posterior_draws,
    posterior_bias,
    posterior_mmd,
)
from experiments.multivariate_normal.generative_model import get_generative_model
from experiments.multivariate_normal.network_config import (
    get_trainer,
    get_trainer_no_sc,
)
from experiments.multivariate_normal.plots import reference_scatter, comparison_lines
from src.self_consistency_real.utils import delete_checkpoints


def train_model(dataset_mean: float, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / f"mean_{dataset_mean}"
        / f"npe_with_sc_{replication}"
    )

    generative_model = get_generative_model(rng=rng)
    train_data = generative_model(1024)

    trainer = get_trainer(
        dataset_mean=dataset_mean, rng=rng, checkpoint_path=checkpoint_dir
    )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def train_model_no_sc(dataset_mean: float, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / f"mean_{dataset_mean}"
        / f"npe_only_{replication}"
    )

    generative_model = get_generative_model(rng=rng)
    train_data = generative_model(1024)

    trainer = get_trainer_no_sc(
        dataset_mean=dataset_mean, rng=rng, checkpoint_path=checkpoint_dir
    )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def posterior_metrics():
    # shape:
    # metric [mean_bias, std_bias, mmd],
    # dataset_mean [0.0, 1.0, 2.0, 3.0, 5.0],
    # location [*range(12)]
    # replication [*range(10)].
    metrics = np.zeros((3, 5, 12, 10))

    for mean_idx, dataset_mean in enumerate([0.0, 1.0, 2.0, 3.0, 5.0]):
        for replication in range(10):
            trainer = train_model(dataset_mean, replication, rerun=False)
            for location in range(12):
                location_ = np.random.normal(loc=[location] * 10, scale=0.1)
                mean, std = posterior_bias(trainer, location_)
                mmd = posterior_mmd(trainer, location_)

                metrics[0, mean_idx, location, replication] = mean[0]
                metrics[1, mean_idx, location, replication] = std[0]
                metrics[2, mean_idx, location, replication] = mmd

    return metrics


def posterior_metrics_no_sc():
    # shape:
    # metric [mean_bias, std_bias, mmd],
    # dataset_mean [0.0, 1.0, 2.0, 3.0, 5.0],
    # location [*range(12)]
    # replication [*range(10)].
    metrics = np.zeros((3, 5, 12, 10))

    for mean_idx, dataset_mean in enumerate([0.0, 1.0, 2.0, 3.0, 5.0]):
        for replication in range(10):
            trainer = train_model_no_sc(dataset_mean, replication, rerun=False)
            for location in range(12):
                location_ = np.random.normal(loc=[location] * 10, scale=0.1)
                mean, std = posterior_bias(trainer, location_)
                mmd = posterior_mmd(trainer, location_)

                metrics[0, mean_idx, location, replication] = mean[0]
                metrics[1, mean_idx, location, replication] = std[0]
                metrics[2, mean_idx, location, replication] = mmd

    return metrics


def posterior_metric_plot():
    metrics = posterior_metrics()
    metrics_no_sc = posterior_metrics_no_sc()

    fig, axes = plt.subplots(1, 3, figsize=(14, 5))
    axes = axes.flatten()

    titles = ["posterior mean bias", "posterior SD bias", r"posterior distance"]
    ylabels = [
        r"$\mu_{\text{est.}} - \mu_{\text{true}}$",
        r"$\sigma_{\text{est.}} - \sigma_{\text{true}}$",
        "MMD",
    ]
    labels = ["no SC", "0", "1", "2", "3", "5"]
    offsets = [-0.2] + [0] * 5

    for i, ax in enumerate(axes):
        means = [np.mean(metrics_no_sc[i, 1, :, :], axis=-1)] + [
            np.mean(metrics[i, mean_idx, :, :], axis=-1) for mean_idx in range(5)
        ]
        stds = [np.std(metrics_no_sc[i, 1, :, :], axis=-1)] + [
            np.std(metrics[i, mean_idx, :, :], axis=-1) for mean_idx in range(5)
        ]

        axes[i] = comparison_lines(
            ax,
            means,
            stds,
            linestyles=[
                "solid",
                (0, (1, 5)),
                (0, (1, 1)),
                (0, (3, 5, 1, 5)),
                (0, (5, 5)),
                "solid",
            ],
            colors=["#7b001b", "grey", "#a5a6d9", "#7788c0", "#4b4c9b", "#052744"],
            offsets=offsets,
            labels=labels,
            ylabel=ylabels[i],
            title=titles[i],
        )

    axes[1].axhline(-1 / np.sqrt(2), linestyle="dotted", color="black")
    axes[1].text(-0.3, -1 / np.sqrt(2) + 0.05, r"$\sigma_{\text{est}}=0$", fontsize=14)

    fig.supxlabel("\n\n\n", fontsize=16)

    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc="lower right",
        ncol=6,
        fontsize=16,
        title="dataset mean",
        title_fontsize=16,
        bbox_to_anchor=(1, 0),
        frameon=False,
    )
    legend.get_title().set_position((-340, -23))
    fig.suptitle("Mean of the unlabeled training data", fontsize=22, x=0, ha="left")
    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "posterior_metric_plot.pdf"
    fig.savefig(path)

    return fig


if __name__ == "__main__":
    for dataset_mean in [0.0, 1.0, 2.0, 3.0, 5.0]:
        for replication in range(10):
            train_model(dataset_mean, replication, rerun=False)
            train_model_no_sc(dataset_mean, replication, rerun=False)

    posterior_metric_plot()
