from aif360.datasets import StandardDataset
from sklearn.model_selection import train_test_split
from aif360.algorithms.postprocessing.reject_option_classification import RejectOptionClassification
import numpy as np
from .utils import sk_model, sample_celoss
from config import cfg
import random

class Rejection:
    """ Post processing method, need to train a base classifier for train and test first"""
    def __init__(self, constraint, params):

        pgroups, ugroups = {}, {}
        cname = 'sensitive'
        pgroups[cname] = 1
        ugroups[cname] = 0
        self._privileged_groups = [pgroups]
        self._unprivileged_groups = [ugroups]
        
        self.constraint = constraint
        self.params = params
        self.debias_model = None
        self.base_model = None

    def debias(self, x, train = True):
        hparams = self.params
        metric_map = {"sp": "Statistical parity difference", "eo": "Average odds difference"}
        metric_name = metric_map[self.constraint]
        dataset_orig = x
        if train:
            print("## Epsilon =" + str(hparams))
            dataset_orig_train, dataset_orig_vt = train_test_split(dataset_orig, test_size=0.3, random_state=random.randint(0,1000))
            dataset_orig_valid, dataset_orig_test = train_test_split(dataset_orig_vt, test_size=0.5,
                                                                     random_state=random.randint(0,1000))

            ### Converting to AIF360 StandardDataset objects ###
            # to predict using the full dataset
            dataset_orig = StandardDataset(x, label_name='target',
                                                 favorable_classes=[1],
                                                 protected_attribute_names=['sensitive'],
                                                 privileged_classes=[[1]])


            dataset_orig_train = StandardDataset(dataset_orig_train, label_name='target',
                                                 favorable_classes=[1],
                                                 protected_attribute_names=['sensitive'],
                                                 privileged_classes=[[1]])
            dataset_orig_valid = StandardDataset(dataset_orig_valid, label_name='target',
                                                 favorable_classes=[1],
                                                 protected_attribute_names=['sensitive'],
                                                 privileged_classes=[[1]])
            dataset_orig_test = StandardDataset(dataset_orig_test, label_name='target',
                                                favorable_classes=[1],
                                                protected_attribute_names=['sensitive'],
                                                privileged_classes=[[1]])

            X_train, y_train = dataset_orig_train.features, dataset_orig_train.labels.ravel()
            X_valid, y_valid = dataset_orig_valid.features, dataset_orig_valid.labels.ravel()
            X_test, y_test = dataset_orig_test.features, dataset_orig_test.labels.ravel()

            model = sk_model('linear')

            model.fit(X_train, y_train)
            self.base_model = model
            y_train_pred = model.predict(X_train)

            # positive class index
            pos_ind = np.where(model.classes_ == dataset_orig_train.favorable_label)[0][0]
            
            dataset_orig_train_pred = dataset_orig_train.copy(deepcopy=True)
            dataset_orig_train_pred.labels = y_train_pred

            dataset_orig_valid_pred = dataset_orig_valid.copy(deepcopy=True)
            dataset_orig_valid_pred.scores = model.predict_proba(X_valid)[:, pos_ind].reshape(-1, 1)

            dataset_orig_test_pred = dataset_orig_test.copy(deepcopy=True)
            dataset_orig_test_pred.scores = model.predict_proba(X_test)[:, pos_ind].reshape(-1, 1)

            ROC = RejectOptionClassification(unprivileged_groups=self._unprivileged_groups,
                                         privileged_groups=self._privileged_groups,
                                         low_class_thresh=0.01, high_class_thresh=0.99,
                                         num_class_thresh=100, num_ROC_margin=50,
                                         metric_name=metric_name,
                                         metric_ub=hparams, metric_lb=-hparams)
            ROC = ROC.fit(dataset_orig_valid, dataset_orig_valid_pred)

            self.debias_model = ROC

            print(
                "Optimal classification threshold (with fairness constraints) = %.4f" % ROC.classification_threshold)
            print("Optimal ROC margin = %.4f" % ROC.ROC_margin)

            dataset_transf_pred = ROC.predict(dataset_orig)
            

        else:
            dataset_test = StandardDataset(dataset_orig, label_name='target',
                                                 favorable_classes=[1],
                                                 protected_attribute_names=['sensitive'],
                                                 privileged_classes=[[1]])

            X_test, y_test = dataset_test.features, dataset_test.labels.ravel()

            pos_ind = np.where(self.base_model.classes_ == dataset_test.favorable_label)[0][0]

            dataset_orig_test_pred = dataset_test.copy(deepcopy=True)
            dataset_orig_test_pred.scores = self.base_model.predict_proba(X_test)[:, pos_ind].reshape(-1, 1)

            dataset_transf_pred = self.debias_model.predict(dataset_orig_test_pred)

            if cfg.get('setting') and cfg.get('setting').get('metric_name') == 'loss':
                prob = dataset_transf_pred.scores.astype(np.float32)
                if prob.shape[1] == 1:
                    prob = np.hstack((1 - prob, prob))
                loss = sample_celoss(prob, y_test.astype(np.int64))
                return loss.astype(np.float32)
        return dataset_transf_pred.labels.astype(np.int64)

