import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection
import seaborn as sns


def posterior_predictive_check(trainer):
    model = trainer.generative_model
    configurator = trainer.configurator
    amortizer = trainer.amortizer

    sim_dict = model(200)
    conf_dict = configurator(sim_dict)

    posterior = amortizer.sample(conf_dict, n_samples=100)
    x = sim_dict["sim_data"][0, :]
    pred = model.simulator(posterior)["sim_data"]

    fig = lineplot(x, pred)

    return fig


def lineplot(x, pred):
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.plot(x, color="#012F47", zorder=1)
    for i in range(100):
        ax.plot(pred[i], color="#AAAAAA", alpha=0.05, zorder=0)

    fig.tight_layout()

    return fig


def plot_recovery(trainer):
    model = trainer.generative_model
    configurator = trainer.configurator
    amortizer = trainer.amortizer

    sim_dict = model(200)
    conf_dict = configurator(sim_dict)

    posterior_samples = amortizer.sample(conf_dict, n_samples=5000)
    prior_samples = sim_dict["prior_draws"]

    # plot layout
    num_params = posterior_samples.shape[-1]
    num_cols = min(4, num_params)
    num_rows = int(np.ceil(num_params / num_cols))
    fig, axarr = centered_subplots(
        num_cols, num_params, figsize=(1.5 + 2 * num_cols, 1.5 + 2 * num_rows)
    )

    for i in range(num_params):
        ax = axarr[i]
        ax.tick_params(axis="both", labelsize=13)
        sns.despine(ax=ax)

        # add scatter points
        x = prior_samples[:, i]
        mean = np.mean(posterior_samples, axis=1)[:, i]
        lower_q, upper_q = np.quantile(posterior_samples, [0.025, 0.975], axis=1)[
            :, :, i
        ]

        # add errorbars
        lines = [[(x[i], lower_q[i]), (x[i], upper_q[i])] for i in range(len(x))]
        lc = LineCollection(lines, colors="#AAAAAA", linewidth=0.5, zorder=1)
        ax.add_collection(lc)

        # add diagonal line, scatter points and errorbars
        ax.plot(
            [0, 1],
            [0, 1],
            color="black",
            linestyle="dashed",
            transform=ax.transAxes,
            zorder=0,
        )
        ax.scatter(x, mean, color="#012F47", zorder=2, s=1)

        # make plots quadratic
        lower = min(lower_q)
        upper = max(upper_q)
        eps = (upper - lower) * 0.1

        ax.set_xlim(lower - eps, upper + eps)
        ax.set_ylim(lower - eps, upper + eps)
        ax.set_aspect("equal")
        ax.set_title(i)

    fig.supxlabel("\nGround truth", fontsize=18)
    fig.supylabel("Estimate", fontsize=18)

    fig.tight_layout()

    return fig


def centered_subplots(num_cols, num_params, **kwargs):
    fig = plt.figure(**kwargs)
    axarr = []
    gs_rows = int(np.ceil(num_params / num_cols))
    gs_cols = num_cols * 2

    gs = fig.add_gridspec(nrows=gs_rows, ncols=gs_cols)

    # fill all but last row
    for row_idx in range(gs_rows - 1):
        for col_idx in range(num_cols):
            ax = fig.add_subplot(gs[row_idx, (2 * col_idx) : (2 * col_idx + 2)])
            axarr.append(ax)

    # fill last row
    num_last_plots = num_params - (num_cols * (gs_rows - 1))

    start_pos = num_cols - num_last_plots
    for idx in range(num_last_plots):
        ax = fig.add_subplot(
            gs[gs_rows - 1, (start_pos + 2 * idx) : (start_pos + 2 * idx + 2)]
        )
        axarr.append(ax)

    return fig, axarr
