import json
import os
import argparse
import wandb

import numpy as np
import random
import torch
from torch import nn, optim
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig, ViTForImageClassification, PretrainedConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from typing import Dict, List, Optional, Set, Tuple, Union
from collections import defaultdict

from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))
from PIL import Image

sys.path.append('/nlp/data/weiqiuy/exlib/src')
import exlib
from exlib.evaluators.attributions import NNZ, InsDel, CompSuff
from exlib.evaluators.comp_suff import CompSuffText
from exlib.modules.sop import SOP
from exlib.explainers.torch_explainer import TorchImageLime, TorchTextLime, TorchImageSHAP, TorchTextSHAP, TorchImageIntGrad, TorchTextIntGrad
from exlib.explainers.rise import TorchImageRISE, TorchTextRISE
from exlib.explainers.common import patch_segmenter
from exlib.explainers.gradcam import TorchImageGradCAM

# from models.wrapper import SumOfPartsWrapper, get_inverse_sqrt_with_separate_heads_schedule_with_warmup

from models.cnn import CNNModel
from custom_datasets import get_datasets, get_masked_datasets

class ViTForImageClassificationLocal(ViTForImageClassification):
    def __init__(self, config):
        super().__init__(config)

    def forward(self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None
    ):
        outputs = super().forward(pixel_values=pixel_values,
                                  head_mask=head_mask,
                                  labels=labels,
                                  output_attentions=output_attentions,
                                  output_hidden_states=output_hidden_states,
                                  interpolate_pos_encoding=interpolate_pos_encoding,
                                  return_dict=return_dict)
        return outputs.logits
    
class GetLogits(nn.Module): 
    def __init__(self, model): 
        self.model = model
    def forward(self, X): 
        return self.model(X).logits
        

def get_chained_attr(obj, attr_path):
    if attr_path is None:
        return None
    # Split the attribute path into individual attribute names
    attributes = attr_path.split('.')

    # Access the attribute dynamically
    desired_attribute = obj
    for attr in attributes:
        desired_attribute = getattr(desired_attribute, attr)
    return desired_attribute

