"""
Script to run the deletion attacks.

Run as follows:
export PYTHONPATH = "."
## Usage: python3 deletion_sgd_batch.py <dataset> <model> <recoursemethod> <params_file> <outputfile>

<dataset> can be admission
<attacks
"""

import sys
import json
import os
from turtle import up

from ML_Models.data_loader import *
from ML_Models.LR.model import Regression
from ML_Models.KernelRidgeRegression.model import KernelRidgeRegression, compute_kernel_derivative, squared_exp_kernel_deriv
from Recourse_Methods.Generative_Model.vae_config import *
from Scripts.deletion_attack import get_attack_params, get_data, get_generative_model, get_model, get_recourse_object

# Attack mechanisms
from Tools.deletion_methods import select_random_point, find_impactful_data_greedy, compute_invalidation_curve
from tqdm import tqdm
import copy
import torch.distributions.normal as normal_distribution
from Tools.eval_deletion import action_instab_summed, outcome_instab_summed
from Tools import jackknife
## Usage: python3 deletion_attack_batch.py <dataset> <model> <recoursemethod> <params_file> <outputfile> <fold>

def update_results(results_dict, ds, modeltype, recource_model, fold, outputfile):
    """ Update the results in the log. """
    if os.path.isfile(outputfile):
        main_results_dict = json.load(open(outputfile))
    else:
        main_results_dict = {}
    
    if ds not in main_results_dict.keys():
        main_results_dict[ds] = {}

    results_lvl2 = main_results_dict[ds]

    if modeltype not in results_lvl2.keys():
        results_lvl2[modeltype] = {}
    results_lvl3= results_lvl2[modeltype]

    #if recource_model not in results_lvl3.keys():
    if recource_model not in results_lvl3.keys():
        results_lvl3[recource_model] = []
    results_lvl4= results_lvl3[recource_model]

    if len(results_lvl4) <= fold:
        results_lvl4.append(results_dict)
    else:
        results_lvl4[fold] = results_dict
    json.dump(main_results_dict, open(outputfile, "w"))


