#%%
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm
# %%

from samples.CLS2IDX import CLS2IDX
# from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
import types

# IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_without_normalize = transforms.Compose([
    # transforms.Resize(256),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
transform = transforms.Compose([
    transform_without_normalize, 
    normalize
])

# %%


# initialize ViT pretrained
# model = vit_LRP(pretrained=True).cuda()
model = vit_new(pretrained=True).cuda()
model.eval()


# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + 1*np.float32(img)
    cam = cam / np.max(cam)
    return cam


# %%


imagenet_r_dir = Path("~/datasets/imagenet-r").expanduser()
imagenet_val_dir = Path("~/ramdisk/datasets/imagenet/val").expanduser()

# acorn
cls_wind = "n12267677"
cls_idx = 988 # acorn



# space_shuttle
cls_wind = "n04266014"
cls_idx = 812 # space_shuttle

# fire_engine
cls_wind = "n03345487"
cls_idx = 555 # fire_engine



cls_wind = "n04536866"
cls_idx = 889 # violin


# ambulance 
cls_wind = "n02701002"
cls_idx = 407 # ambulance

# cucumber
cls_wind = "n07718472"
cls_idx = 943 # cucumber


# golden retriever
cls_wind = "n02099601"
cls_idx = 207 # golden retriever



# Strawberry
cls_wind = "n07745940"
cls_idx = 949 # Strawberry



# Lemon
cls_wind = "n07749582"
cls_idx = 951 # Lemon

# Persian cat
cls_wind = "n02123394"
cls_idx = 283 # Persian cat


layer = 11
# cls_dir = imagenet_r_dir / cls_wind
cls_dir = imagenet_val_dir / cls_wind
# apply_softmax_classes = False
apply_softmax_classes = True
files = list(cls_dir.iterdir())
output_dir: Path = base_dir / f"figures/artifacts/predicatt_decomp_explore/{cls_wind}_l{layer}_softmaxclasses-{apply_softmax_classes}"
output_dir.mkdir(exist_ok=True, parents=True)
plt.ioff()
# plt.ion()
from tqdm import tqdm


# %%

