import re
import os
import glob
import sys
import torch
import random
import torchvision
import numpy as np
import pandas as pd
from scipy.stats import entropy
from PIL import Image

# from tqdm.notebook import tqdm
from tqdm import tqdm
from scipy.stats import wasserstein_distance_nd
from scipy.spatial.distance import jensenshannon
from rex_xai.output.database import db_to_pandas


def calculate_iou(mask1, mask2):
    """
    Calculate Intersection over Union (IoU) for two binary masks.
    
    Args:
        mask1: Binary mask (numpy array or torch tensor)
        mask2: Binary mask (numpy array or torch tensor)
    
    Returns:
        float: IoU value between 0 and 1
    """
    # Convert to numpy if torch tensors
    if hasattr(mask1, 'cpu'):
        mask1 = mask1.cpu().numpy()
    if hasattr(mask2, 'cpu'):
        mask2 = mask2.cpu().numpy()
    
    # Ensure binary masks
    mask1 = (mask1 > 0).astype(np.float32)
    mask2 = (mask2 > 0).astype(np.float32)
    
    # Calculate intersection and union
    intersection = np.sum(mask1 * mask2)
    union = np.sum(mask1) + np.sum(mask2) - intersection
    
    # Handle edge case where both masks are empty
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    
    return intersection / union

def normalize_landscape(landscape):
    """
    Normalise the landscape
    
    Args:
        landscape: Landscape (numpy array)
    """
    return (landscape - np.min(landscape)) / (np.max(landscape) - np.min(landscape))

def calculate_dice(mask1, mask2):
    """
    Calculate DICE coefficient for two binary masks.
    
    Args:
        mask1: Binary mask (numpy array or torch tensor)
        mask2: Binary mask (numpy array or torch tensor)
    
    Returns:
        float: DICE coefficient between 0 and 1
    """
    # Convert to numpy if torch tensors
    if hasattr(mask1, 'cpu'):
        mask1 = mask1.cpu().numpy()
    if hasattr(mask2, 'cpu'):
        mask2 = mask2.cpu().numpy()
    
    # Ensure binary masks
    mask1 = (mask1 > 0).astype(np.float32)
    mask2 = (mask2 > 0).astype(np.float32)
    
    # Calculate intersection
    intersection = np.sum(mask1 * mask2)
    
    # Calculate DICE coefficient
    dice = (2.0 * intersection) / (np.sum(mask1) + np.sum(mask2))
    
    # Handle edge case where both masks are empty
    if np.sum(mask1) + np.sum(mask2) == 0:
        return 1.0
    
    return dice

def calculate_pixel_contribution(mask, confidence):
    """
    Calculate per-pixel contribution to confidence value.
    
    Args:
        mask: Binary mask (numpy array or torch tensor)
        confidence: Float confidence value
    
    Returns:
        float: Per-pixel contribution value (scalar)
    """
    # Convert to numpy if torch tensor
    if hasattr(mask, 'cpu'):
        mask = mask.cpu().numpy()
    
    # Ensure binary mask
    mask = (mask > 0).astype(np.float32)
    
    # Count number of active pixels
    num_active_pixels = np.sum(mask)
    
    # Handle edge case where no pixels are active
    if num_active_pixels == 0:
        return 0.0
    
    # Calculate per-pixel contribution
    contribution_per_pixel = confidence / num_active_pixels
    
    return contribution_per_pixel

def calculate_spectral_entropy(resp_map):
    '''
    Calculate the spectral entropy of a response map
    
    Args:
        resp_map: Response map (numpy array)
    
    Returns:
        float: Spectral entropy value
    
    '''
    normed_resp_map = resp_map / np.sum(resp_map)
    normed_resp_map = normed_resp_map[normed_resp_map > 0]

    se = entropy(normed_resp_map, base=2)

    return se

def reapplied_loss(base_diretory,
                        data_directory,
                        seeds = [42, 43, 44, 45],
                        thresholds = [0.9, 0.7, 0.5, 0.3, 0.1, 0],
                        model = None,
                        ad_model = None,
                        device = None):
    """
    Calculate the accuracy of the masks when applied again to the input
    """
    original_accuracy = {}
    accuracy_dict = {}
    accuracy_count = {}
    mask_vals = ['min', 'max', 'mean', 0, 'AD']
    
    for threshold in tqdm(thresholds, desc = "Thresholds"):
        threshold_dir = os.path.join(base_diretory, f"Threshold_{threshold}")
        for seed in tqdm(seeds, desc = "Seeds"):
            original_accuracy[seed] = {}
            accuracy_dict[seed] = {}
            accuracy_count[seed] = {}
            for mask_val in tqdm(mask_vals, desc = "Mask Values"):
                if mask_val != 'AD':
                    exp_dir = os.path.join(threshold_dir, f"ResNet_masking_val_{mask_val}_seed_{seed}_threshold_{threshold}")
                else:
                    exp_dir = os.path.join(threshold_dir, f"ResNet_AD_seed_{seed}_threshold_{threshold}")
                
                #Now, get the csv file from the exp_dir with the same name, and calculate the accuracy
                results_df = pd.read_csv(os.path.join(exp_dir, exp_dir.split('/')[-1] + '_secondary_.csv'))
                original_accuracy[seed][mask_val] = (results_df['actual classification'] == results_df['predicted classification']).sum() / len(results_df)
                correctly_predicted_list = results_df[results_df['actual classification'] == results_df['predicted classification']]['Filename'].map(lambda x: x.split('/')[-1].split('.')[0]).tolist()
                correct_classes = results_df[results_df['actual classification'] == results_df['predicted classification']]['actual classification'].values.tolist()

                #Now, need to load the mask, apply to the input, and then calculate how many of them are retained
                batch_data = []
                batch_classes = []
                mask_data = []
                num_correct = 0
                for idx, file in enumerate(correctly_predicted_list):
                    if (idx % 64 == 0 and idx != 0) or idx == len(correctly_predicted_list) - 1:
                            if mask_val != 'AD':
                                batched_data = torch.stack(batch_data, dim = 0)
                                batched_classes = torch.tensor(batch_classes)
                            
                                with torch.no_grad():
                                    model.eval()
                                    model.to(device)
                                    output_prob, output_class = torch.topk(model(batched_data.to(device)).softmax(dim = 1), k = 1)
                                    output_prob = output_prob.cpu()
                                    output_class = output_class.cpu()

                                correctly_predicted = (output_class == batched_classes).sum().item()
                                num_correct += correctly_predicted

                            else:
                                for i in range(len(batch_data)):
                                    masked_mutant = batch_data[i]
                                    input_mask = mask_data[i]
                                    with torch.no_grad():
                                        ad_model.eval()
                                        ad_model.to(device)
                                        output_prob, output_class = torch.topk(ad_model(masked_mutant.float().to(device).unsqueeze(0),
                                                                                        explanation_mode=True,
                                                                                        explanation_mask=input_mask.float().to(device).unsqueeze(0)).softmax(dim = 1), k = 1)
                                        
                                        output_prob = output_prob.cpu()
                                        output_class = output_class.cpu()


                                    if output_class.item() == batch_classes[i]:
                                        num_correct += 1
                            
                            batch_data = []
                            batch_classes = []
                            mask_data = []
                    else:
                        mask_file = torch.from_numpy(np.load(os.path.join(exp_dir, file + '_mask.npy'))).squeeze()
                        img_file = torchvision.io.decode_image(os.path.join(data_directory, file + '.jpeg'))
                        if len(img_file.shape) == 2 or img_file.shape[0] == 1:
                            img_file = img_file.expand(3, -1, -1)
                        masked_mutant = model.transforms(img_file) * mask_file.unsqueeze(0)
                        batch_data.append(masked_mutant)
                        batch_classes.append(correct_classes[idx])
                        mask_data.append(mask_file)

                accuracy_dict[seed][mask_val] = num_correct / len(correctly_predicted_list)
                accuracy_count[seed][mask_val] = num_correct
        
    return original_accuracy, accuracy_dict, accuracy_count



