"""
Script to run the deletion attacks.

Run as follows:
export PYTHONPATH = "."
python3 Scripts/deletion_attack.py <configfile>

<dataset> can be admission
<attacks
"""

import sys
import json
import os
from turtle import up

from ML_Models.data_loader import *
from Recourse_Methods.Generative_Model.vae_config import *
from Scripts.deletion_attack import get_attack_params, get_data, get_generative_model, get_model, get_recourse_object

# Attack mechanisms
from Tools.deletion_methods import select_random_point, find_impactful_data_greedy, compute_invalidation_curve
from tqdm import tqdm
import copy

## Usage: python3 deletion_attack_batch.py <dataset> <model> <recoursemethod> <params_file> <outputfile>

def update_results(results_dict, ds, modeltype, recource_model, fold, outputfile):
    """ Update the results in the log. """
    if os.path.isfile(outputfile):
        main_results_dict = json.load(open(outputfile))
    else:
        main_results_dict = {}
    
    if ds not in main_results_dict.keys():
        main_results_dict[ds] = {}

    results_lvl2 = main_results_dict[ds]

    if modeltype not in results_lvl2.keys():
        results_lvl2[modeltype] = {}
    results_lvl3= results_lvl2[modeltype]

    #if recource_model not in results_lvl3.keys():
    if recource_model not in results_lvl3.keys():
        results_lvl3[recource_model] = []
    results_lvl4= results_lvl3[recource_model]

    if len(results_lvl4) <= fold:
        results_lvl4.append(results_dict)
    else:
        results_lvl4[fold] = results_dict
    json.dump(main_results_dict, open(outputfile, "w"))


if __name__ == "__main__":
    if len(sys.argv) < 6:
        print("Please pass the following arguments: <dataset> <model> <recoursemethod> <params_file> <outputfile>")
        exit(-1)
    else:
        ds = sys.argv[1]
        modeltype= sys.argv[2]
        recource_model = sys.argv[3]
        params_file = sys.argv[4]
        outputfile = sys.argv[5]

    print("Config file: ", params_file)
    config = json.load(open(params_file))

    print("Running dataset:", ds)
    sdata, labels_load, data_test, labels_test_load, bin_mode = get_data(ds)
    print("Loading VAE")
    vae_dict = vae_meta_dictionary[ds]
    vae_model = get_generative_model(sdata.shape[1], vae_dict)
    
    print("Using model:", modeltype)
    # The ground truth model.
    model, binary = get_model(modeltype, sdata.shape[1], len(sdata), sdata)

    if binary: # Convert labels to binary.
        if bin_mode == "plain":
            labels = labels_load.long()
            labels_test = labels_test_load.long()
        elif bin_mode == "mean":
            labels = (labels_load > labels_load.mean()).long()
            labels_test = (labels_test_load > labels_load.mean()).long()
        elif bin_mode == "median":
            labels = (labels_load > labels_load.median()).long()
            labels_test = (labels_test_load > labels_load.median()).long()
        else:
            raise ValueError("Unsupported Binary mode. The data set is not properly configured for a classification task.")
    else:
        labels = labels_load
        labels_test = labels_test_load

    # Split the testset
    for k in range(config["folds"]):
        # Create model and recourse object.
        model, binary = get_model(modeltype, sdata.shape[1], len(sdata), sdata)
        print("Fitting model...")
        model.fit(sdata, labels)
        recourse = get_recourse_object(recource_model, model, vae_model)
        #print(model.alpha_target)
        # Create new dict if not existant already.
        results_dict = {}
        results_dict["fold"] = k
        # Select indices in fold
        ind_use = (torch.arange(0, len(data_test)) % config["folds"] == k)
        #ind_use_ids = torch.nonzero(ind_use).flatten()
        #print(ind_use_ids)
        #print(ind_use)
        data_test_fold = data_test[ind_use]
        pred_test_fold = model.predict_with_logits(data_test_fold)

        data_test_fold_neg = data_test_fold[pred_test_fold < 0]
        pred_test_fold_neg = pred_test_fold[pred_test_fold < 0]
        # Compute recourse
        print(f"Computing recourse for {len(data_test_fold_neg)} samples.")
        results_dict["n_recourse_cand"] = len(data_test_fold_neg)
        mycfs = []
        for i in tqdm(range(len(data_test_fold_neg))):
            org_point = data_test_fold_neg[i].reshape(1,-1)
            #print(org_point)
            res = recourse.generate_counterfactuals(org_point, target_class=1)
            mycfs.append(res) 
        mycfs = [i.reshape(1, -1) if i is not None else torch.empty(0, data_test_fold.size(1)) for i in mycfs]
        mycfs = torch.cat(mycfs, dim=0)
        #print(mycfs.grad_fn)
        results_dict["n_recourse_found"] = len(mycfs)
        results_dict["recourse_success_rate"] = len(mycfs)/len(data_test_fold_neg)
        for attackmethod in config["attacks"]:
            attack_fn, addtional_params = get_attack_params(attackmethod)
            model_inval = copy.deepcopy(model)
            invalidation_rates, indices_removed = compute_invalidation_curve(model_inval, mycfs, sdata, labels, 
                    attack_fn, n_points_removed=config["max_deleted_points"], params=addtional_params)
            results_dict[f"invalidation_{attackmethod}"] = invalidation_rates.flatten().numpy().tolist()
            results_dict[f"indices_{attackmethod}"] = indices_removed.flatten().numpy().tolist()
        update_results(results_dict, ds, modeltype, recource_model, k, outputfile)
