from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Patch

from experiments.multivariate_normal.accurate_posterior import (
    analytic_posterior_draws,
    posterior_bias,
    posterior_mmd,
)
from experiments.multivariate_normal.generative_model import get_generative_model
from experiments.multivariate_normal.network_config import (
    get_trainer,
    get_trainer_no_sc,
)
from experiments.multivariate_normal.plots import (
    comparison_lines,
    reference_contour,
    reference_scatter,
)
from src.self_consistency_real.utils import delete_checkpoints
import pickle


def train_model(dimension: int, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / f"dim_{dimension}"
        / f"npe_with_sc_{replication}"
    )

    generative_model = get_generative_model(dimension=dimension, rng=rng)
    train_data = generative_model(1024)

    trainer = get_trainer(dimension=dimension, rng=rng, checkpoint_path=checkpoint_dir)
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def train_model_no_sc(dimension: int, replication: int, rerun=True):
    rng = np.random.default_rng(replication)
    checkpoint_dir = (
        Path(__file__).parents[0]
        / "checkpoints"
        / f"dim_{dimension}"
        / f"npe_only_{replication}"
    )

    generative_model = get_generative_model(dimension=dimension, rng=rng)
    train_data = generative_model(1024)

    trainer = get_trainer_no_sc(
        dimension=dimension, rng=rng, checkpoint_path=checkpoint_dir
    )
    if rerun or trainer.loss_history.latest == 0:
        delete_checkpoints(checkpoint_dir)
        trainer.train_offline(train_data, epochs=100, lr=0.0005, batch_size=32)

    return trainer


def posterior_metrics():
    # shape:
    # metric (mean_bias, std_bias, mmd),
    # dimension (2, 10, 100),
    # location (range(12))
    # replication (range(10)).
    file_path = Path(__file__).parents[0] / "data" / "posterior_metrics.pkl"

    if not file_path.exists():
        metrics = np.zeros((3, 3, 12, 10))

        for dim_idx, dimension in enumerate((2, 10, 100)):
            for replication in range(10):
                trainer = train_model(dimension, replication, rerun=False)
                for location in range(12):
                    location_ = np.random.normal(loc=[location] * dimension, scale=0.1)
                    mean, std = posterior_bias(trainer, location_)
                    mmd = posterior_mmd(trainer, location_)

                    metrics[0, dim_idx, location, replication] = mean[0]
                    metrics[1, dim_idx, location, replication] = std[0]
                    metrics[2, dim_idx, location, replication] = mmd

        with open(file_path, "wb") as file:
            pickle.dump(metrics, file)

    with open(file_path, "rb") as file:
        metrics = pickle.load(file)

    return metrics


def posterior_metrics_no_sc():
    # shape:
    # metric (mean_bias, std_bias, mmd),
    # dimension (2, 10, 100),
    # location (range(12))
    # replication (range(10)).
    file_path = Path(__file__).parents[0] / "data" / "posterior_metrics_no_sc.pkl"

    if not file_path.exists():
        metrics = np.zeros((3, 3, 12, 10))

        for dim_idx, dimension in enumerate((2, 10, 100)):
            for replication in range(10):
                trainer = train_model_no_sc(dimension, replication, rerun=False)
                for location in range(12):
                    location_ = np.random.normal(loc=[location] * dimension, scale=0.1)
                    mean, std = posterior_bias(trainer, location_)
                    mmd = posterior_mmd(trainer, location_)

                    metrics[0, dim_idx, location, replication] = mean[0]
                    metrics[1, dim_idx, location, replication] = std[0]
                    metrics[2, dim_idx, location, replication] = mmd

        with open(file_path, "wb") as file:
            pickle.dump(metrics, file)

    with open(file_path, "rb") as file:
        metrics = pickle.load(file)

    return metrics


