#%%
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
from pathlib import Path
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)
from samples.CLS2IDX import CLS2IDX
import types


import lovely_tensors as lt
lt.monkey_patch()


# %% 
# Auxiliary Functions

from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
from baselines.ViT.ViT_explanation_generator import LRP

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

# initialize ViT pretrained
model = vit_LRP(pretrained=True).cuda()
# model = vit_new(pretrained=True).cuda()
model.eval()
attribution_generator = LRP(model)

def generate_visualization(original_image, class_index=None, method=None):
    method = method if method else "transformer_attribution"
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method=method, index=class_index).detach() # (1, 196)
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis

def print_top_classes(predictions, **kwargs):    
    # Print Top-5 predictions
    prob = torch.softmax(predictions, dim=1)
    class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 0
    class_names = []
    for cls_idx in class_indices:
        class_names.append(CLS2IDX[cls_idx])
        if len(CLS2IDX[cls_idx]) > max_str_len:
            max_str_len = len(CLS2IDX[cls_idx])
    
    print('Top 5 classes:')
    for cls_idx in class_indices:
        output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
        output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
        output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
        print(output_string)


    
# %% Examples

# %%
# Cat-Dog

image = Image.open('samples/catdog.png')
dog_cat_image = transform(image)
image = Image.open('samples/el2.png')
tusker_zebra_image = transform(image)
image = Image.open('samples/dogbird.png')
dog_bird_image = transform(image)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(image);
axs[0].axis('off');

output = model(dog_cat_image.unsqueeze(0).cuda())
print_top_classes(output)

# cat - the predicted class
# cat = generate_visualization(dog_cat_image)
cat = generate_visualization(dog_cat_image, class_index=282, method="predmap15")

# dog 
# generate visualization for class 243: 'bull mastiff'
# dog = generate_visualization(dog_cat_image, class_index=243)
dog = generate_visualization(dog_cat_image, class_index=243, method="predmap15")


axs[1].imshow(cat);
axs[1].axis('off');
axs[2].imshow(dog);
axs[2].axis('off');

# %%

def backbone(self, x):
    B = x.shape[0]
    x = self.patch_embed(x)

    cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    x = torch.cat((cls_tokens, x), dim=1)
    x = self.add([x, self.pos_embed])

    # x.register_hook(self.save_inp_grad)

    for blk in self.blocks:
        x = blk(x)

    # x = self.norm(x) # (1, tokens, embed_dim)
    return x
# Based on: "Override a method at instance level"
# https://stackoverflow.com/a/42154067
model.backbone = types.MethodType(backbone, model)

x = model.backbone(tusker_zebra_image.unsqueeze(0).cuda()).detach()
x = x.norm(dim=-1, p=2).log() # (1, tokens)
display(x[:,0])
x = x[:,1:] # (1, tokens-1)
cax = plt.matshow(
    x.reshape(14,14).cpu()
)
plt.colorbar(cax)
# plt.imshow(tusker_zebra_image.permute(1,2,0).cpu())

# %%

########################################
# %%
# Tusker-Zebra

image = Image.open('samples/el2.png')
tusker_zebra_image = transform(image)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(image);
axs[0].axis('off');

output = model(tusker_zebra_image.unsqueeze(0).cuda())
print_top_classes(output)

# tusker - the predicted class
# tusker = generate_visualization(tusker_zebra_image, class_index=101)
tusker = generate_visualization(tusker_zebra_image, class_index=101, method="predmap9")

# zebra 
# generate visualization for class 340: 'zebra'
# zebra = generate_visualization(tusker_zebra_image, class_index=340)
zebra = generate_visualization(tusker_zebra_image, class_index=340,  method="predmap9")


axs[1].imshow(tusker);
axs[1].axis('off');
axs[2].imshow(zebra);
axs[2].axis('off');


# %% 
# Dog Bird
image = Image.open('samples/dogbird.png')
dog_bird_image = transform(image)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(image);
axs[0].axis('off');

output = model(dog_bird_image.unsqueeze(0).cuda())
print_top_classes(output)

# basset - the predicted class
# basset = generate_visualization(dog_bird_image, class_index=161)
basset = generate_visualization(dog_bird_image, class_index=161, method="predmap15")

# generate visualization for class 87: 'African grey, African gray, Psittacus erithacus (grey parrot)'
# parrot = generate_visualization(dog_bird_image, class_index=87)
parrot = generate_visualization(dog_bird_image, class_index=87, method="predmap15")


axs[1].imshow(basset);
axs[1].axis('off');
axs[2].imshow(parrot);
axs[2].axis('off');


# %%
def get_mask(original_image, class_index=None, predmap=False):
    # method = "transformer_attribution" if not predmap else "predmap"
    method = "transformer_attribution" if not predmap else "predmap9"
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method=method, index=class_index).detach() # (1, 196)
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    return transformer_attribution
    image_zeros = np.zeros_like(original_image.permute(1, 2, 0).data.cpu().numpy())
    vis = show_cam_on_image(image_zeros, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis


# model.predmap = types.MethodType(predmap_clean, model)
# model.predmap = types.MethodType(predmap_softmax, model)
# model.predmap = types.MethodType(predmap_weight_softmax, model)

predmap = True
# predmap = False
fig, axes = plt.subplots(1, 2)
cax = axes[0].imshow(
    # get_mask(dog_cat_image, class_index=282, predmap=predmap),
    # get_mask(tusker_zebra_image, class_index=101, predmap=predmap),
    get_mask(dog_bird_image, class_index=161, predmap=predmap),
    cmap='jet',
)
cax = axes[1].imshow(
    # get_mask(dog_cat_image, class_index=243, predmap=predmap),
    # get_mask(tusker_zebra_image, class_index=340, predmap=predmap),
    get_mask(dog_bird_image, class_index=87, predmap=predmap),
    cmap='jet',
)
fig.colorbar(cax, ax=axes.ravel().tolist(), orientation="horizontal")
# plt.close()

# %%
