import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from sklearn.ensemble import HistGradientBoostingClassifier

from data_transforms import EmptyTransform
from existing_methods import cnd


def nearest_neighbors(
    from_pts: np.ndarray, to_pts: np.ndarray, k: int = None, remove_first: bool = False
) -> np.ndarray:
    """Return the indices of the nearest neighobrs, sorted by distance

    Args:
        from_pts: N1 X D array of points
        to_pts: N2 X D array of points
        k: Number of neighbors to keep
        remove_first: If True, remove the first neighbor (itself)

    Returns:
        np.ndarray: N1 X k array of indices
    """
    distances = cdist(from_pts, to_pts)
    closest_id = np.argsort(distances, axis=1)
    if remove_first:
        closest_id = closest_id[:, 1:]
    if k is not None:
        closest_id = closest_id[:, :k]
    return closest_id


class Collective_Strategy:
    """abstract class for collective action strategies"""

    def __init__(
        self, data_transform: EmptyTransform = None, use_partial: bool = False
    ):
        if data_transform is None:
            data_transform = EmptyTransform()
        self.transform = data_transform
        self.use_partial = use_partial

    def act(
        self,
        features: pd.DataFrame,
        labels: np.ndarray,
        in_group: np.ndarray,
        in_collective: np.ndarray,
    ):
        features = features.values.astype(float)
        if self.transform is not None:
            features = self.transform.transform(features, in_group)
        # else:
        #     features = features.values.astype(float)

        if self.use_partial:
            collective_inds = in_collective & in_group
            majority_inds = in_collective & (~in_group)

            collective_features = features[collective_inds]
            collective_labels = labels[collective_inds]

            majority_features = features[majority_inds]
            majority_labels = labels[majority_inds]

            minority_features = collective_features.copy()
            minority_labels = collective_labels.copy()
        else:
            collective_inds = in_collective & in_group
            collective_features = features[collective_inds]
            collective_labels = labels[collective_inds]

            majority_features = features[~in_group]
            majority_labels = labels[~in_group]

            minority_features = features[in_group]
            minority_labels = labels[in_group]

        # majority_size = len(majority_features)
        # minority_size = len(collective_features)
        # print(
        #     f"Seen majority: {majority_size}, Seen minority: {minority_size}",
        #     file=sys.stderr,
        # )
        return self.get_new_labels(
            collective_features,
            collective_labels,
            majority_features,
            majority_labels,
            minority_features,
            minority_labels,
        )

    def __str__(self) -> str:
        to_return = f"{self.name}-{str(self.transform)}"
        if self.use_partial:
            to_return += "-partial"
        return to_return

    def get_new_labels(
        collective_features,
        collective_labels,
        majority_features,
        majority_labels,
        minority_features,
        minority_labels,
    ) -> np.ndarray:
        raise NotImplementedError


# class Nearest_Neighbour(Collective_Strategy):
#     def __init__(
#         self,
#         k: int = 1,
#         only_negatives=False,
#         threshold_pos2neg: float = 0.5,
#         threshold_neg2pos: float = 0.5,
#         data_transform: EmptyTransform = None,
#         use_partial: bool = False,
#     ):
#         super().__init__(data_transform=data_transform, use_partial=use_partial)
#         self.name = "nn"
#         if only_negatives:
#             self.name = "nnNeg"
#         self.k = k
#         self.threshold_pos2neg = threshold_pos2neg
#         self.threshold_neg2pos = threshold_neg2pos
#         self.only_negatives = only_negatives

#     def get_new_labels(
#         self,
#         collective_features,
#         collective_labels,
#         majority_features,
#         majority_labels,
#         minority_features,
#         minority_labels,
#     ):
#         # dist_from = collective_features
#         # dist_to = majority_features
#         # distances = cdist(dist_from, dist_to)

#         # closest_id = np.argsort(distances, axis=1)[:, : self.k]
#         closest_id = nearest_neighbors(
#             from_pts=collective_features,
#             to_pts=majority_features,
#             k=self.k,
#             remove_first=False,
#         )
#         closest_lbls = majority_labels[closest_id]
#         closest_lbls = closest_lbls.mean(axis=1)
#         new_lables = collective_labels.copy()
#         if self.only_negatives:
#             # new_lables[new_lables == 0] = closest_lbls[new_lables == 0].round()
#             neg2pos = np.logical_and(
#                 collective_labels == 0,
#                 closest_lbls >= self.threshold_neg2pos,
#             )
#             new_lables[neg2pos] = 1
#         else:
#             # new_lables = closest_lbls.mean(axis=1).round()
#             pos2neg = np.logical_and(
#                 collective_labels == 1,
#                 closest_lbls <= (1 - self.threshold_pos2neg),
#             )

