# Copyright (c) Microsoft Corporation and Fairlearn contributors.
# Licensed under the MIT License.

import numpy as np
import pandas as pd
from .moment import ClassificationMoment
from .moment import _ALL, _LABEL, _GROUP_ID

from fairlearn.utils._input_validation import _validate_and_reformat_input


class ErrorRate(ClassificationMoment):
    """Misclassification error."""

    short_name = "Err"

    def load_data(self, X, y, *, sensitive_features, control_features=None):
        """Load the specified data into the object."""
        _, y_train, sf_train, cf_train = \
            _validate_and_reformat_input(X, y,
                                         enforce_binary_labels=True,
                                         sensitive_features=sensitive_features,
                                         control_features=control_features)
        # The following uses X  so that the estimators get X untouched
        super().load_data(X, y_train, sensitive_features=sf_train)
        self.index = [_ALL]

    def gamma(self, predictor, prediction=None):
        """Return the gamma values for the given predictor."""
        pred = prediction if prediction is not None else predictor(self.X)
        if isinstance(pred, np.ndarray):
            # TensorFlow is returning an (n,1) array, which results
            # in the subtraction in the 'error =' line generating an
            # (n,n) array
            pred = np.squeeze(pred)
        error = pd.Series(data=(self.tags[_LABEL] - pred).abs().mean(),
                          index=self.index)
        self._gamma_descr = str(error)
        return error

    def project_lambda(self, lambda_vec):
        """Return the lambda values."""
        return lambda_vec

    def signed_weights(self, lambda_vec=None):
        """Return the signed weights."""
        if lambda_vec is None:
            return 2 * self.tags[_LABEL] - 1
        else:
            return lambda_vec[_ALL] * (2 * self.tags[_LABEL] - 1)


def get_lb_grp_row(df, y, s):
    return df[(df[_LABEL]==y) & (df[_GROUP_ID]==s)].iloc[0]


class ImbErrorRate(ClassificationMoment):
    """Misclassification error for imbalanced datasets."""

    short_name = "ImbErr"


    def __init__(self, *, alpha, C, delta_map):
        """
        
        Parameters
        ----------
        alpha : float
            hyperparameter
        C : float
            hyperparamter
        delta_map : :class:`dict`
            maps (y, s) to delta_{y, s}
        """
        super().__init__()
        self.alpha = alpha
        self.C = C
        self.delta_map = delta_map

    @staticmethod
    def compute_effective_n(a, n_i, n_is, S, y):
        n = n_i[n_i[_LABEL] == y].iloc[0]['count']
        nis = [get_lb_grp_row(n_is, y, s)['count'] for s in S]
        s_prod = np.prod(nis)
        return n * s_prod / (np.sqrt(s_prod) + a * np.sum([np.sqrt(n*nis/j) for j in S])) ** 2

    def get_Deltas(self):
        per_group_counts = self.tags.groupby([_LABEL,_GROUP_ID]).size().reset_index(name='counts')
        per_label_counts = self.tags.groupby([_LABEL]).size().reset_index(name='counts')

        Ys = self.tags[_LABEL].unique()
        Ss = self.tags[_GROUP_ID].unique()

        n_tildes = pd.Series(data=[compute_effective_n(self.alpha, per_label_counts, per_group_counts, Ss, y) for y in Ys],
                          index=Ys)

        Deltas = self.delta_map.copy()
        Deltas['Delta'] = Deltas.apply(lambda e : 
                C / n_tildes[e[_LABEL]] + get_lb_grp_row(delta_map, e[_LABEL], e[_GROUP_ID])[_DELTA_ID]
            , axis=1)
        self.Deltas = Deltas.set_index([_LABEL,_GROUP_ID])



    def load_data(self, X, y, *, sensitive_features, control_features=None):
        """Load the specified data into the object."""
        _, y_train, sf_train, cf_train = \
            _validate_and_reformat_input(X, y,
                                         enforce_binary_labels=True,
                                         sensitive_features=sensitive_features,
                                         control_features=control_features)
        # The following uses X  so that the estimators get X untouched
        super().load_data(X, y_train, sensitive_features=sf_train)
        self.index = [_ALL]

    def gamma(self, predictor, prediction=None):
        """Return the gamma values for the given predictor."""
        pred = prediction if prediction is not None else predictor(self.X)
        if isinstance(pred, np.ndarray):
            # TensorFlow is returning an (n,1) array, which results
            # in the subtraction in the 'error =' line generating an
            # (n,n) array
            pred = np.squeeze(pred)
            self.tags
        error = pd.Series(
            data=(self.tags[_LABEL] - pred - self.Deltas.loc[list(zip(self.tags[_LABEL],self.tags[_GROUP_ID]))]).abs().mean(),
                          index=self.index)
        self._gamma_descr = str(error)
        return error

    def project_lambda(self, lambda_vec):
        """Return the lambda values."""
        return lambda_vec

    def signed_weights(self, lambda_vec=None):
        """Return the signed weights."""
        if lambda_vec is None:
            return 2 * self.tags[_LABEL] - 1
        else:
            return lambda_vec[_ALL] * (2 * self.tags[_LABEL] - 1)











