from pathlib import Path
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from numpy import *
import argparse
from PIL import Image
import imageio
import os
from tqdm import tqdm
import sys
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)


from data.Imagenet import Imagenet_Segmentation

from baselines.ViT.ViT_explanation_generator import Baselines, LRP
from baselines.ViT.ViT_new import vit_base_patch16_224
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_LRP import deit_base_patch16_224 as vit_deit
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP

import matplotlib.pyplot as plt

import torch.nn.functional as F

plt.switch_backend('agg')

import lovely_tensors as lt
lt.monkey_patch()

# hyperparameters
num_workers = 0
batch_size = 1


# Args
parser = argparse.ArgumentParser(description='Training multi-class classifier')
parser.add_argument('--arc', type=str, default='vgg', metavar='N',
                    help='Model architecture')
parser.add_argument('--train_dataset', type=str, default='imagenet', metavar='N',
                    help='Testing Dataset')
parser.add_argument('--method', type=str,
                    default='grad_rollout',
                    choices=[ 'rollout', 'lrp','transformer_attribution', 'full_lrp', 'lrp_last_layer',
                              'attn_last_layer', 'attn_gradcam', 
                              'attn_last_layer2',
                              'attn_last_layer3',
                              'attn_last_layer4',
                              'attn_last_layer5',
                              'attn_last_layer6',
                              'attn_last_layer7',
                              'attn_last_layer8',
                              'predmap',
                              'predmap2',
                              'predmap3',
                              'predmap4',
                              'predmap5',
                              'predmap6',
                              'predmap7',
                              'predmap8',
                              'predmap9',
                              'predmap10',
                              'predmap11',
                              'predmap12_attn_prev',
                              'predmap13',
                              'predmap14',
                              'predmap15',
                              'predmap_temperature1_1',
                              'predmap_temperature0_9',
                              'predmap_temperature1_2',
                              'predmap_temperature1_3',
                              'predmap_temperature1_35',
                              'predmap_temperature1_4',
                              'predmap_temperature1_45',
                              'predmap_temperature1_5',
                              'predmap_temperature1_6',
                              ],
                    help='')
parser.add_argument('--thr', type=float, default=0.,
                    help='threshold')
parser.add_argument('--K', type=int, default=1,
                    help='new - top K results')
parser.add_argument('--save-img', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-ia', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-fx', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-fgx', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-m', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-reg', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--is-ablation', type=bool,
                    default=False,
                    help='')
parser.add_argument('--imagenet-seg-path', type=str, required=True)
parser.add_argument('--use-median', action='store_true', default=False)
args = parser.parse_args()

args.checkname = args.method + '_' + args.arc

alpha = 2

cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# Define Saver
results_dir = Path(__file__).resolve().parents[1] / 'cache'
results_dir.mkdir(parents=True, exist_ok=True)


# Data
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])
test_lbl_trans = transforms.Compose([
    transforms.Resize((224, 224), Image.NEAREST),
])

ds = Imagenet_Segmentation(args.imagenet_seg_path,
                           transform=test_img_trans, target_transform=test_lbl_trans)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)

# LRP
if args.arc == 'deit':
    model_LRP = vit_base_patch16_224(pretrained=True).cuda()
else:
    model_LRP = vit_LRP(pretrained=True).cuda()
model_LRP.eval()
lrp = LRP(model_LRP)


iterator = tqdm(dl)

def eval_batch(image, labels, model, index):
    model.zero_grad()
    image.requires_grad = False # (1, 3, 224, 224)

    if str(args.method).startswith('predmap'):
        res, extras = getattr(model, args.method)(image.cuda(), idx=None, return_extras=True)
        all_attnmap = extras["all_attnmap"].detach().cpu() # (blocks, heads, tokens-1)
        xpredmap = extras["xpredmap"].detach().cpu() # (tokens, classes)
        cls_self_attend = extras["CLS_self_attend"].detach().cpu() # (blocks, heads)

    return all_attnmap, xpredmap, cls_self_attend

all_attnmap_lst = []
xpredmap_lst = []
cls_self_attend_lst = []

for batch_idx, (image, labels) in enumerate(iterator):
    labels = labels.cuda()
    all_attnmap, xpredmap, cls_self_attend = eval_batch(image, labels, model_LRP, batch_idx)
    all_attnmap_lst.append(all_attnmap)
    xpredmap_lst.append(xpredmap)
    cls_self_attend_lst.append(cls_self_attend)
    # if batch_idx == 10:
        # break

all_attnmap_pt = torch.stack(all_attnmap_lst, dim=0) # (image, blocks, heads, tokens-1)
xpredmap_pt = torch.stack(xpredmap_lst, dim=0) # (image, tokens, classes)
cls_self_attend_pt = torch.stack(cls_self_attend_lst, dim=0) # (image, blocks, heads)

data_dict = {
    "all_attnmap": all_attnmap_pt,
    "xpredmap": xpredmap_pt,
    "cls_self_attend": cls_self_attend_pt,
}
torch.save(data_dict, results_dir / f"cache.pt")