""" Ablation study on learning efficiency """

import random
import numpy as np

from dataset import FeatIndex, Adult, Compas, LawSchool, Dutch
from eval import Evaluator
from model import LogReg, SenSeI, Synthesizer, NeuralNets


class RandomData():
    """ Generate random data """

    def __init__(self, X: np.ndarray, feat_idx: FeatIndex):
        self.feat_idx = feat_idx
        self.num_feat2range = {}
        for feat_name in self.feat_idx.num_feat:
            idx = feat_idx.feat2idx[feat_name]
            val = X[:, idx]
            min_, max_ = np.min(val), np.max(val)
            self.num_feat2range.update({feat_name: (min_, max_)})

    def sample(self, N: int):
        X = np.zeros((N, len(self.feat_idx.cat_idx) + len(self.feat_idx.num_idx)))
        for feat_name in self.feat_idx.num_feat:
            idx = self.feat_idx.feat2idx[feat_name]
            feat = np.random.uniform(
                low=self.num_feat2range[feat_name][0],
                high=self.num_feat2range[feat_name][1],
                size=N,
            )
            X[:, idx[0]] = feat

        for feat_name in self.feat_idx.cat_feat:
            idx = self.feat_idx.feat2idx[feat_name]
            select = np.random.randint(low=0, high=len(idx), size=N)
            for i, val in enumerate(select):
                X[i, min(idx) + val] = 1.

        y = np.random.choice(a=[0, 1], size=N)

        return X, y


def random_comp(
        train_X: np.ndarray,
        feat_idx: FeatIndex,
        categorical_thr: int = 1,
        numerical_thr: float = 0.025,
) -> np.ndarray:
    """ Randomly generate comparable samples, with normalized numerical features as input """

    train_X = train_X.copy()
    for i in range(train_X.shape[0]):
        # perturbation on categorical feature
        non_sen_cat_feat = list(set(feat_idx.cat_feat) - set(feat_idx.sen_feat))
        pert_cat = np.random.choice(non_sen_cat_feat, size=categorical_thr, replace=True)
        num_cat_pert = np.random.choice([n for n in range(len(pert_cat) + 1)])
        for feat in pert_cat[:num_cat_pert]:
            curr_val = np.argmax(train_X[i, feat_idx.feat2idx[feat]])
            avail_set = [i for i in range(len(feat_idx.feat2idx[feat])) if i != curr_val]
            perturb_idx = np.random.choice(avail_set)
            perturb = np.zeros(len(feat_idx.feat2idx[feat]))
            perturb[perturb_idx] = 1.
            train_X[i, feat_idx.feat2idx[feat]] = perturb

        # perturbation on numerical feature
        for idx in feat_idx.num_idx:
            perturb = np.random.uniform(low=-numerical_thr, high=numerical_thr, size=1).item()
            val = np.clip(train_X[i, idx] + perturb, 0., 1.)
            train_X[i, idx] = val

        num_sen_pert = np.random.choice([n for n in range(1, len(feat_idx.sen_feat) + 1)])
        pert_sen_feat = np.random.choice(feat_idx.sen_feat, num_sen_pert)
        for feat in pert_sen_feat:
            curr_val = np.argmax(train_X[i, feat_idx.feat2idx[feat]])
            avail_set = [i for i in range(len(feat_idx.feat2idx[feat])) if i != curr_val]
            perturb_idx = np.random.choice(avail_set)
            perturb = np.zeros(len(feat_idx.feat2idx[feat]))
            perturb[perturb_idx] = 1.
            train_X[i, feat_idx.feat2idx[feat]] = perturb

    return train_X


