# %%
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm
# %%

from samples.CLS2IDX import CLS2IDX
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new

# IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_without_normalize = transforms.Compose([
    # transforms.Resize(256),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
transform = transforms.Compose([
    transform_without_normalize, 
    normalize
])

# %%


# initialize ViT pretrained
model = vit_new(pretrained=True).cuda()
model.eval()


# %%


# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + 1*np.float32(img)
    cam = cam / np.max(cam)
    return cam


# %%

def gen_visualization(image, cls_idx, layer):
    image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
    predmap = model.predmap15_ablation_predmap(image_norm, cls_idx, layer=layer) # (B, tokens-1)
    predmap = predmap.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
    xai_map = predmap
    xai_map = (xai_map - xai_map.amin())/(xai_map.amax() - xai_map.amin())
    xai_map = torch.nn.functional.interpolate(xai_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    xai_map = xai_map.squeeze().detach().cpu().numpy() # (224, 224)
    
    vis = show_cam_on_image(image.permute(1, 2, 0).numpy(), xai_map)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis




# %%
# cls_wind = None
# samples_dir = base_dir / "samples"
# sample_files = list(samples_dir.glob("*.png"))
# sample_files = list(samples_dir.glob("*.jpg"))

imagenet_dir = Path("~/ramdisk/datasets/imagenet/").expanduser()
# imagenet_val_dir = Path("~/ramdisk/datasets/imagenet/val").expanduser()
# cls_wind = "n07718472" # cucumber
# cls_wind = "n02099601" # golden retriever

# samples_dir = imagenet_val_dir / cls_wind
# sample_files = list(samples_dir.iterdir())
# sample_files = [f for f in sample_files if 
#                 "38781" in f.stem
# or "6981" in f.stem
# or "11373" in f.stem
# or "12100" in f.stem
# or "26887" in f.stem
# or "37455" in f.stem
#                 ]

import torchvision.datasets
imagenet_ds = torchvision.datasets.ImageNet(imagenet_dir, split='val', transform=transform_without_normalize)
torch.manual_seed(0)
rand_indices = torch.randperm(len(imagenet_ds))

# %%

plt.ion()
# plt.ioff()
output_dir: Path = base_dir / "figures/artifacts/supp/predmap_per_layer_explore"
output_dir.mkdir(exist_ok=True, parents=True)
for i in rand_indices[:50]:
    i=i.item()
    plt.close("all")
    img_path, cls_idx = imagenet_ds.samples[i]
    img_path = Path(img_path)
    print(img_path)
    # open image
    image = Image.open(img_path)
    image = transform_without_normalize(image) # (3, 224, 224)
    if image.shape[0] == 1:
        image = image.expand(3, -1, -1)
    if image.shape[0] == 4:
        # remove alpha channel
        image = image[:3] 

    # get top k predictions
    k = 3
    image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
    # with torch.no_grad():
    #     preds = model(image_norm) # (1, 1000)
    # topk_cls_incdices = torch.topk(preds, k).indices.squeeze().detach().cpu().numpy() # (5,)
    # topk_cls_incdices
    # for cls_idx in topk_cls_incdices:
    #     print(f"{preds[0, cls_idx]: >6.2f} \t {cls_idx} \t {CLS2IDX[cls_idx]}")
    
    layers = range(5,12)
    # fig, axes = plt.subplots(k,1, figsize=(len(layers),k))
    fig, axes = plt.subplots(figsize=(len(layers),1))
    # for i, cls_idx in enumerate(topk_cls_incdices):
    vis_layers = [image.unsqueeze(0)]
    for layer in layers:
        vis = gen_visualization(image, cls_idx, layer) # (224, 224, 3)
        vis = torch.from_numpy(vis).permute(2, 0, 1).unsqueeze(0).float().div(255) # (3, 224, 224)
        vis_layers.append(vis)
    vis_layers = torch.cat(vis_layers, dim=0) # (L, 224, 224, 3)
    img_grid = torchvision.utils.make_grid(vis_layers, nrow=len(layers)+1, scale_each=True, normalize=True, pad_value=1)
    axes.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
    axes.set_title(f"{i=:,}, {CLS2IDX[cls_idx]}")
    axes.axis('off')
    # fig.suptitle(f"{f.stem}")
        
    fig.tight_layout()
    path = output_dir / f"{img_path.parent.stem}_{img_path.stem}_i{i}_predmap_per_layer.png"
    print(path)
    plt.savefig(path)
    # break
    plt.ioff()
# %%
