from copy import deepcopy

import numpy as np

from analysis.pate import calculate_fairness_gaps

class PATEQuery():
    def __init__(self, num_classes, threshold, sigma_threshold, sigma_gnmax):
        self.threshold = threshold
        self.sigma_threshold = sigma_threshold
        self.sigma_gnmax = sigma_gnmax
        self.num_classes = num_classes

    def create_student_training_set(self, queryset_features, queryset_votes):
        """Create a training set for the student model from the queryset."""
         # get the number of sensitive groups
        num_samples = queryset_features.shape[0]
        # Threshold mechanism
        noise_threshold = np.random.normal(0., self.sigma_threshold,
                                            num_samples)
        vote_counts = queryset_votes.max(axis=1)
        answered = (vote_counts + noise_threshold) > self.threshold

        # GNMax mechanism
        assert self.sigma_gnmax > 0
        noise_gnmax = np.random.normal(0., self.sigma_gnmax, (num_samples, self.num_classes))
        preds = np.argmax(queryset_votes + noise_gnmax, axis=1)
        # breakpoint()
        preds = preds[answered]

        # Gap between the ensemble votes of the two most probable
        # classes.
        sorted_votes = np.sort(queryset_votes, axis=-1)[::-1][0]
        # gaps = (sorted_votes[:, 0] - sorted_votes[:, 1])[answered]

        return answered, preds

class FairPATEQuery():
    def __init__(self, sensitive_group_list, min_group_count, max_fairness_violation, num_classes, threshold, sigma_threshold, sigma_gnmax, fairness_metric="DemParity", dataset=None):
        self.threshold = threshold
        self.sigma_threshold = sigma_threshold
        self.sigma_gnmax = sigma_gnmax
        self.dataset = dataset
        self.num_classes = num_classes
        self.fairness_metric = fairness_metric
        self.sensitive_group_list = sensitive_group_list
        self.min_group_count = min_group_count
        self.max_fairness_violation = max_fairness_violation
        self.sensitive_group_count = np.zeros(shape=(len(sensitive_group_list)))
        self.per_class_pos_classified_group_count =  np.zeros(shape=(num_classes, len(sensitive_group_list)))
        
        self.fairness_disparity_gaps = None

    def apply_fairness_constraint(self, preds, answered, sensitive, for_z=None, but_z=None):
        """Ensure demographic parity fairness is within margin. Has side-effects."""         

        _answered = deepcopy(answered)

        # breakpoint()

        # This is a pass-through filter. It should only block answers if they increase the fairness gap beyond `max_fairness_violation`.
        for s_id, z in enumerate(sensitive):
            z = int(z.item())
            answered = _answered[s_id]
            # Selector one-hot vectors for the sensitive feature and the predicted class 
            sensitive_one_hot_over_sensitives = (np.arange(len(self.sensitive_group_list)) == z).astype(int)
            prediction_one_hot_over_classes = (np.arange(self.num_classes) == preds[s_id]).astype(int)
            
            # Calculate of the new (tentative) gaps if the answered
            _per_class_pos_classified_group_count = self.per_class_pos_classified_group_count+ \
                prediction_one_hot_over_classes[:, None].dot(sensitive_one_hot_over_sensitives[:, None].T)
            _sensitive_group_count = self.sensitive_group_count + 1 * sensitive_one_hot_over_sensitives

            # breakpoint()
            # get all the fairness gaps
            all_gaps = calculate_fairness_gaps(_sensitive_group_count, _per_class_pos_classified_group_count, rule_over_classes="all", \
                                               fairness_metric=self.fairness_metric, dataset=self.dataset, for_z=for_z, but_z=but_z)
            _new_gaps = all_gaps[preds[s_id], :]
            # Neg. decisions maintain the gap; don;t block
            '''
            elif preds[s_id] == 0:
                _answered[s_id] = answered
            '''
            # Too few data points to estimate fairness; don't block.
            '''
            if self.sensitive_group_count[z] <  self.min_group_count:
                _answered[s_id] = answered
            '''
            if self.per_class_pos_classified_group_count[preds[s_id], z] < self.min_group_count:
                _answered[s_id] = answered
            # Pos. decisions may widen the gap, check the gap to ensure we are not over budget. If we are not, don't block.
            elif _new_gaps[z] + np.random.normal(0., 0.0) < self.max_fairness_violation:
                _answered[s_id] = answered
            else:
            # We are over budget; block.
                _answered[s_id] = False
            
            # update counters for measuring fairness 
            one_if_answered = 1 if _answered[s_id] else 0
            self.sensitive_group_count[z] += one_if_answered
            self.per_class_pos_classified_group_count[preds[s_id], z] += one_if_answered

            # update the disparity gaps 
            if one_if_answered:
                self.fairness_disparity_gaps = all_gaps

        return _answered
    
    def create_student_training_set(self, queryset_features, queryset_sensitives, queryset_votes, for_z=None, but_z=None):
        """Create a training set for the student model from the queryset."""
         # get the number of sensitive groups
        num_samples = queryset_features.shape[0]
        # Threshold mechanism
        noise_threshold = np.random.normal(0., self.sigma_threshold,
                                            num_samples)
        vote_counts = queryset_votes.max(axis=1)
        answered = (vote_counts + noise_threshold) > self.threshold

        # GNMax mechanism
        assert self.sigma_gnmax > 0
        noise_gnmax = np.random.normal(0., self.sigma_gnmax, (num_samples, self.num_classes))
        preds = np.argmax(queryset_votes + noise_gnmax, axis=1)
        # apply fairness constraint and update the answered list 
        # breakpoint()
        answered = self.apply_fairness_constraint(preds, answered, queryset_sensitives, for_z=for_z, but_z=but_z)
        preds = preds[answered]

        # Gap between the ensemble votes of the two most probable
        # classes.
        sorted_votes = np.sort(queryset_votes, axis=-1)[::-1][0]
        # gaps = (sorted_votes[:, 0] - sorted_votes[:, 1])[answered]

        return answered, preds