#             neg2pos = np.logical_and(
#                 collective_labels == 0,
#                 closest_lbls >= self.threshold_neg2pos,
#             )

#             new_lables[pos2neg] = 0
#             new_lables[neg2pos] = 1

#         return new_lables


class Ranked_Labels(Collective_Strategy):
    def __init__(
        self,
        k: int = 1,
        num_to_flip: int = None,
        ratio_to_flip: float = None,
        data_transform: EmptyTransform = None,
        use_partial: bool = False,
    ):
        if num_to_flip is None and ratio_to_flip is None:
            raise ValueError("Must provide either num_to_flip or ratio_to_flip")
        if num_to_flip is not None and ratio_to_flip is not None:
            raise ValueError("Cannot provide both num_to_flip and ratio_to_flip")
        super().__init__(data_transform=data_transform, use_partial=use_partial)
        self.name = "ranked_labels"
        self.k = k
        self.num_to_flip = num_to_flip
        self.ratio_to_flip = ratio_to_flip

    def set_k(self, k: int):
        self.k = k

    def get_new_labels(
        self,
        collective_features,
        collective_labels,
        majority_features,
        majority_labels,
        minority_features,
        minority_labels,
    ):
        if self.num_to_flip is not None:
            num_to_flip = self.num_to_flip
        else:
            num_to_flip = int(self.ratio_to_flip * collective_labels.shape[0])

        flip_candidates = np.flatnonzero(collective_labels == 0)

        closest_id = nearest_neighbors(
            from_pts=collective_features[flip_candidates],
            to_pts=majority_features,
            k=self.k,
            remove_first=False,
        )
        closest_lbls = majority_labels[closest_id]
        closest_lbls = closest_lbls.sum(axis=1)

        ordered_ids = np.argsort(closest_lbls)[::-1][:num_to_flip]
        new_labels = collective_labels.copy()
        new_labels[flip_candidates[ordered_ids]] = 1

        return new_labels


# class RandomFlipping(Collective_Strategy):
#     # Ablation for nearest neighbor, by randomly flipping labels
#     def __init__(
#         self,
#         only_negatives=False,
#     ):
#         super().__init__()
#         self.name = "RandomFlip"
#         if only_negatives:
#             self.name = "RandomFlipNeg"
#         self.only_negatives = only_negatives

#     def get_new_labels(
#         self,
#         collective_features,
#         collective_labels,
#         majority_features,
#         majority_labels,
#         minority_features,
#         minority_labels,
#     ):
#         if self.only_negatives:
#             new_lables = collective_labels.copy()
#             new_lables[collective_labels == 0] = self.randomize_labels(
#                 new_lables[collective_labels == 0]
#             )
#         else:
#             new_lables = self.randomize_labels(collective_labels)
#         return new_lables

#     def randomize_labels(self, cur_labels):
#         return np.random.choice([0, 1], size=cur_labels.shape[0])


class Ranked_Distance(Collective_Strategy):
    def __init__(
        self,
        num_to_flip: int = None,
        ratio_to_flip: float = None,
        k: int = 1,
        data_transform: EmptyTransform = None,
        use_partial: bool = False,
    ):
        if num_to_flip is None and ratio_to_flip is None:
            raise ValueError("Must provide either num_to_flip or ratio_to_flip")
        if num_to_flip is not None and ratio_to_flip is not None:
            raise ValueError("Cannot provide both num_to_flip and ratio_to_flip")
        super().__init__(data_transform=data_transform, use_partial=use_partial)
        self.name = "ranked_distance"
        self.k = k
        self.num_to_flip = num_to_flip
        self.ratio_to_flip = ratio_to_flip

    def set_k(self, k: int):
        self.k = k

    def get_new_labels(
        self,
        collective_features,
        collective_labels,
        majority_features,
        majority_labels,
        minority_features,
        minority_labels,
    ):
        if self.num_to_flip is not None:
            num_to_flip = self.num_to_flip
        else:
            num_to_flip = int(self.ratio_to_flip * collective_labels.shape[0])

        dist_from = collective_features[collective_labels == 0]
        dist_to = majority_features[majority_labels == 1]

        distances = cdist(dist_from, dist_to)

        # closest_per = np.sort(distances, axis=1)[:, self.k - 1]
        closest_per = np.sort(distances, axis=1)[:, : self.k]
        closest_per = np.mean(closest_per, axis=1)
        closest_ids = np.argsort(closest_per)[:num_to_flip]
        ids_to_flip = np.nonzero(collective_labels == 0)[0][closest_ids]

        new_lbls = collective_labels.copy()
        new_lbls[ids_to_flip] = 1
        return new_lbls


