# %%
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm
# %%

from samples.CLS2IDX import CLS2IDX
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new

# IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_without_normalize = transforms.Compose([
    # transforms.Resize(256),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
transform = transforms.Compose([
    transform_without_normalize, 
    normalize
])

# %%


# initialize ViT pretrained
model = vit_new(pretrained=True).cuda()
model.eval()


# %%

# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + 1*np.float32(img)
    cam = cam / np.max(cam)
    return cam

def gen_visualization(image, cls_idx, layer):
    image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
    predmap = model.predmap15_ablation_predmap(image_norm, cls_idx, layer=layer) # (B, tokens-1)
    predmap = predmap.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
    xai_map = predmap
    xai_map = (xai_map - xai_map.amin())/(xai_map.amax() - xai_map.amin())
    xai_map = torch.nn.functional.interpolate(xai_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    xai_map = xai_map.squeeze().detach().cpu().numpy() # (224, 224)
    
    vis = show_cam_on_image(image.permute(1, 2, 0).numpy(), xai_map)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis


imagenet_dir = Path("~/ramdisk/datasets/imagenet/").expanduser()
imagenet_val_path = imagenet_dir / "val"
data = [
    # ('samples/catdog.png',
    #  [
    #     282, #cat
    #     243, #dog
    #  ]),
    #  ('samples/el2.png',
    #   [
    #     101, #tusker
    #     340, #zebra
    #   ]),
    # ('samples/dogbird.png',
    #  [
    #     161, #basset
    #     87, #parrot
    #  ]),
    # (imagenet_val_path / 'n02099601/ILSVRC2012_val_00038781.JPEG',
    #  [
    #     207, # Golden Retriever
    #  ]),
 ]

import torchvision.datasets
imagenet_ds = torchvision.datasets.ImageNet(imagenet_dir, split='val', transform=transform_without_normalize)


imagenet_indices=[
    440,
4115,
4480,
8628,
11341,
13449,
16904,
18517,
24681,
25569,
29858,
36293,
46451,
49031,
49165,
]
data
for idx in imagenet_indices:
    img_path, cls_idx = imagenet_ds.samples[idx]
    img_path = Path(img_path)
    
    data.append(
        (
            img_path,
            [cls_idx]
        ))


predmap_per_layer_dir = base_dir / Path(f'figures/artifacts/supp/predmap_per_layer/')
predmap_per_layer_dir.mkdir(exist_ok=True, parents=True)
predmap_per_layer_dir
# %%

for image_path, clsidx_lst in data:
    print(image_path)
    for layer in range(12):
        img_name = Path(image_path).stem
        image = Image.open(image_path)
        image = transform_without_normalize(image) # (3, 224, 224)
        path = predmap_per_layer_dir / f'predmap_{img_name}.png'
        plt.imsave(path, image.permute(1, 2, 0).numpy())
        for cls_idx in clsidx_lst:
            vis = gen_visualization(image, cls_idx, layer) # (224, 224, 3)
            cls_name = CLS2IDX[cls_idx]
            # plt.figure()
            # plt.imshow(vis)
            path = predmap_per_layer_dir / f'predmap_{img_name}_c{cls_idx}_l{layer+1}.png'
            plt.imsave(path, vis)
            # break
        # break

# %%