def organize_experiment_data(base_directory, seed = 42):
    """
    Organize ResNet experiment data into dictionaries.
    
    Args:
        base_directory: Path to directory containing ResNet_masking and ResNet_AD subdirectories
    
    Returns:
        tuple: (resnet_masking_dict, resnet_ad_dict)
    """
    resnet_masking_dict = {}
    resnet_ad_dict = {}
    
    # Get all directories in the base directory
    all_dirs = [d for d in os.listdir(base_directory) 
                if os.path.isdir(os.path.join(base_directory, d)) and str(seed) in d]
    
    # Process ResNet_masking directories
    masking_dirs = [d for d in all_dirs if '_masking' in d]
    for dir_name in masking_dirs:
        dir_path = os.path.join(base_directory, dir_name)
        # Find all .npy files in this directory
        npy_files = glob.glob(os.path.join(dir_path, '*.npy'))
        # Extract just the filenames
        npy_filenames = [os.path.basename(f) for f in npy_files]
        resnet_masking_dict[dir_name] = npy_filenames
    
    # Process ResNet_AD directories
    ad_dirs = [d for d in all_dirs if '_AD' in d]
    for dir_name in ad_dirs:
        dir_path = os.path.join(base_directory, dir_name)
        # Find all .npy files in this directory
        npy_files = glob.glob(os.path.join(dir_path, '*.npy'))
        # Extract just the filenames
        npy_filenames = [os.path.basename(f) for f in npy_files]
        resnet_ad_dict[dir_name] = npy_filenames
    
    return resnet_masking_dict, resnet_ad_dict


