#%%
from pathlib import Path
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

import torchvision
import tqdm
# %%

from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
# from baselines.ViT.ViT_LRP import deit_base_patch16_224 as vit_LRP
import types
from PIL import Image

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

# initialize ViT pretrained
model = vit_LRP(pretrained=True).cuda()
# model = vit_new(pretrained=True).cuda()
model.eval()

def predmap_all_layers(self, x):
    B = x.shape[0]
    x = self.patch_embed(x) # (B, tokens-1, embed_dim)

    cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    x = torch.cat((cls_tokens, x), dim=1) # (B, tokens, embed_dim)
    x = self.add([x, self.pos_embed]) # (B, tokens, embed_dim)

    # x.register_hook(self.save_inp_grad)

    layers_out = []
    for blk in self.blocks:
        x = blk(x) # (B, tokens, embed_dim)
        layers_out.append(x)

    layers_out = torch.stack(layers_out, dim=1) # (B, layers, tokens, embed_dim)
    x = layers_out.flatten(0,1) # (B*layers, tokens, embed_dim)
    x = self.norm(x) # (B*layers, tokens, embed_dim)
    # Stip CLS token
    x = x[..., 1:, :] # (B*layers, tokens-1, embed_dim)
    x = self.head(x) # (B*layers, tokens-1, classes)

    n_layers = len(self.blocks)
    x = x.unflatten(0, (B, n_layers)) # (B, layers, tokens-1, classes)
    return x

def get_all_attnmaps(self, x):
    B = x.shape[0]
    x = self.patch_embed(x)

    cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    x = torch.cat((cls_tokens, x), dim=1)
    x = self.add([x, self.pos_embed])

    for blk in self.blocks:
        x = blk(x)

    # get all attention from all blocks
    all_attn = torch.stack(
        [
            blk.attn.get_attn() # (B, heads, tokens, tokens)
            for blk in self.blocks],
        dim=1) # (B, blocks, heads, tokens, tokens)
    all_attn = all_attn.clamp(min=0) # (B, blocks, heads, tokens, tokens)
    all_attnmap = all_attn[..., 0, 1:] # (B, blocks, heads, tokens-1)
    return all_attnmap

model.predmap_all_layers = types.MethodType(predmap_all_layers, model)
model.get_all_attnmaps = types.MethodType(get_all_attnmaps, model)

# %%
image = Image.open('samples/catdog.png')
dog_cat_image = transform(image)

image = Image.open('samples/el2.png')
tusker_zebra_image = transform(image)

image = Image.open('samples/dogbird.png')
dog_bird_image = transform(image)

# %%

data = [
    (dog_cat_image,
     [
        # 895, #warplane
        282, #cat
        243, #dog
     ]),
     (tusker_zebra_image,
      [
        101, #tusker
        340, #zebra
      ]),
    (dog_bird_image,
     [
        161, #basset
        87, #parrot
     ]),
 ]
# %%
# display all predmap layers

for img, labels in data:
    img = img.unsqueeze(0).cuda()
    labels = torch.tensor(labels).cuda()
    predmap = model.predmap_all_layers(img) # (1, layers, tokens-1, classes)
    for clsidx in labels:
        predmap_cls = predmap[..., clsidx] # (1, layers, tokens-1)
        predmap_cls_flat = predmap_cls # (1, layers, tokens-1)

        # compute deltas between layers
        # predmap_cls = predmap_cls[:, 1:, :] - predmap_cls[:, :-1, :] # (1, layers-1, tokens-1)


        # predmap_cls = predmap_cls.softmax(dim=-1)
        predmap_cls = predmap_cls.unflatten(-1, (14,14)) # (1, layers, 14, 14)
        predmap_cls = predmap_cls[0] # (layers, 14, 14)
        #interpolate to 224x224
        predmap_cls = torch.nn.functional.interpolate(predmap_cls.unsqueeze(1), scale_factor=16, mode="bilinear", align_corners=False).squeeze(1)

        # display all predmap layers
        n_layers = predmap_cls.shape[0]
        rows = 3
        cols = (n_layers // rows)
        fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*2, rows*2))
        for i, ax in enumerate(axes.flat):
            ax.imshow(predmap_cls[i].detach().cpu().numpy(), cmap='jet', 
                vmin=predmap_cls.min(), vmax=predmap_cls.max()
            )
            ###################
            # attnmaps = model.get_all_attnmaps(img) # (1, layers, heads, tokens-1)
            # # calc dot product between predmap_cls_flat and attnmaps
            # corrmap = (predmap_cls_flat[:,[i]].unsqueeze(-2) * attnmaps).sum(dim=-1) # (1, layers, heads, tokens-1)
            # ax.matshow(
            #   corrmap[0].cpu().detach().numpy()
            # )
            # ###################
            ax.set_title(f"Layer {i}")
        fig.tight_layout()
        # break
    # break
