import json
import torch
# from transformers import ViltProcessor, ViltForQuestionAnswering, ViltConfig
import os
import argparse
# import evaluate
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import ViTImageProcessor, AutoModel, PretrainedConfig, AutoModelForImageClassification
from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))
import numpy as np
import random
from PIL import Image
import pickle
from torch import nn
from copy import deepcopy
import matplotlib.pyplot as plt

from models.wrapper import SumOfPartsWrapper
from models.attention import AttentionClassifierTwoHeadsFrontExt, AttentionClassifierTwoHeadsDropoutSelect
from models.cnn import CNNModel

from custom_datasets import get_datasets, get_masked_datasets


def parse_args():
    parser = argparse.ArgumentParser()

     # models and configs
    parser.add_argument('--blackbox-model-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box model name')
    parser.add_argument('--blackbox-processor-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box processor name')
    parser.add_argument('--wrapper-config-filepath', type=str, 
                        default='actions/wrapper/configs/imagenet_vit_wrapper_default.json', 
                        help='wrapper config file')
    parser.add_argument('--exp-dir', type=str, 
                        default='exps/imgenet_wrapper', 
                        help='exp dir')
    parser.add_argument('--mask-dir', type=str, 
                        default=None, 
                        help='mask dir, if none, generate, else use')
    
    # data
    parser.add_argument('--dataset', type=str, 
                        default='cosmogrid', choices=['cosmogrid'],
                        help='which dataset to use')
    parser.add_argument('--train-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    parser.add_argument('--val-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    
    # training
    parser.add_argument('--num-epochs', type=int, 
                        default=1, 
                        help='num epochs')
    parser.add_argument('--num-train-reps', type=int, 
                        default=1, 
                        help='number of times to train each head')
    parser.add_argument('--lr', type=float, 
                        default=1e-4, 
                        help='learning rate')
    parser.add_argument('--lr-scheduler-step-size', type=int, 
                        default=1, 
                        help='learning rate scheduler step size (by epoch)')
    parser.add_argument('--lr-scheduler-gamma', type=float, 
                        default=0.1, 
                        help='learning rate scheduler gamma')
    parser.add_argument('--weight-decay', type=float, 
                        default=0.01, 
                        help='weight decay')
    parser.add_argument('--warmup-epochs', type=float, 
                        default=-1, 
                        help='number of epochs to warmup')
    parser.add_argument('--warmup-steps', type=float, 
                        default=0, 
                        help='number of steps to warmup. Use this when warmup epochs is -1')
    parser.add_argument('--scheduler-type', type=str, 
                        default='inverse_sqrt_heads', choices=['cosine',
                                                               'constant',
                                                               'inverse_sqrt_heads'],
                        help='scheduler type')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--train-loss-track-interval', type=int, 
                        default=100, 
                        help='interval to report average of train loss')
    parser.add_argument('--eval-interval', type=int, 
                        default=1000, 
                        help='interval to report average of train loss')
    parser.add_argument('--mask-batch-size', type=int, 
                        default=16, 
                        help='mask batch size')
    parser.add_argument('--project-name', type=str, 
                        default='attn', 
                        help='wandb project name')
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    parser.add_argument('--track', action='store_true', 
                        default=False, 
                        help='track')

    # specify wrapper config. If not None, then use this instead of ones specified in config
    parser.add_argument('--attn-patch-size', type=int, 
                        default=None, 
                        help='attn patch size, does not have to match black box model patch size')
    parser.add_argument('--attn-stride-size', type=int, 
                        default=None, 
                        help='attn stride size, if smaller than patch size then there is overlap')
    parser.add_argument('--num-heads', type=int, 
                        default=None, 
                        help='hidden dim for the first attention layer')
    parser.add_argument('--num-masks-sample', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')
    parser.add_argument('--num-masks-max', type=int, 
                        default=None, 
                        help='number of maximum masks to retain.')
    parser.add_argument('--finetune-layers', type=str, 
                        default=None, 
                        help='Which layer to finetune, seperated by comma.')
    
    # to add
    # output_attn_hidden_dim
    parser.add_argument('--aggr-type', type=str, 
                        default='independent', choices=['joint', 'independent'],
                        help='usually we use independent for regression, but can also try joint')
    parser.add_argument('--proj-hid-size', type=int, 
                        default=None, 
                        help='If specified, use this instead of hidden size for projection')
    parser.add_argument('--mean-center-offset', 
                        default=0,
                        type=float,
                        help='offset to mean center before projection')
    parser.add_argument('--mean-center-offset2', 
                        default=None,
                        type=float,
                        help='offset to mean center after projection')
    parser.add_argument('--mean-center-scale', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering before projection')
    parser.add_argument('--mean-center-scale2', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering after projection')
    
    parser.add_argument('--scatter-only', action='store_true', 
                        default=False, 
                        help='Only make scatter plot, without storing')
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    os.makedirs(args.exp_dir, exist_ok=True)

    # Set the seed for reproducibility
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # merge blackbox config and wrapper config
    blackbox_config = PretrainedConfig()
    with open(args.wrapper_config_filepath) as input_file:
        wrapper_config = json.load(input_file)
    config = blackbox_config
    config.__dict__.update(wrapper_config)
    config.blackbox_model_name = args.blackbox_model_name
    config.blackbox_processor_name = args.blackbox_processor_name

    # allow specifying args to be different from in the json file
    specifiable_arg_list = ['attn_patch_size', 
                            'attn_stride_size', 
                            'num_heads',
                            'num_masks_sample',
                            'num_masks_max',
                            'finetune_layers']
    for specifiable_arg in specifiable_arg_list:
        arg_value = getattr(args, specifiable_arg)
        if arg_value is not None:
            if specifiable_arg == 'finetune_layers':
                config.__dict__[specifiable_arg] = arg_value.split(',')
            else:
                config.__dict__[specifiable_arg] = arg_value

    if args.mask_dir is None:
        val_dataset = get_datasets(args.dataset, 
                                    transform=lambda x: x,
                                    debug=False,
                                    train_size=args.train_size,
                                    val_size=args.val_size,
                                    val_only=True)
    else:
        val_dataset = get_masked_datasets(args.dataset, 
                                                    # processor, 
                                                    processor=None,
                                                    transform=lambda x: x,
                                                    mask_dir=args.mask_dir,
                                                    seg_mask_cut_off=args.num_masks_max,
                                                    debug=False,
                                                    train_size=args.train_size,
                                                    val_size=args.val_size,
                                                    val_only=True)
    # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
    # test_loader = DataLoader(test_dataset, batch_size=args.batch_size)

    # Define your model and optimizer
    print('args.exp_dir', args.exp_dir)
    # config = AutoConfig.from_pretrained(args.blackbox_model_name)
    # inner_model = AutoModel.from_pretrained(args.blackbox_model_name)
    # inner_model = MaskedViTModel.from_pretrained(args.blackbox_model_name,
    #                                              num_mask_heads=args.num_heads1,
    #                                              mask_batch_size=args.mask_batch_size,
    #                                              mask_cut_off=args.input_mask_cut_off)
    inner_model = CNNModel(config.num_labels)
    state_dict = torch.load(args.blackbox_model_name)
    inner_model.load_state_dict(state_dict=state_dict)
    original_model = deepcopy(inner_model)

    model = SumOfPartsWrapper.from_pretrained(args.exp_dir,
                              config=config, 
                              blackbox_model=inner_model,
                              model_type='image',
                              projection_layer=None,
                              aggr_type=args.aggr_type,
                              proj_hid_size=args.proj_hid_size,
                              mean_center_scale=args.mean_center_scale,
                              mean_center_scale2=args.mean_center_scale2,
                              mean_center_offset=args.mean_center_offset,
                              mean_center_offset2=args.mean_center_offset2)

    model = model.to(device)
    original_model = original_model.to(device)

    criterion = nn.MSELoss(reduction='none')

    all_results = []

    progress_bar = tqdm(range(len(val_loader)))
    output_dirname = os.path.join(args.exp_dir, 'val_results')
    os.makedirs(output_dirname, exist_ok=True)
    model.eval()
    preds_wrapped = []
    preds_original = []
    gold_labels = []
    with torch.no_grad():
        idx = 0
        total_loss_wrapped = 0
        total_loss_original = 0
        total_count = 0
        for batch in val_loader:
            output_filename = os.path.join(output_dirname, f'{idx}.pkl')
            if os.path.exists(output_filename):
                progress_bar.update(1)
                idx += 1
                continue
            # if idx < 3538:
            #     progress_bar.update(1)
            #     idx += 1
            #     continue
                
            # import pdb
            # pdb.set_trace()
            # if idx >= 3538:
            #     import pdb
            #     pdb.set_trace()
            if args.mask_dir is None:
                inputs, labels = batch
                inputs = inputs.to(device, dtype=torch.float)
                labels = labels.to(device, dtype=torch.float)
                masks = None
                masks_i = None
            else:
                inputs, labels, masks, masks_i = batch
                inputs = inputs.to(device, dtype=torch.float)
                labels = labels.to(device, dtype=torch.float)
                masks = masks.to(device)

            images = inputs.cpu().numpy()
            
            # images = [Image.fromarray(inputs_numpy[i]) for i in range(inputs_numpy.shape[0])]
            
            outputs_avg, outputs, attn_weights1, attn_weights2, pooler_outputs = model(inputs, 
                                                                       masks=masks,
                                    epoch=model.config.num_heads,
                                    mask_batch_size=args.mask_batch_size,
                                    return_pooler=True)
            original_outputs = original_model(inputs).logits

            preds_wrapped.append(outputs_avg)
            preds_original.append(original_outputs)
            gold_labels.append(labels)

            bsz, num_masks, num_labels = outputs.shape

            loss_avg = criterion(outputs_avg, labels)
            total_loss_wrapped += loss_avg.sum(0)
            loss_original = criterion(original_outputs, labels)
            total_loss_original += loss_original.sum(0)
            total_count += labels.size(0)
            loss = criterion(outputs.reshape(-1, model.config.num_labels), 
                             labels.unsqueeze(1).expand(bsz, 
                                                        num_masks, 
                                                        num_labels).reshape(-1, num_labels))
            loss = loss.reshape(outputs.shape)

            if not args.scatter_only:
                for i in range(len(images)):
                    
                    masks_used = attn_weights1[i][attn_weights2[i].sum(-1).bool()]  # get masks that are used for any class
                    mask_weights = attn_weights2[i][attn_weights2[i].sum(-1).bool()]
                    output_used = outputs[i][attn_weights2[i].sum(-1).bool()]
                    pooler_used = pooler_outputs[i][attn_weights2[i].sum(-1).bool()]
                    loss_used = loss[i][attn_weights2[i].sum(-1).bool()]
                    
                    # import pdb
                    # pdb.set_trace()
                    unique_masks_used, reverse_indices, counts = torch.unique(masks_used, 
                                                                            dim=0, 
                                                                            return_inverse=True, 
                                                                            return_counts=True)
                    indices = []
                    reverse_indices = reverse_indices.cpu().numpy().tolist()

                    for j in range(len(counts)):
                        indices.append(reverse_indices.index(j))
                    # import pdb
                    # pdb.set_trace()
                    unique_mask_weights = mask_weights[indices] * counts.view(-1, 1)
                    unique_outputs = output_used[indices]
                    unique_preds = loss_used[indices].tolist()
                    unique_pooler = pooler_used[indices]
                    
                    entry = {'image': images[i],
                            'outputs_avg': outputs_avg[i],
                            'outputs': unique_outputs, # outputs[i]
                            'pooler': unique_pooler,
                            'outputs_original': original_outputs[i],
                            'masks': attn_weights1[i].cpu().numpy(),
                            'masks_all': masks_i[i].cpu().numpy() \
                                if masks_i is not None \
                                    else None,
                            'masks_used': unique_masks_used,  # masks_used,
                            'mask_weights': unique_mask_weights,  # mask_weights,
                            'pred': loss_avg[i].tolist(),
                            'preds': unique_preds, # predicted_raw[i].tolist(),
                            'label': labels[i].tolist(),
                            'counts': counts.cpu().numpy(),
                            'num_labels': model.config.num_labels}

                    # print('attn_weights1[i]', attn_weights1[i].shape)
                    # print('attn_weights2[i]', attn_weights2[i].shape)
                    
                    all_results.append(entry)
                    output_filename = os.path.join(output_dirname, f'{idx}.pkl')
                    with open(output_filename, 'wb') as output_file:
                        pickle.dump(entry, output_file)
            
            progress_bar.update(1)
            idx += 1

            del images

        preds_wrapped = torch.cat(preds_wrapped).cpu().numpy()
        preds_original = torch.cat(preds_original).cpu().numpy()
        gold_labels = torch.cat(gold_labels).cpu().numpy()

        plot_y = gold_labels
        # predictions = preds
        upp_lims = np.nanmax(plot_y, axis=0)
        low_lims = np.nanmin(plot_y, axis=0)

        # Visualize predictions
        fig, axes = plt.subplots(nrows=1, ncols=len(low_lims), figsize=(20, 10))

        for ind, (low_lim, upp_lim) in enumerate(zip(low_lims, upp_lims)):
            p_w = np.poly1d(np.polyfit(plot_y[:, ind], preds_wrapped[:, ind], 1))
            p_o = np.poly1d(np.polyfit(plot_y[:, ind], preds_original[:, ind], 1))
            
            axes[ind].scatter(plot_y[:, ind], preds_wrapped[:, ind], color="blue", label='Wrapped')
            axes[ind].scatter(plot_y[:, ind], preds_original[:, ind], color="orange", label='Original')
            axes[ind].plot([low_lim, upp_lim], [low_lim, upp_lim], color="black")
            axes[ind].plot([low_lim, upp_lim], [p_w(low_lim), p_w(upp_lim)], color="black", ls=":", label='Wrapped')
            axes[ind].plot([low_lim, upp_lim], [p_o(low_lim), p_o(upp_lim)], color="black", ls="-.", label='Original')
            axes[ind].set_xlim([low_lim, upp_lim])
            axes[ind].set_ylim([low_lim, upp_lim])
            axes[ind].set_xlabel('Gold')
            axes[ind].set_ylabel('Pred')
            axes[ind].legend()
            axes[ind].set_aspect('equal', adjustable='box')
            
            original_loss = round((total_loss_original / total_count)[ind].cpu().item(), 4)
            wrapped_loss = round((total_loss_wrapped / total_count)[ind].cpu().item(), 4)
            
            axes[ind].set_title(f'Output {ind}, \n' + 
                                f'Original loss {original_loss},\n' +
                                f'Wrapped loss {wrapped_loss}')
            
        # plt.tight_layout()
        plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9)
        plt.savefig(os.path.join(args.exp_dir, 'combined_scatter.png'))
        plt.close()

        # Visualize predictions
        fig, axes = plt.subplots(nrows=1, ncols=len(low_lims), figsize=(20, 10))

        for ind, (low_lim, upp_lim) in enumerate(zip(low_lims, upp_lims)):
            p_w = np.poly1d(np.polyfit(plot_y[:, ind], preds_wrapped[:, ind], 1))
            p_o = np.poly1d(np.polyfit(plot_y[:, ind], preds_original[:, ind], 1))
            
            # (Wrapped - Gold) / Gold
            diff_wrapped = (preds_wrapped[:, ind] - plot_y[:, ind]) / plot_y[:, ind]
            # (Original - Gold) / Gold
            diff_original = (preds_original[:, ind] - plot_y[:, ind]) / plot_y[:, ind]
            axes[ind].scatter(plot_y[:, ind], diff_wrapped, color="green", label='Wrapped')
            axes[ind].scatter(plot_y[:, ind], diff_original, color="red", label='Original')
            axes[ind].scatter(plot_y[:, ind], preds_wrapped[:, ind], color="blue", label='Wrapped')
            axes[ind].scatter(plot_y[:, ind], preds_original[:, ind], color="orange", label='Original')
            axes[ind].plot([low_lim, upp_lim], [low_lim, upp_lim], color="black")
            axes[ind].plot([low_lim, upp_lim], [p_w(low_lim), p_w(upp_lim)], color="black", ls=":", label='Wrapped')
            axes[ind].plot([low_lim, upp_lim], [p_o(low_lim), p_o(upp_lim)], color="black", ls="-.", label='Original')
            # axes[ind].set_xlim([low_lim, upp_lim])
            # axes[ind].set_ylim([low_lim, upp_lim])
            axes[ind].set_xlabel('Gold')
            axes[ind].set_ylabel('Pred')
            axes[ind].legend()
            axes[ind].set_aspect('equal', adjustable='box')
            
            original_loss = round((total_loss_original / total_count)[ind].cpu().item(), 4)
            wrapped_loss = round((total_loss_wrapped / total_count)[ind].cpu().item(), 4)
            
            axes[ind].set_title(f'Output {ind}, \n' + 
                                f'Original loss {original_loss},\n' +
                                f'Wrapped loss {wrapped_loss}')
            
        # plt.tight_layout()
        plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9)
        plt.savefig(os.path.join(args.exp_dir, 'combined_scatter_diff.png'))
        plt.close()

        # Visualize predictions
        fig, axes = plt.subplots(nrows=1, ncols=len(low_lims), figsize=(20, 10))


        for ind, (low_lim, upp_lim) in enumerate(zip(low_lims, upp_lims)):
            p_w = np.poly1d(np.polyfit(plot_y[:, ind], preds_wrapped[:, ind], 1))
            p_o = np.poly1d(np.polyfit(plot_y[:, ind], preds_original[:, ind], 1))
            
            # (Wrapped - Gold) / Gold
            diff_wrapped = (preds_wrapped[:, ind] - plot_y[:, ind]) / plot_y[:, ind]
            # (Original - Gold) / Gold
            diff_original = (preds_original[:, ind] - plot_y[:, ind]) / plot_y[:, ind]
            axes[ind].scatter(plot_y[:, ind], diff_wrapped, color="blue", label='Wrapped')
            axes[ind].scatter(plot_y[:, ind], diff_original, color="orange", label='Original')
            # axes[ind].plot([low_lim, upp_lim], [low_lim, upp_lim], color="black")
            # axes[ind].plot([low_lim, upp_lim], [p_w(low_lim), p_w(upp_lim)], color="black", ls=":", label='Wrapped')
            # axes[ind].plot([low_lim, upp_lim], [p_o(low_lim), p_o(upp_lim)], color="black", ls="-.", label='Original')
            # axes[ind].set_xlim([low_lim, upp_lim])
            # axes[ind].set_ylim([low_lim, upp_lim])
            axes[ind].set_xlabel('Gold')
            axes[ind].set_ylabel('Pred')
            axes[ind].legend()
            axes[ind].set_aspect('equal', adjustable='box')
            
            original_loss = round((total_loss_original / total_count)[ind].cpu().item(), 4)
            wrapped_loss = round((total_loss_wrapped / total_count)[ind].cpu().item(), 4)
            
            axes[ind].set_title(f'Output {ind}, \n' + 
                                f'Original loss {original_loss},\n' +
                                f'Wrapped loss {wrapped_loss}')
            
        # plt.tight_layout()
        plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9)
        plt.savefig(os.path.join(args.exp_dir, 'combined_scatter_diff_only.png'))
        plt.close()

        print('mean-center-offset', args.mean_center_offset)
        print('mean-center-scale', args.mean_center_scale)
        print('mean-center-offset2', args.mean_center_offset2)
        print('mean-center-scale2', args.mean_center_scale2)
        print('original_loss', original_loss)
        print('wrapped_loss', wrapped_loss)

        
if __name__ == '__main__':
    main()