"""
Script to run the deletion attacks.

Run as follows:
export PYTHONPATH = "."
python3 Scripts/deletion_attack.py <configfile>

<dataset> can be admission
<attacks
"""

import sys
import json
import os

from ML_Models.data_loader import *
from Recourse_Methods.Generative_Model.vae_config import *

# Machine Learning models
from ML_Models.KernelRidgeRegression.model import KernelRidgeRegression, squared_exp_kernel, polynomial_kernel, neural_tangent_kernel, neural_tangent_kernel_derivative, polynomial_kernel_derivative
from ML_Models.KernelSVM.model import KernelSVM, KernelLeastSquaresSVM
from ML_Models.ANN.model import ANN
from ML_Models.LR.model import Regression

# Generative models
import Recourse_Methods.Generative_Model.model as model_vae

# Recourse Methods: IMF assumption
from Recourse_Methods.gradient_methods import SCFE
from Recourse_Methods.gradient_methods import DICE

# Recourse Methods: Manifold assumption
from Recourse_Methods.manifold_methods import CEM
from Recourse_Methods.manifold_methods import REVISE

# Attack mechanisms
from Tools.deletion_methods import select_random_point, find_impactful_data_greedy, compute_invalidation_curve
from tqdm import tqdm
import copy

# Store discovered basis for KernelLSSVM
basis_idx = None


def get_data(dataset_str):
    """ Get the data set.
        Return a tuple including (train_data, train_labels, test-data, test_labels, binary_mode)
        binary_mode is a string that indicates how this dataset's labels should be used for binary classification.
            Options: 'median' Split at the median. 'mean' Split labels at the mean. 'plain' Labels are already binary.
    """
    if dataset_str == "admission":
        train_load, test_load = return_loaders("admission", True, scaler="standard", target_scaler="standard")
        binary_mode = "median"
    elif dataset_str == "heloc":
        train_load, test_load = return_loaders("heloc", True, scaler="standard", target_scaler="standard")
        binary_mode = "median"
    elif dataset_str == "twomoons":
        train_load, test_load = return_loaders("twomoons", True, scaler="standard", target_scaler=None)
        binary_mode = "plain"
    elif dataset_str == "german":
        train_load, test_load = return_loaders("german", True, scaler="standard", target_scaler=None)
        binary_mode = "plain"
    elif dataset_str == "compas":
        train_load, test_load = return_loaders("compas", True, scaler="standard", target_scaler=None)
        binary_mode = "plain"
    elif dataset_str == "adult":
        train_load, test_load = return_loaders("adult", True, scaler="standard", target_scaler=None)
        binary_mode = "plain"
    elif dataset_str == "diabetes":
        train_load, test_load = return_loaders("diabetes", True, scaler="standard", target_scaler=None)
        binary_mode = "plain"
    else:
        raise ValueError("More datasets will be supported soon.")
    sdata = torch.tensor(train_load.dataset.data, dtype=torch.float32)
    labels = torch.tensor(train_load.dataset.targets, dtype=torch.float32)
    data_test = torch.tensor(test_load.dataset.data, dtype=torch.float32)
    labels_test = torch.tensor(test_load.dataset.targets, dtype=torch.float32)
    return sdata, labels, data_test, labels_test, binary_mode


def get_generative_model(input_size, vae_meta_info):
    # Load the VAEs
    vae_model = model_vae.VAE_model(input_size,
                                    vae_meta_info['activFun'],
                                    vae_meta_info['d'],
                                    vae_meta_info['H1'],
                                    vae_meta_info['H2'])
    vae_model.load_state_dict(torch.load(vae_meta_info["vae_path"]))
    return vae_model
    

