#%%
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm
from matplotlib import pyplot as plt
plt.rcParams.update({
    "text.usetex": True,  # Enables latex equations
    "font.family": "cmu-serif", # Sets the correct font
    "mathtext.fontset": "cm",   # --"--
    "font.size": 21,            # Set the font according to what you need
    "text.latex.preamble": r"\usepackage{amsmath}"   # You can add this to enable complicated math stuff
})
plt.style.use('tableau-colorblind10')  # You can use this to get a colorblind color palette

# %%

from samples.CLS2IDX import CLS2IDX
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
import types

# IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_without_normalize = transforms.Compose([
    # transforms.Resize(256),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
transform = transforms.Compose([
    transform_without_normalize, 
    normalize
])

# %%


# initialize ViT pretrained
model = vit_new(pretrained=True).cuda()
model.eval()

# model_LRP = vit_LRP(pretrained=True).cuda()
# model_LRP.eval()
# from baselines.ViT.ViT_explanation_generator import Baselines, LRP
# lrp = LRP(model_LRP)


# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + 1*np.float32(img)
    cam = cam / np.max(cam)
    return cam


# %%
imagenet_r_path = Path(os.environ["IMAGENET_R_PATH"])
imagenet_path = Path(os.environ["IMAGENET_PATH"])
imagenet_val_path = imagenet_path / "val"
data = []

# cls_idx = 943 # cucumber
# cls_wind = "n07718472"
# img_path = imagenet_r_path / f'{cls_wind}/sticker_0.jpg'

cls_idx = 812 # space_shuttle
cls_wind = "n04266014"
img_path = imagenet_r_path / f'{cls_wind}/toy_7.jpg'


data = [
    (
        imagenet_val_path / f'n07753592/ILSVRC2012_val_00017021.JPEG',
        954, # banana
        11, # layer
        [],
    ),
    (
        imagenet_val_path / f'n07745940/ILSVRC2012_val_00021854.JPEG',
        949, # Strawberry
        10, # layer
        [
            (6,10),
            (0,8),
        ],
    ),
 ]

imagenet_ds = torchvision.datasets.ImageNet(imagenet_path, split='val')
data = []
for i in range(50):
    print(i)
    idx = 500*i
    img_path, cls_idx = imagenet_ds.samples[idx]
    img_path = Path(img_path)
    data.append((img_path, cls_idx, 11, []))
    data.append((img_path, cls_idx, 10, []))
# %%




output_dir = base_dir / Path(f'figures/artifacts/supp/dotprod/')
output_dir.mkdir(exist_ok=True, parents=True)

# plt.ion()
plt.ioff()
plt.close("all")

