import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from pytorch_wavelets import DWT1DForward, DWT1DInverse
from scipy.integrate import trapezoid
from sklearn.metrics import auc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def add_noise(ts):
    mu, sigma = 0, 0.1 # mean and standard deviation
    noise = np.random.normal(mu, sigma, ts.shape[0])
    noisy_ts = np.add(ts.reshape(ts.shape[0]),noise.reshape(ts.shape[0]))
    return noisy_ts

def robustness(explanations, noisy_explanations):
    robust = 0
    # r_2=0 
    for i in range(0,len(explanations)):
        # print(explanations)
        explanation = {key: explanations[i][0][key] for key in sorted(explanations[i][0])}
        explanation = np.array(list(explanation.values()))
        noisy_explanation = {key: noisy_explanations[i][0][key] for key in sorted(noisy_explanations[i][0])}
        noisy_explanation = np.array(list(noisy_explanation.values()))
        original_order = np.argsort(explanation)
        noisy_order = np.argsort(noisy_explanation)

        if len(original_order)==len(noisy_order) and np.array_equal(original_order,noisy_order[:len(original_order)]):
            robust += 1
        # if np.array_equal(original_order,noisy_order[:len(original_order)]):
            # r_2+=1
    return robust/len(explanations)

def reverse_segment(ts, index0, index1):
    perturbed_ts = ts.copy()
    if perturbed_ts.shape[0]<index1:
        print("Error, check the dim!!!", perturbed_ts.shape, index0, index1)
        exit()
    perturbed_ts[index0:index1] = np.flip(ts[index0:index1])
    return perturbed_ts

def global_perturbed_segment(ts, index0, index1, mean, std):
    perturbed_ts = ts.copy()
    perturbation = np.random.normal(loc=mean, scale=std, size=index1 - index0)
    perturbation = perturbation.reshape(perturbation.shape[0], 1)
    # print(perturbation.shape, perturbed_ts.shape)
    # exit()
    perturbed_ts[index0:index1] =  perturbation
    # print(ts.flatten()-perturbed_ts.flatten())
    # print(index0, index1)
    # exit()
    return perturbed_ts

def segment_perturb(ts, position, pertubation):
    perturbed_ts = ts.copy()
    perturbed_ts[position] = 0
    return perturbed_ts

def insert_segment_perturb(ts, position, pertubation):
    perturbed_ts = pertubation.copy()
    perturbed_ts[position] = ts[position]
    return perturbed_ts

def delete_segment(ts, index0, index1):
    perturbed_ts = ts.copy()
    if perturbed_ts.shape[0]<index1:
        print("Error, check the dim!!!", perturbed_ts.shape, index0, index1)
        exit()
    perturbed_ts[index0:index1] = 0
    return perturbed_ts

def insert_segment(ts, index0, index1):
    perturbed_ts = ts.copy()
    # perturbed_ts[index0:index1] = 0
    perturbed_ts[:index0] = 0
    perturbed_ts[index1:] = 0
    return perturbed_ts

def faithfulness(explanations, x_test, y_test, original_predictions, model, model_type):
    perturbed_samples = []
    for i in range(0,len(explanations)):
        top_index = np.argmax(np.abs(explanations[i][0]))
        segment_indices = explanations[i][1]+[-1]
        example_ts = x_test[i].copy()
        reversed_sample = reverse_segment(example_ts,segment_indices[top_index],segment_indices[top_index+1])
        perturbed_samples.append(reversed_sample)

    if model_type == 'proba':
        reversed_predictions = model.predict(np.asarray(perturbed_samples))
        correct_indexes = []
        differences = []
        for i in range(0,len(y_test)):
            if y_test[i] == np.argmax(reversed_predictions[i]): # Why only calculate metrics on that? => Should be wrong
                correct_indexes.append(i)
        for index in correct_indexes:
            prediction_index = int(np.argmax(original_predictions[index]))
            differences.append(np.abs(original_predictions[index][prediction_index] - reversed_predictions[index][prediction_index]))
        return np.mean(differences)
    else:
        
        reversed_samples = np.asarray(perturbed_samples)
        reversed_predictions = model.predict(reversed_samples.reshape(reversed_samples.shape[:2]))
        correct_indexes = []
        # print(original_predictions)
        # print(reversed_predictions)
        # exit()
        for i in range(0,len(original_predictions)):
            # try:
            #     print(original_predictions[i] == reversed_predictions[i])
            # except:
            #     print("hung"*10, original_predictions[i], reversed_predictions[i])
            #     exit()
            if original_predictions[i] == reversed_predictions[i]:
                correct_indexes.append(i)
        return len(correct_indexes)/len(original_predictions)


