import json
import argparse
import numpy as np

from dataset import fetch_dataset, Dataset
from eval import Evaluator
from model import Synthesizer, LogReg, NeuralNets, SenDrop, DRO
from baseline import run_basic_baseline, run_fair_baseline

"""
Adult: iter_anti = 1; iter_dro = 5;
Compas: iter_anti = 8; iter_dro = 10;
Law School: iter_anti = 5; iter_dro = 30;
Dutch: iter_anti = 4; iter_dro = 15;
Oulad: iter_anti = 7; iter_dro = 10;
"""


def parse_args():
    parser = argparse.ArgumentParser(description='Individual Fairness with Antidote Data')
    parser.add_argument('--dataset', type=str, default="adult", help="name of the dataset")
    parser.add_argument('--load', action="store_false", help="load synthesizer")
    parser.add_argument('--model_save', action="store_false")
    parser.add_argument('--cond', type=bool, default=True)
    parser.add_argument('--iter_anti', type=int, default=1)
    parser.add_argument('--iter_dro', type=int, default=5)
    parser.add_argument('--res_save', action="store_false")
    args = parser.parse_args()

    return args


def run_pre(
        train_X: np.ndarray,
        train_y: np.ndarray,
        evaluator: Evaluator,
        dataset: Dataset,
        syn: Synthesizer,
        n_iter: int,
        save: bool = True,
):
    """ Run models with antidote data as input """

    antidote_train_X = train_X.copy()
    antidote_train_y = train_y.copy()

    N_original = antidote_train_X.shape[0]

    for i in range(n_iter):
        syn_X_list = syn.sample_all(train_X)
        comparable_list = [dataset.is_comparable(train_X, syn_X) for syn_X in syn_X_list]
        for syn_X, comp in zip(syn_X_list, comparable_list):
            syn_X = syn_X[comp]
            syn_y = train_y.copy()
            syn_y = syn_y[comp]
            antidote_train_X = np.concatenate([antidote_train_X, syn_X], axis=0)
            antidote_train_y = np.concatenate([antidote_train_y, syn_y], axis=0)

    antidote_train_X = dataset.scale(antidote_train_X, method="cat")

    N_antidote = antidote_train_X.shape[0] - N_original
    rate_antidote = float(N_antidote) / N_original
    print("\n=> Additional %.5f antidote data" % rate_antidote)

    print("\n=> Logistic Regression w. Antidote Data")
    clf = LogReg()
    clf.fit(antidote_train_X, antidote_train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open("./save/%s/antidote_lr.json" % dataset.name, "w") as f:
            json.dump(res, f)

    print("\n=> Discard Sensitive Features in Logistic Regression w. Antidote Data")
    clf = SenDrop(drop_idx=dataset.feat_idx.sen_idx, model=LogReg())
    clf.fit(antidote_train_X, antidote_train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open("./save/%s/antidote_dis_lr.json" % dataset.name, "w") as f:
            json.dump(res, f)

    print("\n=> Neural Networks w. Antidote Data")
    clf = NeuralNets(input_dim=dataset.feat_dim)
    clf.fit(antidote_train_X, antidote_train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open("./save/%s/antidote_nn.json" % dataset.name, "w") as f:
            json.dump(res, f)

    print("\n=> Discard Sensitive Features in Neural Networks w. Antidote Data")
    clf = SenDrop(
        drop_idx=dataset.feat_idx.sen_idx,
        model=NeuralNets(input_dim=dataset.feat_dim - len(dataset.feat_idx.sen_idx)),
    )
    clf.fit(antidote_train_X, antidote_train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open("./save/%s/antidote_dis_nn.json" % dataset.name, "w") as f:
            json.dump(res, f)

    return


def run_dro(
        train_X: np.ndarray,
        train_y: np.ndarray,
        evaluator: Evaluator,
        dataset: Dataset,
        syn: Synthesizer,
        n_iter: int,
        save: bool = True,
):
    """ Distributionally robust optimization with antidote data """

    clf = DRO(sen_idx=dataset.feat_idx.sen_idx, drop_sen=True)

    syn_X_list = []
    comp_list = []
    for i in range(n_iter):
        syn_X_temp_list = syn.sample_all(train_X)
        comparable_list = [dataset.is_comparable(train_X, syn_X)[:, np.newaxis] for syn_X in syn_X_temp_list]
        syn_X_list.extend(syn_X_temp_list)
        comp_list.extend(comparable_list)

    syn_X_list = [dataset.scale(syn_X, method="cat") for syn_X in syn_X_list]
    syn_X_list = [syn_X[:, np.newaxis, :] for syn_X in syn_X_list]
    syn_X = np.concatenate(syn_X_list, axis=1)
    comp_mat = np.concatenate(comp_list, axis=1)

    N_antidote = np.sum(comp_mat)
    rate_antidote = float(N_antidote) / train_X.shape[0]
    print("\n=> Additional %.5f antidote data for DRO" % rate_antidote)

    train_X_scale = dataset.scale(train_X, method="cat")
    clf.fit(train_X_scale, train_y, syn_X, comp_mat)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open("./save/%s/dro.json" % dataset.name, "w") as f:
            json.dump(res, f)

    return


def main():
    dataset = fetch_dataset(args.dataset)

    train_X, train_y = dataset.train_data(scale="num")
    test_X, test_y = dataset.test_data(scale="all")
    test_comp_data = dataset.comp_data(batch_size=1024, train=False, scale="all")
    evaluator = Evaluator(test_X, test_y, test_comp_data)

    train_X_scale = dataset.scale(train_X, method="cat")

    """ Train or load synthesizer """

    if args.load:
        model_path = "./save/%s/model_cond.pt" % dataset.name if args.cond else "./save/%s/model.pt" % dataset.name
        syn = Synthesizer.load(model_path)
    else:
        print("\n=> Train Synthesizer")
        train_comp_data = dataset.comp_data(batch_size=1024, train=True, scale="num")
        syn = Synthesizer()
        syn.fit(train_X, dataset.feat_idx, comp_data=train_comp_data, cond=args.cond, comp_func=dataset.is_comparable)
        if args.model_save:
            if args.cond:
                model_save_path = "./save/%s/model_cond.pt" % dataset.name
            else:
                model_save_path = "./save/%s/model.pt" % dataset.name
            syn.save(model_save_path)

    """ Baselines """

    run_basic_baseline(train_X_scale, train_y, dataset.feat_dim, dataset.feat_idx.sen_idx, evaluator,
                       "./save/%s" % dataset.name)
    run_fair_baseline(train_X_scale, train_y, dataset.train_df, dataset.feat_dim, dataset.feat_idx,
                      evaluator, "./save/%s" % dataset.name)

    """ Run models """

    run_pre(train_X, train_y, evaluator, dataset, syn, args.iter_anti, args.res_save)
    run_dro(train_X, train_y, evaluator, dataset, syn, args.iter_dro, args.res_save)

    return


if __name__ == "__main__":
    args = parse_args()
    main()
