""" Ablation study on algorithmic tradeoff """

import json
import pandas as pd
import numpy as np
from typing import Callable

from dataset import fetch_dataset, Dataset, FeatIndex
from model import Synthesizer, NeuralNets, SenSeI, LCIFR, DRO
from eval import Evaluator

DATASET = "compas"


def anti(n_iter: int, comp_func: Callable, scale_func: Callable, evaluator: Evaluator, train_X: np.ndarray,
         train_y: np.ndarray, n_repeat: int = 5) -> None:
    """ Tradeoff of NN + Anti. """

    model_path = "./save/%s/model_cond.pt" % DATASET
    syn = Synthesizer.load(model_path)

    antidote_train_X = train_X.copy()
    antidote_train_y = train_y.copy()
    N_original = antidote_train_X.shape[0]
    for i in range(n_iter):
        syn_X_list = syn.sample_all(train_X)
        comparable_list = [comp_func(train_X, syn_X) for syn_X in syn_X_list]
        for syn_X, comp in zip(syn_X_list, comparable_list):
            syn_X = syn_X[comp]
            syn_y = train_y.copy()
            syn_y = syn_y[comp]
            antidote_train_X = np.concatenate([antidote_train_X, syn_X], axis=0)
            antidote_train_y = np.concatenate([antidote_train_y, syn_y], axis=0)

    antidote_train_X = scale_func(antidote_train_X, method="cat")
    N_antidote = antidote_train_X.shape[0] - N_original
    rate_antidote = float(N_antidote) / N_original
    print("\n=> Additional %.5f antidote data" % rate_antidote)

    save = {"pos_mean": [], "neg_mean": [], "roc": [], "ap": []}
    for i in range(n_repeat):
        print("\n=> Neural Networks w. Antidote Data")
        clf = NeuralNets(input_dim=train_X.shape[1])
        clf.fit(antidote_train_X, antidote_train_y)
        res = evaluator(clf.pred, clf.pred_proba)

        save["pos_mean"].append(np.mean(res["or"]["true"]))
        save["neg_mean"].append(np.mean(res["or"]["false"]))
        save["roc"].append(res["roc"])
        save["ap"].append(res["ap"])

    save["rate"] = rate_antidote
    save["n_iter"] = n_iter

    with open("./ablation/%s_anti_nn_%d.json" % (DATASET, n_iter), "w") as f:
        json.dump(save, f)

    return


def dro(n_iter: int, comp_func: Callable, scale_func: Callable, evaluator: Evaluator, train_X: np.ndarray,
        train_y: np.ndarray, feat_idx: FeatIndex, n_repeat: int = 5):
    model_path = "./save/%s/model_cond.pt" % DATASET
    syn = Synthesizer.load(model_path)

    syn_X_list = []
    comp_list = []
    for i in range(n_iter):
        syn_X_temp_list = syn.sample_all(train_X)
        comparable_list = [comp_func(train_X, syn_X)[:, np.newaxis] for syn_X in syn_X_temp_list]
        syn_X_list.extend(syn_X_temp_list)
        comp_list.extend(comparable_list)

    syn_X_list = [scale_func(syn_X, method="cat") for syn_X in syn_X_list]
    syn_X_list = [syn_X[:, np.newaxis, :] for syn_X in syn_X_list]
    syn_X = np.concatenate(syn_X_list, axis=1)
    comp_mat = np.concatenate(comp_list, axis=1)

    N_antidote = np.sum(comp_mat)
    rate_antidote = float(N_antidote) / train_X.shape[0]
    print("\n=> Additional %.5f antidote data for DRO" % rate_antidote)

    train_X_scale = scale_func(train_X, method="cat")

    save = {"pos_mean": [], "neg_mean": [], "roc": [], "ap": []}
    for i in range(n_repeat):
        print("\n=> DRO with Antidote Data")
        clf = DRO(sen_idx=feat_idx.sen_idx, drop_sen=True)
        clf.fit(train_X_scale, train_y, syn_X, comp_mat)
        res = evaluator(clf.pred, clf.pred_proba)

        save["pos_mean"].append(np.mean(res["or"]["true"]))
        save["neg_mean"].append(np.mean(res["or"]["false"]))
        save["roc"].append(res["roc"])
        save["ap"].append(res["ap"])

    save["rate"] = rate_antidote
    save["n_iter"] = n_iter

    with open("./ablation/%s_dro_%d.json" % (DATASET, n_iter), "w") as f:
        json.dump(save, f)

    return