def pytorch_faithfulness(explanations, test_dataset, model):
    model.eval()
    perturbed_samples = []
    original_samples = []
    for i in range(0,len(explanations)):
        test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
        original_samples.append(test_instance)
        example_ts = test_instance.cpu().numpy().copy()
        test_instance = test_instance.unsqueeze(0)
    
        exp_dict = {key: abs(value) for key, value in explanations[i][0].items()}
        top_index, max_value = max(exp_dict.items(), key=lambda item: item[1])
        segment_indices = explanations[i][1]
        # top_index = np.argmax(np.abs(explanations[i][0]))
        # print(segment_indices, top_index)
        # exit()
        # segment_indices = explanations[i][1]+[len(example_ts)]
        # print(segment_indices, explanations[i][0][0])
        # exit()
        # if top_index >= len(segment_indices):
            # print(i, top_index, segment_indices, np.abs(explanations[i][0]))
        example_ts = example_ts.flatten()
        reversed_sample = reverse_segment(example_ts,segment_indices[top_index][0],segment_indices[top_index][1])
        reversed_sample = torch.Tensor(reversed_sample).unsqueeze(0)
        # reversed_sample = reversed_sample[:, None]
        # print(example_ts.reshape(example_ts.shape[0]))
        # print(reversed_sample.reshape(reversed_sample.shape[0]))
        perturbed_samples.append(reversed_sample)
    # exit() 
    original_samples = torch.stack(original_samples)
    with torch.no_grad():
        original_predictions = model(original_samples.to(device))
        original_predictions = torch.nn.Softmax(dim=1)(original_predictions)
    original_predictions = original_predictions.detach().cpu().numpy()
    # reversed_predictions = model.predict(np.asarray(perturbed_samples))
    perturbed_samples = torch.stack(perturbed_samples)
    with torch.no_grad():
        reversed_predictions = model(torch.Tensor(np.asarray(perturbed_samples)).to(device))
        reversed_predictions = torch.nn.Softmax(dim=1)(reversed_predictions)
    reversed_predictions = reversed_predictions.detach().cpu().numpy()
    # print(original_predictions.shape)
    # print(reversed_predictions.shape)
    # exit()
    correct_indexes = []
    differences = []
    for i in range(0,len(test_dataset)):
        test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
        if test_label == np.argmax(original_predictions[i]): # Only evaluate on correct predicted data points
        # if True:
            correct_indexes.append(i)
    # print(len(correct_indexes))
    # exit()
    for index in correct_indexes:
        prediction_index = int(np.argmax(original_predictions[index]))
        differences.append(np.abs(original_predictions[index][prediction_index] - reversed_predictions[index][prediction_index]))
    # print(np.mean(differences))
    # print(differences)
    # exit()
    return np.mean(differences)