def posterior_metric_plot():
    metrics = posterior_metrics()
    metrics_no_sc = posterior_metrics_no_sc()

    fig, axes = plt.subplots(1, 3, figsize=(14, 5))
    axes = axes.flatten()

    titles = ["posterior mean bias", "posterior SD bias", r"posterior distance"]
    ylabels = [
        r"$\mu_{\text{est.}} - \mu_{\text{true}}$",
        r"$\sigma_{\text{est.}} - \sigma_{\text{true}}$",
        "MMD",
    ]
    labels = ["2", "2 (no SC)", "10", "10 (no SC)", "100", "100 (no SC)"]
    offsets = [-0.1, -0.1, 0, 0, 0.1, 0.1]

    for i, ax in enumerate(axes):
        means_sc = [np.mean(metrics[i, dim_idx, :, :], axis=-1) for dim_idx in range(3)]
        means_no_sc = [
            np.mean(metrics_no_sc[i, dim_idx, :, :], axis=-1) for dim_idx in range(3)
        ]
        means = [val for pair in zip(means_sc, means_no_sc) for val in pair]

        stds_sc = [np.std(metrics[i, dim_idx, :, :], axis=-1) for dim_idx in range(3)]
        stds_no_sc = [
            np.std(metrics_no_sc[i, dim_idx, :, :], axis=-1) for dim_idx in range(3)
        ]
        stds = [val for pair in zip(stds_sc, stds_no_sc) for val in pair]

        axes[i] = comparison_lines(
            ax,
            means,
            stds,
            linestyles=[":", ":", "-.", "-.", "--", "--"],
            colors=["#96b7d6", "#ffacac", "#2a63ab", "#b33b4b", "#001842", "#7b001b"],
            labels=labels,
            offsets=offsets,
            ylabel=ylabels[i],
            title=titles[i],
        )

    axes[1].axhline(-1 / np.sqrt(2), linestyle="dotted", color="black")
    axes[1].text(-0.3, -1 / np.sqrt(2) + 0.05, r"$\sigma_{\text{est}}=0$", fontsize=14)

    fig.supxlabel("\n\n\n", fontsize=16)

    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc="lower right",
        ncol=3,
        fontsize=16,
        title="Parameter dimensionality",
        title_fontsize=16,
        bbox_to_anchor=(1, 0),
        frameon=False,
    )
    legend.get_title().set_position((-350, -34))
    fig.suptitle("Parameter dimensionality", fontsize=22, x=0, ha="left")
    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "posterior_metric_plot.pdf"
    fig.savefig(path)

    return fig


def posterior_scatter_plot():
    trainer = train_model(10, 0, rerun=False)
    trainer_no_sc = train_model_no_sc(10, 0, rerun=False)

    fig, axes = plt.subplots(nrows=3, ncols=4, sharex=True, sharey=True)
    axes = axes.flatten()

    # scatter plots
    for i, ax in enumerate(axes):
        input_dict = {
            "parameters": np.array([[0.0] * 10], dtype=np.float32),
            "direct_conditions": np.array([[i] * 10], dtype=np.float32),
        }
        posterior_draws = trainer.amortizer.sample(input_dict, n_samples=100)
        accurate_draws = analytic_posterior_draws(
            [i] * 10,
            prior_mean=[0.0] * 10,
            prior_std=[1.0] * 10,
            y_std=None,
            n_samples=100,
        )
        no_sc_posterior = trainer_no_sc.amortizer.sample(input_dict, n_samples=100)

        axes[i] = reference_scatter(
            ax,
            posterior_draws,
            accurate_draws,
            no_sc_posterior=no_sc_posterior,
            title=r"$\mu_{\text{obs}}$=" + f"{i}",
        )

    # legend
    handles = [
        Patch(color="#2a63bb", label="NPE + SC"),
        Patch(color="#b33b4b", label="NPE only"),
        Patch(color="grey", label="reference"),
    ]

    fig.legend(
        handles=handles,
        loc="lower right",
        frameon=False,
        ncol=3,
        bbox_to_anchor=(1, 0.0),
    )

    # labels
    subxlabel = fig.supxlabel("dimension 1\n\n")
    subxlabel.set_y(0.05)

    fig.supylabel("dimension 2")

    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "posterior_scatter.pdf"

    fig.savefig(path)

    return fig


