import json
import os
import argparse
import wandb

import numpy as np
import random
import torch
from torch import nn, optim
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig, ViTForImageClassification, PretrainedConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from typing import Dict, List, Optional, Set, Tuple, Union

from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))
from PIL import Image

sys.path.append('/tmp/exlib/src')
import exlib
from exlib.evaluators.attributions import NNZ, InsDel, CompSuff
from exlib.modules.sop import SOP
from exlib.explainers.torch_explainer import TorchImageLime, TorchTextLime, TorchImageSHAP, TorchTextSHAP, TorchImageIntGrad, TorchTextIntGrad
from exlib.explainers.rise import TorchImageRISE, TorchTextRISE
from exlib.explainers.common import patch_segmenter
from exlib.explainers.gradcam import TorchImageGradCAM
from exlib.evaluators.attributions import CompSuffSem, InsDelSem

# from models.wrapper import SumOfPartsWrapper, get_inverse_sqrt_with_separate_heads_schedule_with_warmup

from models.cnn import CNNModel
from custom_datasets import get_datasets, get_masked_datasets

class ViTForImageClassificationLocal(ViTForImageClassification):
    def __init__(self, config):
        super().__init__(config)

    def forward(self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None
    ):
        outputs = super().forward(pixel_values=pixel_values,
                                  head_mask=head_mask,
                                  labels=labels,
                                  output_attentions=output_attentions,
                                  output_hidden_states=output_hidden_states,
                                  interpolate_pos_encoding=interpolate_pos_encoding,
                                  return_dict=return_dict)
        return outputs.logits
    
class GetLogits(nn.Module): 
    def __init__(self, model): 
        self.model = model
    def forward(self, X): 
        return self.model(X).logits
        

def get_chained_attr(obj, attr_path):
    if attr_path is None:
        return None
    # Split the attribute path into individual attribute names
    attributes = attr_path.split('.')

    # Access the attribute dynamically
    desired_attribute = obj
    for attr in attributes:
        desired_attribute = getattr(desired_attribute, attr)
    return desired_attribute