def main():
    dataset = Adult()
    train_X, train_y = dataset.train_data(scale="num")
    test_X, test_y = dataset.test_data(scale="all")
    test_comp_data = dataset.comp_data(batch_size=1024, train=False, scale="all")
    evaluator = Evaluator(test_X, test_y, test_comp_data)
    feat_idx = dataset.feat_idx

    # train_X = dataset.scale(train_X, method="all")

    # model = LogReg()

    model = NeuralNets(input_dim=train_X.shape[1])
    # model.fit(dataset.scale(train_X, method="cat"), train_y)
    # evaluator(model.pred, model.pred_proba)

    """ Random Data """

    # random_data = RandomData(X=train_X, feat_idx=dataset.feat_idx)
    # random_X, random_y = random_data.sample(N=len(train_X))
    # random_X = dataset.scale(random_X, method="all")
    #
    # model.fit(random_X, random_y)
    # evaluator(model.pred, model.pred_proba)

    """ Random comparable samples """

    random_comp_X = random_comp(train_X, dataset.feat_idx)

    train_X = np.concatenate([train_X, random_comp(train_X, dataset.feat_idx), random_comp(train_X, dataset.feat_idx), random_comp(train_X, dataset.feat_idx), random_comp(train_X, dataset.feat_idx), random_comp(train_X, dataset.feat_idx)])
    train_y = np.concatenate([train_y, train_y, train_y, train_y, train_y, train_y])
    model.fit(dataset.scale(train_X, method="cat"), train_y)
    evaluator(model.pred, model.pred_proba)

    """ Perturbations in SenSeI """

    # clf = SenSeI(
    #     input_dim=train_X.shape[1],
    #     n_input=train_X.shape[0],
    #     rho=1e+4,
    #     adv_lr=1e-1,
    #     n_iter=1000,
    # )
    # pert_X, pert_y = clf.fit(X=train_X, y=train_y, train_df=dataset.train_df, sen_feat=feat_idx.sen_feat,
    #                              sen_idx=feat_idx.sen_idx, verbose=True)
    # pert_X = np.concatenate(pert_X)
    # pert_y = np.concatenate(pert_y)
    # select_idx = random.sample(range(0, len(pert_X)), train_X.shape[0])
    #
    # pert_X = np.take(pert_X, select_idx, axis=0)
    # pert_y = np.take(pert_y, select_idx, axis=0)
    #
    # model.fit(pert_X, pert_y)
    # evaluator(model.pred, model.pred_proba)

    """ Antidote Data """

    # train_comp_data = dataset.comp_data(batch_size=1024, train=True, scale="num")
    # syn = Synthesizer(epochs=500, batch_size=4096)
    # # syn.fit(train_X, dataset.feat_idx, comp_data=train_comp_data, cond=True, comp_func=dataset.is_comparable)
    # model_save_path = "./ablation/adult_model.pt"
    # syn.save(model_save_path)
    # syn = Synthesizer.load("./ablation/adult_model.pt")

    # antidote_train_X = train_X.copy()
    # antidote_train_y = train_y.copy()
    # for i in range(2):
    #     syn_X_list = syn.sample_all(train_X)
    #     comparable_list = [dataset.is_comparable(train_X, syn_X) for syn_X in syn_X_list]
    #     for syn_X, comp in zip(syn_X_list, comparable_list):
    #         syn_X = syn_X[comp]
    #         syn_y = train_y.copy()
    #         syn_y = syn_y[comp]
    #         antidote_train_X = np.concatenate([antidote_train_X, syn_X], axis=0)
    #         antidote_train_y = np.concatenate([antidote_train_y, syn_y], axis=0)
    #
    # antidote_train_X = antidote_train_X[train_X.shape[0]:]
    # antidote_train_y = antidote_train_y[train_y.shape[0]:]
    # select_idx = random.sample(range(0, len(antidote_train_X)), train_X.shape[0])
    # antidote_train_X = np.take(antidote_train_X, select_idx, axis=0)
    # antidote_train_y = np.take(antidote_train_y, select_idx, axis=0)
    # antidote_train_X = dataset.scale(antidote_train_X, method="cat")
    #
    # syn_X = syn.sample(train_X)
    # syn_X = dataset.scale(syn_X, method="cat")
    # #
    # model.fit(syn_X, train_y)
    # evaluator(model.pred, model.pred_proba)

    return


if __name__ == "__main__":
    main()
