import json
import os
import argparse
import wandb

import numpy as np
import random
import torch
from torch import nn, optim
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig, ViTForImageClassification, PretrainedConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from typing import Dict, List, Optional, Set, Tuple, Union

from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))
from PIL import Image

sys.path.append('/tmp/exlib/src')
import exlib
from exlib.evaluators.attributions import NNZ, InsDelSem, CompSuffSem
from exlib.modules.sop import SOP
from exlib.modules.fresh import FRESH
from exlib.explainers.torch_explainer import TorchImageLime, TorchImageSHAP
from exlib.explainers.rise import TorchImageRISE
from exlib.explainers.common import patch_segmenter
from exlib.explainers.gradcam import TorchImageGradCAM

# from models.wrapper import SumOfPartsWrapper, get_inverse_sqrt_with_separate_heads_schedule_with_warmup

from models.cnn import CNNModel
from custom_datasets import get_datasets, get_masked_datasets

class ViTForImageClassificationLocal(ViTForImageClassification):
    def __init__(self, config):
        super().__init__(config)

    def forward(self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None
    ):
        outputs = super().forward(pixel_values=pixel_values,
                                  head_mask=head_mask,
                                  labels=labels,
                                  output_attentions=output_attentions,
                                  output_hidden_states=output_hidden_states,
                                  interpolate_pos_encoding=interpolate_pos_encoding,
                                  return_dict=return_dict)
        return outputs.logits
    
class GetLogits(nn.Module): 
    def __init__(self, model): 
        self.model = model
    def forward(self, X): 
        return self.model(X).logits
        

def get_chained_attr(obj, attr_path):
    if attr_path is None:
        return None
    # Split the attribute path into individual attribute names
    attributes = attr_path.split('.')

    # Access the attribute dynamically
    desired_attribute = obj
    for attr in attributes:
        desired_attribute = getattr(desired_attribute, attr)
    return desired_attribute

