#%%
import sys
import os
from pathlib import Path
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
from samples.CLS2IDX import CLS2IDX
import types


import lovely_tensors as lt
lt.monkey_patch()



from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
from baselines.ViT.ViT_explanation_generator import LRP
# %%

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

# initialize ViT pretrained
model = vit_LRP(pretrained=True).cuda()
# model = vit_new(pretrained=True).cuda()
model.eval()
# %%

def print_top_classes(predictions):    
    # Print Top-5 predictions
    prob = torch.softmax(predictions, dim=1)
    class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 0
    class_names = []
    for cls_idx in class_indices:
        class_names.append(CLS2IDX[cls_idx])
        if len(CLS2IDX[cls_idx]) > max_str_len:
            max_str_len = len(CLS2IDX[cls_idx])
    
    print('Top 5 classes:')
    for cls_idx in class_indices:
        output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
        output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
        output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
        print(output_string)
# %%

image = Image.open('samples/catdog.png')
dog_cat_image = transform(image)

tusker_zebra_image_raw = Image.open('samples/el2.png')
tusker_zebra_image = transform(tusker_zebra_image_raw)

dog_bird_image_raw = Image.open('samples/dogbird.png')
dog_bird_image = transform(dog_bird_image_raw)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(image);
axs[0].axis('off');

output = model(dog_cat_image.unsqueeze(0).cuda())
print_top_classes(output)
# %%
def predmap_scores(model, x):
    B = x.shape[0]
    x = model.patch_embed(x)

    cls_tokens = model.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    x = torch.cat((cls_tokens, x), dim=1)
    x = model.add([x, model.pos_embed])

    x.register_hook(model.save_inp_grad)

    for blk in model.blocks:
        x = blk(x)

    x = model.norm(x) # (1, tokens, embed_dim)
    x = model.head(x) # (1, tokens, classes)
    return x

predmap = predmap_scores(model, dog_cat_image.unsqueeze(0).cuda()) # (1, tokens, classes)
predmap = predmap.detach().cpu()
# plt.matshow(predmap.cpu().detach().numpy()[0, :, :])   
predmap = predmap[0][1:] # (tokens-1, classes)
predmap = predmap.reshape(14, 14, -1) # (14, 14, classes)
predmap = predmap.softmax(dim=-1) # (14, 14, classes)
predmap = predmap.permute(2, 0, 1) # (classes, 14, 14)
import torchvision

imgs = torchvision.utils.make_grid(predmap[0:64].unsqueeze(1), nrow=8,normalize=True, scale_each=True)
plt.imshow(imgs.permute(1, 2, 0).numpy())


# %%
# Print Top-5 predictions
predictions = model(dog_cat_image.unsqueeze(0).cuda())
prob = torch.softmax(predictions, dim=1)
class_indices_sorted = prob.sort(descending=True).indices # (B, classes)
IN = torchvision.datasets.ImageNet(Path("~/imagenet"), split='val')
predmap_sorted = predmap[class_indices_sorted[0]] # (classes, 14, 14)
predmap_sorted
# IN.classes[class_indices_sorted[0,1]]

imgs = torchvision.utils.make_grid(predmap_sorted[0:64].unsqueeze(1), nrow=8,normalize=True, scale_each=True)
plt.imshow(imgs.permute(1, 2, 0).numpy())


# %%
for clsidx in class_indices_sorted[0][:64]:
    print(IN.classes[clsidx])
# %%

class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
max_str_len = 0
class_names = []
for cls_idx in class_indices:
    class_names.append(CLS2IDX[cls_idx])
    if len(CLS2IDX[cls_idx]) > max_str_len:
        max_str_len = len(CLS2IDX[cls_idx])

print('Top 5 classes:')
for cls_idx in class_indices:
    output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
    output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
    output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
    print(output_string)

# %%
def get_heatmaps(model, method, img, classes):
    # img (3, 224, 224)
    if method == "predmap":
        ps = predmap_scores(model, img.unsqueeze(0).cuda()) # (1, tokens, classes)
        # ps = ps.softmax(dim=-1) # softmax over tokens
        # ps = ps.softmax(dim=-2) # softmax over classes
        ps = ps[0, 1:, classes] # (tokens-1, classes)
        ps = ps.permute(1, 0) # (classes, tokens-1)
        ps = ps.reshape(-1, 14, 14) # (classes, 14, 14)
        return ps
    if method == "transformer_attribution":
        arr = []
        lrp = LRP(model)
        # TODO: Should start_layer be 1 or 0?
        for idx in classes:
            arr.append(lrp.generate_LRP(img.unsqueeze(0).cuda(), start_layer=0, index=idx, method="transformer_attribution").reshape(-1, 14, 14).cpu().detach())
        res = torch.cat(arr, dim=0) # (classes, 14, 14)
        return res
    raise NotImplementedError()

def interpolate_heatmaps(heatmaps):
    pass

# img = dog_cat_image
img = tusker_zebra_image
method = "predmap"
# method = "transformer_attribution"
predictions = model(img.unsqueeze(0).cuda()) # (1, classes)
prob = torch.softmax(predictions, dim=1) # (1, classes)
# class indices sorted by probability 
class_indices_sorted_full = prob.sort(descending=True).indices # (1, classes)
class_indices_sorted = class_indices_sorted_full[0][:64] # (classes_subset)
heatmaps = get_heatmaps(model, method, img, class_indices_sorted) # (classes_subset, 14, 14)
heatmaps = torch.nn.functional.interpolate(heatmaps.unsqueeze(1), scale_factor=16, mode='bilinear') # (classes_subset, 1, 224, 224)

# %%

plt.close("all")
plt.imshow(
    torchvision.utils.make_grid(heatmaps.detach(), nrow=8, normalize=True, scale_each=True)[0]
    .cpu().numpy(),
    cmap="jet")

# heatmaps_interpolate = interpolate_heatmaps(heatmaps) # (classes_subset, 3, 224, 224)
for i, clsidx in enumerate(class_indices_sorted):
    print(f"{i:02} {clsidx} {IN.classes[clsidx]} {prob[0, clsidx]*100:.3f}%")
# %%
# Plotting with matplotlib is much slower than make_grid()
# plt.close("all")
# fig, axes = plt.subplots(8,8, figsize=(64,64))
# for i, ax in enumerate(axes.flatten()):
#     ax.imshow(heatmaps[i].squeeze().cpu().detach().numpy(), cmap="jet")
#     ax.set_title(f"{class_indices_sorted[i]} {IN.classes[class_indices_sorted[i]]} {prob[0, class_indices_sorted[i]]*100:.3f}%")
#     ax.axis('off')
# %%
