"""Leave-one-out data valuation"""
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neural_network import MLPClassifier as MLP
from .decision_boundary import Scatter2D
from tqdm import tqdm


class LOODV:
    """"""

    def __init__(self, hidden_layer_sizes, metric, X_base=None, y_base=None):
        """

        Args:
            model:
            metric:
            X_base:
            y_base:
        """
        self.hidden_layer_sizes = hidden_layer_sizes
        self.metric = metric

        self.X_base = X_base
        self.y_base = y_base

    def predict_dv(self, X, y, inv_diff=False, plot_step=False, save_folder=None, base_model_f=None):
        """"""
        db_diff = []

        if plot_step:
            sct = Scatter2D(X, y)

        for i in tqdm(range(X.shape[0])):

            if self.X_base is not None:
                X_train_new = np.vstack([self.X_base, X[i]])
                y_train_new = np.hstack([self.y_base, y[i]])
            else:
                X_train_new = np.delete(X, i, 0)
                y_train_new = np.delete(y, i)

            model = MLP(hidden_layer_sizes=self.hidden_layer_sizes, activation='relu', max_iter=2000)
            model.fit(X_train_new, y_train_new)

            if inv_diff:
                db_diff.append(1 - self.metric(model.predict))
            else:
                db_diff.append(self.metric(model.predict))

                if plot_step:
                    sct.add_boundary(model.predict)

                    if base_model_f is not None:
                        sct.add_boundary(base_model_f)

                    sct.scatter(self.X_base, self.y_base)
                    marker = "x" if y[i] == 1 else "D"
                    sct.scatter(np.array([X[i]]), np.array([y[i]]), marker=marker, scatter_size=np.array([100]))

                    if save_folder is not None:
                        plt.savefig(save_folder+"/dv_loo"+str(i))
                    sct.show(scatter=False, title="Data Value: "+str(db_diff[-1]))

        return np.asarray(db_diff)
