import os

import matplotlib.pyplot as plt
import numpy as np


def _cdf_with_replacement(i, n, N):
    return (i / N) ** n


def _compute_variance(N, cur_data, expected_max_cond_n, pdfs):
    """
    this computes the standard error of the max.
    this is what the std dev of the bootstrap estimates of the mean of the max converges to, as
    is stated in the last sentence of the summary on page 10 of
    http://www.stat.cmu.edu/~larry/=stat705/Lecture13.pdf
    """
    variance_of_max_cond_n = []
    for n in range(N):
        # for a given n, estimate variance with \sum(p(x) * (x-mu)^2), where mu is \sum(p(x) * x).
        cur_var = 0
        for i in range(N):
            cur_var += (cur_data[i] - expected_max_cond_n[n]) ** 2 * pdfs[n][i]
        cur_var = np.sqrt(cur_var)
        variance_of_max_cond_n.append(cur_var)
    return variance_of_max_cond_n


# this implementation assumes sampling with replacement for computing the empirical cdf
def samplemax(validation_performance):
    validation_performance = list(validation_performance)
    validation_performance.sort()
    N = len(validation_performance)
    pdfs = []
    for n in range(1, N + 1):
        # the CDF of the max
        F_Y_of_y = []
        for i in range(1, N + 1):
            F_Y_of_y.append(_cdf_with_replacement(i, n, N))

        f_Y_of_y = []
        cur_cdf_val = 0
        for i in range(len(F_Y_of_y)):
            f_Y_of_y.append(F_Y_of_y[i] - cur_cdf_val)
            cur_cdf_val = F_Y_of_y[i]

        pdfs.append(f_Y_of_y)

    expected_max_cond_n = []
    for n in range(N):
        # for a given n, estimate expected value with \sum(x * p(x)), where p(x) is prob x is max.
        cur_expected = 0
        for i in range(N):
            cur_expected += validation_performance[i] * pdfs[n][i]
        expected_max_cond_n.append(cur_expected)

    var_of_max_cond_n = _compute_variance(
        N, validation_performance, expected_max_cond_n, pdfs
    )

    return {
        "mean": expected_max_cond_n,
        "var": var_of_max_cond_n,
        "max": np.max(validation_performance),
        "min": np.min(validation_performance),
    }


def evp_plot(
    datas,
    data_name,
    fontsize,
    metric_name,
    colors,
    linewidth,
    model,
    alias,
    dataset_alias,
    plot_name,
    pretrain,
    logx=False,
    plot_errorbar=True,
    avg_time=0,
):
    _, cur_ax = plt.subplots(1, 1)
    for i, p in enumerate(datas):
        # to set default values
        linestyle = "-"
        errorbar_kind = "shade"
        errorbar_alpha = 0.1
        x_axis_time = avg_time != 0

        cur_ax.set_title(f"EVP {dataset_alias[data_name]} {model}", fontsize=fontsize)
        cur_ax.set_ylabel(
            "Expected validation " + metric_name,
            fontsize=fontsize,
        )

        if x_axis_time:
            cur_ax.set_xlabel("Training duration", fontsize=fontsize)
        else:
            cur_ax.set_xlabel("Hyperparameter assignments", fontsize=fontsize)

        if logx:
            cur_ax.set_xscale("log")

        means = datas[p]["mean"]
        vars = datas[p]["var"]
        max_acc = datas[p]["max"]
        min_acc = datas[p]["min"]

        if x_axis_time:
            x_axis = [avg_time * (i + 1) for i in range(len(means))]
        else:
            x_axis = [i + 1 for i in range(len(means))]

        if plot_errorbar:
            if errorbar_kind == "shade":
                minus_vars = [
                    x - y if (x - y) >= min_acc else min_acc
                    for x, y in zip(means, vars)
                ]
                plus_vars = [
                    x + y if (x + y) <= max_acc else max_acc
                    for x, y in zip(means, vars)
                ]
                plt.fill_between(
                    x_axis,
                    minus_vars,
                    plus_vars,
                    alpha=errorbar_alpha,
                    color=colors[p],
                )
            else:
                cur_ax.errorbar(
                    x_axis,
                    means,
                    yerr=vars,
                    linestyle=linestyle,
                    linewidth=linewidth,
                    color=colors[p],
                )
        cur_ax.plot(
            x_axis,
            means,
            linestyle=linestyle,
            linewidth=linewidth,
            label=alias[p],
            color=colors[p],
        )
        # ax.legend()

        left, right = cur_ax.get_xlim()

        plt.xlim((left, right))
        plt.locator_params(axis="y", nbins=10)
        plt.tight_layout()

    cur_ax.legend(loc="lower right")
    cur_ax.grid(color="#dadada", linewidth=0.5)
    save_plot(plot_name)


def save_plot(name):

    if not os.path.exists("plots"):
        os.mkdir("plots")
    plt.savefig(name, dpi=300)


def _max(row):
    result = (0, 0, 0, 0)
    for t in row:
        if t[0] > result[0]:
            result = t
    return result


_max.__name__ = "max_mean"
