"""Leave-pair-out data valuation"""
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neural_network import MLPClassifier as MLP
from .decision_boundary import Scatter2D
from tqdm import tqdm


class LPODV:
    """Leave-pair-out valuation."""

    def __init__(self, hidden_layer_sizes, metric, X_base=None, y_base=None):
        """

        Args:
            model:
            metric:
            X_base:
            y_base:
        """
        self.hidden_layer_sizes = hidden_layer_sizes
        self.metric = metric

        self.X_base = X_base
        self.y_base = y_base

    def predict_dv(self, X, y, inv_diff=False, plot_step=False, save_folder=None, base_model_f=None):
        """"""
        dv = np.zeros((X.shape[0], X.shape[0]))

        if plot_step:
            sct = Scatter2D(X, y)

        pairs = [(i, j) for i in range(X.shape[0]) for j in range(X.shape[0])]

        for i, j in tqdm(pairs):

            if self.X_base is not None:
                X_train_new = np.vstack([self.X_base, X[i], X[j]])
                y_train_new = np.hstack([self.y_base, y[i], y[j]])
            else:
                X_train_new = np.delete(X, [i, j], 0)
                y_train_new = np.delete(y, [i, j])

            model = MLP(hidden_layer_sizes=self.hidden_layer_sizes, activation='relu', max_iter=2000)
            model.fit(X_train_new, y_train_new)

            if inv_diff:
                dv[i, j] = 1 - self.metric(model.predict) if inv_diff else self.metric(model.predict)

        return dv