# %%
# disaply predmap, dot-product map and high-correlation attention maps
for img, labels in data:
    img = img.unsqueeze(0).cuda()
    labels = torch.tensor(labels).cuda()
    predmap = model.predmap_all_layers(img) # (1, layers, tokens-1, classes)
    for clsidx in labels:
        predmap_cls = predmap[..., clsidx] # (1, layers, tokens-1)
        predmap_cls_flat = predmap_cls # (1, layers, tokens-1)

        # compute deltas between layers
        # predmap_cls = predmap_cls[:, 1:, :] - predmap_cls[:, :-1, :] # (1, layers-1, tokens-1)


        # predmap_cls = predmap_cls.softmax(dim=-1)
        predmap_cls = predmap_cls.unflatten(-1, (14,14)) # (1, layers, 14, 14)
        predmap_cls = predmap_cls[0] # (layers, 14, 14)
        #interpolate to 224x224
        # predmap_cls = torch.nn.functional.interpolate(predmap_cls.unsqueeze(1), scale_factor=16, mode="bilinear", align_corners=False).squeeze(1)

        # display all predmap layers
        k = 4
        n_layers = predmap_cls.shape[0]
        rows = n_layers
        cols = k+2
        fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*2, rows*2))
        all_attnmap = model.get_all_attnmaps(img) # (1, layers, heads, tokens-1)
        for i in range(n_layers):
            # predmap
            ax = axes[i, 0]
            ax.imshow(predmap_cls[i].detach().cpu().numpy(), cmap='jet', 
                vmin=predmap_cls.min(), vmax=predmap_cls.max()
            )
            ax.set_title(f"predmap l{i}")

            # corrmap
            ax = axes[i, 1]
            # calc dot product between predmap_cls_flat and attnmaps
            predmap_layer = predmap_cls_flat[:,[i]].unsqueeze(-2) # (1, 1, 1, tokens-1)
            corrmap = (predmap_layer * all_attnmap).sum(dim=-1) # (1, layers, heads)
            # corrmap = corrmap / (predmap_layer.norm(dim=-1) * all_attnmap.norm(dim=-1)) # (1, layers, heads)
            ax.matshow(
              corrmap[0].cpu().detach().numpy()
            )

            # get the indices of the top-k valus of projections, and then plot the attention maps of these heads
            topk_indices = corrmap.flatten().topk(k).indices # (k, )
            topk_values = corrmap.flatten().topk(k).values # (k, )
            topk_attnmap = all_attnmap.flatten(0, 2)[topk_indices] # (k, tokens-1)
            for j, (topk_idx, corr, attnmap) in enumerate(zip(topk_indices, topk_values, topk_attnmap)):
                attnmap = attnmap.reshape(14, 14)
                ax = axes[i, j+2]
                ax.matshow(attnmap.detach().cpu().numpy(), cmap='jet')
                # k to block idx and head idx
                block_idx = topk_idx // 12
                head_idx = topk_idx % 12
                ax.set_title(f"b {block_idx}, h {head_idx}, corr {corr:.2f}")

            # ax.imshow(predmap_cls[i].detach().cpu().numpy(), cmap='jet', 
            #     vmin=predmap_cls.min(), vmax=predmap_cls.max()
            # )
            ###################
            # # calc dot product between predmap_cls_flat and attnmaps
            # corrmap = (predmap_cls_flat[:,[i]].unsqueeze(-2) * attnmaps).sum(dim=-1) # (1, layers, heads, tokens-1)
            # ax.matshow(
            #   corrmap[0].cpu().detach().numpy()
            # )
            # ###################
            # ax.set_title(f"Layer {i}")
        fig.tight_layout()
        # break
    # break

# %%


# clsidx = 282 #cat
clsidx = 243 #dog
# clsidx = 340 #zebra
predmap_object = predmap[..., clsidx] # (1, layers, tokens-1)
mask = predmap_object < predmap_object.mean(dim=-1, keepdim=True) # (1, layers, tokens-1)
mask = mask[0,9] # (tokens-1, )
object_tokens_idx = mask.nonzero().squeeze() # (object_tokens, )

plt.matshow(mask.reshape(14,14).detach().cpu().numpy())

# %%


plt.matshow(predmap_object[0][..., object_tokens_idx].detach().cpu().numpy())
plt.xlabel("Tokens")
plt.ylabel("Layers")

# plt.matshow(object_tokens_idx[0,10].reshape(14,14).detach().cpu().numpy())
# %%

# %%