def sensei(rho: float, evaluator: Evaluator, train_X: np.ndarray, train_y: np.ndarray, train_df: pd.DataFrame,
           feat_idx: FeatIndex, n_repeat: int = 5):
    save = {"pos_mean": [], "neg_mean": [], "roc": [], "ap": []}
    for i in range(n_repeat):
        clf = SenSeI(
            input_dim=train_X.shape[1],
            n_input=train_X.shape[0],
            rho=rho,
        )
        clf.fit(X=train_X, y=train_y, train_df=train_df, sen_feat=feat_idx.sen_feat, sen_idx=feat_idx.sen_idx,
                verbose=True)
        res = evaluator(clf.pred, clf.pred_proba)
        save["pos_mean"].append(np.mean(res["or"]["true"]))
        save["neg_mean"].append(np.mean(res["or"]["false"]))
        save["roc"].append(res["roc"])
        save["ap"].append(res["ap"])

    save["rho"] = rho
    with open("./ablation/%s_sensei_%d.json" % (DATASET, rho), "w") as f:
        json.dump(save, f)

    return


def lcifr(weight: float, evaluator: Evaluator, train_X: np.ndarray, train_y: np.ndarray, feat_idx: FeatIndex,
          n_repeat: int = 5):
    num_feat_idx = feat_idx.num_idx
    sen_feat2idx = {feat: idx for feat, idx in feat_idx.feat2idx.items() if feat in feat_idx.sen_feat}
    save = {"pos_mean": [], "neg_mean": [], "roc": [], "ap": []}
    for i in range(n_repeat):
        clf = LCIFR(
            sen_feat2idx=sen_feat2idx,
            num_feat_idx=num_feat_idx,
            dl2_weight=weight,
        )
        clf.fit(train_X, train_y)
        res = evaluator(clf.pred, clf.pred_proba)
        save["pos_mean"].append(np.mean(res["or"]["true"]))
        save["neg_mean"].append(np.mean(res["or"]["false"]))
        save["roc"].append(res["roc"])
        save["ap"].append(res["ap"])

    save["weight"] = weight
    with open("./ablation/%s_lcifr_%.1f.json" % (DATASET, weight), "w") as f:
        json.dump(save, f)

    return


def main():
    dataset = fetch_dataset(DATASET)
    train_X, train_y = dataset.train_data(scale="num")
    test_X, test_y = dataset.test_data(scale="all")
    test_comp_data = dataset.comp_data(batch_size=1024, train=False, scale="all")
    evaluator = Evaluator(test_X, test_y, test_comp_data)

    # for i in range(6, 12):
    #     anti(n_iter=i, comp_func=dataset.is_comparable, scale_func=dataset.scale, evaluator=evaluator, train_X=train_X,
    #          train_y=train_y)

    for i in range(7, 13):
        dro(n_iter=i, comp_func=dataset.is_comparable, scale_func=dataset.scale, evaluator=evaluator, train_X=train_X,
            train_y=train_y, feat_idx=dataset.feat_idx)

    # train_X = dataset.scale(train_X, method="cat")
    #
    # for rho in (1e+3, 5e+3, 1e+4, 5e+4, 1e+5, 2e+5, 5e+5):
    #     sensei(rho, evaluator=evaluator, train_X=train_X, train_y=train_y, train_df=dataset.train_df,
    #            feat_idx=dataset.feat_idx)
    #
    # for weight in (0.1, 1, 10, 50, 100):
    #     lcifr(weight, evaluator, train_X, train_y, dataset.feat_idx, )

    return


if __name__ == "__main__":
    main()