for img_path, cls_idx, layer, attnmap_indices in data:
    cls_wind = img_path.parent.stem
    plt.close("all")
    print(img_path)
    image = Image.open(img_path)
    image = transform_without_normalize(image) # (3, 224, 224)
    if image.shape[0] == 1:
        image = image.expand(3, -1, -1)
    if image.shape[0] == 4:
        # remove alpha channel
        image = image[:3] 
    image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
    # cls_pred_vec = model(image_norm)
    # topk = torch.topk(cls_pred_vec, 5)
    # print("Top predictions:")
    # for ci in topk.indices.squeeze():
    #     print(f"\t {cls_pred_vec[0, ci].item(): >6.2f} \t {CLS2IDX[ci.item()]}, {ci.item()}, ")

    apply_softmax_classes = True
    xai_map, extras = model.predmap15_batched_layer(image_norm, cls_idx, layer=layer, return_extras=True, apply_softmax_classes=apply_softmax_classes) # (1, tokens-1)
    xai_map = xai_map.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
    xai_map = (xai_map - xai_map.amin())/(xai_map.amax() - xai_map.amin())
    xai_map = torch.nn.functional.interpolate(xai_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    xai_map = xai_map.squeeze().detach().cpu().numpy() # (224, 224)
    
    ##################
    # trans_attr_map = lrp.generate_LRP(image_norm, cls_idx, method="transformer_attribution")
    # trans_attr_map = trans_attr_map.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
    # trans_attr_map = (trans_attr_map - trans_attr_map.amin())/(trans_attr_map.amax() - trans_attr_map.amin())
    # trans_attr_map = torch.nn.functional.interpolate(trans_attr_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    # trans_attr_map = trans_attr_map.squeeze().detach().cpu().numpy() # (224, 224)
    # vis_trans_attr_map = show_cam_on_image(image.permute(1, 2, 0).numpy(), trans_attr_map)
    # vis_trans_attr_map =  np.uint8(255 * vis_trans_attr_map)
    # vis_trans_attr_map = cv2.cvtColor(np.array(vis_trans_attr_map), cv2.COLOR_RGB2BGR)
    ##################

    predmap = extras["predmap"] # (1, tokens-1, classes)
    predmap = predmap[0, :, cls_idx] # (tokens-1)
    predmap = predmap.reshape(1,1,14, 14).detach().cpu() # (1,1,14, 14)
    predmap = (predmap - predmap.amin())/(predmap.amax() - predmap.amin())
    predmap = torch.nn.functional.interpolate(predmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    predmap = predmap.squeeze().detach().cpu().numpy() # (224, 224)

    # weighted_attnmap = extras['weighted_attnmap'] # (1, tokens-1, classes)
    # weighted_attnmap = weighted_attnmap[0, :, cls_idx] # (tokens-1)
    # weighted_attnmap = weighted_attnmap.reshape(1, 1, 14, 14).detach().cpu() # (1, 1, 14, 14)
    # weighted_attnmap = (weighted_attnmap - weighted_attnmap.amin())/(weighted_attnmap.amax() - weighted_attnmap.amin())
    # weighted_attnmap = torch.nn.functional.interpolate(weighted_attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    # weighted_attnmap = weighted_attnmap.squeeze().detach().cpu().numpy() # (224, 224)

    all_attnmap = extras["all_attnmap"] # (B, blocks, heads, tokens-1)
    all_attnmap = all_attnmap[0] # (blocks, heads, tokens-1)
    all_attnmap = all_attnmap.flatten(0,1) # (blocks*heads, tokens-1)
    all_attnmap = all_attnmap.unflatten(-1, (14,14)) # (blocks*heads, 14, 14)
    all_attnmap = all_attnmap.detach().cpu()
    mean_attnmap = all_attnmap.mean(0) # (14, 14)
    mean_attnmap = mean_attnmap.reshape(1, 1, 14, 14) # (1, 1, 14, 14)
    mean_attnmap = (mean_attnmap - mean_attnmap.amin())/(mean_attnmap.amax() - mean_attnmap.amin())
    mean_attnmap = torch.nn.functional.interpolate(mean_attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    mean_attnmap = mean_attnmap.squeeze().detach().cpu().numpy() # (224, 224)


    dotprod = extras["projections"] # (1, blocks, heads, classes)
    dotprod = dotprod[0, ..., cls_idx] # (blocks, heads)
    dotprod = dotprod.squeeze().detach().cpu() # (blocks, heads)



    vis_predicatt = show_cam_on_image(image.permute(1, 2, 0).numpy(), xai_map)
    vis_predicatt =  np.uint8(255 * vis_predicatt)
    vis_predicatt = cv2.cvtColor(np.array(vis_predicatt), cv2.COLOR_RGB2BGR)

    vis_predmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), predmap)
    vis_predmap =  np.uint8(255 * vis_predmap)
    vis_predmap = cv2.cvtColor(np.array(vis_predmap), cv2.COLOR_RGB2BGR)
    


    # Save original image
    path = output_dir / f"dotprod_{cls_wind}_{img_path.stem}_original.png"
    plt.imsave(path, image.permute(1, 2, 0).numpy())

    # Save predmap
    path = output_dir / f"dotprod_{cls_wind}_{img_path.stem}_p{cls_idx}_predmap_l{layer+1}.png"
    plt.imsave(path, vis_predmap)

    # Save dotprod
    plt.rcParams.update({
    "text.usetex": True,  # Enables latex equations
    "font.family": "cmu-serif", # Sets the correct font
    "mathtext.fontset": "cm",   # --"--
    "font.size": 30,            # Set the font according to what you need
    "text.latex.preamble": r"\usepackage{amsmath}"   # You can add this to enable complicated math stuff
    })
    fig, ax = plt.subplots()
    ax.matshow(dotprod.numpy())
    ticks = np.arange(2,12+1,2)
    ax.set_xticks(ticks-1)
    ax.set_xticklabels(ticks)
    ax.set_yticks(ticks-1)
    ax.set_yticklabels(ticks)
    ax.set_xlabel("Heads", fontsize=30)
    ax.set_ylabel("Layers", fontsize=30)
    path = output_dir / f"dotprod_{cls_wind}_{img_path.stem}_p{cls_idx}_dotprod_l{layer+1}.svg"
    print(path)
    fig.savefig(path, transparent=True, bbox_inches='tight', pad_inches=0.0)
    


    topk=5
    botk=5
    fig, axes = plt.subplots(1, 3+topk+botk, figsize=(20, 4))
    ax = axes[0]
    ax.imshow(image.permute(1, 2, 0).numpy())
    # ax.set_title(f'{CLS2IDX[cls_idx]}\text{{ {img_path.stem} }}')
    ax.axis('off')

    ax = axes[1]
    ax.imshow(vis_predmap)
    ax.set_title(f'PredMap l{layer+1}\n{CLS2IDX[cls_idx]}')
    ax.axis('off')

    ax = axes[2]
    ax.matshow(dotprod.numpy())
    ax.set_title(f'dot-product')
    ax.axis('on')

    # show 2 top attnmaps
    
    # get the indices of the top-k valus of projections, and then plot the attention maps of these heads
    indices = dotprod.flatten().argsort() # (layers*heads, )
    # topk_indices = dotprod.flatten().topk(k).indices # (k, )
    # topk_values = dotprod.flatten().topk(k).values # (k, )
    # topk_attnmap = all_attnmap[topk_indices] # (k, 14, 14)
    # for j, (topk_idx, corr, attnmap) in enumerate(zip(topk_indices, topk_values, topk_attnmap)):
    for j, idx in enumerate(reversed(indices)[:topk]):
        attnmap = all_attnmap[idx] # (14, 14)
        attnmap = attnmap.reshape(14, 14) # (14, 14)
        ax = axes[3+j]
        # ax.matshow(attnmap.detach().cpu().numpy(), cmap='jet')
        block_idx = idx // 12
        head_idx = idx % 12

        attnmap = attnmap.reshape(1, 1, 14, 14) # (1, 1, 14, 14)
        attnmap = (attnmap - attnmap.amin())/(attnmap.amax() - attnmap.amin())
        attnmap = torch.nn.functional.interpolate(attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
        attnmap = attnmap.squeeze().detach().cpu().numpy()
        vis_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), attnmap)
        vis_attnmap =  np.uint8(255 * vis_attnmap)
        vis_attnmap = cv2.cvtColor(np.array(vis_attnmap), cv2.COLOR_RGB2BGR)
        ax.imshow(vis_attnmap, cmap='jet')
        path = output_dir / f"dotprod_{cls_wind}_{img_path.stem}_p{cls_idx}_layer{layer+1}_b{block_idx}_h{head_idx}_top{j}.png"
        print(path)
        plt.imsave(path, vis_attnmap)

        corr = dotprod[block_idx, head_idx].item()
        ax.set_title(f"b {block_idx}, h {head_idx}, corr {corr:.2f}")
        ax.axis('off')

    for j, idx in enumerate(indices[:botk]):
        attnmap = all_attnmap[idx] # (14, 14)
        attnmap = attnmap.reshape(14, 14) # (14, 14)
        ax = axes[3+topk+j]
        # ax.matshow(attnmap.detach().cpu().numpy(), cmap='jet')
        block_idx = idx // 12
        head_idx = idx % 12

        attnmap = attnmap.reshape(1, 1, 14, 14) # (1, 1, 14, 14)
        attnmap = (attnmap - attnmap.amin())/(attnmap.amax() - attnmap.amin())
        attnmap = torch.nn.functional.interpolate(attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
        attnmap = attnmap.squeeze().detach().cpu().numpy()
        vis_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), attnmap)
        vis_attnmap =  np.uint8(255 * vis_attnmap)
        vis_attnmap = cv2.cvtColor(np.array(vis_attnmap), cv2.COLOR_RGB2BGR)
        ax.imshow(vis_attnmap, cmap='jet')
        path = output_dir / f"dotprod_{cls_wind}_{img_path.stem}_p{cls_idx}_layer{layer+1}_b{block_idx}_h{head_idx}_bot{j}.png"
        print(path)
        plt.imsave(path, vis_attnmap)

        # k to block idx and head idx
        corr = dotprod[block_idx, head_idx].item()
        ax.set_title(f"b {block_idx}, h {head_idx}, corr {corr:.2f}")
        ax.axis('off')

    for block_idx, head_idx in attnmap_indices:
        attnmap = all_attnmap[block_idx*12 + head_idx]
        attnmap = attnmap.reshape(1, 1, 14, 14) # (1, 1, 14, 14)
        attnmap = (attnmap - attnmap.amin())/(attnmap.amax() - attnmap.amin())
        attnmap = torch.nn.functional.interpolate(attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
        attnmap = attnmap.squeeze().detach().cpu().numpy()
        vis_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), attnmap)
        vis_attnmap =  np.uint8(255 * vis_attnmap)
        vis_attnmap = cv2.cvtColor(np.array(vis_attnmap), cv2.COLOR_RGB2BGR)
        ax.imshow(vis_attnmap, cmap='jet')
        path = output_dir / f"dotprod_{cls_wind}_{img_path.stem}_p{cls_idx}_b{block_idx}_h{head_idx}.png"
        print(path)
        plt.imsave(path, vis_attnmap)

    
    path = output_dir / f"dotprod_{cls_wind}_{img_path.stem}_p{cls_idx}_layer{layer+1}.png"

    fig.tight_layout()
    # fig.savefig(path)

    # ax = axes[4]
    # ax.imshow(vis_mean_last_attnmap)
    # ax.set_title(f'mean last attnmap')
    # break
    plt.ioff()

    attnmap_lst = []
    for i, attnmap in enumerate(all_attnmap):
        attnmap = attnmap.reshape(1, 1, 14, 14) # (1, 1, 14, 14)
        attnmap = (attnmap - attnmap.amin())/(attnmap.amax() - attnmap.amin())
        attnmap = torch.nn.functional.interpolate(attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
        attnmap = attnmap.squeeze().detach().cpu().numpy()
        vis_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), attnmap)
        vis_attnmap =  np.uint8(255 * vis_attnmap)
        vis_attnmap = cv2.cvtColor(np.array(vis_attnmap), cv2.COLOR_RGB2BGR)
        pt_vis_attnmap = torch.from_numpy(vis_attnmap).permute(2, 0, 1) # (3, 224, 224)
        attnmap_lst.append(pt_vis_attnmap)
    
    all_attnmap_img = torchvision.utils.make_grid(
        torch.stack(attnmap_lst, dim=0), # (blocks*heads, 3, 224, 224)
        nrow=12,
    )
    # plt.ion()
    # plt.figure()
    # plt.imshow(all_attnmap_img.permute(1,2,0).numpy())
    path = output_dir / f"dotprod_{cls_wind}_{img_path.stem}_allattnmap.png"
    plt.imsave(path,all_attnmap_img.permute(1,2,0).numpy())
    # break

# %%
