import json
import argparse
import numpy as np

from dataset import fetch_dataset, FeatIndex, Adult
from eval import Evaluator
from model.manifold.generator import sen_all_diff
from model import LogReg, NeuralNets, SenDrop


def parse_args():
    parser = argparse.ArgumentParser(description='Baselines')
    parser.add_argument('--dataset', type=str, default="adult", help="name of the dataset")
    parser.add_argument('--save', action="store_true")
    args = parser.parse_args()

    return args


def random_comp(
        antidote_train_X: np.ndarray,
        feat_idx: FeatIndex,
        categorical_thr: int = 1,
        numerical_thr: float = 0.025
) -> np.ndarray:
    """ Randomly generate comparable samples, with normalized numerical features as input """

    for i in range(antidote_train_X.shape[0]):
        # perturbation on categorical feature
        non_sen_cat_feat = list(set(feat_idx.cat_feat) - set(feat_idx.sen_feat))
        pert_cat = np.random.choice(non_sen_cat_feat, size=categorical_thr)
        num_cat_pert = np.random.choice([n for n in range(len(pert_cat) + 1)])
        for feat in pert_cat[:num_cat_pert]:
            curr_val = np.argmax(antidote_train_X[i, feat_idx.feat2idx[feat]])
            avail_set = [i for i in range(len(feat_idx.feat2idx[feat])) if i != curr_val]
            perturb_idx = np.random.choice(avail_set)
            perturb = np.zeros(len(feat_idx.feat2idx[feat]))
            perturb[perturb_idx] = 1.
            antidote_train_X[i, feat_idx.feat2idx[feat]] = perturb

        # perturbation on numerical feature
        for idx in feat_idx.num_idx:
            perturb = np.random.uniform(low=-numerical_thr, high=numerical_thr, size=1).item()
            val = np.clip(antidote_train_X[i, idx] + perturb, 0., 1.)
            antidote_train_X[i, idx] = val

    return antidote_train_X


def main():
    dataset = Adult(sensitive_feat=("race", "sex", "marital-status"))

    train_X, train_y = dataset.train_data(scale="num")
    test_X, test_y = dataset.test_data(scale="all")
    test_comp_data = dataset.comp_data(batch_size=1024, train=False, scale="all")

    evaluator = Evaluator(test_X, test_y, test_comp_data)

    sen_feat = sen_all_diff(train_X, dataset.feat_idx)
    antidote_train_X = train_X.copy()
    antidote_train_y = train_y.copy()
    for i in range(sen_feat.shape[1]):
        flip_X = train_X.copy()
        flip_X[:, dataset.feat_idx.sen_idx] = sen_feat[:, i]
        antidote_train_X = np.concatenate([antidote_train_X, flip_X], axis=0)
        antidote_train_y = np.concatenate([antidote_train_y, train_y.copy()], axis=0)

    antidote_train_X[train_X.shape[0]:] = random_comp(antidote_train_X[train_X.shape[0]:], dataset.feat_idx)
    antidote_train_X = dataset.scale(antidote_train_X, method="cat")

    print("\n=> Logistic Regression")
    clf = LogReg()
    clf.fit(antidote_train_X, antidote_train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if args.save:
        with open("../save/%s/random_comp.json" % dataset.name, "w") as f:
            json.dump(res, f)

    return


if __name__ == "__main__":
    args = parse_args()
    main()