class OptimizeDataWeightsMC:
    def __init__(self, model, data_train: torch.tensor, labels_train: torch.tensor, losstype="output",
                 lam: float = 0.0, lr: float = 0.005, max_iter: int = 1000, alpha: float = 1, sigma: float = 0.2,
                 norm_delta: int = 1, target_score: float = 0, k_mc_samples: int = 1000, solve_ana=True, refit_iter = int(1e8)):
        
        #self.X_extend = torch.cat((data_train, torch.ones(len(data_train), 1)), dim=1).float()
        self.data_train = data_train
        self.labels_train = labels_train.float()
        self.k_mc_samples = k_mc_samples
        self.target_score = target_score
        self.norm_delta = norm_delta
        self.max_iter = max_iter
        self.sigma = sigma
        self.alpha = alpha
        self.lam = lam
        self.lr = lr
        self.model = model
        self.objective = losstype
        self.loss_fn = self.loss_fn_action_instability if losstype=="action" else self.loss_fn_outcome_instability
        self.solve_ana = solve_ana # True if the counterfactual problem should be solved analytically.s
        self.refit_iter = refit_iter

    ## Helper functions for loss estimation.
    def _get_z(self, mu, eps):
        z = torch.max(torch.tensor(0), torch.min(torch.tensor(1), mu + eps))
        return z

    def compute_beta(self, data_weights):
        #W = torch.diag(data_weights)
        #beta = torch.linalg.inv(self.X_extend.T @ W @ self.X_extend) @ self.X_extend.T @ W @ self.labels_train
        beta = self.model.compute_parameters_from_data_weights(data_weights, self.data_train, self.labels_train)
        return beta

    def get_delta(self, beta, x, delta_star, analytical=True):
        """ Compute the expected change in the Counterfactuals. (Vectorized for many points)
            Return one delta per line.
        """
        #print("Delta input",beta.shape,  x.shape)
        if analytical:
            # Subtract offset.
            if type(self.model) == Regression:
                beta = beta.flatten()
                delta = ((self.target_score - self.model.predict_with_logits(x)) / (self.lam + torch.norm(beta[:-1], 2)**2)).reshape(-1,1).matmul(beta[:-1].reshape(1,-1))
                return delta
            if type(self.model) == KernelRidgeRegression:
                # Compute a linear approximation for the kernel regression model.
                linear_beta = model.kernel_deriv(x, self.data_train).transpose(1,2).matmul(beta.reshape(-1,1)).squeeze(2) 
                print(linear_beta.shape)
                delta = ((self.target_score - self.model.predict_with_logits(x).reshape(-1,1)) / (self.lam + torch.sum(linear_beta.pow(2), dim=1)).reshape(-1,1)) * linear_beta 
                return delta # shape [len(x), D]
        else:
            print(x.shape, mycfs.shape, delta_star.shape)
            # Use the jeckknife approximation on the counterfactual problem again.
            objective = jackknife.scfe_recourse_objective(self.model, 0.0) 
            # Dxcf/dW
            jac_mat = jackknife.jackknife_compute_jacobians(mycfs, objective, self.model.get_all_params().flatten().clone().requires_grad_(True), x) # [shape len(mycfs), len(weights,1), input_dims]
            print(torch.norm(model.get_all_params().flatten()-beta))
            return (jac_mat @ (model.get_all_params().flatten()-beta).reshape(-1,1)).squeeze(2) + delta_star

    def loss_ell0_surrogate(self, dat_weights_mu: torch.tensor):
        n_weights = dat_weights_mu.shape[0]
        # define standard normal distribution18.info.10

        normal = normal_distribution.Normal(loc=0, scale=1)
        loss_ell_0 = 0
        #print(torch.min(data_weights_mu), torch.max(data_weights_mu))
        for i in range(n_weights):
            loss_ell_0 += (1/n_weights) * normal.cdf(dat_weights_mu[i]/self.sigma)
        return loss_ell_0

    def loss_delta_mc(self, dat_weights_mu, x, xcf, delta_star, org_output):
        loss_delta = 0
        # use k Monte-Carlo samples to approximate loss
        for i in range(self.k_mc_samples):
            eps = torch.normal(torch.tensor(0), torch.tensor(self.sigma)**2, size=(dat_weights_mu.shape[0], ))
            z = self._get_z(dat_weights_mu, eps)
            #print(f'stochastic weight at iteration {i}:', z, torch.sum(z.long()))
            #beta = self.compute_beta(z.long().float())
            beta = self.model.compute_parameters_from_data_weights(z, self.data_train, self.labels_train)
            #print(torch.norm(beta - beta_star)
            #torch.sum(beta).backward()
            #print(dat_weights_mu.grad)
            if self.objective == "action":
                loss_delta += self.loss_fn_action_instability(beta, x, delta_star)
            else:
                loss_delta += self.loss_fn_outcome_instability(beta, xcf, org_output)
           
        return loss_delta/self.k_mc_samples


    def optimize_objective(self, dat_weights_mu, x, mycf):
        """ Parameters:
            dat_weights_mu: data weights vector
            mycfs: couterfactual explanations
        """
        #x = self.data_train[index,:]
        # compute coefficient
        #beta_star = self.compute_beta(dat_weights_mu)
        # compute perturbation
        delta_star = mycf - x
        f_star = self.model.predict_with_logits(mycf)
        print(f_star)
        # init data weights
        #dat_weights_mu = dat_weights_mu * 0.5
        dat_weights_mu += 0.001*torch.randn_like(dat_weights_mu)
        
        # optimizer
        dat_weights_mu.requires_grad_(True)
        optim = torch.optim.Adam([dat_weights_mu], self.lr)
        # loss containers
        loss_diffs = []
        loss_sparse = []
        iters = []
        # iterative optimization
        for i in range(self.max_iter):
            # compute losses
            loss_ell0 = self.loss_ell0_surrogate(1-dat_weights_mu)
            print('ell0 loss:', loss_ell0)
            loss_delta = self.loss_delta_mc(dat_weights_mu, x=x, xcf=mycf, delta_star=delta_star, org_output=f_star)
            #loss_delta.backward()
            #print(data_weights_mu.grad)
            print('delta loss:', loss_delta)
            total_loss = loss_delta + self.alpha * loss_ell0 + loss_delta #+ 
            # update weights
            optim.zero_grad()
            total_loss.backward()
            if not data_weights_mu.grad.isnan().any():
                optim.step()
                # collect losses
                loss_diffs.append((-1)*loss_delta.detach().numpy())
                loss_sparse.append(loss_ell0.detach().numpy())
                iters.append(i)
                print(total_loss.detach(), torch.sum(dat_weights_mu>0.99))
            # Refit:
            if i % self.refit_iter == self.refit_iter - 1:
                print("Refitting model.")
                z = torch.max(torch.tensor(0), torch.min(torch.tensor(1), data_weights_mu))
                self.model.data_weights_vector.data = z
                self.model.fit(self.data_train, self.labels_train)
        # plots
        #, axs = plt.subplots(nrows=1, ncols=2)
        #axs[0].plot(iters, loss_diffs, label=' (-1) * Delta Loss')
        #axs[0].plot(iters, loss_sparse, label='Sparsity Loss')
        #axs[0].set_xlabel('Iteration')
        #axs[0].set_ylabel('Loss')
        #axs[1].hist(dat_weights_mu.detach().numpy(), label='Weight Differnces', bins=10)
        ##axs[1].set_xlim([0, 1.5])
        #axs[0].legend()
        #axs[1].set_xlabel('Data Weights')
        #axs[1].set_xlabel('Frequency')
        #plt.tight_layout()
        return dat_weights_mu.detach()

    ## Loss functions
    def loss_fn_outcome_instability(self, beta, xcf, org_output=0):
        #print(self.model.predict_from_parameters(xcf, beta.reshape(1, -1)))
        return -torch.mean(torch.abs(torch.sigmoid(10.0*self.model.predict_from_parameters(xcf, beta.reshape(1, -1)))-1))

    def loss_fn_action_instability(self, beta, x, delta_star=0, org_output=0, norm_delta=2):
        #print(beta)
        delta = self.get_delta(beta, x=x, delta_star=delta_star, analytical=self.solve_ana)
        #print(delta.shape)
        loss_delta = -torch.mean(torch.norm(delta_star - delta, norm_delta, dim=1))
        return loss_delta