def calculate_mask_metrics(masking_data, ad_data, 
                           dataset_path, model, ad_model, 
                           base_dir = '../ImageNet-onek/Results/Threshold_0.9/',
                           seed = 42,
                           num_backgrounds = 100,
                           device = None):

    mask_dirs = list(masking_data.keys())
    mask_files = masking_data[mask_dirs[0]]
    ad_dir = list(ad_data.keys())
    assert len(ad_dir) == 1, "Only one AD directory is supported"
 
    mask_dbs = [db_to_pandas(db=os.path.join(base_dir,mask_dir,mask_dir+'.db')) for mask_dir in mask_dirs]
    ad_db = db_to_pandas(db=os.path.join(base_dir,ad_dir[0],ad_dir[0]+'.db'))

    #Getting the colours for coloured backgrounds experiments
    colours = []
    base_backgrounds =  ['min', 'max', 'mean', 'zero', 'ad']
    
    random.seed(seed)
    for i in range(num_backgrounds - len(base_backgrounds)):
        colours.append(tuple(random.randint(0, 255) for _ in range(3)))

    colours = colours + base_backgrounds

    #Get 100 random images from the dataset
    all_background_files = []
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.endswith('.jpeg') or file.endswith('.jpg'):
                all_background_files.append(os.path.join(root, file))

    all_background_files = random.sample(all_background_files, 100)

    #Adding the filename colum for easy indexing
    for mask_db in mask_dbs:
        mask_db['filename'] = mask_db['path'].apply(lambda x: x.split('/')[-1].split('.')[0])
    ad_db['filename'] = ad_db['path'].apply(lambda x: x.split('/')[-1].split('.')[0])
    
    iou_dict = {}
    dice_dict = {}
    per_pixel_dict = {}
    pixel_count_dict = {}
    robustness_dict_ad = {}
    robustness_dict_masked = {}
    robustness_dict_ad_images = {}
    robustness_dict_masked_images = {}
    
    for mask_file in tqdm(mask_files, desc="Calculating metrics"):
        if ad_db[ad_db['filename'] == mask_file.replace('.jpg', '').rsplit('_',1)[0]].empty:
            continue
        iou_dict[mask_file] = []
        dice_dict[mask_file] = []
        per_pixel_dict[mask_file] = []
        pixel_count_dict[mask_file] = []

        mask_paths = []

        for mask_dir in mask_dirs:
            mask_path = os.path.join(base_dir, mask_dir, mask_file)
            mask_paths.append(mask_path)

        ad_path = os.path.join(base_dir, ad_dir[0], mask_file)

        # Now we calculate the metrics
        for i in range(len(mask_paths)):
            try:
                mask1 = np.load(mask_paths[i])
                ad_mask = np.load(ad_path)
            except:
                continue
            
            failed_load = False
            for j in range(i+1, len(mask_paths)):
                try:
                    mask2 = np.load(mask_paths[j])
                except:
                    failed_load = True
                    break
                #Calculate the pairwise metrics
                iou_dict[mask_file].append(calculate_iou(mask1, mask2))
                dice_dict[mask_file].append(calculate_dice(mask1, mask2))

            if failed_load:
                continue

            #Calculate the metrics for the masking masks
            mask_exp_conf = mask_dbs[i][mask_dbs[i]['filename'] == mask_file.replace('.jpg', '').rsplit('_',1)[0]]['explanation_confidence'].values[0]
            mask_exp_class = mask_dbs[i][mask_dbs[i]['filename'] == mask_file.replace('.jpg', '').rsplit('_',1)[0]]['target'].values[0]
            per_pixel_dict[mask_file].append(calculate_pixel_contribution(mask1, mask_exp_conf))
            pixel_count_dict[mask_file].append(np.sum(mask1))
                
            #Calculate the metrics for the AD mask
            iou_dict[mask_file].append(calculate_iou(mask1, ad_mask))
            dice_dict[mask_file].append(calculate_dice(mask1, ad_mask))

            #Calculate the robustness metrics
            if mask_paths[i].split('/')[-2] not in robustness_dict_masked:
                robustness_dict_masked[mask_paths[i].split('/')[-2]] = {}
            
            if mask_paths[i].split('/')[-2] not in robustness_dict_masked_images:
                robustness_dict_masked_images[mask_paths[i].split('/')[-2]] = {}
       
            if 'CalTech-256' in dataset_path:
                # img_file = torchvision.io.decode_image(os.path.join(dataset_path, mask_file.rsplit('_',1)[0] + '.jpg'))
                img_file = Image.open(os.path.join(dataset_path, mask_file.replace('.jpg', '').rsplit('_',1)[0] + '.jpg')).convert('RGB')
            else:
                # img_file = torchvision.io.decode_image(os.path.join(dataset_path, mask_file.rsplit('_',1)[0] + '.jpeg')).float()
                img_file = Image.open(os.path.join(dataset_path, mask_file.rsplit('_',1)[0] + '.jpeg')).convert('RGB')
            
            if len(np.array(img_file).shape) == 2 or np.array(img_file).shape[-1] == 1:
                # img_file = img_file.expand(3, -1, -1)
                img_file = Image.fromarray(np.stack([np.array(img_file)]*3, axis = -1))
            robustness_dict_masked[mask_paths[i].split('/')[-2]][mask_file] = {}
            robustness_dict_masked_images[mask_paths[i].split('/')[-2]][mask_file] = {}

            #Calculate the output of the model on a different image as background
            batch_data_masked_images = []
            batch_data_masked_backgrounds = []
            for idx, background_file in enumerate(all_background_files):
                # background_img_file = torchvision.io.decode_image(background_file).float()
                background_img_file = Image.open(background_file).convert('RGB')
                if len(np.array(background_img_file).shape) == 2 or np.array(background_img_file).shape[-1] == 1:
                    # background_img_file = background_img_file.expand(3, -1, -1)
                    background_img_file = Image.fromarray(np.stack([np.array(background_img_file)]*3, axis = -1))
                
                transformed_background_img_file = model.transforms(background_img_file)
                transformed_mutant_img_file = model.transforms(img_file)
                mutant_mask = torch.from_numpy(mask1)

                background_img_mutant = torch.where(mutant_mask == 0, transformed_background_img_file, transformed_mutant_img_file)
                batch_data_masked_images.append(background_img_mutant)
                batch_data_masked_backgrounds.append(background_file)

                if len(batch_data_masked_images) == 64 or idx == len(all_background_files) - 1:
                    with torch.no_grad():
                        model.eval()
                        masked_output_prob, masked_output_class = torch.topk(model(torch.stack(batch_data_masked_images).to(device)).softmax(dim=1), k=1)
                        
                        batch_data_masked_images = []
                        masked_entries_images = []
                        for d in range(len(masked_output_class)):
                            if masked_output_class[d] != mask_exp_class:
                                masked_entries_images.append(('N', masked_output_prob[d].item(), masked_output_class[d].item(), mask_exp_conf))
                            else:
                                masked_entries_images.append(('Y', masked_output_prob[d].item(), masked_output_class[d].item(), mask_exp_conf))

                        for d in range(len(batch_data_masked_backgrounds)):
                            robustness_dict_masked_images[mask_paths[i].split('/')[-2]][mask_file][batch_data_masked_backgrounds[d]] = masked_entries_images[d]
                        batch_data_masked_backgrounds = []

            #Calculate the output of the model using a diffent colour background
            batch_data_masked = []
            batch_data_masked_colour = []
            for idx, colour in enumerate(colours):
                if colour == 'min':
                    # colour = model.transforms(img_file).reshape(3,-1).min(dim=1,keepdim=True)[0].view(3,1,1)
                    colour = torch.min(model.transforms(img_file)).item()
                    key = 'min'
                elif colour == 'max':
                    # colour = model.transforms(img_file).reshape(3,-1).max(dim=1,keepdim=True)[0].view(3,1,1)
                    colour = torch.max(model.transforms(img_file)).item()
                    key = 'max'
                elif colour == 'mean':
                    # colour = torch.round(torch.mean(model.transforms(img_file), dim=(1,2))).int().view(3,1,1)
                    # colour = torch.ones_like(model.transforms(img_file)) * torch.mean(model.transforms(img_file))
                    colour = torch.mean(model.transforms(img_file)).item()
                    key = 'mean'
                elif colour == 'zero':
                    colour = torch.zeros_like(model.transforms(img_file))
                    key = 'zero'
                elif colour != 'ad':
                    # colour = torch.ones_like(torch.from_numpy(np.array(img_file))) * torch.tensor(colour, dtype=torch.uint8)#.view(3, 1, 1)
                    colour = np.ones_like(np.array(img_file)) * colour
                    colour = Image.fromarray(colour.astype(np.uint8))
                    colour = model.transforms(colour)
                    key = colours[idx]

                if colour == 'ad':
                    masked_mutant = torch.from_numpy(mask1) * ad_model.transforms(img_file)
                    with torch.no_grad():
                        ad_model.eval()
                        masked_output_prob, masked_output_class = torch.topk(ad_model(masked_mutant.to(device).float().unsqueeze(0), explanation_mode = True,
                                                                                  explanation_mask = torch.from_numpy(mask1).unsqueeze(0).to(device).float()).softmax(dim=1), k=1)
                    key = 'ad'

                    if masked_output_class != mask_exp_class:
                        masked_entry = ('N', masked_output_prob.item(), masked_output_class.item(), mask_exp_conf)
                    else:
                        masked_entry = ('Y', masked_output_prob.item(), masked_output_class.item(), mask_exp_conf)

                    robustness_dict_masked[mask_paths[i].split('/')[-2]][mask_file][key] = masked_entry

                else:
                    masked_mutant = torch.from_numpy(mask1) * model.transforms(img_file)
                    masked_mutant = torch.where(masked_mutant == 0, colour, masked_mutant)
                    batch_data_masked.append(masked_mutant)
                    batch_data_masked_colour.append(key)

                    if len(batch_data_masked) == 64 or idx == len(colours) - 2: # -2 because we want to disregard ad
                        with torch.no_grad():
                            model.eval()
                            masked_output_prob, masked_output_class = torch.topk(model(torch.stack(batch_data_masked).to(device)).softmax(dim=1), k=1)
                            # masked_output_prob, masked_output_class = torch.topk(model(masked_mutant.to(device).unsqueeze(0)).softmax(dim=1), k=1)

                        batch_data_masked = []
                        masked_entires = []
                        for z in range(len(masked_output_class)):
                            if masked_output_class[z] != mask_exp_class:
                                masked_entires.append(('N', masked_output_prob[z].item(), masked_output_class[z].item(), mask_exp_conf))
                            else:
                                masked_entires.append(('Y', masked_output_prob[z].item(), masked_output_class[z].item(), mask_exp_conf))

                        for z in range(len(batch_data_masked_colour)):
                            robustness_dict_masked[mask_paths[i].split('/')[-2]][mask_file][batch_data_masked_colour[z]] = masked_entires[z]
                        batch_data_masked_colour = []

        ad_exp_conf = ad_db[ad_db['filename'] == mask_file.replace('.jpg', '').rsplit('_',1)[0]]['explanation_confidence'].values[0]
        ad_exp_class = ad_db[ad_db['filename'] == mask_file.replace('.jpg', '').rsplit('_',1)[0]]['target'].values[0]
        per_pixel_dict[mask_file].append(calculate_pixel_contribution(ad_mask, ad_exp_conf))
        pixel_count_dict[mask_file].append(np.sum(ad_mask))

        #Calculate the robustness metrics for the AD mask
        if ad_path.split('/')[-2] not in robustness_dict_ad:
            robustness_dict_ad[ad_path.split('/')[-2]] = {}

        if ad_path.split('/')[-2] not in robustness_dict_ad_images:
            robustness_dict_ad_images[ad_path.split('/')[-2]] = {}
        
        robustness_dict_ad[ad_path.split('/')[-2]][mask_file] = {}
        robustness_dict_ad_images[ad_path.split('/')[-2]][mask_file] = {}

        #Calculate the output of the model on a different image as background
        for idx, background_file in enumerate(all_background_files):
            # background_img_file = torchvision.io.decode_image(background_file).float()
            background_img_file = Image.open(background_file).convert('RGB')
            if len(np.array(background_img_file).shape) == 2 or np.array(background_img_file).shape[-1] == 1:
                # background_img_file = background_img_file.expand(3, -1, -1)
                background_img_file = Image.fromarray(np.stack([np.array(background_img_file)]*3, axis = -1))
            
            transformed_background_img_file = ad_model.transforms(background_img_file)
            transformed_mutant_img_file = ad_model.transforms(img_file)
            mutant_mask = torch.from_numpy(ad_mask)

            background_img_mutant = torch.where(mutant_mask == 0, transformed_background_img_file, transformed_mutant_img_file)

            with torch.no_grad():
                model.eval()
                masked_output_prob, masked_output_class = torch.topk(model(background_img_mutant.to(device).unsqueeze(0)).softmax(dim=1), k=1)
                
                if masked_output_class != ad_exp_class:
                    masked_entry = ('N', masked_output_prob.item(), masked_output_class.item(), ad_exp_conf)
                else:
                    masked_entry = ('Y', masked_output_prob.item(), masked_output_class.item(), ad_exp_conf)

                robustness_dict_ad_images[ad_path.split('/')[-2]][mask_file][background_file] = masked_entry

        
        batch_data_ad = []
        batch_data_ad_colour = []
        for idx, colour in enumerate(colours):
            if colour == 'min':
                # colour = model.transforms(img_file).reshape(3,-1).min(dim=1,keepdim=True)[0].view(3,1,1)
                colour = torch.min(model.transforms(img_file)).item()
                key = 'min'
            elif colour == 'max':
                # colour = model.transforms(img_file).reshape(3,-1).max(dim=1,keepdim=True)[0].view(3,1,1)
                colour = torch.max(model.transforms(img_file)).item()
                key = 'max'
            elif colour == 'mean':
                # colour = torch.round(torch.mean(model.transforms(img_file), dim=(1,2))).view(3,1,1)
                # colour = torch.ones_like(model.transforms(img_file)) * torch.mean(model.transforms(img_file))
                colour = torch.mean(model.transforms(img_file)).item()
                key = 'mean'
            elif colour == 'zero':
                colour = torch.zeros_like(model.transforms(img_file))
                key = 'zero'
            elif colour != 'ad':
                # colour = torch.ones_like(torch.from_numpy(np.array(img_file))) * torch.tensor(colour, dtype=torch.uint8)#.view(3, 1, 1)
                colour = np.ones_like(np.array(img_file)) * colour
                colour = Image.fromarray(colour.astype(np.uint8))
                colour = model.transforms(colour)
                key = colours[idx]

            if colour == 'ad':
                ad_mutant = torch.from_numpy(ad_mask) * ad_model.transforms(img_file)
                with torch.no_grad():
                    ad_model.eval()
                    ad_output_prob, ad_output_class = torch.topk(ad_model(ad_mutant.to(device).float().unsqueeze(0), explanation_mode = True,
                                                                      explanation_mask = torch.from_numpy(ad_mask).unsqueeze(0).to(device).float()).softmax(dim=1), k=1)
                key = 'ad'

                if ad_output_class != ad_exp_class:
                    ad_entry = ('N', ad_output_prob.item(), ad_output_class.item(), ad_exp_conf)
                else:
                    ad_entry = ('Y', ad_output_prob.item(), ad_output_class.item(), ad_exp_conf)

                robustness_dict_ad[ad_path.split('/')[-2]][mask_file][key] = ad_entry

            else:
                ad_mutant = torch.from_numpy(ad_mask) * model.transforms(img_file)
                ad_mutant = torch.where(ad_mutant == 0, colour, ad_mutant)
                batch_data_ad.append(ad_mutant)
                batch_data_ad_colour.append(key)

                if len(batch_data_ad) == 64 or idx == len(colours) - 2: # -2 because we want to disregard ad
                    with torch.no_grad():
                        model.eval()
                        ad_output_prob, ad_output_class = torch.topk(model(torch.stack(batch_data_ad).to(device)).softmax(dim=1), k=1)
                        # ad_output_prob, ad_output_class = torch.topk(model(ad_mutant.to(device).unsqueeze(0)).softmax(dim=1), k=1)

                    batch_data_ad = []
                    ad_entires = []

                    for z in range(len(ad_output_class)):
                        if ad_output_class[z] != ad_exp_class:
                            ad_entires.append(('N', ad_output_prob[z].item(), ad_output_class[z].item(), ad_exp_conf))
                        else:
                            ad_entires.append(('Y', ad_output_prob[z].item(), ad_output_class[z].item(), ad_exp_conf))

                    for z in range(len(batch_data_ad_colour)):
                        robustness_dict_ad[ad_path.split('/')[-2]][mask_file][batch_data_ad_colour[z]] = ad_entires[z]
                    batch_data_ad_colour = []

    return iou_dict, dice_dict, per_pixel_dict, pixel_count_dict, robustness_dict_masked, robustness_dict_ad, robustness_dict_masked_images, robustness_dict_ad_images