def parse_args():
    parser = argparse.ArgumentParser()

    # models and configs
    parser.add_argument('--blackbox-model-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box model name')
    parser.add_argument('--blackbox-processor-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box processor name')
    parser.add_argument('--wrapper-config-filepath', type=str, 
                        default='actions/wrapper/configs/imagenet_vit_wrapper_default.json', 
                        help='wrapper config file')
    parser.add_argument('--wrapper-model', type=str, 
                        default='sop', choices=['sop', 'fresh'],
                        help='sop or fresh')
    parser.add_argument('--exp-dir', type=str, 
                        default='exps/imgenet_wrapper', 
                        help='exp dir')
    parser.add_argument('--mask-dir', type=str, 
                        default=None, 
                        help='mask dir, if none, generate, else use')
    parser.add_argument('--sop-exp-dir', type=str, 
                        default='exps/imgenet_wrapper', 
                        help='sop exp dir')
    parser.add_argument('--output-filename', type=str, 
                        default='exps/baselines/baselines.json', 
                        help='exp dir')
    parser.add_argument('--model-type', type=str, 
                        default='image', choices=['image', 'text'],
                        help='image or text model')
    parser.add_argument('--model-type-spec', type=str, 
                        default='encoder', choices=['encoder', 'clip', 'gpt2'],
                        help='what specific model to use')
    parser.add_argument('--projection-layer-name', type=str, 
                        default=None,  #distilbert.embeddings
                        help='projection layer if specified, else train from scratch')
    
    # data
    parser.add_argument('--dataset', type=str, 
                        default='imagenet', choices=['imagenet', 'imagenet_m',
                                                     'multirc',
                                                     'movies', 'voc', 'cosmogrid'],
                        help='which dataset to use')
    parser.add_argument('--train-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    parser.add_argument('--val-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    
    # training
    parser.add_argument('--num-epochs', type=int, 
                        default=1, 
                        help='num epochs')
    parser.add_argument('--num-train-reps', type=int, 
                        default=1, 
                        help='number of times to train each head')
    parser.add_argument('--lr', type=float, 
                        default=1e-4, 
                        help='learning rate')
    parser.add_argument('--lr-scheduler-step-size', type=int, 
                        default=1, 
                        help='learning rate scheduler step size (by epoch)')
    parser.add_argument('--lr-scheduler-gamma', type=float, 
                        default=0.1, 
                        help='learning rate scheduler gamma')
    parser.add_argument('--warmup-epochs', type=float, 
                        default=-1, 
                        help='number of epochs to warmup')
    parser.add_argument('--warmup-steps', type=float, 
                        default=0, 
                        help='number of steps to warmup. Use this when warmup epochs is -1')
    parser.add_argument('--scheduler-type', type=str, 
                        default='inverse_sqrt_heads', choices=['cosine',
                                                               'constant',
                                                               'inverse_sqrt_heads'],
                        help='scheduler type')
    parser.add_argument('--criterion', type=str, 
                        default='ce', choices=['ce', 'bce', 'mse'],
                        help='which criterion to use, if multi-label then bce')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--mask-batch-size', type=int, 
                        default=16, 
                        help='mask batch size')
    parser.add_argument('--project-name', type=str, 
                        default='attn', 
                        help='wandb project name')
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    parser.add_argument('--shuffle-val', action='store_true', 
                        default=False, 
                        help='if true then shuffle val')
    parser.add_argument('--num-masks-sample-inference', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')

    # specify wrapper config. If not None, then use this instead of ones specified in config
    parser.add_argument('--attn-patch-size', type=int, 
                        default=None, 
                        help='attn patch size, does not have to match black box model patch size')
    parser.add_argument('--attn-stride-size', type=int, 
                        default=None, 
                        help='attn stride size, if smaller than patch size then there is overlap')
    parser.add_argument('--num-heads', type=int, 
                        default=None, 
                        help='hidden dim for the first attention layer')
    parser.add_argument('--num-masks-sample', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')
    parser.add_argument('--num-masks-max', type=int, 
                        default=None, 
                        help='number of maximum masks to retain.')
    parser.add_argument('--finetune-layers', type=str, 
                        default=None, 
                        help='Which layer to finetune, seperated by comma.')
    parser.add_argument('--ins-del-step-size', 
                        default=None,
                        type=int,
                        help='ins del step size for insertion deletion')
    
    parser.add_argument('--aggr-type', type=str, 
                        default='joint', choices=['joint', 'independent'],
                        help='usually we use joint for classification, but can also try independent')
    parser.add_argument('--proj-hid-size', type=int, 
                        default=None, 
                        help='If specified, use this instead of hidden size for projection')
    parser.add_argument('--mean-center-offset', 
                        default=0,
                        type=float,
                        help='offset to mean center before projection')
    parser.add_argument('--mean-center-offset2', 
                        default=None,
                        type=float,
                        help='offset to mean center after projection')
    parser.add_argument('--mean-center-scale', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering before projection')
    parser.add_argument('--mean-center-scale2', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering after projection')
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    os.makedirs(args.exp_dir, exist_ok=True)

    # Set the seed for reproducibility
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

     # Define your dataset and dataloader
    if args.model_type == 'image':
        if args.dataset == 'cosmogrid':
            processor = None
        else:
            processor = AutoImageProcessor.from_pretrained(args.blackbox_processor_name)
    else:
        processor = AutoTokenizer.from_pretrained(args.blackbox_processor_name)

    if args.dataset == 'cosmogrid':
        # merge blackbox config and wrapper config
        blackbox_config = PretrainedConfig()
        with open(args.wrapper_config_filepath) as input_file:
            wrapper_config = json.load(input_file)
        config = blackbox_config
        config.__dict__.update(wrapper_config)
        config.blackbox_model_name = args.blackbox_model_name
        config.blackbox_processor_name = args.blackbox_processor_name
    else:
        # merge blackbox config and wrapper config
        blackbox_config = AutoConfig.from_pretrained(args.blackbox_model_name)
        with open(args.wrapper_config_filepath) as input_file:
            wrapper_config = json.load(input_file)
        config = blackbox_config
        config.__dict__.update(wrapper_config)
        config.blackbox_model_name = args.blackbox_model_name
        config.blackbox_processor_name = args.blackbox_processor_name

    # allow specifying args to be different from in the json file
    specifiable_arg_list = ['attn_patch_size', 
                            'attn_stride_size', 
                            'num_heads',
                            'num_masks_sample',
                            'num_masks_max',
                            'finetune_layers',
                            'ins_del_step_size']
    for specifiable_arg in specifiable_arg_list:
        arg_value = getattr(args, specifiable_arg)
        if arg_value is not None:
            if specifiable_arg == 'finetune_layers':
                config.__dict__[specifiable_arg] = arg_value.split(',')
            else:
                config.__dict__[specifiable_arg] = arg_value

    if args.mask_dir is None:
        val_dataset = get_datasets(args.dataset, 
                                    processor, 
                                    debug=False,
                                    train_size=args.train_size,
                                    val_size=args.val_size,
                                    label2id=config.label2id,
                                    val_only=True,
                                    mode='test')
    else:
        val_dataset = get_masked_datasets(args.dataset, 
                                                    # processor, 
                                                    processor=None,
                                                    transform=lambda x: x,
                                                    mask_dir=args.mask_dir,
                                                    seg_mask_cut_off=args.num_masks_max,
                                                    debug=False,
                                                    train_size=args.train_size,
                                                    val_size=args.val_size,
                                                    val_only=True)

    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, 
                            shuffle=args.shuffle_val)

    if args.model_type == 'image':
        if args.dataset == 'cosmogrid':
            inner_model = CNNModel(config.num_labels)
            state_dict = torch.load(args.blackbox_model_name)
            inner_model.load_state_dict(state_dict=state_dict)
        else:
            inner_model = AutoModelForImageClassification.from_pretrained(args.blackbox_model_name)
    else:
        inner_model = AutoModelForSequenceClassification.from_pretrained(args.blackbox_model_name)
    
    projection_layer = get_chained_attr(inner_model, args.projection_layer_name)

    sop_model = SOP.from_pretrained(args.sop_exp_dir,
                                config=config, 
                                blackbox_model=inner_model,
                                model_type=args.model_type,
                                projection_layer=projection_layer,
                                aggr_type=args.aggr_type,
                                pooler=False if args.model_type_spec == 'encoder' and args.dataset != 'cosmogrid' else True,
                                return_tuple=True,
                                mean_center_scale=args.mean_center_scale,
                                mean_center_scale2=args.mean_center_scale2,
                                mean_center_offset=args.mean_center_offset,
                                mean_center_offset2=args.mean_center_offset2)

    if args.wrapper_model == 'sop':
        model = SOP.from_pretrained(args.exp_dir,
                                config=config, 
                                blackbox_model=inner_model,
                                model_type=args.model_type,
                                projection_layer=projection_layer,
                                aggr_type=args.aggr_type,
                                pooler=False if args.model_type_spec == 'encoder' and args.dataset != 'cosmogrid' else True,
                                return_tuple=True,
                                mean_center_scale=args.mean_center_scale,
                                mean_center_scale2=args.mean_center_scale2,
                                mean_center_offset=args.mean_center_offset,
                                mean_center_offset2=args.mean_center_offset2)
    else:
        model = FRESH.from_pretrained(args.exp_dir,
                                      config=config,
                      blackbox_model=inner_model,
                      model_type=args.model_type,
                      return_tuple=True,
                      postprocess_attn=lambda x: x.attentions[-1].mean(dim=1)[:,0],
                      postprocess_logits=lambda x: x.logits)

    # model = model.to(device)
    sop_model = sop_model.to(device)
    sop_model.eval()
    model = model.to(device)
    model.eval()

    eik = {
        "segmentation_fn": patch_segmenter,
        "top_labels": 5, 
        "hide_color": 0, 
        "num_samples": 1000
    }
    gimk = {
        "positive_only": False
    }

    explainers = {
        args.wrapper_model: model,
    }

    # nnz_evaluator = NNZ().to(device)
    net_in_size = config.ins_del_step_size if config.ins_del_step_size else 224
    ins_sem_evaluator = InsDelSem(inner_model, 
                            'ins', 
                            net_in_size, 
                            substrate_fn=torch.zeros_like,
                            postprocess=lambda x: x.logits,
                            task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)
    if args.model_type == 'image':
        klen = 11
        ksig = 5
        kern = InsDelSem.gkern(klen, ksig, config.num_channels).to(device)
        blur = lambda x: nn.functional.conv2d(x, kern, padding=klen//2)
        ins_sem_blur_evaluator = InsDelSem(inner_model, 
                                    'ins', 
                                    net_in_size, 
                                    substrate_fn=blur,
                                    postprocess=lambda x: x.logits,
                                    task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)
    else:
        ins_sem_blur_evaluator = None
    del_sem_evaluator = InsDelSem(inner_model, 
                        'del', 
                        net_in_size,
                        substrate_fn=torch.zeros_like,
                        postprocess=lambda x: x.logits,
                        task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)

    comp_sem_evaluator = CompSuffSem(inner_model, 
                            'comp', 
                            k_fraction=0.2,
                            postprocess=lambda x: x.logits).to(device)
    
    suff_sem_evaluator = CompSuffSem(inner_model, 
                            'suff', 
                            k_fraction=0.2,
                            postprocess=lambda x: x.logits).to(device)
    
    results = {}
    results_raw = {}
    for expl_name in explainers:
        results_raw[expl_name] = {
            # 'correct': 0,
            'ins_sem_scores': 0,
            'ins_sem_blur_scores': 0,
            'del_sem_scores': 0,
            'comp_sem_scores': 0,
            'suff_sem_scores': 0,
            'ins_sem_scores_max': 0,
            'ins_sem_blur_scores_max': 0,
            'del_sem_scores_max': 0,
            'total': 0,
        }
    
    
    print('Eval..')

    with torch.no_grad():
        progress_bar_eval = tqdm(range(len(val_loader)))
        for batch in val_loader:
            expl_name = args.wrapper_model

            if args.model_type == 'image':
                if args.mask_dir is None:
                    inputs, labels = batch
                    masks = None
                else:
                    inputs, labels, masks, masks_i = batch
                    masks = masks.to(device)
                if args.dataset != 'cosmogrid':
                    inputs, labels = inputs.to(device), labels.to(device)
                else:
                    inputs = inputs.to(device, dtype=torch.float)
                    labels = labels.to(device, dtype=torch.float)
                token_type_ids = None
                attention_mask = None
                kwargs = {}
            else:
                if not isinstance(batch['input_ids'], torch.Tensor):
                    inputs = torch.stack(batch['input_ids']).transpose(0, 1).to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = torch.stack(batch['token_type_ids']).transpose(0, 1).to(device)
                    else:
                        token_type_ids = None
                    attention_mask = torch.stack(batch['attention_mask']).transpose(0, 1).to(device)
                else:
                    inputs = batch['input_ids'].to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = batch['token_type_ids'].to(device)
                    else:
                        token_type_ids = None
                    attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                kwargs = {'token_type_ids': token_type_ids,
                      'attention_mask': attention_mask}

            bsz = inputs.size(0)
            

            sop_outputs = sop_model(inputs, 
                                    token_type_ids=token_type_ids,
                                    attention_mask=attention_mask,
                                    epoch=model.config.num_heads,
                                    mask_batch_size=args.mask_batch_size,
                                    label=torch.abs(labels) - 1 if args.criterion == 'bce' else None)
            
            # _, predicted_sop = torch.max(sop_outputs.logits, -1)

            if args.wrapper_model == 'sop':
                outputs = sop_outputs
            else:  # 'fresh'
                outputs = model(inputs,
                                token_type_ids=token_type_ids,
                                attention_mask=attention_mask)

            attributions = outputs.attributions
            logits = outputs.logits

            # print('sop_outputs.flat_masks', sop_outputs.flat_masks)
            # import pdb
            # pdb.set_trace()

            # _, predicted = torch.max(outputs.logits, -1)

            expln = outputs
            if args.model_type == 'image':
                ins_sem_score = ins_sem_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
            # if args.model_type == 'image':
                # ins_sem_blur_score = ins_sem_blur_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
                del_sem_score = del_sem_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
                if args.dataset == 'cosmogrid':
                    ins_sem_score = ins_sem_score.mean(-1)
                    # ins_sem_blur_score = ins_sem_blur_score.mean(-1)
                    del_sem_score = del_sem_score.mean(-1)
                results_raw[expl_name]['total'] += labels.size(0)
                results_raw[expl_name]['ins_sem_scores'] += ins_sem_score.sum().item()
                # if args.model_type == 'image':
                    # results_raw[expl_name]['ins_sem_blur_scores'] += ins_sem_blur_score.sum().item()
                results_raw[expl_name]['del_sem_scores'] += del_sem_score.sum().item()
            else: # text
                comp_sem_score = comp_sem_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)
                suff_sem_score = suff_sem_evaluator(inputs, expln.attributions, sop_outputs.flat_masks, kwargs)

            
                results_raw[expl_name]['comp_sem_scores'] += comp_sem_score.sum().item()
                results_raw[expl_name]['suff_sem_scores'] += suff_sem_score.sum().item()

            progress_bar_eval.update(1)
        

    for expl_name in explainers:
        print(expl_name)
        ins_sem_avg = 100 * results_raw[expl_name]['ins_sem_scores'] / \
            results_raw[expl_name]['total']
        ins_sem_blur_avg = 100 * results_raw[expl_name]['ins_sem_blur_scores'] / \
            results_raw[expl_name]['total']
        del_sem_avg = 100 * results_raw[expl_name]['del_sem_scores'] / \
            results_raw[expl_name]['total']
        comp_sem_avg = 100 * results_raw[expl_name]['comp_sem_scores'] / \
            results_raw[expl_name]['total']
        suff_sem_avg = 100 * results_raw[expl_name]['suff_sem_scores'] / \
            results_raw[expl_name]['total']
        print(f'Explainer {expl_name}')
        print(f'Insertion Sem Score Avg (high) {ins_sem_avg:.4f}%')
        print(f'Insertion Sem Blur Score Avg (high) {ins_sem_blur_avg:.4f}%')
        print(f'Deletion Sem Score Avg (low) {del_sem_avg:.4f}%')
        print(f'Comprehensiveness Sem Score Avg (high) {comp_sem_avg:.4f}%')
        print(f'Sufficiency Sem Score Avg (low) {suff_sem_avg:.4f}%')

        results[expl_name] = {'dataset': args.dataset,
                              'insertion_semantic': ins_sem_avg,
                                'insertion_blur_semantic': ins_sem_blur_avg,
                                'deletion_semantic': del_sem_avg,
                                'comprehensiveness_semantic': comp_sem_avg,
                                'sufficiency_semantic': suff_sem_avg}
            
    # os.makedirs(os.path.dirname(args.output_filename), exist_ok=True)
    output_filename = os.path.join(args.exp_dir, f'eval_results_semantics_{args.val_size}.json')
    with open(output_filename, 'wt') as output_file:
        json.dump(results,
                    output_file,
                    indent=4)


if __name__ == '__main__':
    main()