if __name__ == "__main__":
    if len(sys.argv) < 7:
        print("Please pass the following arguments: <dataset> <model> <recoursemethod> <params_file> <outputfile>")
        exit(-1)
    else:
        ds = sys.argv[1]
        modeltype= sys.argv[2]
        recource_model = sys.argv[3]
        params_file = sys.argv[4]
        outputfile = sys.argv[5]
        fold_id = int(sys.argv[6])

    print("Config file: ", params_file)
    config = json.load(open(params_file))

    print("Running dataset:", ds)
    sdata, labels_load, data_test, labels_test_load, bin_mode = get_data(ds)
    print("Loading VAE")
    vae_dict = vae_meta_dictionary[ds]
    vae_model = get_generative_model(sdata.shape[1], vae_dict)
    
    print("Using model:", modeltype)

    k = fold_id 

    # The ground truth model.
    model, binary = get_model(modeltype, sdata.shape[1], len(sdata), sdata)

    if binary: # Convert labels to binary.
        if bin_mode == "plain":
            labels = labels_load.long()
            labels_test = labels_test_load.long()
        elif bin_mode == "mean":
            labels = (labels_load > labels_load.mean()).long()
            labels_test = (labels_test_load > labels_load.mean()).long()
        elif bin_mode == "median":
            labels = (labels_load > labels_load.median()).long()
            labels_test = (labels_test_load > labels_load.median()).long()
        else:
            raise ValueError("Unsupported Binary mode. The data set is not properly configured for a classification task.")
    else:
        labels = labels_load
        labels_test = labels_test_load

    # Split the testset

    # Create model and recourse object.
    model, binary = get_model(modeltype, sdata.shape[1], len(sdata), sdata)
    print("Fitting model...")
    model.fit(sdata, labels)
    recourse = get_recourse_object(recource_model, model, vae_model)
    #print(model.alpha_target)
    # Create new dict if not existant already.
    results_dict = {}
    results_dict["fold"] = k
    # Select indices in fold
    ind_use = (torch.arange(0, len(data_test)) % config["folds"] == k)
    #ind_use_ids = torch.nonzero(ind_use).flatten()
    #print(ind_use_ids)
    #print(ind_use)
    data_test_fold = data_test[ind_use]
    pred_test_fold = model.predict_with_logits(data_test_fold)
    #(pred_test_fold)
    data_test_fold_neg = data_test_fold[pred_test_fold < 0]
    pred_test_fold_neg = pred_test_fold[pred_test_fold < 0]
    # Compute recourse
    print(f"Computing recourse for {len(data_test_fold_neg)} samples.")
    results_dict["n_recourse_cand"] = len(data_test_fold_neg)
    mycfs = []
    x_orgs = [] # List of x's for which a CF has been found.
    for i in tqdm(range(len(data_test_fold_neg))):
        org_point = data_test_fold_neg[i].reshape(1,-1)
        #print(org_point)
        res = recourse.generate_counterfactuals(org_point, target_class=1)
        mycfs.append(res)
        if res is not None:
            x_orgs.append(data_test_fold_neg[i].reshape(1,-1))

    mycfs = [i.reshape(1, -1) if i is not None else torch.empty(0, data_test_fold.size(1)) for i in mycfs]
    mycfs = torch.cat(mycfs, dim=0)
    x_orgs = torch.cat(x_orgs, dim=0)
    #print(mycfs.shape, x_orgs.shape)
    results_dict["n_recourse_found"] = len(mycfs)
    results_dict["recourse_success_rate"] = len(mycfs)/len(data_test_fold_neg)

    for attack in config["attacks"]:
        if attack == "sgd":
            dw_optimizer = OptimizeDataWeightsMC(model, data_train=sdata, labels_train=labels, losstype=config["losstype"],
                    lam=config["lambda"], max_iter=config["max_iter"], alpha = config["alpha"], k_mc_samples=config["k_mc_samples"], 
                    sigma=config["sigma"], solve_ana = config["analytical"], refit_iter=config["refit_iter"])
            data_weights_mu = torch.ones(len(sdata))
            index = 1
            mu_ret = dw_optimizer.optimize_objective(data_weights_mu, x_orgs, mycfs)
            deleted_idx = torch.argsort(mu_ret)
        elif attack == "random":
            deleted_idx = torch.randperm(len(sdata))
        else:
            raise ValueError("Unknown attack")

        results_dict[f"indices_{attack}"] = deleted_idx[:config["max_deleted_points"]].numpy().tolist()
        vlist = []
        for kpnt in range(config["max_deleted_points"]+1):
            new_weights = torch.ones(len(sdata), dtype=torch.long)
            new_weights[deleted_idx[:(kpnt)]] = 0
            print(data_weights_mu[deleted_idx[:(kpnt)]])
            if config["losstype"] == "outcome":
                vlist.append(outcome_instab_summed(model, new_weights, sdata, labels, mycfs, thres=0))
            else:
                vlist.append(action_instab_summed(model, new_weights, sdata, labels, x_orgs, mycfs, recourse_str=recource_model, vae_model=vae_model))
        results_dict[f"values_{attack}"] = vlist

    #results_dict.result = weights_ret.tolist()
    #print(results_dict["values"])
    update_results(results_dict, ds, modeltype, recource_model, k, outputfile)