def get_model(model_str, input_dim, train_sz, data_train):
    global basis_idx
    """
        :param model_str: Model name
        :param dataset_tuple: Name of the dataset
    """
    if model_str == "kernelridge":
        model = KernelRidgeRegression(polynomial_kernel, lambd=2.0, weighted_model=True, train_set_size=train_sz, kernel_deriv=polynomial_kernel_derivative)
        binary = False
    elif model_str == "ntk":
        model = KernelRidgeRegression(neural_tangent_kernel, lambd=(2.0 if input_dim > 4 else 5.0), weighted_model=True, train_set_size=train_sz, kernel_deriv=neural_tangent_kernel_derivative)
        binary = False
    elif model_str == "linreg":
        model = Regression(input_dim, task="regression", weighted_model=True, train_set_size=train_sz)
        binary = False
    elif model_str == "logreg":
        model = Regression(input_dim, task="classification", weighted_model=True, train_set_size=train_sz)
        binary = True
    elif model_str == "kernelsvm":
        model = KernelSVM(input_dim, weighted_model=True, train_set_size=train_sz, kernel_str="rbf")
        binary = True
    elif model_str == "kernellssvm":
        from ML_Models.KernelSVM.find_basis import find_basis_points
        if basis_idx == None:
            print("Computing basis for lssvm...")
            basis_idx = find_basis_points(data_train, squared_exp_kernel, num_candidates=100, max_basis_vectors=50, error_thres = 0.02, num_eval_cand=200)
        model = KernelLeastSquaresSVM(input_dim, support_points=data_train[basis_idx], weighted_model=True, train_set_size=train_sz, gamma=1.0)
        binary = True
    elif model_str == "ann":
        model =  ANN(input_dim, hidden_layer = 30, n_epochs = 15000, weighted_model=True, train_set_size=train_sz, fit_scipy=True, l2_reg_lambda=0.15)
        binary = True
    elif model_str == "annreg": # ANN for regression.
        model =  ANN(input_dim, hidden_layer = 40, n_epochs = 15000, weighted_model=True, 
                train_set_size=train_sz, fit_scipy=True, l2_reg_lambda=0.15, task="regression")
        binary = False

    else:
        raise ValueError("More models will be supported soon.")
    return model, binary


def get_recourse_object(recourse_str, model, vae_model):
    if recourse_str == "scfe":
        return SCFE(model, setting="regression", lr=2e-3, max_iter=2000, target_thres=0.0, _lambda=0.00, step=0.0)
    elif recourse_str == "dice":
        return DICE(model, setting="regression", lr=2e-3, max_iter=2000, target_thres=0.0, _lambda=0.00, step=0.0)
    elif recourse_str == "cem":
        return CEM(model, vae_model, setting="regression", lr=2e-3, max_iter=2000, target_thres=0.0, _lambda=0.00, step=0.0)
    elif recourse_str == "revise":
        return REVISE(model, vae_model, setting="regression", lr=2e-3, max_iter=2000, target_thres=0.0, _lambda=0.00, step=0.0)
    else:
        raise ValueError("More recourse methods will be supported soon.")


