## Measures for evaluation of the Robustness.
import copy
from statistics import mode
from tqdm import tqdm
from matplotlib.pyplot import get
from ML_Models.base_model import BaseMLModel
from Scripts.deletion_attack import get_recourse_object, get_generative_model
import torch

def outcome_instab_summed(model: BaseMLModel, new_data_weights, data_train, labels_train, xcf, thres=0):
    """ Sum of absolute output changes. """
    model_refit = copy.deepcopy(model)
    model_refit.data_weights_vector.data = new_data_weights
    model_refit.fit(data_train, labels_train)
    output_prev = torch.abs(model.predict_with_logits(xcf))
    output_new = model_refit.predict_with_logits(xcf)
    abs_chng = torch.mean(torch.abs(output_prev-output_new))
    inval = torch.sum(output_new < thres).float()/len(output_new)
    return {"absolute_change": abs_chng.item(), "invalidation": inval.item()}


def action_instab_summed(model: BaseMLModel, new_data_weights, data_train, labels_train, x, cfs_old, recourse_str="scfe", vae_model=None):
    model_refit = copy.deepcopy(model)
    model_refit.data_weights_vector.data = new_data_weights
    model_refit.fit(data_train, labels_train)
    recourse_obj = get_recourse_object(recourse_str, model_refit, vae_model)
    print(f"Evaluating deletion for {torch.sum(new_data_weights).item()} points. Generating new CFs.")
    mycfs = []
    oldcfs = [] 
    for i in tqdm(range(len(x))):
        org_point = x[i].reshape(1,-1)
        #print(org_point)
        res = recourse_obj.generate_counterfactuals(org_point, target_class=1)
        mycfs.append(res)
        if res is not None:
            oldcfs.append(cfs_old[i].reshape(1,-1))
    mycfs = [i.reshape(1, -1) if i is not None else torch.empty(0, cfs_old.size(1)) for i in mycfs]
    mycfs = torch.cat(mycfs, dim=0)
    oldcfs= torch.cat(oldcfs, dim=0)
    #print(mycfs.shape, oldcfs.shape)
    model_output_cng = torch.mean(torch.abs(model.predict_with_logits(oldcfs)-model.predict_with_logits(mycfs)))
    sign_flips = torch.sum(torch.sign(model.predict_with_logits(oldcfs))*torch.sign(model.predict_with_logits(oldcfs)) < 0)/len(oldcfs)
    abs_chng = torch.mean(torch.norm(mycfs-oldcfs))
    return {"absolute_change": abs_chng.item(), "modeloutputchng": model_output_cng.item(), "sign_flips": sign_flips.item()}