class Random_Choice(Collective_Strategy):
    # Like EqDP, but chooses random candidates to flip instead of Knn
    def __init__(self, num_to_flip: int = None):
        super().__init__()
        self.name = "Random"
        self.num_to_flip = num_to_flip

    def get_new_labels(
        self,
        collective_features,
        collective_labels,
        majority_features,
        majority_labels,
        minority_features,
        minority_labels,
    ):
        if self.num_to_flip is None:
            minority_pos_chance = np.mean(minority_labels)
            group_size = minority_labels.shape[0]
            majority_pos_chance = np.mean(majority_labels)

            num_to_flip = int((majority_pos_chance - minority_pos_chance) * group_size)
        else:
            num_to_flip = self.num_to_flip

        flip_candidates = np.flatnonzero(collective_labels == 0)
        sampled_indices = np.random.permutation(flip_candidates)[:num_to_flip]

        new_lbls = collective_labels.copy()
        new_lbls[sampled_indices] = 1
        return new_lbls


# class KDP(Collective_Strategy):
#     """
#     From https://dl.acm.org/doi/abs/10.1145/2020408.2020488
#     "k-NN as an implementation of situation testing
#     for discrimination discovery and prevention"
#     """

#     def __init__(
#         self,
#         k: int,
#         threshold: float = None,
#         use_quantiles: bool = False,
#         use_relative=False,
#         data_transform: EmptyTransform = None,
#     ):
#         super().__init__(data_transform=data_transform)
#         self.name = "kdp"
#         self.k = k
#         self.threshold = threshold
#         self.use_quantiles = use_quantiles
#         self.use_relative = use_relative

#     def get_new_labels(
#         self,
#         collective_features,
#         collective_labels,
#         majority_features,
#         majority_labels,
#         minority_features,
#         minority_labels,
#     ):
#         closest_minority = nearest_neighbors(
#             collective_features, minority_features, self.k, remove_first=True
#         )
#         closest_majority = nearest_neighbors(
#             collective_features, majority_features, self.k, remove_first=False
#         )

#         rate_majority = (majority_labels[closest_majority] == 0).mean(axis=1)
#         rate_minority = (minority_labels[closest_minority] == 0).mean(axis=1)
#         diffs = rate_minority - rate_majority
#         if self.use_relative:
#             x0 = diffs.min()
#             x1 = diffs.max()
#             cur_threshold = ((1 - self.threshold) * x0) + (self.threshold * x1)
#         elif self.use_quantiles:
#             cur_threshold = np.quantile(diffs, self.threshold)
#         else:
#             cur_threshold = self.threshold
#         to_flip = diffs > cur_threshold
#         to_flip = np.logical_and(to_flip, collective_labels == 0)

#         new_lables = collective_labels.copy()
#         new_lables[to_flip] = 1
#         return new_lables


class Ranked_KDP(Collective_Strategy):
    """
    From https://dl.acm.org/doi/abs/10.1145/2020408.2020488
    "k-NN as an implementation of situation testing
    for discrimination discovery and prevention"
    """

    def __init__(
        self,
        k: int = 1,
        num_to_flip: int = None,
        ratio_to_flip: float = None,
        data_transform: EmptyTransform = None,
        use_partial: bool = False,
    ):
        if num_to_flip is None and ratio_to_flip is None:
            raise ValueError("Must provide either num_to_flip or ratio_to_flip")
        if num_to_flip is not None and ratio_to_flip is not None:
            raise ValueError("Cannot provide both num_to_flip and ratio_to_flip")
        super().__init__(data_transform=data_transform, use_partial=use_partial)
        self.name = "kdp"
        self.k = k
        self.num_to_flip = num_to_flip
        self.ratio_to_flip = ratio_to_flip

    def set_k(self, k: int):
        self.k = k

    def get_new_labels(
        self,
        collective_features,
        collective_labels,
        majority_features,
        majority_labels,
        minority_features,
        minority_labels,
    ):
        if self.num_to_flip is not None:
            num_to_flip = self.num_to_flip
        else:
            num_to_flip = int(self.ratio_to_flip * len(collective_labels))

        candidate_ids = np.flatnonzero(collective_labels == 0)

        closest_minority = nearest_neighbors(
            collective_features[candidate_ids],
            minority_features,
            self.k,
            remove_first=True,
        )
        closest_majority = nearest_neighbors(
            collective_features[candidate_ids],
            majority_features,
            self.k,
            remove_first=False,
        )

        rate_majority = (majority_labels[closest_majority] == 0).mean(axis=1)
        rate_minority = (minority_labels[closest_minority] == 0).mean(axis=1)
        diffs = rate_minority - rate_majority
        # if self.use_relative:
        #     x0 = diffs.min()
        #     x1 = diffs.max()
        #     cur_threshold = ((1 - self.threshold) * x0) + (self.threshold * x1)
        # elif self.use_quantiles:
        #     cur_threshold = np.quantile(diffs, self.threshold)
        # else:
        #     cur_threshold = self.threshold
        # to_flip = diffs > cur_threshold
        neg_ids_to_flip = np.argsort(diffs)[::-1][:num_to_flip]
        to_flip = candidate_ids[neg_ids_to_flip]

        new_lables = collective_labels.copy()
        new_lables[to_flip] = 1
        return new_lables


