"""Leave-one-out data valuation"""
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neural_network import MLPClassifier as MLP
from .decision_boundary import Scatter2D
from tqdm import tqdm


class LOODV:
    """"""

    def __init__(self, hidden_layer_sizes, metric, X_base=None, y_base=None):
        """

        Args:
            model:
            metric:
            X_base:
            y_base:
        """
        self.hidden_layer_sizes = hidden_layer_sizes
        self.metric = metric

        self.X_base = X_base
        self.y_base = y_base

    def _maybe_plot(self, X, y, sct, model, base_model_f, cmap, cmap_base):
        """"""
        # Maybe add contour plot
        if cmap is not None:
            sct.add_boundary(model.predict, cmap=cmap)

        sct.add_boundary(model.predict)

        if base_model_f is not None:
            sct.add_boundary(base_model_f)
            if cmap_base is not None:
                sct.add_boundary(base_model_f, cmap=cmap_base)

        sct.scatter(self.X_base, self.y_base)
        marker = "D"  # "x" if y[i] == 1 else "D"
        color = "tab:orange" if y[i] == 1 else "tab:blue"
        plt.scatter(X[i, 0], X[i, 1], marker=marker, color=color, s=100)

        plt.title("Data Value: " + str(db_diff[-1]))
        plt.xlim(sct.x_lim)
        plt.ylim(sct.y_lim)

        if save_folder is not None:
            plt.savefig(save_folder + "/dv_loo" + str(i) + ".png", dpi=150, bbox_inches='tight')

        sct.show(scatter=False)

    def predict_dv(self, X, y, inv_diff=False, plot_step=False, save_folder=None,
                   base_model_f=None, cmap=None, cmap_base=None):
        """"""

        if plot_step:
            sct = Scatter2D(X, y)

        db_diff = []
        for i in tqdm(range(X.shape[0])):

            if self.X_base is not None:
                X_train_new = np.vstack([self.X_base, X[i]])
                y_train_new = np.hstack([self.y_base, y[i]])
            else:
                X_train_new = np.delete(X, i, 0)
                y_train_new = np.delete(y, i)

            model = MLP(hidden_layer_sizes=self.hidden_layer_sizes, activation='relu', max_iter=2000)
            model.fit(X_train_new, y_train_new)

            if inv_diff:
                db_diff.append(1 - self.metric(model.predict))
            else:
                db_diff.append(self.metric(model.predict))

            if plot_step:
                self._maybe_plot(X, y, sct, model, base_model_f, cmap, cmap_base)

        return np.asarray(db_diff)