def parse_args():
    parser = argparse.ArgumentParser()

    # models and configs
    parser.add_argument('--blackbox-model-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box model name')
    parser.add_argument('--blackbox-processor-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box processor name')
    parser.add_argument('--wrapper-config-filepath', type=str, 
                        default='actions/wrapper/configs/imagenet_vit_wrapper_default.json', 
                        help='wrapper config file')
    parser.add_argument('--exp-dir', type=str, 
                        default='exps/imgenet_wrapper', 
                        help='exp dir')
    parser.add_argument('--mask-dir', type=str, 
                        default=None, 
                        help='mask dir, if none, generate, else use')
    parser.add_argument('--output-filename', type=str, 
                        default='exps/baselines/baselines.json', 
                        help='exp dir')
    parser.add_argument('--model-type', type=str, 
                        default='image', choices=['image', 'text'],
                        help='image or text model')
    parser.add_argument('--model-type-spec', type=str, 
                        default='encoder', choices=['encoder', 'clip', 'gpt2'],
                        help='what specific model to use')
    parser.add_argument('--projection-layer-name', type=str, 
                        default=None,  #distilbert.embeddings
                        help='projection layer if specified, else train from scratch')
    
    # data
    parser.add_argument('--dataset', type=str, 
                        default='imagenet', choices=['imagenet', 'imagenet_m', 'multirc',
                                                     'movies', 'voc',
                                                     'cosmogrid'],
                        help='which dataset to use')
    parser.add_argument('--train-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    parser.add_argument('--val-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    
    # training
    parser.add_argument('--num-epochs', type=int, 
                        default=1, 
                        help='num epochs')
    parser.add_argument('--num-train-reps', type=int, 
                        default=1, 
                        help='number of times to train each head')
    parser.add_argument('--lr', type=float, 
                        default=1e-4, 
                        help='learning rate')
    parser.add_argument('--lr-scheduler-step-size', type=int, 
                        default=1, 
                        help='learning rate scheduler step size (by epoch)')
    parser.add_argument('--lr-scheduler-gamma', type=float, 
                        default=0.1, 
                        help='learning rate scheduler gamma')
    parser.add_argument('--warmup-epochs', type=float, 
                        default=-1, 
                        help='number of epochs to warmup')
    parser.add_argument('--warmup-steps', type=float, 
                        default=0, 
                        help='number of steps to warmup. Use this when warmup epochs is -1')
    parser.add_argument('--scheduler-type', type=str, 
                        default='inverse_sqrt_heads', choices=['cosine',
                                                               'constant',
                                                               'inverse_sqrt_heads'],
                        help='scheduler type')
    parser.add_argument('--criterion', type=str, 
                        default='ce', choices=['ce', 'bce', 'mse'],
                        help='which criterion to use, if multi-label then bce')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--mask-batch-size', type=int, 
                        default=16, 
                        help='mask batch size')
    parser.add_argument('--project-name', type=str, 
                        default='attn', 
                        help='wandb project name')
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    parser.add_argument('--shuffle-val', action='store_true', 
                        default=False, 
                        help='if true then shuffle val')
    parser.add_argument('--num-masks-sample-inference', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')

    # specify wrapper config. If not None, then use this instead of ones specified in config
    parser.add_argument('--attn-patch-size', type=int, 
                        default=None, 
                        help='attn patch size, does not have to match black box model patch size')
    parser.add_argument('--attn-stride-size', type=int, 
                        default=None, 
                        help='attn stride size, if smaller than patch size then there is overlap')
    parser.add_argument('--num-heads', type=int, 
                        default=None, 
                        help='hidden dim for the first attention layer')
    parser.add_argument('--num-masks-sample', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')
    parser.add_argument('--num-masks-max', type=int, 
                        default=None, 
                        help='number of maximum masks to retain.')
    parser.add_argument('--finetune-layers', type=str, 
                        default=None, 
                        help='Which layer to finetune, seperated by comma.')
    parser.add_argument('--ins-del-step-size', 
                        default=None,
                        type=int,
                        help='ins del step size for insertion deletion')
    
    parser.add_argument('--aggr-type', type=str, 
                        default='joint', choices=['joint', 'independent'],
                        help='usually we use joint for classification, but can also try independent')
    parser.add_argument('--proj-hid-size', type=int, 
                        default=None, 
                        help='If specified, use this instead of hidden size for projection')
    parser.add_argument('--mean-center-offset', 
                        default=0,
                        type=float,
                        help='offset to mean center before projection')
    parser.add_argument('--mean-center-offset2', 
                        default=None,
                        type=float,
                        help='offset to mean center after projection')
    parser.add_argument('--mean-center-scale', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering before projection')
    parser.add_argument('--mean-center-scale2', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering after projection')
    
    parser.add_argument('--start', type=int, 
                        default=-1, 
                        help='if not -1, set a start idx for eval')
    parser.add_argument('--end', type=int, 
                        default=-1, 
                        help='if not -1, set a end idx for eval')
    parser.add_argument('--save', action='store_true', 
                        default=False, 
                        help='save the attributions if true, else compute the metrics')
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    os.makedirs(args.exp_dir, exist_ok=True)

    # Set the seed for reproducibility
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define your dataset and dataloader
    if args.model_type == 'image':
        if args.dataset == 'cosmogrid':
            processor = None
        else:
            processor = AutoImageProcessor.from_pretrained(args.blackbox_processor_name)
    else:
        processor = AutoTokenizer.from_pretrained(args.blackbox_processor_name)

    if args.dataset == 'cosmogrid':
        # merge blackbox config and wrapper config
        blackbox_config = PretrainedConfig()
        with open(args.wrapper_config_filepath) as input_file:
            wrapper_config = json.load(input_file)
        config = blackbox_config
        config.__dict__.update(wrapper_config)
        config.blackbox_model_name = args.blackbox_model_name
        config.blackbox_processor_name = args.blackbox_processor_name
    else:
        # merge blackbox config and wrapper config
        blackbox_config = AutoConfig.from_pretrained(args.blackbox_model_name)
        with open(args.wrapper_config_filepath) as input_file:
            wrapper_config = json.load(input_file)
        config = blackbox_config
        config.__dict__.update(wrapper_config)
        config.blackbox_model_name = args.blackbox_model_name
        config.blackbox_processor_name = args.blackbox_processor_name

    # allow specifying args to be different from in the json file
    specifiable_arg_list = ['attn_patch_size', 
                            'attn_stride_size', 
                            'num_heads',
                            'num_masks_sample',
                            'num_masks_max',
                            'finetune_layers',
                            'ins_del_step_size']
    for specifiable_arg in specifiable_arg_list:
        arg_value = getattr(args, specifiable_arg)
        if arg_value is not None:
            if specifiable_arg == 'finetune_layers':
                config.__dict__[specifiable_arg] = arg_value.split(',')
            else:
                config.__dict__[specifiable_arg] = arg_value

    if args.mask_dir is None:
        val_dataset = get_datasets(args.dataset, 
                                    processor, 
                                    debug=False,
                                    train_size=args.train_size,
                                    val_size=args.val_size,
                                    label2id=config.label2id,
                                    val_only=True,
                                    mode='test')
    else:
        val_dataset = get_masked_datasets(args.dataset, 
                                                    # processor, 
                                                    processor=None,
                                                    transform=lambda x: x,
                                                    mask_dir=args.mask_dir,
                                                    seg_mask_cut_off=args.num_masks_max,
                                                    debug=False,
                                                    train_size=args.train_size,
                                                    val_size=args.val_size,
                                                    val_only=True)

    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, 
                            shuffle=args.shuffle_val)
    def raw_collate_fn(batch):
        new_batch = defaultdict(list)
        for batch_i in batch:
            for k, v in batch_i.items():
                new_batch[k].append(v)
        return new_batch
    
    if args.model_type == 'image':
        if args.dataset == 'cosmogrid':
            inner_model = CNNModel(config.num_labels)
            state_dict = torch.load(args.blackbox_model_name)
            inner_model.load_state_dict(state_dict=state_dict)
        else:
            inner_model = AutoModelForImageClassification.from_pretrained(args.blackbox_model_name)
    else:
        inner_model = AutoModelForSequenceClassification.from_pretrained(args.blackbox_model_name)
    
    projection_layer = get_chained_attr(inner_model, args.projection_layer_name)

    def mask_combine(inputs, masks):
        with torch.no_grad():
            inputs_embed = projection_layer(inputs)
            mask_embed = projection_layer(torch.tensor([0]).int().to(inputs.device))
            masked_inputs_embeds = inputs_embed.unsqueeze(1) * masks.unsqueeze(-1) + \
                    mask_embed.view(1,1,1,-1) * (1 - masks.unsqueeze(-1))
        return masked_inputs_embeds

    model = SOP.from_pretrained(args.exp_dir,
                              config=config, 
                              blackbox_model=inner_model,
                              model_type=args.model_type,
                              projection_layer=projection_layer,
                              aggr_type=args.aggr_type,
                              pooler=False if args.model_type_spec == 'encoder' and args.dataset != 'cosmogrid' else True,
                              return_tuple=True,
                              mean_center_scale=args.mean_center_scale,
                              mean_center_scale2=args.mean_center_scale2,
                              mean_center_offset=args.mean_center_offset,
                              mean_center_offset2=args.mean_center_offset2)

    # model = model.to(device)
    model = model.to(device)
    model.eval()
    inner_model = inner_model.to(device)
    inner_model.eval()

    if args.model_type == 'image':
        eik = {
            "segmentation_fn": patch_segmenter,
            "top_labels": 5, 
            "hide_color": 0, 
            "num_samples": 1000
        }
        gimk = {
            "positive_only": False
        }
    else:
        eik = {
            # "segmentation_fn": patch_segmenter,
            "top_labels": 2, 
            "num_samples": 100
        }
        gimk = {
            "positive_only": False
        }
        
    if args.model_type == 'image':
        lime_explainer = TorchImageLime(inner_model, 
                                        explain_instance_kwargs=eik, 
                                        get_image_and_mask_kwargs=gimk,
                                        postprocess=lambda x: x.logits,
                                        task='clf' if args.criterion != 'mse' else 'reg' # if args.criterion == 'ce' else 'multiclf'
                                        ).to(device)
        shap_explainer = TorchImageSHAP(inner_model, 
                                        postprocess=lambda x: x.logits).to(device)
        rise_explainer = TorchImageRISE(inner_model, (224, 224) if args.dataset != 'cosmogrid' else (66, 66), 
                                        gpu_batch=args.batch_size,
                                        postprocess=lambda x: x.logits).to(device)
        intgrad_explainer = TorchImageIntGrad(inner_model, postprocess=lambda x: x.logits)
        gradcam_explainer = TorchImageGradCAM(inner_model, 
                                              [inner_model.vit.encoder.layer[11].layernorm_after] if args.dataset != 'cosmogrid' else [inner_model.relu6],
                                          postprocess=lambda x: x.hidden_states[-1] if args.dataset != 'cosmogrid' else lambda x: x.pooler).to(device)
    else:
        lime_explainer = TorchTextLime(inner_model, 
                                       tokenizer=processor,
                                        explain_instance_kwargs=eik, 
                                        get_image_and_mask_kwargs=gimk,
                                        postprocess=lambda x: x.logits,
                                        task='clf',
                                        batch_size=4).to(device)
        shap_explainer = TorchTextSHAP(inner_model, tokenizer=processor,
                                        postprocess=lambda x: x.logits).to(device)
        rise_explainer = TorchTextRISE(inner_model, 512, 
                                        gpu_batch=args.batch_size,
                                        postprocess=lambda x: x.logits,
                                        mask_combine=mask_combine).to(device)
        intgrad_explainer = TorchTextIntGrad(inner_model, postprocess=lambda x: x.logits,
                                             mask_combine=mask_combine)
        gradcam_explainer = TorchImageGradCAM(inner_model, [inner_model.bert.encoder.layer[-1].output.LayerNorm],
                                          postprocess=lambda x: x.hidden_states[-1]).to(device)
    
    
    explainers = {
        'lime': lime_explainer,
        'shap': shap_explainer,
        'rise': rise_explainer,
        'intgrad': intgrad_explainer,
        'gradcam': gradcam_explainer
    }
    # if args.dataset != 'cosmogrid':
    #     explainers['gradcam'] = gradcam_explainer

    nnz_evaluator = NNZ().to(device)
    net_in_size = config.ins_del_step_size if config.ins_del_step_size else 224
    ins_evaluator = InsDel(inner_model, 
                            'ins', 
                            net_in_size, 
                            substrate_fn=torch.zeros_like,
                            postprocess=lambda x: x.logits,
                            task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)
    if args.model_type == 'image':
        klen = 11
        ksig = 5
        kern = InsDel.gkern(klen, ksig, config.num_channels).to(device)
        blur = lambda x: nn.functional.conv2d(x, kern, padding=klen//2)
        ins_blur_evaluator = InsDel(inner_model, 
                                    'ins', 
                                    net_in_size, 
                                    substrate_fn=blur,
                                    postprocess=lambda x: x.logits,
                                    task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)
    else: # doesn't make sense to blur the text
        ins_blur_evaluator = None
    del_evaluator = InsDel(inner_model, 
                        'del', 
                        net_in_size,
                        substrate_fn=torch.zeros_like,
                        postprocess=lambda x: x.logits,
                        task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)

    if args.model_type == 'image':
        comp_evaluator = CompSuff(inner_model, 
                                'comp', 
                                k_fraction=0.2,
                                postprocess=lambda x: x.logits).to(device)
        
        suff_evaluator = CompSuff(inner_model, 
                                'suff', 
                                k_fraction=0.2,
                                postprocess=lambda x: x.logits).to(device)
    else:
        comp_evaluator = CompSuffText(model, 
                                'comp', 
                                postprocess=lambda x: x.logits).to(device)
        
        suff_evaluator = CompSuffText(model, 
                                'suff', 
                                postprocess=lambda x: x.logits).to(device)
    
    results = {}
    results_raw = {}
    for expl_name in explainers:
        results_raw[expl_name] = {
            'correct': 0,
            'nnz_scores': 0,
            'ins_scores': 0,
            'ins_blur_scores': 0,
            'del_scores': 0,
            'comp_scores': 0,
            'suff_scores': 0,
            'total': 0,
        }

    # with torch.no_grad():
    
    progress_bar_eval = tqdm(range(len(val_loader)))
    with torch.no_grad():
        for b_i, batch in enumerate(val_loader):
            print(b_i, args.start)
            print(b_i, args.end)
            if args.start != -1:
                if b_i < args.start:
                    print('cont')
                    progress_bar_eval.update(1)
                    continue
            if args.end != -1:
                if b_i >= args.end:
                    print('break')
                    progress_bar_eval.update(1)
                    break
            # print('making batch')
            if args.model_type == 'image':
                if args.mask_dir is None:
                    inputs, labels = batch
                    masks = None
                else:
                    inputs, labels, masks, masks_i = batch
                    masks = masks.to(device)
                if args.dataset != 'cosmogrid':
                    inputs, labels = inputs.to(device), labels.to(device)
                else:
                    inputs = inputs.to(device, dtype=torch.float)
                    labels = labels.to(device, dtype=torch.float)
                token_type_ids = None
                attention_mask = None
                kwargs = {}
            else:
                if not isinstance(batch['input_ids'], torch.Tensor):
                    inputs = torch.stack(batch['input_ids']).transpose(0, 1).to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = torch.stack(batch['token_type_ids']).transpose(0, 1).to(device)
                    else:
                        token_type_ids = None
                    attention_mask = torch.stack(batch['attention_mask']).transpose(0, 1).to(device)
                else:
                    inputs = batch['input_ids'].to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = batch['token_type_ids'].to(device)
                    else:
                        token_type_ids = None
                    attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                kwargs = {'token_type_ids': token_type_ids,
                    'attention_mask': attention_mask}

            bsz = inputs.size(0)
            
            with torch.no_grad():
                # print('runing inner model pred')
                if args.model_type == 'image':
                    outputs = inner_model(inputs)
                else:
                    outputs = inner_model(inputs, **kwargs)

            # accuracy
            # _, predicted = torch.max(outputs.logits, -1)
            logits = outputs.logits
            # accuracy
            if args.criterion == 'ce':
                _, predicted = torch.max(logits.data, 1)
                correct = (predicted == labels).sum().item()
                explain_labels = predicted
            elif args.criterion == 'bce':
                probs = torch.sigmoid(logits)
                predicted = (probs > 0.5).float()
                correct = (predicted[range(bsz), torch.abs(labels) - 1] == (labels > 0)).sum().item()
                explain_labels = torch.abs(labels) - 1
                def gradcam_target_func(output, label):
                    label_class = torch.abs(labels) - 1
                    label_val = (labels > 0)
                    probs = torch.sigmoid(output)
                    predicted = (probs > 0.5).float()
                    return predicted[label_class] == label

                def gradcam_target_func_new(label):
                    def inner_func(output):
                        return gradcam_target_func(output, label)
                    return inner_func
            else: # mse
                criterion = nn.MSELoss()
                predicted = criterion(logits, labels)
                correct = predicted.sum().item()
                explain_labels = torch.zeros(bsz)
                def gradcam_target_func(output, label):
                    return criterion(output, label)
                def gradcam_target_func_new(label):
                    def inner_func(output):
                        return gradcam_target_func(output, label)
                    return inner_func

                # total += labels.size(0)
                # correct += (predicted == labels).sum().item()

                # print('explaining')
            for expl_name, explainer in tqdm(explainers.items()):
                if args.model_type == 'image':
                    # if expl_name == 'shap':
                    #     expln = explainer(inputs)
                    # else:
                    if expl_name == 'gradcam':
                        if args.criterion == 'ce':
                            expln = explainer(inputs) #, labels)
                        else:
                            expln = explainer(inputs, [gradcam_target_func_new(label) for label in labels])
                        # import pdb
                        # pdb.set_trace()
                    else:
                        expln = explainer(inputs) #, explain_labels)
                else:
                    if expl_name in ['lime', 'shap']:
                        texts = []
                        for input_ids in inputs:
                            if 0 in input_ids.tolist():
                                texts.append(processor.decode(input_ids[1:input_ids.tolist().index(0) - 1]))
                            else:
                                texts.append(processor.decode(input_ids[1:-1]))
                        expln = explainer(texts)
                    elif expl_name == 'gradcam':
                        expln = explainer(inputs)
                    # elif expl_name in ['lime']:
                    #     expln = explainer(inputs, kwargs=kwargs)
                    else:
                        expln = explainer(inputs, explain_labels, kwargs)

                if not args.save:
                    # get nnz score for aggregated mask
                    nnz_score = nnz_evaluator(inputs, expln.attributions, normalize=True)
                    if args.model_type == 'image':
                        ins_score = ins_evaluator(inputs, expln.attributions, kwargs)
                        # if args.model_type == 'image':
                        #     ins_blur_score = ins_blur_evaluator(inputs, expln.attributions)
                        del_score = del_evaluator(inputs, expln.attributions, kwargs)
                        if args.dataset == 'cosmogrid':
                            ins_score = ins_score.mean(-1)
                            # ins_blur_score = ins_blur_score.mean(-1)
                            del_score = del_score.mean(-1)
                        if args.model_type == 'image':
                            results_raw[expl_name]['ins_scores'] += ins_score.sum().item()
                            # if args.model_type == 'image':
                            #     results_raw[expl_name]['ins_blur_scores'] += ins_blur_score.sum().item()
                            results_raw[expl_name]['del_scores'] += del_score.sum().item()
                    else:
                        comp_score = comp_evaluator(inputs, expln.attributions, kwargs)
                        suff_score = suff_evaluator(inputs, expln.attributions, kwargs)
                        results_raw[expl_name]['comp_scores'] += comp_score.sum().item()
                        results_raw[expl_name]['suff_scores'] += suff_score.sum().item()

                    results_raw[expl_name]['total'] += labels.size(0)
                    results_raw[expl_name]['correct'] += correct
                    results_raw[expl_name]['nnz_scores'] += nnz_score.sum().item()
                else:
                    os.makedirs(os.path.join(args.exp_dir, f'attributions_{expl_name}'), exist_ok=True)
                    save_filepath = os.path.join(args.exp_dir, f'attributions_{expl_name}' , f'{b_i}.pt')
                    torch.save({'inputs': inputs,
                                'labels': labels,
                                'outputs': expln}, save_filepath)
                
            progress_bar_eval.update(1)
        

        # val_acc = 100 * correct / total
        # nnz_avg = 100 * nnz_scores / total
        # # nnz_max_avg = 100 * nnz_scores_max / total
        # ins_avg = 100 * ins_scores / total
        # ins_blur_avg = 100 * ins_blur_scores / total
        # del_avg = 100 * del_scores / total
        # comp_avg = 100 * comp_scores / total
        # suff_avg = 100 * suff_scores / total
        # print(f'Explainer {expl_name}')
        # print(f'Validation Accuracy {val_acc:.2f}%')
        # print(f'Number of Non-zeros {nnz_avg:.2f}%')
        # # print(f'Number of Non-zeros Max {nnz_max_avg:.2f}%')
        # print(f'Insertion Score Avg {ins_avg:.2f}%')
        # print(f'Insertion Blur Score Avg {ins_blur_avg:.2f}%')
        # print(f'Deletion Score Avg {del_avg:.2f}%')
        # print(f'Comprehensiveness Score Avg {comp_avg:.2f}%')
        # print(f'Sufficiency Score Avg {suff_avg:.2f}%')

        # results[expl_name] = {'accuracy': val_acc,
        #                         'num_nonzeros': nnz_avg,
        #                         'insertion': ins_avg,
        #                         'insertion_blur': ins_blur_avg,
        #                         'deletion': del_avg,
        #                         'comprehensiveness': comp_avg,
        #                         'sufficiency': suff_avg}
    if not args.save:
        for expl_name in explainers:
            print(expl_name)
            val_acc = 100 * results_raw[expl_name]['correct'] / results_raw[expl_name]['total']
            nnz_avg = 100 * results_raw[expl_name]['nnz_scores'] / \
                results_raw[expl_name]['total']
            ins_avg = 100 * results_raw[expl_name]['ins_scores'] / \
                results_raw[expl_name]['total']
            ins_blur_avg = 100 * results_raw[expl_name]['ins_blur_scores'] / \
                results_raw[expl_name]['total']
            del_avg = 100 * results_raw[expl_name]['del_scores'] / \
                results_raw[expl_name]['total']
            comp_avg = 100 * results_raw[expl_name]['comp_scores'] / \
                results_raw[expl_name]['total']
            suff_avg = 100 * results_raw[expl_name]['suff_scores'] / \
                results_raw[expl_name]['total']
            print(f'Explainer {expl_name}')
            print(f'Accuracy Score Avg (high) {val_acc:.2f}%')
            print(f'Number of Nonzeros Score Avg (low) {nnz_avg:.2f}%')
            print(f'Insertion Score Avg (high) {ins_avg:.2f}%')
            print(f'Insertion Blur Score Avg (high) {ins_blur_avg:.2f}%')
            print(f'Deletion Score Avg (low) {del_avg:.2f}%')
            print(f'Comprehensiveness Score Avg (high) {comp_avg:.2f}%')
            print(f'Sufficiency Score Avg (low) {suff_avg:.2f}%')

            results[expl_name] = {'dataset': args.dataset,
                                'accuracy': val_acc,
                                'num_nonzeros': nnz_avg,
                                'insertion': ins_avg,
                                'insertion_blur': ins_blur_avg,
                                'deletion': del_avg,
                                'comprehensiveness': comp_avg,
                                'sufficiency': suff_avg}
                
                
        # os.makedirs(os.path.dirname(args.output_filename), exist_ok=True)
            if args.dataset == 'cosmogrid':
                output_filename = os.path.join(args.exp_dir, f'eval_results_explainers_{args.val_size}_{expl_name}.json')
                with open(output_filename, 'wt') as output_file:
                    json.dump(results,
                                output_file,
                                indent=4)

        if args.dataset != 'cosmogrid':
            if args.start == -1 and args.end == -1:
                output_filename = os.path.join(args.exp_dir, f'eval_results_explainers_{args.val_size}.json')
            else:
                output_filename = os.path.join(args.exp_dir, f'eval_results_explainers_{args.val_size}_s{args.start}_e{args.end}.json')
            with open(output_filename, 'wt') as output_file:
                json.dump(results,
                            output_file,
                            indent=4)


if __name__ == '__main__':
    main()