plt.close("all")
for f in tqdm(files):
    # print(f)
    image = Image.open(f)
    image = transform_without_normalize(image) # (3, 224, 224)
    if image.shape[0] == 1:
        image = image.expand(3, -1, -1)
    if image.shape[0] == 4:
        # remove alpha channel
        image = image[:3] 
    image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
    # res = model(image_norm)
    # topk = torch.topk(res, 5)
    # print(topk.indices.squeeze())

    # xai_map, extras = model.predmap15_softmax_classes_batched_layer(image_norm, cls_idx, layer=layer, return_extras=True) # (1, tokens-1)
    xai_map, extras = model.predmap15_batched_layer(image_norm, cls_idx, layer=layer, return_extras=True, apply_softmax_classes=apply_softmax_classes) # (1, tokens-1)
    xai_map = xai_map.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
    xai_map = (xai_map - xai_map.amin())/(xai_map.amax() - xai_map.amin())
    xai_map = torch.nn.functional.interpolate(xai_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    xai_map = xai_map.squeeze().detach().cpu().numpy() # (224, 224)

    predmap = extras["predmap"] # (1, tokens-1, classes)
    predmap = predmap[0, :, cls_idx] # (tokens-1)
    predmap = predmap.reshape(1,1,14, 14).detach().cpu() # (1,1,14, 14)
    predmap = (predmap - predmap.amin())/(predmap.amax() - predmap.amin())
    predmap = torch.nn.functional.interpolate(predmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    predmap = predmap.squeeze().detach().cpu().numpy() # (224, 224)

    weighted_attnmap = extras['weighted_attnmap'] # (1, tokens-1, classes)
    weighted_attnmap = weighted_attnmap[0, :, cls_idx] # (tokens-1)
    weighted_attnmap = weighted_attnmap.reshape(1, 1, 14, 14).detach().cpu() # (1, 1, 14, 14)
    weighted_attnmap = (weighted_attnmap - weighted_attnmap.amin())/(weighted_attnmap.amax() - weighted_attnmap.amin())
    weighted_attnmap = torch.nn.functional.interpolate(weighted_attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    weighted_attnmap = weighted_attnmap.squeeze().detach().cpu().numpy() # (224, 224)

    all_attnmap = extras["all_attnmap"] # (B, blocks, heads, tokens-1)
    all_attnmap = all_attnmap[0] # (blocks, heads, tokens-1)
    all_attnmap = all_attnmap.flatten(0,1) # (blocks*heads, tokens-1)
    all_attnmap = all_attnmap.unflatten(-1, (14,14)) # (blocks*heads, 14, 14)
    all_attnmap = all_attnmap.detach().cpu()
    mean_attnmap = all_attnmap.mean(0) # (14, 14)
    mean_attnmap = mean_attnmap.reshape(1, 1, 14, 14) # (1, 1, 14, 14)
    mean_attnmap = (mean_attnmap - mean_attnmap.amin())/(mean_attnmap.amax() - mean_attnmap.amin())
    mean_attnmap = torch.nn.functional.interpolate(mean_attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    mean_attnmap = mean_attnmap.squeeze().detach().cpu().numpy() # (224, 224)


    
    mean_last_attnmap = all_attnmap # (blocks*heads, 14, 14)
    mean_last_attnmap = mean_last_attnmap.unflatten(0, (12,12)) # (blocks, heads, 14,14)
    mean_last_attnmap = mean_last_attnmap[-1] # (heads, 14,14)
    mean_last_attnmap = mean_last_attnmap.mean(0) # (14, 14)
    mean_last_attnmap = mean_last_attnmap.reshape(1, 1, 14, 14) # (1, 1, 14, 14)
    mean_last_attnmap = (mean_last_attnmap - mean_last_attnmap.amin())/(mean_last_attnmap.amax() - mean_last_attnmap.amin())
    mean_last_attnmap = torch.nn.functional.interpolate(mean_last_attnmap, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    mean_last_attnmap = mean_last_attnmap.squeeze().detach().cpu().numpy() # (224, 224)


    vis_predicatt = show_cam_on_image(image.permute(1, 2, 0).numpy(), xai_map)
    vis_predicatt =  np.uint8(255 * vis_predicatt)
    vis_predicatt = cv2.cvtColor(np.array(vis_predicatt), cv2.COLOR_RGB2BGR)

    vis_predmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), predmap)
    vis_predmap =  np.uint8(255 * vis_predmap)
    vis_predmap = cv2.cvtColor(np.array(vis_predmap), cv2.COLOR_RGB2BGR)

    vis_weighted_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), weighted_attnmap)
    vis_weighted_attnmap =  np.uint8(255 * vis_weighted_attnmap)
    vis_weighted_attnmap = cv2.cvtColor(np.array(vis_weighted_attnmap), cv2.COLOR_RGB2BGR)

    vis_mean_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), mean_attnmap)
    vis_mean_attnmap =  np.uint8(255 * vis_mean_attnmap)
    vis_mean_attnmap = cv2.cvtColor(np.array(vis_mean_attnmap), cv2.COLOR_RGB2BGR)

    vis_mean_last_attnmap = show_cam_on_image(image.permute(1, 2, 0).numpy(), mean_last_attnmap)
    vis_mean_last_attnmap =  np.uint8(255 * vis_mean_last_attnmap)
    vis_mean_last_attnmap = cv2.cvtColor(np.array(vis_mean_last_attnmap), cv2.COLOR_RGB2BGR)

    # plt.ion()
    fig, axes = plt.subplots(1,5, figsize=(25, 5))
    ax = axes[0]
    ax.imshow(vis_predicatt)
    ax.set_title('predicatt')
    cls_pred_vec = extras['xpredmap'][0,0].detach().cpu() # (1000)
    topk = torch.topk(cls_pred_vec, 5)
    pred_str = "\n".join(
        f"{val.item(): >4.2f} {CLS2IDX[ind.item()][:20]}" for ind,val in zip(topk.indices.squeeze(), topk.values.squeeze())
    )
    # for cls_idx in topk.indices.squeeze():
        # print(f"\t {cls_pred_vec[0, cls_idx].item(): >6.2f} \t {CLS2IDX[cls_idx.item()]}, {cls_idx.item()}, ")
    
    ax.text(0, 0, pred_str, color='white', fontsize=19, ha='left', va='top', backgroundcolor='black')

    ax = axes[1]
    ax.imshow(vis_predmap)
    ax.set_title(f'predmap l{layer} {f.stem}')

    ax = axes[2]
    ax.imshow(vis_weighted_attnmap)
    ax.set_title('weighted attnmap')

    ax = axes[3]
    ax.imshow(vis_mean_attnmap)
    ax.set_title('mean attnmap')

    ax = axes[4]
    ax.imshow(vis_mean_last_attnmap)
    ax.set_title('mean last attnmap')

    path = output_dir / f"{cls_wind}_{f.stem}_decomp.png"
    fig.savefig(path)
    # break
    plt.ioff()
# %%