def calculate_landscape_metrics(masking_data, ad_data):
    """
    Calculate landscape metrics for the explanation data.
    
    Args:
        masking_data: Dictionary containing the masking data
        ad_data: Dictionary containing the AD data
    """
    #Initialise the dictionaries
    landscape_similarity_dict_jsd = {}
    landscape_similarity_dict_wasserstein = {}

    mask_dirs = list(masking_data.keys())
    mask_files = list(masking_data[mask_dirs[0]].values())
    ad_dir = list(ad_data.keys())
    assert len(ad_dir) == 1, "Only one AD directory is supported"

    mask_dbs = {mask_dir:db_to_pandas(db=os.path.join(mask_dir,mask_dir+'.db')) for mask_dir in mask_dirs}
    ad_db = db_to_pandas(db=os.path.join(ad_dir[0],ad_dir[0]+'.db'))

    #Adding the filename colum for easy indexing
    for mask_db in mask_dbs:
        mask_db['filename'] = mask_db['path'].apply(lambda x: x.split('/')[-1].split('.')[0])
    ad_db['filename'] = ad_db['path'].apply(lambda x: x.split('/')[-1].split('.')[0])

    for mask_file in mask_files:
        ad_landscape = ad_db[ad_db['filename'] == mask_file.replace('.jpg', '').rsplit('_',1)[0]]['responsibility'].values[0]
        normalised_ad_landscape = normalize_landscape(ad_landscape)
        avg_landscape = np.zeros_like(ad_landscape)

        for mask_dir in mask_dirs:
            mask_landscape = mask_dbs[mask_dir][mask_dbs[mask_dir]['filename'] == mask_file.replace('.jpg', '').rsplit('_',1)[0]]['responsibility'].values[0]
            normalised_mask_landscape = normalize_landscape(mask_landscape)
            jsd_score = jensenshannon(normalised_ad_landscape, normalised_mask_landscape)
            wasserstein_score = wasserstein_distance_nd(normalised_ad_landscape, normalised_mask_landscape)
            landscape_similarity_dict_jsd[mask_file] = jsd_score
            landscape_similarity_dict_wasserstein[mask_file] = wasserstein_score
            avg_landscape += mask_landscape
        
        avg_landscape /= len(mask_dirs)
        normalised_avg_landscape = normalize_landscape(avg_landscape)
        
        for mask_dir in mask_dirs:
            mask_landscape = mask_dbs[mask_dir][mask_dbs[mask_dir]['filename'] == mask_file.replace('.jpg', '').rsplit('_',1)[0]]['responsibility'].values[0]
            normalised_mask_landscape = normalize_landscape(mask_landscape)
            jsd_score = jensenshannon(normalised_avg_landscape, normalised_mask_landscape)
            wasserstein_score = wasserstein_distance_nd(normalised_avg_landscape, normalised_mask_landscape)
            landscape_similarity_dict_jsd[mask_file] = jsd_score
            landscape_similarity_dict_wasserstein[mask_file] = wasserstein_score

        jsd_score = jensenshannon(normalised_ad_landscape, normalised_avg_landscape)
        wasserstein_score = wasserstein_distance_nd(normalised_ad_landscape, normalised_avg_landscape)
        landscape_similarity_dict_jsd[mask_file] = jsd_score
        landscape_similarity_dict_wasserstein[mask_file] = wasserstein_score

        return landscape_similarity_dict_jsd, landscape_similarity_dict_wasserstein
    

def calculate_avg_exp_size(mask_size = (224, 224),
                           results_dir = None):
    """
    Caluculate the average explanation size using the pixel count of the mask.

    Args:
        mask_size: The size of the mask
        dataframe: The dataframe containing the mask data
    
    Returns:
        Dataframe containing the average explanation size for each mask
    
    """
    seeds = [42, 43, 44, 45]
    thresholds = [0.9, 0.7, 0.5, 0.3, 0.1, 0]


    for threshold in thresholds:
        base_dir_threshold = os.path.join(results_dir, f"Threshold_{threshold}")
        for seed in seeds:
            avg_exp_size_dict = {}
            pixel_count_df = pd.read_csv(os.path.join(base_dir_threshold, f"pixel_count_results_seed_{seed}.csv"))
            avg_exp_size_dict['Max'] = (pixel_count_df['Max'] / (mask_size[0] * mask_size[1])).mean()
            avg_exp_size_dict['Zero'] = (pixel_count_df['Zero'] / (mask_size[0] * mask_size[1])).mean()
            avg_exp_size_dict['Mean'] = (pixel_count_df['Mean'] / (mask_size[0] * mask_size[1])).mean()
            avg_exp_size_dict['Min'] = (pixel_count_df['Min'] / (mask_size[0] * mask_size[1])).mean()
            avg_exp_size_dict['AD'] = (pixel_count_df['AD'] / (mask_size[0] * mask_size[1])).mean()
            
            avg_exp_size_df = pd.DataFrame.from_dict(avg_exp_size_dict, orient='index')
            avg_exp_size_df.to_csv(os.path.join(base_dir_threshold, f"avg_exp_size_seed_{seed}.csv"))