class PATESPreProcessor():
    def __init__(self, sensitive_group_list, min_group_count, max_fairness_violation, num_classes, fairness_metric="DemParity", dataset=None):
        
        self.num_classes = num_classes
        self.dataset = dataset
        
        self.max_fairness_violation = max_fairness_violation
        self.min_group_count = min_group_count
        self.sensitive_group_list = sensitive_group_list
        
        self.sensitive_group_count = np.zeros(shape=(len(sensitive_group_list)))
        self.fairness_metric = fairness_metric
        self.per_class_pos_classified_group_count =  np.zeros(shape=(num_classes, len(sensitive_group_list)))
        
        self.fairness_disparity_gaps = None

    def apply_fairness_constraint(self, preds, answered, sensitive, for_z=None, but_z=None):
        """Ensure demographic parity fairness is within margin. Has side-effects."""         

        _answered = deepcopy(answered)

        # breakpoint()

        # This is a pass-through filter. It should only block answers if they increase the fairness gap beyond `max_fairness_violation`.
        for s_id, z in enumerate(sensitive):
            z = int(z.item())
            answered = _answered[s_id]
            # Selector one-hot vectors for the sensitive feature and the predicted class 
            sensitive_one_hot_over_sensitives = (np.arange(len(self.sensitive_group_list)) == z).astype(int)
            prediction_one_hot_over_classes = (np.arange(self.num_classes) == preds[s_id]).astype(int)
            
            # Calculate of the new (tentative) gaps if the answered
            _per_class_pos_classified_group_count = self.per_class_pos_classified_group_count+ \
                prediction_one_hot_over_classes[:, None].dot(sensitive_one_hot_over_sensitives[:, None].T)
            _sensitive_group_count = self.sensitive_group_count + 1 * sensitive_one_hot_over_sensitives

            # breakpoint()
            # get all the fairness gaps
            all_gaps = calculate_fairness_gaps(_sensitive_group_count, _per_class_pos_classified_group_count, rule_over_classes="all", \
                                               fairness_metric=self.fairness_metric, dataset=self.dataset, for_z=for_z, but_z=but_z)
            _new_gaps = all_gaps[preds[s_id], :]
            # Neg. decisions maintain the gap; don;t block
            '''
            elif preds[s_id] == 0:
                _answered[s_id] = answered
            '''
            # Too few data points to estimate fairness; don't block.
            '''
            if self.sensitive_group_count[z] <  self.min_group_count:
                _answered[s_id] = answered
            '''
            if self.per_class_pos_classified_group_count[preds[s_id], z] < self.min_group_count:
                _answered[s_id] = answered
            # Pos. decisions may widen the gap, check the gap to ensure we are not over budget. If we are not, don't block.
            elif _new_gaps[z] + np.random.normal(0., 0.0) < self.max_fairness_violation:
                _answered[s_id] = answered
            else:
            # We are over budget; block.
                _answered[s_id] = False
            
            # update counters for measuring fairness 
            one_if_answered = 1 if _answered[s_id] else 0
            self.sensitive_group_count[z] += one_if_answered
            self.per_class_pos_classified_group_count[preds[s_id], z] += one_if_answered

            # update the disparity gaps 
            if one_if_answered:
                self.fairness_disparity_gaps = all_gaps

        return _answered
    
    def filter_student_training_set(self, preds, sensitives, for_z=None, but_z=None):
        assert preds.shape[0] == sensitives.shape[0]
        initial_mask = np.ones(len(preds), dtype=bool)
        final_mask = self.apply_fairness_constraint(preds, initial_mask, 
                                                                   sensitives, for_z=for_z, but_z=but_z)
        return final_mask