def posterior_contour_plot():
    trainer = train_model(10, 0, rerun=False)
    trainer_no_sc = train_model_no_sc(10, 0, rerun=False)

    fig, axes = plt.subplots(
        nrows=1, ncols=7, sharex=True, sharey=True, figsize=(9, 2.4)
    )
    mu = [0, 1, 2, 3, 5, 8, 11]
    axes = axes.flatten()

    # scatter plots
    for i, ax in enumerate(axes):
        print(i)
        location = np.random.normal(loc=[mu[i]] * 10, scale=0.1)
        input_dict = {
            "parameters": np.array([[0.0] * 10], dtype=np.float32),
            "direct_conditions": np.array([location], dtype=np.float32),
        }
        posterior_draws = trainer.amortizer.sample(input_dict, n_samples=2000)
        accurate_draws = analytic_posterior_draws(
            location,
            prior_mean=[0.0] * 10,
            prior_std=[1.0] * 10,
            y_std=None,
            n_samples=2000,
        )
        no_sc_posterior = trainer_no_sc.amortizer.sample(input_dict, n_samples=2000)

        axes[i] = reference_contour(
            ax,
            posterior_draws,
            accurate_draws,
            no_sc_posterior=no_sc_posterior,
            title=r"  $\mu_{\text{obs}}$=" + f"{mu[i]}",
            title_fontsize=13,
        )
        axes[i].tick_params(axis="both", labelsize=12)

    # legend
    handles = [
        Patch(color="#2a63bb", label="NPE + SC"),
        Patch(color="#b33b4b", label="NPE only"),
        Patch(color="grey", label="reference"),
    ]

    fig.legend(
        handles=handles,
        loc="lower right",
        frameon=False,
        ncol=3,
        bbox_to_anchor=(1, 0),
        fontsize=12,
    )

    # labels
    supxlabel = fig.supxlabel("dimension 1\n", fontsize=13)
    supxlabel.set_y(0.11)

    supylabel = fig.supylabel("dimension 2", fontsize=13)
    supylabel.set_x(0.05)
    supylabel.set_y(0.62)

    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "posterior_contour.pdf"

    fig.savefig(path, bbox_inches="tight")

    return fig


def mmd_plot():
    metrics = posterior_metrics()
    metrics_no_sc = posterior_metrics_no_sc()

    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7.5 * 1.2, 4.5 * 1.2))

    ylabel = "MMD"
    labels = ["2", "2 (no SC)", "10", "10 (no SC)", "100", "100 (no SC)"]
    offsets = [-0.1, -0.1, 0, 0, 0.1, 0.1]

    means_sc = [np.mean(metrics[2, dim_idx, :, :], axis=-1) for dim_idx in range(3)]
    means_no_sc = [
        np.mean(metrics_no_sc[2, dim_idx, :, :], axis=-1) for dim_idx in range(3)
    ]
    means = [val for pair in zip(means_sc, means_no_sc) for val in pair]

    stds_sc = [np.std(metrics[2, dim_idx, :, :], axis=-1) for dim_idx in range(3)]
    stds_no_sc = [
        np.std(metrics_no_sc[2, dim_idx, :, :], axis=-1) for dim_idx in range(3)
    ]
    stds = [val for pair in zip(stds_sc, stds_no_sc) for val in pair]

    ax = comparison_lines(
        ax,
        means,
        stds,
        linestyles=[":", ":", "-.", "-.", "--", "--"],
        colors=["#96b7d6", "#ffacac", "#2a63ab", "#b33b4b", "#001842", "#7b001b"],
        labels=labels,
        offsets=offsets,
        ylabel=ylabel,
        linewidth=3,
        title_fontsize=26,
        xlabel_fontsize=25,
        ylabel_fontsize=25,
        tick_labelsize=20,
        title="parameter dimensionality",
    )

    fig.supxlabel("\n\n\n", fontsize=16)

    handles, labels = ax.get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc="lower right",
        ncol=3,
        fontsize=22,
        bbox_to_anchor=(1, 0),
        frameon=False,
    )
    legend.get_title().set_position((-350, -34))
    fig.tight_layout()

    path = Path(__file__).parents[0] / "plots" / "mmd_plot.pdf"
    fig.savefig(path)

    return fig


if __name__ == "__main__":
    for dimension in [2, 10, 100]:
        for replication in range(10):
            train_model(dimension, replication, rerun=False)
            train_model_no_sc(dimension, replication, rerun=False)

    # posterior_metric_plot()
    # posterior_scatter_plot()
    mmd_plot()