def robustness_metrics(threshold_dir,
                        mask_vals = ['min', 'mean', 0, 'max'],
                        seeds = [42, 43, 44, 45],
                        num_backgrounds = 100):
    
    correct_samples = {}
    incorrect_samples = {}

    beta_goodness_correct = {}
    beta_goodness_incorrect = {}

    for seed in seeds:
        robustness_masked = pd.read_csv(os.path.join(threshold_dir, f"robustness_masked_seed_{seed}.csv"))
        robustness_ad = pd.read_csv(os.path.join(threshold_dir, f"robustness_ad_seed_{seed}.csv"))
        mask_val = 'AD'

        #Need to load the explanation log
        for exp_dir in os.listdir(threshold_dir):
            if str(mask_val) in exp_dir and str(seed) in exp_dir:
                for file in os.listdir(os.path.join(threshold_dir, exp_dir)):
                    if file.endswith('.csv') and '-old' not in file:
                        explanation_log_ad = pd.read_csv(os.path.join(threshold_dir, exp_dir, file))
        
        #Get the total number of correct predictions
        try:
            total_correct_predictions = (explanation_log_ad[' target'] == explanation_log_ad[' classification']).sum()
            total_incorrect_predictions = (explanation_log_ad[' target'] != explanation_log_ad[' classification']).sum()
        except:
            total_correct_predictions = (explanation_log_ad[' actual classification'] == explanation_log_ad[' predicted classification']).sum()
            total_incorrect_predictions = (explanation_log_ad[' actual classification'] != explanation_log_ad[' predicted classification']).sum()

        assert total_correct_predictions + total_incorrect_predictions == len(explanation_log_ad), "Total correct and incorrect predictions do not sum to the total number of predictions"


        correct_samples[seed] = {}
        incorrect_samples[seed] = {}
        beta_goodness_correct[seed] = {}
        beta_goodness_incorrect[seed] = {}

        correct_samples[seed][mask_val] = []
        incorrect_samples[seed][mask_val] = []
        
        unique_files = robustness_ad['File'].unique()

        for file in unique_files:
            # Extract class ID from filename using the pattern '_class{classid}_'
            match = re.search(r'_class(\d+)_', file)
            if match:
                true_class_id = int(match.group(1))
                search_string = 'ad'

                # Get the predicted class ID for this file with current mask_val
                file_mask_row = robustness_ad[
                    (robustness_ad['File'] == file) & 
                    (robustness_ad['Masking Value'] == search_string)
                ]

                if not file_mask_row.empty:
                    if 'CalTech' in threshold_dir:
                        predicted_class_id = file_mask_row['ClassID'].iloc[0]
                        true_class_id -=1
                    else:
                        predicted_class_id = file_mask_row['ClassID'].iloc[0]

                    if predicted_class_id == true_class_id:
                        correct_samples[seed][mask_val].append(file)
                    else:
                        incorrect_samples[seed][mask_val].append(file)
            
        running_count = []
        for current_file in correct_samples[seed][mask_val]:
            file_entries = robustness_ad[(robustness_ad['File'] == current_file)]

            assert len(file_entries) == num_backgrounds, f"Expected {num_backgrounds} entries for {current_file}, but found {len(file_entries)}"

            y_count = len(file_entries[file_entries['Same as explanation?'] == 'Y'])
            running_count.append(y_count)

        # if len(running_count) != total_correct_predictions:
        #     for i in range(len(running_count), total_correct_predictions):
        #         running_count.append(1) #We assume that the correct predictions are all 1s

        if running_count:
            mean_count = np.mean(running_count)
            var_count = np.var(running_count)
            std_count = np.sqrt(var_count)
        else:
            mean_count = 0
            var_count = 0
            std_count = 0

        count_dict = {}
        total_count = 0
        for i in range(1, num_backgrounds + 1):
            count_dict[f'{i}s'] = running_count.count(i)
            total_count += running_count.count(i)

        beta_goodness_correct[seed][mask_val] = {
            'mean': mean_count,
            'var': var_count,
            'std': std_count,
            'total': total_count,
            **count_dict
        }

        running_count = [] 
        for current_file in incorrect_samples[seed][mask_val]:
            file_entries = robustness_ad[
                (robustness_ad['File'] == current_file) & 
                (robustness_ad['Model'].str.contains(f'{mask_val}'))
            ]

            assert len(file_entries) == num_backgrounds, f"Expected {num_backgrounds} entries for {current_file}, but found {len(file_entries)}"

            y_count = len(file_entries[file_entries['Same as explanation?'] == 'Y'])
            running_count.append(y_count)

        if running_count:
            mean_count = np.mean(running_count)
            var_count = np.var(running_count)
            std_count = np.sqrt(var_count)
        else:
            mean_count = 0
            var_count = 0
            std_count = 0

        count_dict = {}
        total_count = 0

        # if len(running_count) != total_incorrect_predictions:
        #     for i in range(len(running_count), total_incorrect_predictions):
        #         running_count.append(1) #We assume that the incorrect predictions are all 1s


        for i in range(1, num_backgrounds + 1):
            count_dict[f'{i}s'] = running_count.count(i)
            total_count += running_count.count(i)

        beta_goodness_incorrect[seed][mask_val] = {
            'mean': mean_count,
            'var': var_count,
            'std': std_count,
            'total': total_count,
            **count_dict
        }

        #Do the same for the masked samples
        for mask_val in mask_vals:
            # Filter robustness_masked for current mask_val
            masked_filtered = robustness_masked[robustness_masked['Model'].str.contains(f'val_{mask_val}')]
            unique_files = masked_filtered['File'].unique()
            correct_samples[seed][mask_val] = []
            incorrect_samples[seed][mask_val] = []
            
            #Need to load the explanation log
            for exp_dir in os.listdir(threshold_dir):
                if str(mask_val) in exp_dir and str(seed) in exp_dir:
                    for file in os.listdir(os.path.join(threshold_dir, exp_dir)):
                        if file.endswith('.csv') and '-old' not in file:
                            explanation_log_masked = pd.read_csv(os.path.join(threshold_dir, exp_dir, file))
            
            #Get the total number of correct predictions
            try:
                total_correct_predictions = (explanation_log_masked[' target'] == explanation_log_masked[' classification']).sum()
                total_incorrect_predictions = (explanation_log_masked[' target'] != explanation_log_masked[' classification']).sum()
            except:
                total_correct_predictions = (explanation_log_masked[' actual classification'] == explanation_log_masked[' predicted classification']).sum()
                total_incorrect_predictions = (explanation_log_masked[' actual classification'] != explanation_log_masked[' predicted classification']).sum()

            assert total_correct_predictions + total_incorrect_predictions == len(explanation_log_masked), "Total correct and incorrect predictions do not sum to the total number of predictions"



            for file in unique_files:
                # Extract class ID from filename using the pattern '_class{classid}_'
                match = re.search(r'_class(\d+)_', file)
                if match:
                    true_class_id = int(match.group(1))

                    if mask_val == 0:
                        search_string = 'zero'
                    else:
                        search_string = mask_val


                    # Get the predicted class ID for this file with current mask_val
                    file_mask_row = robustness_masked[
                        (robustness_masked['File'] == file) & 
                        (robustness_masked['Masking Value'] == search_string)
                        (robustness_masked['Model'].str.contains(f'val_{mask_val}'))
                    ]
                    
                    if not file_mask_row.empty:
                        if 'CalTech' in threshold_dir:
                            predicted_class_id = file_mask_row['ClassID'].iloc[0]
                            true_class_id -=1
                        else:
                            predicted_class_id = file_mask_row['ClassID'].iloc[0]
                        
                        # Check if prediction matches true class
                        if predicted_class_id == true_class_id:
                            correct_samples[seed][mask_val].append(file)
                        else:
                            incorrect_samples[seed][mask_val].append(file)

            running_count = []
            for current_file in correct_samples[seed][mask_val]:
                # Get entries for current_file from robustness_masked with matching mask_val
                file_entries = robustness_masked[
                    (robustness_masked['File'] == current_file) & 
                    (robustness_masked['Model'].str.contains(f'_val_{mask_val}'))
                ]
                
                # Assert that there are exactly num_backgrounds entries for this file
                assert len(file_entries) == num_backgrounds, f"Expected {num_backgrounds} entries for {current_file}, but found {len(file_entries)}"
                
                # Count how many entries have 'Y' in the 'Same as explanation?' column
                y_count = len(file_entries[file_entries['Same as explanation?'] == 'Y'])
                running_count.append(y_count)

            # if len(running_count) != total_correct_predictions:
            #     for i in range(len(running_count), total_correct_predictions):
            #         running_count.append(1) #We assume that the correct predictions are all 1s
            
            # Calculate mean and variance of running_count
            if running_count:
                mean_count = np.mean(running_count)
                var_count = np.var(running_count)
                std_count = np.sqrt(var_count)
            else:
                mean_count = 0
                var_count = 0
                std_count = 0
            # Count occurrences of each value from 1 to num_backgrounds
            count_dict = {}
            total_count = 0
            for i in range(1, num_backgrounds + 1):
                count_dict[f'{i}s'] = running_count.count(i)
                total_count += running_count.count(i)
            
            # Store results in beta_goodness_correct dictionary
            beta_goodness_correct[seed][mask_val] = {
                'mean': mean_count,
                'var': var_count,
                'std': std_count,
                'total': total_count,
                **count_dict
            }
            
            # Reuse running_count for incorrect samples
            running_count = []
            for current_file in incorrect_samples[seed][mask_val]:
                # Get entries for current_file from robustness_masked
                file_entries = robustness_masked[
                    (robustness_masked['File'] == current_file) & 
                    (robustness_masked['Model'].str.contains(f'_val_{mask_val}'))
                ]
                
                # Assert that there are exactly num_backgrounds entries for this file
                assert len(file_entries) == num_backgrounds, f"Expected {num_backgrounds} entries for {current_file}, but found {len(file_entries)}"
                
                # Count how many entries have 'Y' in the 'Same as explanation?' column
                y_count = len(file_entries[file_entries['Same as explanation?'] == 'Y'])
                running_count.append(y_count)
            
            # if len(running_count) != total_incorrect_predictions:
            #     for i in range(len(running_count), total_incorrect_predictions):
            #         running_count.append(1) #We assume that the incorrect predictions are all 1s
            
            # Calculate mean and variance of running_count for incorrect samples
            if running_count:
                mean_count = np.mean(running_count)
                var_count = np.var(running_count)
                std_count = np.sqrt(var_count)
            else:
                mean_count = 0
                var_count = 0
                std_count = 0
            # Count occurrences of each value from 1 to num_backgrounds for incorrect samples
            count_dict = {}
            total_count = 0
            for i in range(1, num_backgrounds + 1):
                count_dict[f'{i}s'] = running_count.count(i)
                total_count += running_count.count(i)
            
            # Store results in beta_goodness_incorrect dictionary
            beta_goodness_incorrect[seed][mask_val] = {
                'mean': mean_count,
                'var': var_count,
                'std': std_count,
                'total': total_count,
                **count_dict
            }

    return beta_goodness_correct, beta_goodness_incorrect


