"""Data valuation with memorization."""
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neural_network import MLPClassifier as MLP
from .decision_boundary import Scatter2D
from tqdm import tqdm


class MemorizationDV:
    """"""

    def __init__(self, hidden_layer_sizes, X_base=None, y_base=None, num_models=5):
        """

        Args:
            model:
            X_base:
            y_base:
            num_models:
        """
        self.hidden_layer_sizes = hidden_layer_sizes

        self.X_base = X_base
        self.y_base = y_base
        self.num_models = num_models

    def predict_dv(self, X, y):
        """"""
        db_diff = []

        # Make predictions with the baseline model
        base_predictions = []
        for _ in range(self.num_models):
            base_model = MLP(hidden_layer_sizes=self.hidden_layer_sizes, activation='relu', max_iter=2000)

            if self.X_base is not None:
                base_model.fit(self.X_base, self.y_base)
            else:
                base_model.fit(X, y)

            base_predictions.append(base_model.predict(X))

        base_predictions = np.array(base_predictions).T

        # Make predictions with models trained without each x_i
        predictions = []
        for i in tqdm(range(X.shape[0])):

            if self.X_base is not None:
                X_train_new = np.vstack([self.X_base, X[i]])
                y_train_new = np.hstack([self.y_base, y[i]])
            else:
                X_train_new = np.delete(X, i, 0)
                y_train_new = np.delete(y, i)

            pred = []
            for _ in range(self.num_models):
                model = MLP(hidden_layer_sizes=self.hidden_layer_sizes, activation='relu', max_iter=2000)
                model.fit(X_train_new, y_train_new)
                pred.append(model.predict(X)[i])

            predictions.append(np.array(pred))

        predictions = np.array(predictions)

        # Compare predictions
        diff = np.abs(predictions - base_predictions)
        return np.mean(diff, axis=1)