def pytorch_faithfulness_perturbed(explanations, test_dataset, model):
    model.eval()
    perturbed_samples = []
    original_samples = []
    for i in range(0,len(explanations)):
        test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
        original_samples.append(test_instance)
        example_ts = test_instance.cpu().numpy().copy()
        test_instance = test_instance.unsqueeze(0)
    
        exp_dict = {key: abs(value) for key, value in explanations[i][0].items()}
        top_index, max_value = max(exp_dict.items(), key=lambda item: item[1])
        segment_indices = explanations[i][1]
        # top_index = np.argmax(np.abs(explanations[i][0]))
        # print(segment_indices, top_index)
        # exit()
        # segment_indices = explanations[i][1]+[len(example_ts)]
        # print(segment_indices, explanations[i][0][0])
        # exit()
        # if top_index >= len(segment_indices):
            # print(i, top_index, segment_indices, np.abs(explanations[i][0]))
        example_ts = example_ts.flatten()
        reversed_sample = delete_segment(example_ts,segment_indices[top_index][0],segment_indices[top_index][1])
        reversed_sample = torch.Tensor(reversed_sample).unsqueeze(0)
        # reversed_sample = reversed_sample[:, None]
        # print(example_ts.reshape(example_ts.shape[0]))
        # print(reversed_sample.reshape(reversed_sample.shape[0]))
        perturbed_samples.append(reversed_sample)
    # exit() 
    original_samples = torch.stack(original_samples)
    with torch.no_grad():
        original_predictions = model(original_samples.to(device))
        original_predictions = torch.nn.Softmax(dim=1)(original_predictions)
    original_predictions = original_predictions.detach().cpu().numpy()
    # reversed_predictions = model.predict(np.asarray(perturbed_samples))
    perturbed_samples = torch.stack(perturbed_samples)
    with torch.no_grad():
        reversed_predictions = model(torch.Tensor(np.asarray(perturbed_samples)).to(device))
        reversed_predictions = torch.nn.Softmax(dim=1)(reversed_predictions)
    reversed_predictions = reversed_predictions.detach().cpu().numpy()
    # print(original_predictions.shape)
    # print(reversed_predictions.shape)
    # exit()
    correct_indexes = []
    differences = []
    for i in range(0,len(test_dataset)):
        test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
        if test_label == np.argmax(original_predictions[i]): # Only evaluate on correct predicted data points
        # if True:
            correct_indexes.append(i)
    # print(len(correct_indexes))
    # exit()
    for index in correct_indexes:
        prediction_index = int(np.argmax(original_predictions[index]))
        differences.append(np.abs(original_predictions[index][prediction_index] - reversed_predictions[index][prediction_index]))
    # print(np.mean(differences))
    # print(differences)
    # exit()
    return np.mean(differences)

def pytorch_faithfulness_attribution_dwt(attributions, test_dataset, model, wavelet_level, pertubation, thresholds = [8, 20, 30]):
    dwt1d = DWT1DForward(wave='haar', J=wavelet_level).to(device)
    idwt1d = DWT1DInverse(wave='haar').to(device)
    test_series, _ = test_dataset[0][0], test_dataset[0][1].cpu().item()
    # data_len = test_series.shape[-1]
    # num_top_feattures = 10 # Chinatown
    results_dict = {}
    top_positions_sets = []
    for threshold in thresholds:
        # num_top_features = int(len(attributions[0])*threshold) # Others
        num_top_features = threshold
        model.eval()
        perturbed_samples = []
        original_samples = []
        for i in range(0,len(attributions)):
            test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
            test_instance = test_instance.unsqueeze(dim=0).to(device)
            if wavelet_level!=0:
                testl, testh = dwt1d(test_instance)
                wave_input = [testl] + testh
                test_instance = torch.cat(wave_input, dim=-1)
            # print(test_instance.shape)
            # exit()
            # reconstructed_X = idwt1d((testl, testh))
            # print(reconstructed_X)
            # exit()
            original_samples.append(test_instance)
            example_ts = test_instance.cpu().numpy().copy()
            # test_instance = test_instance.unsqueeze(0)

            attribution = attributions[i]
            feature_order = np.argsort(attribution)
            top_features = feature_order[-num_top_features:]

            if threshold == thresholds[-1]:
                top_positions_sets.append(top_features.tolist())

            # if 252 in feature_order:
            #     print(len(feature_order), len(attribution))
            #     print(feature_order)
            #     exit(1)
        
            # exp_dict = {key: abs(value) for key, value in explanations[i][0].items()}
            # top_index, max_value = max(exp_dict.items(), key=lambda item: item[1])
            # segment_indices = explanations[i][1]
            # top_index = np.argmax(np.abs(explanations[i][0]))
            # print(segment_indices, top_index)
            # exit()
            # segment_indices = explanations[i][1]+[len(example_ts)]
            # print(segment_indices, explanations[i][0][0])
            # exit()
            # if top_index >= len(segment_indices):
                # print(i, top_index, segment_indices, np.abs(explanations[i][0]))
            example_ts = example_ts.flatten()
            perturbed_sample = segment_perturb(example_ts, top_features, pertubation)
            # print(perturbed_sample, example_ts)
            # exit()
            perturbed_sample = torch.Tensor(perturbed_sample).unsqueeze(0).unsqueeze(0)
            # perturbed_sample = torch.Tensor(perturbed_sample).unsqueeze(0)
            # reversed_sample = reversed_sample[:, None]
            # print(example_ts.reshape(example_ts.shape[0]))
            # print(reversed_sample.reshape(reversed_sample.shape[0]))
            perturbed_samples.append(perturbed_sample)
        # exit() 
        original_samples = torch.cat(original_samples, dim=0)
        # print(original_samples[0].shape)
        # exit()
        with torch.no_grad():
            original_predictions = model(original_samples.to(device))
            original_predictions = torch.nn.Softmax(dim=1)(original_predictions)
        original_predictions = original_predictions.detach().cpu().numpy()
        # reversed_predictions = model.predict(np.asarray(perturbed_samples))
        perturbed_samples = torch.cat(perturbed_samples, dim=0)
        # perturbed_samples = torch.stack(perturbed_samples)
        with torch.no_grad():
            pertubed_predictions = model(torch.Tensor(np.asarray(perturbed_samples)).to(device))
            pertubed_predictions = torch.nn.Softmax(dim=1)(pertubed_predictions)
        pertubed_predictions = pertubed_predictions.detach().cpu().numpy()
        # print(original_predictions.shape)
        # print(reversed_predictions.shape)
        # exit()
        correct_indexes = []
        differences = []
        # for i in range(0, 1): # fixing it .......
        # for i in range(0,len(test_dataset)):
            # test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
            # if test_label == np.argmax(original_predictions[i]): # Only evaluate on correct predicted data points
            # if True:
                # correct_indexes.append(i)
        # print(len(correct_indexes))
        # exit()
        for index in range(0, len(test_dataset)):
            test_instance, test_label = test_dataset[index][0], test_dataset[index][1].cpu().item()
            prediction_index = test_label
            differences.append(original_predictions[index][prediction_index] - pertubed_predictions[index][prediction_index])
        # print(np.mean(differences))
        # print(differences)
        # exit()
        results_dict[threshold] = np.mean(differences).item()
    return results_dict, top_positions_sets

