import json
import torch
# from transformers import ViltProcessor, ViltForQuestionAnswering, ViltConfig
import os
import argparse
# import evaluate
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))
import numpy as np
import random
from PIL import Image
import pickle
from torch import nn

from models.wrapper import SumOfPartsWrapper
from models.attention import AttentionClassifierTwoHeadsFrontExt, AttentionClassifierTwoHeadsDropoutSelect
from models.cnn import CNNModel

from custom_datasets import get_datasets, get_masked_datasets


def get_chained_attr(obj, attr_path):
    if attr_path is None:
        return None
    # Split the attribute path into individual attribute names
    attributes = attr_path.split('.')

    # Access the attribute dynamically
    desired_attribute = obj
    for attr in attributes:
        desired_attribute = getattr(desired_attribute, attr)
    return desired_attribute


def parse_args():
    parser = argparse.ArgumentParser()

     # models and configs
    parser.add_argument('--blackbox-model-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box model name')
    parser.add_argument('--blackbox-processor-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box processor name')
    parser.add_argument('--wrapper-config-filepath', type=str, 
                        default='actions/wrapper/configs/imagenet_vit_wrapper_default.json', 
                        help='wrapper config file')
    parser.add_argument('--exp-dir', type=str, 
                        default='exps/imgenet_wrapper', 
                        help='exp dir')
    parser.add_argument('--model-type', type=str, 
                        default='image', choices=['image', 'text'],
                        help='image or text model')
    parser.add_argument('--projection-layer-name', type=str, 
                        default=None,  #distilbert.embeddings
                        help='projection layer if specified, else train from scratch')
    
    # data
    parser.add_argument('--dataset', type=str, 
                        default='imagenet', choices=['imagenet', 'imagenet_m'],
                        help='which dataset to use')
    parser.add_argument('--train-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    parser.add_argument('--val-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    
    # training
    parser.add_argument('--num-epochs', type=int, 
                        default=1, 
                        help='num epochs')
    parser.add_argument('--num-train-reps', type=int, 
                        default=1, 
                        help='number of times to train each head')
    parser.add_argument('--lr', type=float, 
                        default=1e-4, 
                        help='learning rate')
    parser.add_argument('--lr-scheduler-step-size', type=int, 
                        default=1, 
                        help='learning rate scheduler step size (by epoch)')
    parser.add_argument('--lr-scheduler-gamma', type=float, 
                        default=0.1, 
                        help='learning rate scheduler gamma')
    parser.add_argument('--weight-decay', type=float, 
                        default=0.01, 
                        help='weight decay')
    parser.add_argument('--warmup-epochs', type=float, 
                        default=-1, 
                        help='number of epochs to warmup')
    parser.add_argument('--warmup-steps', type=float, 
                        default=0, 
                        help='number of steps to warmup. Use this when warmup epochs is -1')
    parser.add_argument('--scheduler-type', type=str, 
                        default='inverse_sqrt_heads', choices=['cosine',
                                                               'constant',
                                                               'inverse_sqrt_heads'],
                        help='scheduler type')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--train-loss-track-interval', type=int, 
                        default=100, 
                        help='interval to report average of train loss')
    parser.add_argument('--eval-interval', type=int, 
                        default=1000, 
                        help='interval to report average of train loss')
    parser.add_argument('--mask-batch-size', type=int, 
                        default=16, 
                        help='mask batch size')
    parser.add_argument('--project-name', type=str, 
                        default='attn', 
                        help='wandb project name')
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    parser.add_argument('--track', action='store_true', 
                        default=False, 
                        help='track')

    # specify wrapper config. If not None, then use this instead of ones specified in config
    parser.add_argument('--attn-patch-size', type=int, 
                        default=None, 
                        help='attn patch size, does not have to match black box model patch size')
    parser.add_argument('--attn-stride-size', type=int, 
                        default=None, 
                        help='attn stride size, if smaller than patch size then there is overlap')
    parser.add_argument('--num-heads', type=int, 
                        default=None, 
                        help='hidden dim for the first attention layer')
    parser.add_argument('--num-masks-sample', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')
    parser.add_argument('--num-masks-max', type=int, 
                        default=None, 
                        help='number of maximum masks to retain.')
    parser.add_argument('--finetune-layers', type=str, 
                        default=None, 
                        help='Which layer to finetune, seperated by comma.')
    
    # to add
    # output_attn_hidden_dim
    parser.add_argument('--aggr-type', type=str, 
                        default='joint', choices=['joint', 'independent'],
                        help='usually we use independent for regression, but can also try joint')
    parser.add_argument('--proj-hid-size', type=int, 
                        default=None, 
                        help='If specified, use this instead of hidden size for projection')
    parser.add_argument('--mean-center-scale', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering')
    
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    os.makedirs(args.exp_dir, exist_ok=True)

    # Set the seed for reproducibility
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define your dataset and dataloader
    if args.model_type == 'image':
        processor = AutoImageProcessor.from_pretrained(args.blackbox_processor_name)
    else:
        processor = AutoTokenizer.from_pretrained(args.blackbox_processor_name)

    # merge blackbox config and wrapper config
    blackbox_config = AutoConfig.from_pretrained(args.blackbox_model_name)
    with open(args.wrapper_config_filepath) as input_file:
        wrapper_config = json.load(input_file)
    config = blackbox_config
    for k, v in wrapper_config.items():
        setattr(config, k, v)
    config.blackbox_model_name = args.blackbox_model_name
    config.blackbox_processor_name = args.blackbox_processor_name

    # allow specifying args to be different from in the json file
    specifiable_arg_list = ['attn_patch_size', 
                            'attn_stride_size', 
                            'num_heads',
                            'num_masks_sample',
                            'num_masks_max',
                            'finetune_layers']
    for specifiable_arg in specifiable_arg_list:
        arg_value = getattr(args, specifiable_arg)
        if arg_value is not None:
            if specifiable_arg == 'finetune_layers':
                config.__dict__[specifiable_arg] = arg_value.split(',')
            else:
                config.__dict__[specifiable_arg] = arg_value

    val_dataset = get_datasets(args.dataset, 
                                transform=None,
                                debug=False,
                                train_size=args.train_size,
                                val_size=args.val_size,
                                label2id=config.label2id,
                                val_only=True)
    # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
    # test_loader = DataLoader(test_dataset, batch_size=args.batch_size)

    # Define your model and optimizer
    print('args.exp_dir', args.exp_dir)

    if args.model_type == 'image':
        inner_model = AutoModelForImageClassification.from_pretrained(args.blackbox_model_name)
    else:
        inner_model = AutoModelForSequenceClassification.from_pretrained(args.blackbox_model_name)
    
    projection_layer = get_chained_attr(inner_model, args.projection_layer_name)

    model = SumOfPartsWrapper.from_pretrained(args.exp_dir,
                              config=config, 
                              blackbox_model=inner_model,
                              model_type=args.model_type,
                              projection_layer=projection_layer,
                              aggr_type=args.aggr_type,
                              pooler=False)

    model = model.to(device)

    # criterion = nn.MSELoss(reduction='none')

    all_results = []

    progress_bar = tqdm(range(len(val_loader)))
    output_dirname = os.path.join(args.exp_dir, 'val_results')
    os.makedirs(output_dirname, exist_ok=True)
    model.eval()
    with torch.no_grad():
        idx = 0
        for batch in val_loader:
            # if idx < 3538:
            #     progress_bar.update(1)
            #     idx += 1
            #     continue
                
            # import pdb
            # pdb.set_trace()
            # if idx >= 3538:
            #     import pdb
            #     pdb.set_trace()
            inputs, labels = batch
            inputs = inputs.to(device) # , dtype=torch.float)
            labels = labels.to(device) #, dtype=torch.float)

            inputs_numpy = inputs.cpu().numpy()
            images = [Image.fromarray(inputs_numpy[i]) for i in range(inputs_numpy.shape[0])]
            inputs = processor(images, return_tensors='pt')
            pixel_values = inputs['pixel_values'].to(device)
            if len(pixel_values.shape) == 5:
                pixel_values = pixel_values.squeeze(0)
            
            # images = [Image.fromarray(inputs_numpy[i]) for i in range(inputs_numpy.shape[0])]
            
            outputs_avg, outputs, attn_weights1, attn_weights2 = model(pixel_values, 
                                    epoch=model.config.num_heads,
                                    mask_batch_size=args.mask_batch_size)
            
            _, predicted = torch.max(outputs_avg.data, -1)
            _, predicted_raw = torch.max(outputs.data, -1)

            bsz, num_masks, num_labels = outputs.shape

            for i in range(len(images)):
                
                masks_used = attn_weights1[i][attn_weights2[i].sum(-1).bool()]  # get masks that are used for any class
                mask_weights = attn_weights2[i][attn_weights2[i].sum(-1).bool()]
                output_used = outputs[i][attn_weights2[i].sum(-1).bool()]
                predicted_raw_used = predicted_raw[i][attn_weights2[i].sum(-1).bool()]
                
                # import pdb
                # pdb.set_trace()
                unique_masks_used, reverse_indices, counts = torch.unique(masks_used, 
                                                                          dim=0, 
                                                                          return_inverse=True, 
                                                                          return_counts=True)
                indices = []
                reverse_indices = reverse_indices.cpu().numpy().tolist()

                for j in range(len(counts)):
                    indices.append(reverse_indices.index(j))
                # import pdb
                # pdb.set_trace()
                unique_mask_weights = mask_weights[indices] * counts.view(-1, 1)
                unique_outputs = output_used[indices]
                unique_preds = predicted_raw_used[indices].tolist()
                
                entry = {'image': images[i],
                        'outputs_avg': outputs_avg[i],
                        'outputs': unique_outputs, # outputs[i]
                        'masks': attn_weights1[i].cpu().numpy(),
                        'masks_all': attn_weights1[i].cpu().numpy(),
                        'masks_used': unique_masks_used,  # masks_used,
                        'mask_weights': unique_mask_weights,  # mask_weights,
                        'pred': predicted[i].tolist(),
                        'preds': unique_preds, # predicted_raw[i].tolist(),
                        'label': labels[i].tolist(),
                        'counts': counts.cpu().numpy(),
                        'id2label': model.config.id2label}

                # print('attn_weights1[i]', attn_weights1[i].shape)
                # print('attn_weights2[i]', attn_weights2[i].shape)
                
                all_results.append(entry)
                output_filename = os.path.join(output_dirname, f'{idx}.pkl')
                with open(output_filename, 'wb') as output_file:
                    pickle.dump(entry, output_file)
            
            progress_bar.update(1)
            idx += 1

            del images

        
if __name__ == '__main__':
    main()