import random

from aif360.datasets import StandardDataset
from aif360.algorithms.inprocessing.exponentiated_gradient_reduction import ExponentiatedGradientReduction
import numpy as np
from .utils import sk_model, sample_celoss
from config import cfg


class Reduction:
    def __init__(self, constraint, params):
        self.constraint = constraint
        self.params = params
        self.debias_model = None

    def debias(self, x, train = True):
        dataset_orig = x
        hparams = self.params
        metric_map = {"sp":"DemographicParity", "eo":"EqualizedOdds"}
        constraint = metric_map[self.constraint]  
        model = sk_model('linear')
        dataset_orig= StandardDataset(dataset_orig, label_name='target',
                                             favorable_classes=[1],
                                             protected_attribute_names=['sensitive'],
                                             privileged_classes=[[1]])

        # --- Reduction --- #

        np.random.seed(random.randint(0,1000))
        exp_grad_red = ExponentiatedGradientReduction(estimator=model,
                                                      constraints=constraint,
                                                      drop_prot_attr=False, eps=hparams)
        if train:
            print("## Epsilon = " + str(hparams))
            print("Starting Reduction...")
            exp_grad_red.fit(dataset_orig)
            self.debias_model = exp_grad_red
            exp_grad_red_pred = exp_grad_red.predict(dataset_orig)
        else:
            exp_grad_red_pred = self.debias_model.predict(dataset_orig)
            if cfg.get('setting') and cfg.get('setting').get('metric_name') == 'loss':
                prob = exp_grad_red_pred.scores.astype(np.float32)
                if prob.shape[1] == 1:
                    prob = np.hstack((1 - prob, prob))
                loss = sample_celoss(prob, dataset_orig.labels.astype(np.int64).ravel())
                return loss.astype(np.float32)

        return exp_grad_red_pred.labels.astype(np.int64)