def segment_pertube_with_mask(ts, mask):
    perturbed_ts = ts.copy()
    random_array = np.random.normal(loc=0.0, scale=1.0, size=ts.shape)
    perturbed_ts = mask*ts + (1-mask)*random_array
    return perturbed_ts

def pytorch_AUCStop_attribution_dwt(attributions, test_dataset, model, wavelet_level, pertubation):
    dwt1d = DWT1DForward(wave='haar', J=wavelet_level).to(device)
    idwt1d = DWT1DInverse(wave='haar').to(device)
    test_series, _ = test_dataset[0][0], test_dataset[0][1].cpu().item()
    # data_len = test_series.shape[-1]
    # num_top_feattures = 10 # Chinatown
    results_dict = {}
    top_positions_sets = []
    ratios = [0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.0]
    AUCStop_scores = []
    AUCS_bottom_scores = []
    F1S_scores = []
    for i in range(0,len(attributions)):
        attribution = attributions[i]
        test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
        test_instance = test_instance.unsqueeze(dim=0).to(device)
        # print(test_instance)
        if wavelet_level!=0:
            testl, testh = dwt1d(test_instance)
            wave_input = [testl] + testh
            test_instance = torch.cat(wave_input, dim=-1)
        # reconstructed_X = idwt1d((testl, testh))
        # print(reconstructed_X)
        # exit()
        # original_samples.append(test_instance)
        example_ts = test_instance.cpu().numpy().copy()
        model.eval()
        perturbed_samples = []
        bottom_perturbed_samples = []
        original_samples = []
        n_features_ratios = [0]
        bottom_n_features_ratios = [0]
        for ratio in ratios:
            threshold = np.quantile(attribution, 1-ratio)
            mask = (attribution>=threshold).astype(int)
            mask = 1- mask
            
            bottom_threshold = np.quantile(attribution, ratio)
            bottom_mask = (attribution<=bottom_threshold).astype(int)
            bottom_mask = 1- bottom_mask
            
            n_features_ratios.append(float(np.sum(mask==0)/mask.shape[0]))
            bottom_n_features_ratios.append(float(np.sum(bottom_mask==0)/bottom_mask.shape[0]))
            example_ts = example_ts.flatten()
            
            perturbed_sample = segment_pertube_with_mask(example_ts, mask)
            perturbed_sample = torch.Tensor(perturbed_sample).unsqueeze(0).unsqueeze(0)
            perturbed_samples.append(perturbed_sample)
            
            bottom_perturbed_sample = segment_pertube_with_mask(example_ts, bottom_mask)
            bottom_perturbed_sample = torch.Tensor(bottom_perturbed_sample).unsqueeze(0).unsqueeze(0)
            bottom_perturbed_samples.append(bottom_perturbed_sample)
            
            original_samples.append(test_instance)
        # exit() 
        original_samples = torch.cat(original_samples, dim=0)
        perturbed_samples = torch.cat(perturbed_samples, dim=0)
        bottom_perturbed_samples = torch.cat(bottom_perturbed_samples, dim=0)
        with torch.no_grad():
            predictions = model(original_samples.to(device))
            predictions = torch.nn.Softmax(dim=1)(predictions)
            perturb_predictions = model(perturbed_samples.to(device))
            perturb_predictions = torch.nn.Softmax(dim=1)(perturb_predictions)
            bottom_perturb_predictions = model(bottom_perturbed_samples.to(device))
            bottom_perturb_predictions = torch.nn.Softmax(dim=1)(bottom_perturb_predictions)
        label_pred = torch.argmax(predictions[0]).cpu().item()
        if label_pred == test_label:
            original_pred = predictions[0, test_label].cpu().item()
            change_scores = ((predictions[:, test_label] - perturb_predictions[:, test_label])/original_pred).cpu().numpy()
            change_scores = np.concatenate((np.array([0]), change_scores), axis=0)
            bottom_change_scores = ((predictions[:, test_label] - bottom_perturb_predictions[:, test_label])/original_pred).cpu().numpy()
            bottom_change_scores = np.concatenate((np.array([0]), bottom_change_scores), axis=0)
            
            # n_features_ratios.append(1)
            # bottom_n_features_ratios.append(1)
            
            AUCStop_score = trapezoid(change_scores, np.array(n_features_ratios))
            # AUCStop_score = auc(change_scores, np.array(n_features_ratios))
            # if AUCStop_score < 0:
            #     print(change_scores)
            #     print(predictions.flatten())
            #     print(perturb_predictions.flatten())
            #     print(original_pred)
            #     print(np.array(n_features_ratios))
            #     print(AUCStop_score)
            #     exit()
            AUCStop_scores.append(AUCStop_score) 
            # print(AUCStop_score, AUCStop_scores)
            AUCS_bottom_score = trapezoid(bottom_change_scores, np.array(bottom_n_features_ratios))
            AUCS_bottom_scores.append(AUCS_bottom_score)
            F1S_score = (AUCStop_score*(1-AUCS_bottom_score))/(AUCStop_score+(1-AUCS_bottom_score))
            F1S_scores.append(F1S_score)
    # print(AUCStop_scores)
    # exit()
    AUCStop_mean_score = sum(AUCStop_scores)/len(AUCStop_scores)
    # AUCS_bottom_mean_score = sum(AUCStop_scores)/len(AUCStop_scores)
    F1S_mean_score = sum(F1S_scores)/len(F1S_scores) 
        # print(original_predictions.shape)
        # print(pertubed_predictions.shape)
        # exit()
        # correct_indexes = []
        # for i in range(0, 1): # fixing it .......
        # for i in range(0,len(test_dataset)):
            # test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
            # if test_label == np.argmax(original_predictions[i]): # Only evaluate on correct predicted data points
            # if True:
                # correct_indexes.append(i)
        # print(len(correct_indexes))
        # exit()
        # differences = []
        # for index in range(0, len(test_dataset)):
        #     test_instance, test_label = test_dataset[index][0], test_dataset[index][1].cpu().item()
        #     prediction_index = test_label
        #     differences.append(original_predictions[index][prediction_index] - pertubed_predictions[index][prediction_index])
        # print(np.mean(differences))
        # print(differences)
        # exit()
    return AUCStop_mean_score, F1S_mean_score



