#%%
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm
# %%

from samples.CLS2IDX import CLS2IDX
# from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
import types

# IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_without_normalize = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
transform = transforms.Compose([
    transform_without_normalize, 
    normalize
])

# %%


# initialize ViT pretrained
# model = vit_LRP(pretrained=True).cuda()
model = vit_new(pretrained=True).cuda()
model.eval()


# %%
image = Image.open('samples/el2.png')
tusker_zebra_image = transform_without_normalize(image)
image = Image.open('samples/catdog.png')
cat_dog_image = transform_without_normalize(image)

# fig, axs = plt.subplots()
# axs.imshow(tusker_zebra_image.permute(1, 2, 0))

cls_idx_tusker = 101
cls_idx_zebra = 340
cls_idx_cat = 282
cls_idx_dog = 243

# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + 1*np.float32(img)
    cam = cam / np.max(cam)
    return cam


# %%

def gen_visualization(image, cls_idx, layer):
    image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
    xai_map, extras = model.predmap15_softmax_classes_batched_layer(image_norm, cls_idx, layer=layer, return_extras=True)
    xai_map = xai_map.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
    xai_map = (xai_map - xai_map.amin())/(xai_map.amax() - xai_map.amin())
    xai_map = torch.nn.functional.interpolate(xai_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    xai_map = xai_map.squeeze().detach().cpu().numpy() # (224, 224)
    
    vis = show_cam_on_image(image.permute(1, 2, 0).numpy(), xai_map)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

    return vis


data = [
    ('samples/catdog.png',
     [
        282, #cat
        243, #dog
     ]),
     ('samples/el2.png',
      [
        101, #tusker
        340, #zebra
      ]),
    ('samples/dogbird.png',
     [
        161, #basset
        87, #parrot
     ]),
 ]


layer = 11
for image_path, clsidx_lst in data:
    img_name = Path(image_path).stem
    image = Image.open(image_path)
    image = transform_without_normalize(image) # (3, 224, 224)
    for cls_idx in clsidx_lst:
        image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
        # xai_map, extras = model.predmap15_softmax_classes_batched_layer(image_norm, cls_idx, layer=layer, return_extras=True)
        xai_map, extras = model.predmap15_batched_layer(image_norm, cls_idx, layer=layer, return_extras=True)
        xai_map = xai_map.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
        
        projections = extras['projections'] # (1, layers, heads, classes)
        projections = projections[0, ..., cls_idx] # (layers, heads)
        weighted_attnmap = extras['weighted_attnmap'] # (1, tokens-1, classes)
        # weighted_attnmap = weighted_attnmap.log()
        weighted_attnmap = weighted_attnmap[0, :, cls_idx].reshape(14, 14).detach().cpu().numpy() # (14, 14)
        all_attnmap = extras["all_attnmap"] # (B, blocks, heads, tokens-1)
        all_attnmap = all_attnmap[0] # (blocks, heads, tokens-1)
        all_attnmap = all_attnmap.flatten(0,1) # (blocks*heads, tokens-1)
        all_attnmap = all_attnmap.unflatten(-1, (14,14)) # (blocks*heads, 14, 14)
        all_attnmap = all_attnmap.detach()
        all_attnmap_img = torchvision.utils.make_grid(all_attnmap.unsqueeze(1), nrow=12, scale_each=True, normalize=True, pad_value=1).permute(1,2,0).detach().cpu().numpy()
        mean_attnmap = all_attnmap.mean(0) # (14, 14)
        # plt.matshow(all_attnmap_img)

        predmap = extras['predmap'] # (1, tokens-1, classes)
        predmap = predmap[0, :, cls_idx].reshape(14, 14).detach().cpu().numpy() # (14, 14)

        fig, axes = plt.subplots(1,5, figsize=(15, 3))
        ax = axes[0]
        ax.matshow(mean_attnmap.detach().cpu().numpy(), cmap='jet')
        ax.set_title(f'mean attnmap')

        ax = axes[1]
        ax.matshow(weighted_attnmap, cmap='jet')
        ax.axis('off')
        ax.set_title(f'weighted attnmap')
        
        ax = axes[2]
        ax.matshow(predmap, cmap='jet')
        ax.set_title(f'predmap')
        
        ax = axes[3]
        ax.matshow(xai_map.squeeze(), cmap='jet')
        ax.set_title(f'predicatt')
        
        ax = axes[4]
        ax.matshow(projections.detach().cpu().numpy(), cmap='jet')
        ax.set_title(f'dot-product')
        # break
    # break
        # vis = gen_visualization(image, cls_idx, layer)
        # cls_name = CLS2IDX[cls_idx]
        # plt.imshow(vis)
        # plt.imsave(f'figures/artifacts/{img_name}_predicatt_l{layer}_{cls_name}.png', vis)


# %%

# %%

img_path = base_dir/ 'samples/el2.png'
img_path = base_dir/ 'samples/0.002886_sandal _ sandal_0.9212473.jpg'
img_path = base_dir/ 'samples/misc_93.jpg'
img_path = base_dir/ 'samples/misc_83.jpg'
img_path = base_dir/ 'samples/misc_89.jpg'
img_path = base_dir/ 'samples/misc_28.jpg'
img_path = base_dir/ 'samples/sticker_0.jpg'
img_path = base_dir/ 'samples/sculpture_8.jpg'
img_path = base_dir/ 'samples/misc_20.jpg'
img_path = base_dir/ 'samples/videogame_7.jpg'
img_path = base_dir/ 'samples/toy_3.jpg'
img_path = base_dir/ 'samples/misc_41.jpg'
imagenet_r_path = Path(os.environ["IMAGENET_R_PATH"])
img_path = imagenet_r_path / "n07718472/sticker_0.jpg" # cucumber
image = Image.open(img_path)
image = transform_without_normalize(image) # (3, 224, 224)
layer = 10

image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
# cls_idx = 340 # zebra

res = model(image_norm)
topk = torch.topk(res, 5)
for cls_idx in topk.indices.squeeze():
    print(CLS2IDX[cls_idx.item()], cls_idx.item())
# %%
cls_idx = 101 # tusker
cls_idx = 774 # sandal
cls_idx = 168 # redbone
cls_idx = 6 # stingray
cls_idx = 949 # strawberry
cls_idx = 988 # acorn
cls_idx = 889 # violin
cls_idx = 717 # pickup_truck
cls_idx = 658 # mitten
cls_idx = 943 # cucumber

# xai_map, extras = model.predmap15_softmax_classes_batched_layer(image_norm, cls_idx, layer=layer, return_extras=True) # (1, tokens-1)
xai_map, extras = model.predmap15_batched_layer(image_norm, cls_idx, layer=layer, return_extras=True, apply_softmax_classes=False) # (1, tokens-1)
xai_map = xai_map.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
xai_map = (xai_map - xai_map.amin())/(xai_map.amax() - xai_map.amin())
xai_map = torch.nn.functional.interpolate(xai_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
xai_map = xai_map.squeeze().detach().cpu().numpy() # (224, 224)

predmap = extras["predmap"] # (1, tokens-1, classes)
predmap = predmap[0, :, cls_idx] # (tokens-1)
predmap = predmap.reshape(1,1,14, 14).detach().cpu() # (1,1,14, 14)
predmap = (predmap - predmap.amin())/(predmap.amax() - predmap.amin())
predmap = torch.nn.functional.interpolate(predmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
predmap = predmap.squeeze().detach().cpu().numpy() # (224, 224)

weighted_attnmap = extras['weighted_attnmap'] # (1, tokens-1, classes)
weighted_attnmap = weighted_attnmap[0, :, cls_idx] # (tokens-1)
weighted_attnmap = weighted_attnmap.reshape(1, 1, 14, 14).detach().cpu() # (1, 1, 14, 14)
weighted_attnmap = (weighted_attnmap - weighted_attnmap.amin())/(weighted_attnmap.amax() - weighted_attnmap.amin())
weighted_attnmap = torch.nn.functional.interpolate(weighted_attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
weighted_attnmap = weighted_attnmap.squeeze().detach().cpu().numpy() # (224, 224)

all_attnmap = extras["all_attnmap"] # (B, blocks, heads, tokens-1)
all_attnmap = all_attnmap[0] # (blocks, heads, tokens-1)
all_attnmap = all_attnmap.flatten(0,1) # (blocks*heads, tokens-1)
# all_attnmap = all_attnmap[-2:-1].flatten(0,1) # (blocks*heads, tokens-1)
all_attnmap = all_attnmap.unflatten(-1, (14,14)) # (blocks*heads, 14, 14)
all_attnmap = all_attnmap.detach().cpu()
mean_attnmap = all_attnmap.mean(0) # (14, 14)
mean_attnmap = mean_attnmap.reshape(1, 1, 14, 14) # (1, 1, 14, 14)
mean_attnmap = (mean_attnmap - mean_attnmap.amin())/(mean_attnmap.amax() - mean_attnmap.amin())
mean_attnmap = torch.nn.functional.interpolate(mean_attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
mean_attnmap = mean_attnmap.squeeze().detach().cpu().numpy() # (224, 224)


vis_predicatt = show_cam_on_image(image.permute(1, 2, 0).numpy(), xai_map)
vis_predicatt =  np.uint8(255 * vis_predicatt)
vis_predicatt = cv2.cvtColor(np.array(vis_predicatt), cv2.COLOR_RGB2BGR)

vis_predmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), predmap)
vis_predmap =  np.uint8(255 * vis_predmap)
vis_predmap = cv2.cvtColor(np.array(vis_predmap), cv2.COLOR_RGB2BGR)

vis_weighted_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), weighted_attnmap)
vis_weighted_attnmap =  np.uint8(255 * vis_weighted_attnmap)
vis_weighted_attnmap = cv2.cvtColor(np.array(vis_weighted_attnmap), cv2.COLOR_RGB2BGR)

vis_mean_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), mean_attnmap)
vis_mean_attnmap =  np.uint8(255 * vis_mean_attnmap)
vis_mean_attnmap = cv2.cvtColor(np.array(vis_mean_attnmap), cv2.COLOR_RGB2BGR)



# draw 3 vis
fig, axes = plt.subplots(1,4, figsize=(20, 5))
ax = axes[0]
ax.imshow(vis_predicatt)
ax.set_title('predicatt')

ax = axes[1]
ax.imshow(vis_predmap)
ax.set_title('predmap')

ax = axes[2]
ax.imshow(vis_weighted_attnmap)
ax.set_title('weighted attnmap')

ax = axes[3]
ax.imshow(vis_mean_attnmap)
ax.set_title('mean attnmap')


# %%
temp = all_attnmap.unflatten(0, (12,12))[-1] # (heads, 14,14)
temp = temp.unsqueeze(1) # (heads, 1, 14, 14)
temp = torchvision.utils.make_grid(temp, nrow=6, scale_each=True, normalize=True, pad_value=1).permute(1,2,0).detach().cpu().numpy()
plt.matshow(temp[...,-1])
plt.figure()
plt.matshow(all_attnmap.unflatten(0, (12,12))[-1].mean(0).detach().cpu().numpy())
# %%