def calculate_all_metrics(thresholds = [0.9, 0.7, 0.5, 0.3, 0.1, 0],
                          seeds = [42, 43, 44, 45],
                          num_backgrounds = 100,
                          base_dir = "../ImageNet-onek/Results",
                          columns = ['Max vs Zero', 'Max vs Mean', 'Max vs Min', ' Max vs AD', 'Zero vs Mean', 'Zero vs Min', 'Zero vs AD', 'Mean vs Min', 'Mean vs AD', 'Min vs AD'],
                          per_pixel_columns = ['Max', 'Zero', 'Mean', 'Min', 'AD'],
                          dataset_path = "..ImageNet-onek/IN-onek_data/Test",
                          model = None,
                          ad_model = None,
                          device = None):


    with tqdm(thresholds, desc="Thresholds:") as pbar:
        for threshold in pbar:
            pbar.set_description(f"Threshold: {threshold}")
            base_dir_threshold = os.path.join(base_dir, f"Threshold_{threshold}")
            with tqdm(seeds, desc="Seeds:") as pbar_seed:
                for seed in pbar_seed:
                    pbar_seed.set_description(f"Seed: {seed}")
                    masking_data, ad_data = organize_experiment_data(base_dir_threshold, seed = seed)
                    iou_dict, dice_dict, per_pixel_dict, pixel_count_dict, robustness_dict_masked, robustness_dict_ad, robustness_dict_masked_images, robustness_dict_ad_images = calculate_mask_metrics(masking_data=masking_data,
                                                                                                                                        ad_data=ad_data,
                                                                                                                                        dataset_path=dataset_path,
                                                                                                                                        num_backgrounds = num_backgrounds,
                                                                                                                                        base_dir=base_dir_threshold,
                                                                                                                                        model = model,
                                                                                                                                        ad_model = ad_model,
                                                                                                                                        device = device)
                    #Store the mask metrics to disk
                    df_iou = pd.DataFrame.from_dict(iou_dict, orient='index')
                    df_dice = pd.DataFrame.from_dict(dice_dict, orient='index')
                    df_per_pixel = pd.DataFrame.from_dict(per_pixel_dict, orient='index')
                    df_pixel_count = pd.DataFrame.from_dict(pixel_count_dict, orient='index')

                    df_iou.columns = columns
                    df_dice.columns = columns
                    df_per_pixel.columns = per_pixel_columns
                    df_pixel_count.columns = per_pixel_columns

                    df_iou = df_iou.reset_index()
                    df_dice = df_dice.reset_index()
                    df_per_pixel = df_per_pixel.reset_index()
                    df_pixel_count = df_pixel_count.reset_index()

                    df_iou = df_iou.rename(columns={'index': 'filename'})
                    df_dice = df_dice.rename(columns={'index': 'filename'})
                    df_per_pixel = df_per_pixel.rename(columns={'index': 'filename'})
                    df_pixel_count = df_pixel_count.rename(columns={'index': 'filename'})

                    masked_records = []
                    ad_records = []
                    masked_records_images = []
                    ad_records_images = []

                    for model_name, files_data in robustness_dict_masked_images.items():
                        for file_name, methods_data in files_data.items():
                            for background_image, values in methods_data.items():
                                # Create a dictionary for the current row
                                record = {
                                    'Model': model_name,
                                    'File': file_name,
                                    'Background Image': background_image,
                                    'Same as explanation?': values[0],
                                    'Confidence': values[1],
                                    'ClassID': values[2],
                                    'Explanation Confidence': values[3]
                                }
                                masked_records_images.append(record)


                    for model_name, files_data in robustness_dict_ad_images.items():
                        for file_name, methods_data in files_data.items():
                            for background_image, values in methods_data.items():
                                # Create a dictionary for the current row
                                record = {
                                    'Model': model_name,
                                    'File': file_name,
                                    'Background Image': background_image,
                                    'Same as explanation?': values[0],
                                    'Confidence': values[1],
                                    'ClassID': values[2],
                                    'Explanation Confidence': values[3]
                                }
                                ad_records_images.append(record)

                    for model_name, files_data in robustness_dict_masked.items():
                        for file_name, methods_data in files_data.items():
                            for method, values in methods_data.items():
                                # Create a dictionary for the current row
                                record = {
                                    'Model': model_name,
                                    'File': file_name,
                                    'Masking Value': method,
                                    'Same as explanation?': values[0],
                                    'Confidence': values[1],
                                    'ClassID': values[2],
                                    'Explanation Confidence': values[3]
                                }
                                masked_records.append(record)

                    for model_name, files_data in robustness_dict_ad.items():
                        for file_name, methods_data in files_data.items():
                            for method, values in methods_data.items():
                                # Create a dictionary for the current row
                                record = {
                                    'Model': model_name,
                                    'File': file_name,
                                    'Masking Value': method,
                                    'Same as explanation?': values[0],
                                    'Confidence': values[1],
                                    'ClassID': values[2],
                                    'Explanation Confidence': values[3]
                                }
                                ad_records.append(record)

                        df_masked = pd.DataFrame(masked_records)
                        df_masked_images = pd.DataFrame(masked_records_images)
                        df_ad = pd.DataFrame(ad_records)
                        df_ad_images = pd.DataFrame(ad_records_images)
                        
                        df_masked_images.to_csv(os.path.join(base_dir_threshold, f"robustness_masked_images_seed_{seed}.csv"), index=False)
                        df_ad_images.to_csv(os.path.join(base_dir_threshold, f"robustness_ad_images_seed_{seed}.csv"), index=False)
                        df_masked.to_csv(os.path.join(base_dir_threshold, f"robustness_masked_seed_{seed}.csv"), index=False)
                        df_ad.to_csv(os.path.join(base_dir_threshold, f"robustness_ad_seed_{seed}.csv"), index=False)
                    
                        df_iou.to_csv(os.path.join(base_dir_threshold, f"iou_results_seed_{seed}.csv"), index=False)
                        df_dice.to_csv(os.path.join(base_dir_threshold, f"dice_results_seed_{seed}.csv"), index=False)
                        df_per_pixel.to_csv(os.path.join(base_dir_threshold, f"per_pixel_results_seed_{seed}.csv"), index=False)
                        df_pixel_count.to_csv(os.path.join(base_dir_threshold, f"pixel_count_results_seed_{seed}.csv"), index=False)