class CND(Collective_Strategy):
    """
    Algorithm from "Classifying without discriminating"
    https://ieeexplore.ieee.org/abstract/document/4909197
    """

    def __init__(self, num_to_flip: int = None, seed: int = 0):
        super().__init__()
        self.name = "cnd"
        self.num_to_flip = num_to_flip
        self.seed = seed

    def act(
        self,
        features: pd.DataFrame,
        labels: np.ndarray,
        in_group: np.ndarray,
        in_collective: np.ndarray,
    ):
        features = features.values.astype(float)
        new_labels = cnd(features, labels, in_group, self.num_to_flip, self.seed)

        return new_labels


# Ranked probability
class Ranked_Probability(Collective_Strategy):
    def __init__(
        self,
        num_to_flip: int = None,
        ratio_to_flip: float = None,
        data_transform: EmptyTransform = None,
        use_partial: bool = False,
        seed: int = 0,
    ):
        if num_to_flip is None and ratio_to_flip is None:
            raise ValueError("Must provide either num_to_flip or ratio_to_flip")
        if num_to_flip is not None and ratio_to_flip is not None:
            raise ValueError("Cannot provide both num_to_flip and ratio_to_flip")
        super().__init__(data_transform=data_transform, use_partial=use_partial)
        self.name = "ranked_proba"
        self.num_to_flip = num_to_flip
        self.ratio_to_flip = ratio_to_flip
        self.seed = seed
        self.ranker = HistGradientBoostingClassifier(random_state=self.seed)

    # def act(
    #     self,
    #     features: pd.DataFrame,
    #     labels: np.ndarray,
    #     in_group: np.ndarray,
    #     in_collective: np.ndarray,
    # ):
    #     if self.num_to_flip is not None:
    #         num_to_flip = self.num_to_flip
    #     else:
    #         num_to_flip = int(self.ratio_to_flip * np.count_nonzero(in_collective))
    #     features = features.values.astype(float)
    #     inds_tp_change = np.logical_and(in_group, in_collective)
    #     new_labels = collective_cnd(
    #         features, labels, in_group, in_collective, num_to_flip, self.seed
    #     )[inds_tp_change]

    #     return new_labels

    def get_new_labels(
        self,
        collective_features,
        collective_labels,
        majority_features,
        majority_labels,
        minority_features,
        minority_labels,
    ):
        if self.num_to_flip is not None:
            num_to_flip = self.num_to_flip
        else:
            num_to_flip = int(self.ratio_to_flip * collective_labels.shape[0])

        # remove sensitive feature
        col_to_remove = np.ones((len(collective_features), 1))
        del_col = np.all(collective_features == col_to_remove, axis=0)
        cur_coll_features = collective_features[:, ~del_col]

        self.ranker = HistGradientBoostingClassifier(random_state=self.seed)
        self.ranker.fit(majority_features[:, ~del_col], majority_labels)
        class_proba = self.ranker.predict_proba(cur_coll_features)[:, 1]

        promotable = np.flatnonzero(collective_labels == 0)
        sorted_idx = np.argsort(class_proba)
        to_promote = sorted_idx[np.isin(sorted_idx, promotable)]
        to_promote = to_promote[::-1]

        new_labels = collective_labels.copy()
        new_labels[to_promote[:num_to_flip]] = 1

        return new_labels
