import numpy as np
import pandas as pd
from holisticai.bias.mitigation import LearningFairRepresentation
from scipy.stats import mode
from sklearn.tree import DecisionTreeClassifier


class EmptyTransform:
    def __init__(self):
        self.name = "raw"
        self.requires_fit = False

    def __str__(self) -> str:
        return self.name

    def transform(self, features: pd.DataFrame, in_group: np.ndarray):
        return features


class LFR(EmptyTransform):
    def __init__(
        self,
        k: int = 5,
        Ax: float = 0.01,
        Ay: float = 1.0,
        Az: float = 50.0,
        maxiter: int = 5000,
        maxfun: int = 5000,
        verbose: int = 0,
        seed: int = 0,
    ):
        super().__init__()
        self.name = "LFR"
        self.requires_fit = True
        self.lfr_model = LearningFairRepresentation(
            k=k,
            Ax=Ax,
            Ay=Ay,
            Az=Az,
            maxiter=maxiter,
            maxfun=maxfun,
            verbose=verbose,
            seed=seed,
        )

    def transform(self, features: np.ndarray, in_group: np.ndarray):
        new_features = self.lfr_model.transform(
            features, group_a=np.logical_not(in_group), group_b=in_group
        )
        return new_features

    # def fit_transform(
    #     self, features: pd.DataFrame, labels: np.ndarray, in_group: np.ndarray
    # ):
    #     new_features = self.lfr_model.fit_transform(
    #         features, labels, group_a=np.logical_not(in_group), group_b=in_group
    #     )
    #     return new_features

    def fit(self, features: pd.DataFrame, labels: np.ndarray, in_group: np.ndarray):
        x_train = features.values.astype(float)
        self.lfr_model.fit(
            x_train, labels, group_a=np.logical_not(in_group), group_b=in_group
        )


class FARE(EmptyTransform):
    def __init__(
        self, max_leaf_nodes=200, min_samples_leaf=100, gamma=0.85, gini_metric="dp"
    ):
        super().__init__()
        self.name = "FARE"
        self.requires_fit = True
        self.alpha = gamma

        criterion = f"fair_gini_{gini_metric}"  # e.g., fair_gini_dp
        self.tree = DecisionTreeClassifier(
            criterion=criterion,
            max_leaf_nodes=max_leaf_nodes,
            random_state=0,
            min_samples_leaf=min_samples_leaf,
        )

        self.medians = None

    def transform(self, features: np.ndarray, in_group: np.ndarray = None):
        xs = features.copy()
        cells = self.tree.apply(xs)
        zs = []
        for cell in cells:
            zs.append(self.medians[cell])
        return np.vstack(zs)

    def fit(self, features: pd.DataFrame, labels: np.ndarray, in_group: np.ndarray):
        cat_pos = np.where(features.dtypes == "category")[0]
        x_train = features.values.astype(float)
        s_train = in_group[:, None]
        y_train = labels[:, None]
        self.tree = self.tree.fit(
            x_train, y_train, s_train, cat_pos=cat_pos, alpha=self.alpha
        )

        cells_train = self.tree.apply(x_train)
        cell_ids = sorted(list(set(cells_train)))

        medians = {}
        for cid in cell_ids:
            # get all train set xs that go to this cell
            xs = x_train[np.where(cells_train == cid)]

            # get median
            median = np.zeros(xs.shape[1])
            for i in range(xs.shape[1]):
                if i in cat_pos:
                    # categorical takes mode
                    median[i] = mode(xs[:, i].astype(int))[0]  # check
                else:
                    # continuous takes median
                    median[i] = np.median(xs[:, i])  # needs numpy 1.9.0
                    # median[i] = np.mean(xs[:, i])
            medians[cid] = median
        self.medians = medians