def post_pytorch_faithfulness_attribution_dwt(positions, test_dataset, model, wavelet_level, pertubation, thresholds = [8, 20]):
    dwt1d = DWT1DForward(wave='haar', J=wavelet_level).to(device)
    idwt1d = DWT1DInverse(wave='haar').to(device)
    test_series, _ = test_dataset[0][0], test_dataset[0][1].cpu().item()
    data_len = test_series.shape[-1]
    # num_top_feattures = 10 # Chinatown
    results_dict = {}
    top_positions_sets = []
    for threshold in thresholds:
        # num_top_features = int(len(attributions[0])*threshold) # Others
        num_top_features = threshold
        model.eval()
        perturbed_samples = []
        original_samples = []
        for i in range(0,len(positions)):
            test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
            test_instance = test_instance.unsqueeze(dim=0).to(device)
            # print(test_instance)
            testl, testh = dwt1d(test_instance)
            wave_input = [testl] + testh
            test_instance = torch.cat(wave_input, dim=-1)
            # reconstructed_X = idwt1d((testl, testh))
            # print(reconstructed_X)
            # exit()
            original_samples.append(test_instance)
            example_ts = test_instance.cpu().numpy().copy()
            # test_instance = test_instance.unsqueeze(0)

            top_features = positions[i][-threshold:]
            # feature_order = np.argsort(attribution)
            # top_features = feature_order[-num_top_features:]

            if threshold == thresholds[-1]:
                top_positions_sets.append(top_features)

            # if 252 in feature_order:
            #     print(len(feature_order), len(attribution))
            #     print(feature_order)
            #     exit(1)
        
            # exp_dict = {key: abs(value) for key, value in explanations[i][0].items()}
            # top_index, max_value = max(exp_dict.items(), key=lambda item: item[1])
            # segment_indices = explanations[i][1]
            # top_index = np.argmax(np.abs(explanations[i][0]))
            # print(segment_indices, top_index)
            # exit()
            # segment_indices = explanations[i][1]+[len(example_ts)]
            # print(segment_indices, explanations[i][0][0])
            # exit()
            # if top_index >= len(segment_indices):
                # print(i, top_index, segment_indices, np.abs(explanations[i][0]))
            example_ts = example_ts.flatten()
            perturbed_sample = segment_perturb(example_ts, top_features, pertubation)
            # print(perturbed_sample, example_ts)
            # exit()
            perturbed_sample = torch.Tensor(perturbed_sample).unsqueeze(0)
            # reversed_sample = reversed_sample[:, None]
            # print(example_ts.reshape(example_ts.shape[0]))
            # print(reversed_sample.reshape(reversed_sample.shape[0]))
            perturbed_samples.append(perturbed_sample)
        # exit() 
        original_samples = torch.cat(original_samples, dim=0)
        # print(original_samples.shape)
        # exit()
        with torch.no_grad():
            original_predictions = model(original_samples.to(device))
            original_predictions = torch.nn.Softmax(dim=1)(original_predictions)
        original_predictions = original_predictions.detach().cpu().numpy()
        # reversed_predictions = model.predict(np.asarray(perturbed_samples))
        perturbed_samples = torch.stack(perturbed_samples)
        with torch.no_grad():
            pertubed_predictions = model(torch.Tensor(np.asarray(perturbed_samples)).to(device))
            pertubed_predictions = torch.nn.Softmax(dim=1)(pertubed_predictions)
        pertubed_predictions = pertubed_predictions.detach().cpu().numpy()
        # print(original_predictions.shape)
        # print(reversed_predictions.shape)
        # exit()
        correct_indexes = []
        differences = []
        # for i in range(0, 1): # fixing it .......
        # for i in range(0,len(test_dataset)):
            # test_instance, test_label = test_dataset[i][0], test_dataset[i][1].cpu().item()
            # if test_label == np.argmax(original_predictions[i]): # Only evaluate on correct predicted data points
            # if True:
                # correct_indexes.append(i)
        # print(len(correct_indexes))
        # exit()
        for index in range(0, len(test_dataset)):
            test_instance, test_label = test_dataset[index][0], test_dataset[index][1].cpu().item()
            prediction_index = test_label
            differences.append(original_predictions[index][prediction_index] - pertubed_predictions[index][prediction_index])
        # print(np.mean(differences))
        # print(differences)
        # exit()
        results_dict[threshold] = np.mean(differences).item()
    return results_dict, top_positions_sets

