from .utils import sk_model, get_idx_wo_protected, sample_celoss

from aif360.datasets import StandardDataset
from aif360.algorithms.postprocessing.calibrated_eq_odds_postprocessing import CalibratedEqOddsPostprocessing
from aif360.algorithms.postprocessing.eq_odds_postprocessing import EqOddsPostprocessing
import numpy as np
from config import cfg
from sklearn.model_selection import train_test_split
import random


class EqOdds:
    # for binary classification problems with binary sensitive attributes.
    def __init__(self):
        pgroups, ugroups = {}, {}
        for cname in ['sensitive']:
            pgroups[cname] = 1
            ugroups[cname] = 0
        self._privileged_groups = [pgroups]
        self._unprivileged_groups = [ugroups]
        self.debias_model = None
        self.base_model = None
        self.calibrated = False


    def debias(self, x, train = True):
        constraint = 'weighted'
        class_thresh = 0.5
        dataset_orig = x
        if train:
            dataset_orig_train, dataset_orig_vt = train_test_split(dataset_orig, test_size=0.3, random_state=random.randint(0,1000))
            dataset_orig_valid, dataset_orig_test = train_test_split(dataset_orig_vt, test_size=0.5, random_state=random.randint(0,1000))

            ### Converting to AIF360 StandardDataset objects ###
            dataset_orig = StandardDataset(dataset_orig, label_name='target',
                                                 favorable_classes=[1],
                                                 protected_attribute_names=['sensitive'],
                                                 privileged_classes=[[1]])

            dataset_orig_train = StandardDataset(dataset_orig_train, label_name='target',
                                                 favorable_classes=[1],
                                                 protected_attribute_names=['sensitive'],
                                                 privileged_classes=[[1]])
            dataset_orig_valid = StandardDataset(dataset_orig_valid, label_name='target',
                                                 favorable_classes=[1],
                                                 protected_attribute_names=['sensitive'],
                                                 privileged_classes=[[1]])
            dataset_orig_test = StandardDataset(dataset_orig_test, label_name='target',
                                                favorable_classes=[1],
                                                protected_attribute_names=['sensitive'],
                                                privileged_classes=[[1]])

            # Placeholder for predicted and transformed datasets
            dataset_orig_train_pred = dataset_orig_train.copy(deepcopy=True)
            dataset_orig_valid_pred = dataset_orig_valid.copy(deepcopy=True)
            dataset_orig_test_pred = dataset_orig_test.copy(deepcopy=True)

            idx_wo_protected = get_idx_wo_protected(dataset_orig_train.feature_names, ['sensitive'])

            X_train, y_train = dataset_orig_train.features[:, idx_wo_protected], dataset_orig_train.labels.ravel()

            model = sk_model('linear')

            # Train Original Classifier #
            model.fit(X_train, y_train, sample_weight=dataset_orig_train.instance_weights)
            self.base_model = model
            fav_idx = np.where(model.classes_ == dataset_orig_train.favorable_label)[0][0]
            y_train_pred_prob = model.predict_proba(X_train)[:, fav_idx]

            # Prediction probs for validation and testing data
            X_valid = dataset_orig_valid.features[:, idx_wo_protected]
            y_valid_pred_prob = model.predict_proba(X_valid)[:, fav_idx]

            X_test = dataset_orig_test.features[:, idx_wo_protected]
            y_test_pred_prob = model.predict_proba(X_test)[:, fav_idx]

            # Decision boundary for predictions
            dataset_orig_train_pred.scores = y_train_pred_prob.reshape(-1, 1)
            dataset_orig_valid_pred.scores = y_valid_pred_prob.reshape(-1, 1)
            dataset_orig_test_pred.scores = y_test_pred_prob.reshape(-1, 1)

            # get prediction based on probability
            y_train_pred = np.zeros_like(dataset_orig_train_pred.labels)
            y_train_pred[y_train_pred_prob >= class_thresh] = dataset_orig_train_pred.favorable_label
            y_train_pred[~(y_train_pred_prob >= class_thresh)] = dataset_orig_train_pred.unfavorable_label
            dataset_orig_train_pred.labels = y_train_pred

            y_valid_pred = np.zeros_like(dataset_orig_valid_pred.labels)
            y_valid_pred[y_valid_pred_prob >= class_thresh] = dataset_orig_valid_pred.favorable_label
            y_valid_pred[~(y_valid_pred_prob >= class_thresh)] = dataset_orig_valid_pred.unfavorable_label
            dataset_orig_valid_pred.labels = y_valid_pred

            y_test_pred = np.zeros_like(dataset_orig_test_pred.labels)
            y_test_pred[y_test_pred_prob >= class_thresh] = dataset_orig_test_pred.favorable_label
            y_test_pred[~(y_test_pred_prob >= class_thresh)] = dataset_orig_test_pred.unfavorable_label
            dataset_orig_test_pred.labels = y_test_pred

            # Compute metrics for original model
            if self.calibrated:
                cpp = CalibratedEqOddsPostprocessing(privileged_groups=self._privileged_groups,
                                                     unprivileged_groups=self._unprivileged_groups, seed=random.randint(0,1000),
                                                     cost_constraint=constraint)
                print("Starting Calibrated EqualizedOdds...")

            else:
                cpp = EqOddsPostprocessing(privileged_groups=self._privileged_groups,
                                           unprivileged_groups=self._unprivileged_groups, seed=random.randint(0,1000))
                print("Starting EqualizedOdds...")


            cpp = cpp.fit(dataset_orig_valid, dataset_orig_valid_pred)
            self.debias_model = cpp
            dataset_transf_pred = cpp.predict(dataset_orig)


        else:
            # test case
            dataset_test = StandardDataset(dataset_orig, label_name='target',
                                                 favorable_classes=[1],
                                                 protected_attribute_names=['sensitive'],
                                                 privileged_classes=[[1]])


            idx_wo_protected = get_idx_wo_protected(dataset_test.feature_names, ['sensitive'])
            X_test, y_test = dataset_test.features[:, idx_wo_protected], dataset_test.labels.ravel()

            # Train Original Classifier using test data
            fav_idx = np.where(self.base_model.classes_ == dataset_test.favorable_label)[0][0]
            dataset_test_pred = dataset_test.copy(deepcopy=True)
            y_test_pred_prob = self.base_model.predict_proba(X_test)[:, fav_idx]
            dataset_test_pred.scores = y_test_pred_prob.reshape(-1, 1)
            y_test_pred = np.zeros_like(dataset_test.labels)
            y_test_pred[y_test_pred_prob >= class_thresh] = dataset_test.favorable_label
            y_test_pred[~(y_test_pred_prob >= class_thresh)] = dataset_test.unfavorable_label
            dataset_test_pred.labels = y_test_pred



            dataset_transf_pred = self.debias_model.predict(dataset_test_pred)
            # loss ablation, calculate loss in model
            if cfg.get('setting') and cfg.get('setting').get('metric_name') == 'loss':
                prob = dataset_transf_pred.scores.astype(np.float32)
                if prob.shape[1] == 1:
                    prob = np.hstack((1 - prob, prob))
                loss = sample_celoss(prob, dataset_test.labels.astype(np.int64).ravel())
                return loss.astype(np.float32)
        
        return dataset_transf_pred.labels.astype(np.int64)