def get_attack_params(attack_str):
    """ Return (function, addtional_params) tuple"""
    if attack_str == "greedycnt":
        return find_impactful_data_greedy, {"mode": "cnt"}
    elif attack_str == "greedymaxinval":
        return find_impactful_data_greedy, {"mode": "maxinval"}
    elif attack_str == "greedymaxcng":
        return find_impactful_data_greedy, {"mode": "maxcng"}
    elif attack_str == "random":
        return select_random_point, None
    else:
        raise ValueError(f"Unsupported attack '{attack_str}'.")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Please pass a config file as argument.")
    print("Config file: ", sys.argv[1])
    config = json.load(open(sys.argv[1]))
    if os.path.isfile(config["outputfile"]):
        main_results_dict = json.load(open(config["outputfile"]))
    else:
        main_results_dict = {}
    for ds in config["datasets"]:
        basis_idx = None
        print("Running dataset:", ds)
        sdata, labels_load, data_test, labels_test_load, bin_mode = get_data(ds)
        print("Loading VAE")
        vae_dict = vae_meta_dictionary[ds]
        vae_model = get_generative_model(sdata.shape[1], vae_dict)
        if ds not in main_results_dict.keys():
            main_results_dict[ds] = {}

        results_lvl2 = main_results_dict[ds]

        for modeltype in config["models"]:
            print("Using model:", modeltype)
            # The ground truth model.
            model, binary = get_model(modeltype, sdata.shape[1], len(sdata), sdata)

            if binary: # Convert labels to binary.
                if bin_mode == "plain":
                    labels = labels_load.long()
                    labels_test = labels_test_load.long()
                elif bin_mode == "mean":
                    labels = (labels_load > labels_load.mean()).long()
                    labels_test = (labels_test_load > labels_load.mean()).long()
                elif bin_mode == "median":
                    labels = (labels_load > labels_load.median()).long()
                    labels_test = (labels_test_load > labels_load.median()).long()
                else:
                    raise ValueError("Unsupported Binary mode. The data set is not properly configured for a classification task.")
            else:
                labels = labels_load
                labels_test = labels_test_load

            #model.fit(sdata, labels)

            if modeltype not in results_lvl2.keys():
                results_lvl2[modeltype] = {}
            results_lvl3= results_lvl2[modeltype]

            for recource_model in config["recourse_methods"]:
                #if recource_model not in results_lvl3.keys():
                if recource_model not in results_lvl3.keys():
                    results_lvl3[recource_model] = []
                results_lvl4= results_lvl3[recource_model]

                # Split the testset
                for k in range(config["folds"]):
                    # Create model and recourse object.
                    model, binary = get_model(modeltype, sdata.shape[1], len(sdata), sdata)
                    print("Fitting model...")
                    model.fit(sdata, labels)
                    recourse = get_recourse_object(recource_model, model, vae_model)
                    #print(model.alpha_target)
                    # Create new dict if not existant already.
                    if len(results_lvl4) <= k:
                        results_dict = {}
                    else:
                        results_dict = results_lvl4[k]

                    results_dict["fold"] = k
                    # Select indices in fold
                    ind_use = (torch.arange(0, len(data_test)) % config["folds"] == k)
                    #ind_use_ids = torch.nonzero(ind_use).flatten()
                    #print(ind_use_ids)
                    #print(ind_use)
                    data_test_fold = data_test[ind_use]
                    pred_test_fold = model.predict_with_logits(data_test_fold)
    
                    data_test_fold_neg = data_test_fold[pred_test_fold < 0]
                    pred_test_fold_neg = pred_test_fold[pred_test_fold < 0]
                    # Compute recourse
                    print(f"Computing recourse for {len(data_test_fold_neg)} samples.")
                    results_dict["n_recourse_cand"] = len(data_test_fold_neg)
                    mycfs = []
                    for i in tqdm(range(len(data_test_fold_neg))):
                        org_point = data_test_fold_neg[i].reshape(1,-1)
                        #print(org_point)
                        res = recourse.generate_counterfactuals(org_point, target_class=1)
                        mycfs.append(res) 
                    mycfs = [i if i is not None else torch.empty(0, data_test_fold.size(1)) for i in mycfs]
                    mycfs = torch.cat(mycfs, dim=0)
                    print(mycfs.grad_fn)
                    results_dict["n_recourse_found"] = len(mycfs)
                    results_dict["recourse_success_rate"] = len(mycfs)/len(data_test_fold_neg)
                    for attackmethod in config["attacks"]:
                        attack_fn, addtional_params = get_attack_params(attackmethod)
                        model_inval = copy.deepcopy(model)
                        invalidation_rates, indices_removed = compute_invalidation_curve(model_inval, mycfs, sdata, labels, 
                                attack_fn, n_points_removed=config["max_deleted_points"], params=addtional_params)
                        results_dict[f"invalidation_{attackmethod}"] = invalidation_rates.flatten().numpy().tolist()
                        results_dict[f"indices_{attackmethod}"] = indices_removed.flatten().numpy().tolist()
                    if k >= len(results_lvl4):
                        results_lvl4.append(results_dict)
                    else:
                        results_lvl4[k] = results_dict
                    json.dump(main_results_dict, open(config["outputfile"], "w"))