def calculate_expected_mask(explanation_data_dir,
                            seeds = [42,43,44,45]):
    '''
    In this function, we are going to evaluate the DICE + IoU of the AD mask against
    the Union of the masks from all the other masking methods.
    '''
    average_dice = {}
    average_iou = {}
    for seed in seeds:
        average_dice[f"Seed_{seed}"] = []
        average_iou[f"Seed_{seed}"] = []

        ad_dir = [dir for dir in os.listdir(explanation_data_dir) if f'seed_{seed}' in dir and 'AD' in dir and os.path.isdir(os.path.join(explanation_data_dir, dir))]
        assert len(ad_dir) == 1, f"Expected 1 AD directory for seed {seed}, but found {len(ad_dir)}"
        ad_dir = ad_dir[0]

        masking_dirs = [dir for dir in os.listdir(explanation_data_dir) if f'seed_{seed}' in dir and 'AD' not in dir and os.path.isdir(os.path.join(explanation_data_dir, dir))]
        assert len(masking_dirs) == 4, f"Expected 4 masking directories for seed {seed}, but found {len(masking_dirs)}"

        ad_explanation_masks = [file for file in os.listdir(os.path.join(explanation_data_dir, ad_dir)) if file.endswith('.npy')]
        assert len(ad_explanation_masks) == 1000, f"Expected 1000 AD explanation masks for seed {seed}, but found {len(ad_explanation_masks)}"

        masking_explanation_masks = {}
        for masking_dir in masking_dirs:
            key = masking_dir.split('_seed_')[0].rsplit('_',1)[-1]
            masking_explanation_masks[key] = [file for file in os.listdir(os.path.join(explanation_data_dir, masking_dir)) if file.endswith('.npy')]
        
        #Now we need to calculate the common set of masks for each masking method
        common_masks = set(ad_explanation_masks)
        for k,v in masking_explanation_masks.items():
            common_masks = common_masks.intersection(set(v))

        #Now we need to calculate the DICE + IoU of the AD mask against the Union of the masks from all the other masking methods
        for mask in tqdm(common_masks, desc=f"Calculating DICE + IoU for seed {seed}"):
            ad_mask = np.load(os.path.join(explanation_data_dir, ad_dir, mask))
            union_mask = np.zeros_like(ad_mask)
            for masking_dir in masking_dirs:
                masking_mask = np.load(os.path.join(explanation_data_dir, masking_dir, mask))
                union_mask = np.logical_or(union_mask, masking_mask)
            
            dice = calculate_dice(ad_mask, union_mask)
            iou = calculate_iou(ad_mask, union_mask)

            average_dice[f"Seed_{seed}"].append(dice)
            average_iou[f"Seed_{seed}"].append(iou)
            
        average_dice[f"Seed_{seed}"] = np.mean(average_dice[f"Seed_{seed}"])
        average_iou[f"Seed_{seed}"] = np.mean(average_iou[f"Seed_{seed}"])

    #Save the average dice and iou to disk in csv format
    print(average_dice)
    print(average_iou)
    
    df_dice = pd.DataFrame(average_dice, index = [0])
    df_iou = pd.DataFrame(average_iou, index = [0])
    df_dice.to_csv(os.path.join(explanation_data_dir, 'average_dice_expected_mask.csv'), index=False)
    df_iou.to_csv(os.path.join(explanation_data_dir, 'average_iou_expected_mask.csv'), index=False)