def parse_args():
    parser = argparse.ArgumentParser()

    # models and configs
    parser.add_argument('--blackbox-model-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box model name')
    parser.add_argument('--blackbox-processor-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box processor name')
    parser.add_argument('--wrapper-config-filepath', type=str, 
                        default='actions/wrapper/configs/imagenet_vit_wrapper_default.json', 
                        help='wrapper config file')
    parser.add_argument('--exp-dir', type=str, 
                        default='exps/imgenet_wrapper', 
                        help='exp dir')
    parser.add_argument('--sop-exp-dir', type=str, 
                        default='exps/imgenet_wrapper', 
                        help='exp dir')
    parser.add_argument('--mask-dir', type=str, 
                        default=None, 
                        help='mask dir, if none, generate, else use')
    parser.add_argument('--output-filename', type=str, 
                        default='exps/baselines/baselines.json', 
                        help='exp dir')
    parser.add_argument('--model-type', type=str, 
                        default='image', choices=['image', 'text'],
                        help='image or text model')
    parser.add_argument('--model-type-spec', type=str, 
                        default='encoder', choices=['encoder', 'clip', 'gpt2'],
                        help='what specific model to use')
    parser.add_argument('--projection-layer-name', type=str, 
                        default=None,  #distilbert.embeddings
                        help='projection layer if specified, else train from scratch')
    
    # data
    parser.add_argument('--dataset', type=str, 
                        default='imagenet', choices=['imagenet', 'imagenet_m', 'multirc',
                                                     'movies', 'voc', 'cosmogrid'],
                        help='which dataset to use')
    parser.add_argument('--train-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    parser.add_argument('--val-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    
    # training
    parser.add_argument('--num-epochs', type=int, 
                        default=1, 
                        help='num epochs')
    parser.add_argument('--num-train-reps', type=int, 
                        default=1, 
                        help='number of times to train each head')
    parser.add_argument('--lr', type=float, 
                        default=1e-4, 
                        help='learning rate')
    parser.add_argument('--lr-scheduler-step-size', type=int, 
                        default=1, 
                        help='learning rate scheduler step size (by epoch)')
    parser.add_argument('--lr-scheduler-gamma', type=float, 
                        default=0.1, 
                        help='learning rate scheduler gamma')
    parser.add_argument('--warmup-epochs', type=float, 
                        default=-1, 
                        help='number of epochs to warmup')
    parser.add_argument('--warmup-steps', type=float, 
                        default=0, 
                        help='number of steps to warmup. Use this when warmup epochs is -1')
    parser.add_argument('--scheduler-type', type=str, 
                        default='inverse_sqrt_heads', choices=['cosine',
                                                               'constant',
                                                               'inverse_sqrt_heads'],
                        help='scheduler type')
    parser.add_argument('--criterion', type=str, 
                        default='ce', choices=['ce', 'bce', 'mse'],
                        help='which criterion to use, if multi-label then bce')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--mask-batch-size', type=int, 
                        default=16, 
                        help='mask batch size')
    parser.add_argument('--project-name', type=str, 
                        default='attn', 
                        help='wandb project name')
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    parser.add_argument('--shuffle-val', action='store_true', 
                        default=False, 
                        help='if true then shuffle val')
    parser.add_argument('--num-masks-sample-inference', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')

    # specify wrapper config. If not None, then use this instead of ones specified in config
    parser.add_argument('--attn-patch-size', type=int, 
                        default=None, 
                        help='attn patch size, does not have to match black box model patch size')
    parser.add_argument('--attn-stride-size', type=int, 
                        default=None, 
                        help='attn stride size, if smaller than patch size then there is overlap')
    parser.add_argument('--num-heads', type=int, 
                        default=None, 
                        help='hidden dim for the first attention layer')
    parser.add_argument('--num-masks-sample', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')
    parser.add_argument('--num-masks-max', type=int, 
                        default=None, 
                        help='number of maximum masks to retain.')
    parser.add_argument('--finetune-layers', type=str, 
                        default=None, 
                        help='Which layer to finetune, seperated by comma.')
    parser.add_argument('--ins-del-step-size', 
                        default=None,
                        type=int,
                        help='ins del step size for insertion deletion')
    
    parser.add_argument('--aggr-type', type=str, 
                        default='joint', choices=['joint', 'independent'],
                        help='usually we use joint for classification, but can also try independent')
    parser.add_argument('--proj-hid-size', type=int, 
                        default=None, 
                        help='If specified, use this instead of hidden size for projection')
    parser.add_argument('--mean-center-offset', 
                        default=0,
                        type=float,
                        help='offset to mean center before projection')
    parser.add_argument('--mean-center-offset2', 
                        default=None,
                        type=float,
                        help='offset to mean center after projection')
    parser.add_argument('--mean-center-scale', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering before projection')
    parser.add_argument('--mean-center-scale2', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering after projection')
    
    parser.add_argument('--start', type=int, 
                        default=-1, 
                        help='if not -1, set a start idx for eval')
    parser.add_argument('--end', type=int, 
                        default=-1, 
                        help='if not -1, set a end idx for eval')
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    os.makedirs(args.exp_dir, exist_ok=True)

    # Set the seed for reproducibility
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define your dataset and dataloader
    if args.model_type == 'image':
        if args.dataset == 'cosmogrid':
            processor = None
        else:
            processor = AutoImageProcessor.from_pretrained(args.blackbox_processor_name)
    else:
        processor = AutoTokenizer.from_pretrained(args.blackbox_processor_name)

    if args.dataset == 'cosmogrid':
        # merge blackbox config and wrapper config
        blackbox_config = PretrainedConfig()
        with open(args.wrapper_config_filepath) as input_file:
            wrapper_config = json.load(input_file)
        config = blackbox_config
        config.__dict__.update(wrapper_config)
        config.blackbox_model_name = args.blackbox_model_name
        config.blackbox_processor_name = args.blackbox_processor_name
    else:
        # merge blackbox config and wrapper config
        blackbox_config = AutoConfig.from_pretrained(args.blackbox_model_name)
        with open(args.wrapper_config_filepath) as input_file:
            wrapper_config = json.load(input_file)
        config = blackbox_config
        config.__dict__.update(wrapper_config)
        config.blackbox_model_name = args.blackbox_model_name
        config.blackbox_processor_name = args.blackbox_processor_name
    
    # allow specifying args to be different from in the json file
    specifiable_arg_list = ['attn_patch_size', 
                            'attn_stride_size', 
                            'num_heads',
                            'num_masks_sample',
                            'num_masks_max',
                            'finetune_layers']
    for specifiable_arg in specifiable_arg_list:
        arg_value = getattr(args, specifiable_arg)
        if arg_value is not None:
            if specifiable_arg == 'finetune_layers':
                config.__dict__[specifiable_arg] = arg_value.split(',')
            else:
                config.__dict__[specifiable_arg] = arg_value

    if args.mask_dir is None:
        val_dataset = get_datasets(args.dataset, 
                                    processor, 
                                    debug=False,
                                    train_size=args.train_size,
                                    val_size=args.val_size,
                                    label2id=config.label2id,
                                    val_only=True,
                                    mode='test')
    else:
        val_dataset = get_masked_datasets(args.dataset, 
                                                    # processor, 
                                                    processor=None,
                                                    transform=lambda x: x,
                                                    mask_dir=args.mask_dir,
                                                    seg_mask_cut_off=args.num_masks_max,
                                                    debug=False,
                                                    train_size=args.train_size,
                                                    val_size=args.val_size,
                                                    val_only=True)

    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)

    if args.model_type == 'image':
        if args.dataset == 'cosmogrid':
            inner_model = CNNModel(config.num_labels)
            state_dict = torch.load(args.blackbox_model_name)
            inner_model.load_state_dict(state_dict=state_dict)
        else:
            inner_model = AutoModelForImageClassification.from_pretrained(args.blackbox_model_name)
    else:
        inner_model = AutoModelForSequenceClassification.from_pretrained(args.blackbox_model_name)
    
    projection_layer = get_chained_attr(inner_model, args.projection_layer_name)

    def mask_combine(inputs, masks):
        with torch.no_grad():
            inputs_embed = projection_layer(inputs)
            mask_embed = projection_layer(torch.tensor([0]).int().to(inputs.device))
            masked_inputs_embeds = inputs_embed.unsqueeze(1) * masks.unsqueeze(-1) + \
                    mask_embed.view(1,1,1,-1) * (1 - masks.unsqueeze(-1))
        return masked_inputs_embeds

    model = SOP.from_pretrained(args.sop_exp_dir,
                              config=config, 
                              blackbox_model=inner_model,
                              model_type=args.model_type,
                              projection_layer=projection_layer,
                              aggr_type=args.aggr_type,
                              pooler=False if args.model_type_spec == 'encoder' and args.dataset != 'cosmogrid' else True,
                              return_tuple=True,
                              mean_center_scale=args.mean_center_scale,
                              mean_center_scale2=args.mean_center_scale2,
                              mean_center_offset=args.mean_center_offset,
                              mean_center_offset2=args.mean_center_offset2)
    model.eval()

    # model = model.to(device)
    model = model.to(device)
    model.eval()

    if args.model_type == 'image':
        eik = {
            "segmentation_fn": patch_segmenter,
            "top_labels": 5, 
            "hide_color": 0, 
            "num_samples": 1000
        }
        gimk = {
            "positive_only": False
        }
    else:
        eik = {
            # "segmentation_fn": patch_segmenter,
            "top_labels": 5, 
            "num_samples": 1000
        }
        gimk = {
            "positive_only": False
        }
        
    if args.model_type == 'image':
        lime_explainer = TorchImageLime(inner_model, 
                                        explain_instance_kwargs=eik, 
                                        get_image_and_mask_kwargs=gimk,
                                        postprocess=lambda x: x.logits,
                                        task='clf' if args.criterion != 'mse' else 'reg' # if args.criterion == 'ce' else 'multiclf'
                                        ).to(device)
        shap_explainer = TorchImageSHAP(inner_model, 
                                        postprocess=lambda x: x.logits).to(device)
        rise_explainer = TorchImageRISE(inner_model, (224, 224) if args.dataset != 'cosmogrid' else (66, 66), 
                                        gpu_batch=args.batch_size,
                                        postprocess=lambda x: x.logits).to(device)
        intgrad_explainer = TorchImageIntGrad(inner_model, postprocess=lambda x: x.logits)
        gradcam_explainer = TorchImageGradCAM(inner_model, 
                                              [inner_model.vit.encoder.layer[11].layernorm_after] if args.dataset != 'cosmogrid' else [inner_model.relu6],
                                          postprocess=lambda x: x.hidden_states[-1] if args.dataset != 'cosmogrid' else lambda x: x.pooler).to(device)
    else:
        lime_explainer = TorchTextLime(inner_model, 
                                       tokenizer=processor,
                                        explain_instance_kwargs=eik, 
                                        get_image_and_mask_kwargs=gimk,
                                        postprocess=lambda x: x.logits).to(device)
        shap_explainer = TorchTextSHAP(inner_model, tokenizer=processor,
                                        postprocess=lambda x: x.logits).to(device)
        rise_explainer = TorchTextRISE(inner_model, 512, 
                                        gpu_batch=args.batch_size,
                                        postprocess=lambda x: x.logits,
                                        mask_combine=mask_combine).to(device)
        intgrad_explainer = TorchTextIntGrad(inner_model, postprocess=lambda x: x.logits,
                                             mask_combine=mask_combine)
        gradcam_explainer = TorchImageGradCAM(inner_model, [inner_model.bert.encoder.layer[-1].output.LayerNorm],
                                          postprocess=lambda x: x.hidden_states[-1]).to(device)
    
    
    explainers = {
        # 'lime': lime_explainer,
        # 'shap': shap_explainer,
        # 'rise': rise_explainer,
        # 'intgrad': intgrad_explainer,
        'gradcam': gradcam_explainer
    }
    # if args.dataset != 'cosmogrid':
    #     explainers['gradcam'] = intgrad_explainer


    # nnz_evaluator = NNZ().to(device)
    net_in_size = config.ins_del_step_size if config.ins_del_step_size else 224
    ins_sem_evaluator = InsDelSem(inner_model, 
                            'ins', 
                            net_in_size, 
                            substrate_fn=torch.zeros_like,
                            postprocess=lambda x: x.logits,
                            task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)
    if args.model_type == 'image':
        klen = 11
        ksig = 5
        kern = InsDelSem.gkern(klen, ksig, config.num_channels).to(device)
        blur = lambda x: nn.functional.conv2d(x, kern, padding=klen//2)
        ins_sem_blur_evaluator = InsDelSem(inner_model, 
                                    'ins', 
                                    net_in_size, 
                                    substrate_fn=blur,
                                    postprocess=lambda x: x.logits,
                                    task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)
    else:
        ins_sem_blur_evaluator = None
    del_sem_evaluator = InsDelSem(inner_model, 
                        'del', 
                        net_in_size,
                        substrate_fn=torch.zeros_like,
                        postprocess=lambda x: x.logits,
                        task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)

    comp_sem_evaluator = CompSuffSem(inner_model, 
                            'comp', 
                            k_fraction=0.2,
                            postprocess=lambda x: x.logits).to(device)
    
    suff_sem_evaluator = CompSuffSem(inner_model, 
                            'suff', 
                            k_fraction=0.2,
                            postprocess=lambda x: x.logits).to(device)
    
    results = {}
    results_raw = {}
    for expl_name in explainers:
        results_raw[expl_name] = {
            'correct': 0,
            'ins_sem_scores': 0,
            'ins_sem_blur_scores': 0,
            'del_sem_scores': 0,
            'comp_sem_scores': 0,
            'suff_sem_scores': 0,
            'ins_sem_scores_max': 0,
            'ins_sem_blur_scores_max': 0,
            'del_sem_scores_max': 0,
            'total': 0,
        }
    
    
    print('Eval..')

    progress_bar_eval = tqdm(range(len(val_loader)))
    with torch.no_grad():
        for b_i, batch in enumerate(val_loader):
            if args.start != -1:
                if b_i < args.start:
                    progress_bar_eval.update(1)
                    continue
            if args.end != -1:
                if b_i >= args.end:
                    progress_bar_eval.update(1)
                    break
            if args.model_type == 'image':
                if args.mask_dir is None:
                    inputs, labels = batch
                    masks = None
                else:
                    inputs, labels, masks, masks_i = batch
                    masks = masks.to(device)
                if args.dataset != 'cosmogrid':
                    inputs, labels = inputs.to(device), labels.to(device)
                else:
                    inputs = inputs.to(device, dtype=torch.float)
                    labels = labels.to(device, dtype=torch.float)
                token_type_ids = None
                attention_mask = None
                kwargs = {}
            else:
                if not isinstance(batch['input_ids'], torch.Tensor):
                    inputs = torch.stack(batch['input_ids']).transpose(0, 1).to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = torch.stack(batch['token_type_ids']).transpose(0, 1).to(device)
                    else:
                        token_type_ids = None
                    attention_mask = torch.stack(batch['attention_mask']).transpose(0, 1).to(device)
                else:
                    inputs = batch['input_ids'].to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = batch['token_type_ids'].to(device)
                    else:
                        token_type_ids = None
                    attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                kwargs = {'token_type_ids': token_type_ids,
                        'attention_mask': attention_mask}

            bsz = inputs.size(0)

            # inputs_numpy = inputs.cpu().numpy()
            # images = [Image.fromarray(inputs_numpy[i]) for i in range(inputs_numpy.shape[0])]
            # inputs = processor(images, return_tensors='pt')
            # pixel_values = inputs['pixel_values'].to(device)
            # if len(pixel_values.shape) == 5:
            #     pixel_values = pixel_values.squeeze(0)
            with torch.no_grad():
                sop_outputs = model(inputs,
                                    token_type_ids=token_type_ids,
                                    attention_mask=attention_mask,
                                    epoch=model.config.num_heads,
                                    mask_batch_size=args.mask_batch_size)
            _, predicted_sop = torch.max(sop_outputs.logits, -1)
            outputs = inner_model(inputs)

            # print('sop_outputs.flat_masks', sop_outputs.flat_masks)
            # import pdb
            # pdb.set_trace()

            _, predicted = torch.max(outputs.logits, -1)

            for expl_name, explainer in tqdm(explainers.items()):
                print('expl_name', expl_name)

                # expln = explainer(inputs)
                if args.model_type == 'image':
                    expln = explainer(inputs)
                    # if expl_name == 'shap':
                    #     expln = explainer(inputs)
                    # else:
                    #     expln = explainer(inputs, predicted)
                else:
                    if expl_name in ['lime', 'shap']:
                        texts = []
                        for input_ids in inputs:
                            if 0 in input_ids.tolist():
                                texts.append(processor.decode(input_ids[1:input_ids.tolist().index(0) - 1]))
                            else:
                                texts.append(processor.decode(input_ids[1:-1]))
                        expln = explainer(texts)
                    elif expl_name == 'gradcam':
                        if args.criterion == 'mse':
                            criterion = nn.MSELoss()
                            def gradcam_target_func(output, label):
                                return criterion(output, label)
                            def gradcam_target_func_new(label):
                                def inner_func(output):
                                    return gradcam_target_func(output, label)
                                return inner_func
                            expln = explainer(inputs, [gradcam_target_func_new(label) for label in labels])
                        else:
                            expln = explainer(inputs)
                    else:
                        expln = explainer(inputs, predicted, kwargs)

                if args.model_type == 'image':
                    ins_sem_score = ins_sem_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
                    # ins_sem_blur_score = ins_sem_blur_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
                    del_sem_score = del_sem_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
                    if args.dataset == 'cosmogrid':
                        ins_sem_score = ins_sem_score.mean(-1)
                        # ins_sem_blur_score = ins_sem_blur_score.mean(-1)
                        del_sem_score = del_sem_score.mean(-1)
                    results_raw[expl_name]['ins_sem_scores'] += ins_sem_score.sum().item()
                    # if args.model_type == 'image':
                    #     results_raw[expl_name]['ins_sem_blur_scores'] += ins_sem_blur_score.sum().item()
                    results_raw[expl_name]['del_sem_scores'] += del_sem_score.sum().item()
                else:  # text
                    comp_sem_score = comp_sem_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
                    suff_sem_score = suff_sem_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
                    results_raw[expl_name]['comp_sem_scores'] += comp_sem_score.sum().item()
                    results_raw[expl_name]['suff_sem_scores'] += suff_sem_score.sum().item()
                results_raw[expl_name]['total'] += labels.size(0)
                # results_raw[expl_name]['correct'] += (predicted == labels).sum().item()
                
            progress_bar_eval.update(1)
        

    for expl_name in explainers:
        print(expl_name)
        ins_sem_avg = 100 * results_raw[expl_name]['ins_sem_scores'] / \
            results_raw[expl_name]['total']
        ins_sem_blur_avg = 100 * results_raw[expl_name]['ins_sem_blur_scores'] / \
            results_raw[expl_name]['total']
        del_sem_avg = 100 * results_raw[expl_name]['del_sem_scores'] / \
            results_raw[expl_name]['total']
        comp_sem_avg = 100 * results_raw[expl_name]['comp_sem_scores'] / \
            results_raw[expl_name]['total']
        suff_sem_avg = 100 * results_raw[expl_name]['suff_sem_scores'] / \
            results_raw[expl_name]['total']
        print(f'Explainer {expl_name}')
        print(f'Insertion Sem Score Avg (high) {ins_sem_avg:.2f}%')
        print(f'Insertion Sem Blur Score Avg (high) {ins_sem_blur_avg:.2f}%')
        print(f'Deletion Sem Score Avg (low) {del_sem_avg:.2f}%')
        print(f'Comprehensiveness Sem Score Avg (high) {comp_sem_avg:.2f}%')
        print(f'Sufficiency Sem Score Avg (low_ {suff_sem_avg:.2f}%')

        results[expl_name] = {'dataset': args.dataset,
                              'insertion_semantic': ins_sem_avg,
                                'insertion_blur_semantic': ins_sem_blur_avg,
                                'deletion_semantic': del_sem_avg,
                                'comprehensiveness_semantic': comp_sem_avg,
                                'sufficiency_semantic': suff_sem_avg}
        if args.dataset == 'cosmogrid':
            output_filename = os.path.join(args.exp_dir, f'eval_results_explainers_semantics_{args.val_size}_{expl_name}.json')
            with open(output_filename, 'wt') as output_file:
                json.dump(results,
                            output_file,
                            indent=4)
            
    # os.makedirs(os.path.dirname(args.output_filename), exist_ok=True)
    if args.dataset != 'cosmogrid':
        if args.start == -1 and args.end == -1:
            output_filename = os.path.join(args.exp_dir, f'eval_results_explainers_semantics_{args.val_size}.json')
        else:
            output_filename = os.path.join(args.exp_dir, f'eval_results_explainers_semantics_{args.val_size}_s{args.start}_e{args.end}.json')
        with open(output_filename, 'wt') as output_file:
            json.dump(results,
                    output_file,
                    indent=4)


if __name__ == '__main__':
    main()