def calculate_confidence_difference(threshold_dir,
                                    seeds = [42],
                                    num_samples = 150,
                                    num_backgrounds = 100):
    '''
    The goal is to calculate the confidence dropoff when performing robustness experiments, both
    for coloured backgrounds and for background images.
    '''
    for seed in seeds:
        for test_type in ['colored_background', 'background_images']:
            method_types = {}
            method_scores_common_stats = {}
            method_scores_correct_stats = {}
            method_scores_incorrect_stats = {}
            if test_type == 'colored_background':
                #Load each of the robustness files
                robustness_masked = pd.read_csv(os.path.join(threshold_dir, f"robustness_masked_seed_{seed}.csv"))
                robustness_ad = pd.read_csv(os.path.join(threshold_dir, f"robustness_ad_seed_{seed}.csv"))
            elif test_type == 'background_images':
                robustness_masked = pd.read_csv(os.path.join(threshold_dir, f"robustness_masked_images_seed_{seed}.csv"))
                robustness_ad = pd.read_csv(os.path.join(threshold_dir, f"robustness_ad_images_seed_{seed}.csv"))
            
            #For masked samples first
            total_number_of_samples = len(robustness_masked)
            total_number_of_samples_per_method = num_samples * num_backgrounds
            total_number_of_methods = total_number_of_samples // total_number_of_samples_per_method #These many methods should be present in the robustness files

            #Make sure that AD also fits the bill
            assert len(robustness_ad) == total_number_of_samples_per_method, f"Expected {total_number_of_samples_per_method} files for AD, but found {len(robustness_ad)}"
            assert robustness_ad['Model'].nunique() == 1, f"Expected 1 method for AD, but found {robustness_ad['Model'].nunique()}"
            method_names = robustness_ad['Model'].unique()[0].split("_seed")[0].split("_")[1]
            method_types[method_names] = robustness_ad

            for i in range(total_number_of_methods):
                method_files = robustness_masked[i*total_number_of_samples_per_method:(i+1)*total_number_of_samples_per_method]
                #Make sure that the method is the same for all the files
                assert method_files['Model'].nunique() == 1, f"Expected 1 method for {i}th set of files, but found {method_files['Model'].nunique()}"
                #Make sure the number of files is correct
                assert len(method_files) == total_number_of_samples_per_method, f"Expected {total_number_of_samples_per_method} files for {i}th set of files, but found {len(method_files)}"

                method_name = method_files['Model'].unique()[0].split("val_")[1].split("_seed")[0]
                method_types[method_name] = method_files
            
            #Now we need to ensure a few things
            #1. The sequence of files is the same for all the methods
            #2. The sequence of the masking value is the same for all the methods

            file_sequence = True
            masking_value_sequence = True
            method_keys = list(method_types.keys())
            for idx in range(len(method_keys)):
                if idx + 1 < len(method_keys):
                    file_sequence = file_sequence and method_types[method_keys[idx]]['File'].tolist() == method_types[method_keys[idx+1]]['File'].tolist()
                    if test_type == 'colored_background':
                        masking_value_sequence = masking_value_sequence and method_types[method_keys[idx]]['Masking Value'].tolist() == method_types[method_keys[idx+1]]['Masking Value'].tolist()
                    else:
                        masking_value_sequence = masking_value_sequence and method_types[method_keys[idx]]['Background Image'].tolist() == method_types[method_keys[idx+1]]['Background Image'].tolist()
            
                #Now we need to get the commonly correct samples
                if idx == 0:
                    common_correct_samples = method_types[method_keys[idx]]['Same as explanation?'].map(lambda x: 1 if x == 'Y' else 0).to_numpy() * method_types[method_keys[idx+1]]['Same as explanation?'].map(lambda x: 1 if x == 'Y' else 0).to_numpy()
                else:
                    common_correct_samples *= method_types[method_keys[idx]]['Same as explanation?'].map(lambda x: 1 if x == 'Y' else 0).to_numpy()
                
                print(f"Common correct samples {method_keys[:idx+1]}:", sum(common_correct_samples))

            #This makes sure that the commonly correct samples are the same for all the methods
            assert file_sequence, "The sequence of files is not the same for all the methods"
            if test_type == 'colored_background':
                assert masking_value_sequence, "The sequence of the masking value is not the same for all the methods"
            else:
                assert masking_value_sequence, "The sequence of the background image is not the same for all the methods"

            #Now we need to calculate the confidence dropoff, for common, correct and incorrect samples
            for method_key in method_keys:
                #We do the common samples first
                common_samples = method_types[method_key][common_correct_samples == 1]
                print("Length of common samples: ", len(common_samples))
                print("Length of common correct samples: ", sum(common_correct_samples))
                #Make sure all are correctly predicted
                assert common_samples['Same as explanation?'].unique()[0] == 'Y', f"All common samples should be correctly predicted, not the case for method {method_key}, type {test_type}"
                explanation_confidence = common_samples['Explanation Confidence'].to_numpy()
                background_confidence = common_samples['Confidence'].to_numpy()
                confidence_difference = explanation_confidence - background_confidence
                percentange_difference = (confidence_difference / np.maximum(explanation_confidence, background_confidence)) * 100

                method_scores_common = {
                    'Explanation Confidence': explanation_confidence,
                    'Background Confidence': background_confidence,
                    'Confidence Difference': confidence_difference,
                    'Percentage Difference': percentange_difference
                }

                #Now we need to calculate it for the correct samples
                correct_samples = method_types[method_key][method_types[method_key]['Same as explanation?'] == 'Y']
                if len(correct_samples) == 0:
                    continue
                assert correct_samples['Same as explanation?'].unique()[0] == 'Y', "All correct samples should be correctly predicted"
                explanation_confidence_correct = correct_samples['Explanation Confidence'].to_numpy()
                background_confidence_correct = correct_samples['Confidence'].to_numpy()
                confidence_difference_correct = explanation_confidence_correct - background_confidence_correct
                percentange_difference_correct = (confidence_difference_correct / np.maximum(explanation_confidence_correct, background_confidence_correct)) * 100

                method_scores_correct = {
                    'Explanation Confidence': explanation_confidence_correct,
                    'Background Confidence': background_confidence_correct,
                    'Confidence Difference': confidence_difference_correct,
                    'Percentage Difference': percentange_difference_correct
                }
                
                #Finally, for the incorrect samples
                incorrect_samples = method_types[method_key][method_types[method_key]['Same as explanation?'] == 'N']
                if len(incorrect_samples) == 0:
                    continue
                assert incorrect_samples['Same as explanation?'].unique()[0] == 'N', "All incorrect samples should be incorrectly predicted"
                explanation_confidence_incorrect = incorrect_samples['Explanation Confidence'].to_numpy()
                background_confidence_incorrect = incorrect_samples['Confidence'].to_numpy()
                confidence_difference_incorrect = explanation_confidence_incorrect - background_confidence_incorrect
                percentange_difference_incorrect = (confidence_difference_incorrect / np.maximum(explanation_confidence_incorrect, background_confidence_incorrect)) * 100
                
                method_scores_incorrect = {
                    'Explanation Confidence': explanation_confidence_incorrect,
                    'Background Confidence': background_confidence_incorrect,
                    'Confidence Difference': confidence_difference_incorrect,
                    'Percentage Difference': percentange_difference_incorrect
                }

                df_common = pd.DataFrame(method_scores_common)
                df_correct = pd.DataFrame(method_scores_correct)
                df_incorrect = pd.DataFrame(method_scores_incorrect)

                
                if test_type == 'colored_background':
                    df_common.to_csv(os.path.join(threshold_dir, f"confidence_difference_colored_background_common_{method_key}_seed_{seed}.csv"), index=False)
                    df_correct.to_csv(os.path.join(threshold_dir, f"confidence_difference_colored_background_correct_{method_key}_seed_{seed}.csv"), index=False)
                    df_incorrect.to_csv(os.path.join(threshold_dir, f"confidence_difference_colored_background_incorrect_{method_key}_seed_{seed}.csv"), index=False)
                else:
                    df_common.to_csv(os.path.join(threshold_dir, f"confidence_difference_background_images_common_{method_key}_seed_{seed}.csv"), index=False)
                    df_correct.to_csv(os.path.join(threshold_dir, f"confidence_difference_background_images_correct_{method_key}_seed_{seed}.csv"), index=False)
                    df_incorrect.to_csv(os.path.join(threshold_dir, f"confidence_difference_background_images_incorrect_{method_key}_seed_{seed}.csv"), index=False)

                method_scores_common_stats[method_key] = {
                    'Average Explanation Confidence': np.mean(df_common['Explanation Confidence']),
                    'Average Background Confidence': np.mean(df_common['Background Confidence']),
                    'Average Difference': np.mean(df_common['Confidence Difference']),
                    'Average Percentage Difference': np.mean(df_common['Percentage Difference']),
                    'Average Percentage Difference(absolute)': np.mean(np.abs(df_common['Percentage Difference'])),
                    'Std Explanation Confidence': np.std(df_common['Explanation Confidence']),
                    'Std Background Confidence': np.std(df_common['Background Confidence']),
                    'Std Difference': np.std(df_common['Confidence Difference']),
                    'Std Percentage Difference': np.std(df_common['Percentage Difference']),
                    'Std Percentage Difference(absolute)': np.std(np.abs(df_common['Percentage Difference'])),
                }
                method_scores_correct_stats[method_key] = {
                    'Average Explanation Confidence': np.mean(df_correct['Explanation Confidence']),
                    'Average Background Confidence': np.mean(df_correct['Background Confidence']),
                    'Average Difference': np.mean(df_correct['Confidence Difference']),
                    'Average Percentage Difference': np.mean(df_correct['Percentage Difference']),
                    'Average Percentage Difference(absolute)': np.mean(np.abs(df_correct['Percentage Difference'])),
                    'Std Explanation Confidence': np.std(df_correct['Explanation Confidence']),
                    'Std Background Confidence': np.std(df_correct['Background Confidence']),
                    'Std Difference': np.std(df_correct['Confidence Difference']),
                    'Std Percentage Difference': np.std(df_correct['Percentage Difference']),
                    'Std Percentage Difference(absolute)': np.std(np.abs(df_correct['Percentage Difference'])),
                }
                method_scores_incorrect_stats[method_key] = {
                    'Average Explanation Confidence': np.mean(df_incorrect['Explanation Confidence']),
                    'Average Background Confidence': np.mean(df_incorrect['Background Confidence']),
                    'Average Difference': np.mean(df_incorrect['Confidence Difference']),
                    'Average Percentage Difference': np.mean(df_incorrect['Percentage Difference']),
                    'Average Percentage Difference(absolute)': np.mean(np.abs(df_incorrect['Percentage Difference'])),
                    'Std Explanation Confidence': np.std(df_incorrect['Explanation Confidence']),
                    'Std Background Confidence': np.std(df_incorrect['Background Confidence']),
                    'Std Difference': np.std(df_incorrect['Confidence Difference']),
                    'Std Percentage Difference': np.std(df_incorrect['Percentage Difference']),
                    'Std Percentage Difference(absolute)': np.std(np.abs(df_incorrect['Percentage Difference'])),
                }
            
            df_common_stats = pd.DataFrame.from_dict(method_scores_common_stats, orient='index')
            df_correct_stats = pd.DataFrame.from_dict(method_scores_correct_stats, orient='index')
            df_incorrect_stats = pd.DataFrame.from_dict(method_scores_incorrect_stats, orient='index')

            df_common_stats.index.name = 'Method'
            df_correct_stats.index.name = 'Method'
            df_incorrect_stats.index.name = 'Method'

            if test_type == 'colored_background':
                df_common_stats.to_csv(os.path.join(threshold_dir, f"confidence_difference_colored_background_common_stats_seed_{seed}.csv"), index=True)
                df_correct_stats.to_csv(os.path.join(threshold_dir, f"confidence_difference_colored_background_correct_stats_seed_{seed}.csv"), index=True)
                df_incorrect_stats.to_csv(os.path.join(threshold_dir, f"confidence_difference_colored_background_incorrect_stats_seed_{seed}.csv"), index=True)
            else:
                df_common_stats.to_csv(os.path.join(threshold_dir, f"confidence_difference_background_images_common_stats_seed_{seed}.csv"), index=True)
                df_correct_stats.to_csv(os.path.join(threshold_dir, f"confidence_difference_background_images_correct_stats_seed_{seed}.csv"), index=True)
                df_incorrect_stats.to_csv(os.path.join(threshold_dir, f"confidence_difference_background_images_incorrect_stats_seed_{seed}.csv"), index=True)
            

