# %%
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm
# %%

from samples.CLS2IDX import CLS2IDX
from baselines.ViT.ViT_explanation_generator import Baselines, LRP
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
import types

# IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_without_normalize = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
transform = transforms.Compose([
    transform_without_normalize, 
    normalize
])

# %%


# initialize ViT pretrained
model_LRP = vit_LRP(pretrained=True).cuda()
model_LRP.eval()
lrp = LRP(model_LRP)

model_new = vit_new(pretrained=True).cuda()
model_new.eval()
baselines = Baselines(model_new)


# %%

# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + 1*np.float32(img)
    cam = cam / np.max(cam)
    return cam


# %%

def gen_visualization(image, xai_map):
    # xai_map  # (1, 1, 14, 14)
    xai_map = (xai_map - xai_map.amin())/(xai_map.amax() - xai_map.amin())
    xai_map = torch.nn.functional.interpolate(xai_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    xai_map = xai_map.squeeze().detach().cpu().numpy() # (224, 224)
    
    vis = show_cam_on_image(image.permute(1, 2, 0).numpy(), xai_map)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis


data = [
    ('samples/catdog.png',
     [
        282, #cat
        243, #dog
     ]),
     ('samples/el2.png',
      [
        101, #tusker
        340, #zebra
      ]),
      # SUPPLEMENTARY
    ('samples/dogbird.png',
     [
        163, #basset
        87, #parrot
     ]),
    ('samples/dogcat2.png',
     [
        207, # golden retriever, 207, 
        285, # Egyptian cat, 285,
     ]),
     ('samples/el1.png',
      [
        386, # African elephant
        340, #zebra
      ]),
     ('samples/el3.png',
      [
        386, # African elephant
        340, # zebra
      ]),
     ('samples/el4.png',
      [
        386, # African elephant
        340, # zebra
      ]),
     ('samples/el5.png',
      [
        386, # African elephant
        340, # zebra
      ]),
 ]

method_name2func={
    f"PredicAtt_11": lambda image_norm, cls_idx: (model_new.predmap15_softmax_classes_batched_layer(image_norm, cls_idx, layer=11-1)
                                              .detach().cpu().reshape(1, 1, 14,14)),
    f"PredicAtt_12": lambda image_norm, cls_idx: (model_new.predmap15_softmax_classes_batched_layer(image_norm, cls_idx, layer=12-1)
                                              .detach().cpu().reshape(1, 1, 14,14)),
    "TransAttr": lambda image_norm, cls_idx: (lrp.generate_LRP(image_norm, cls_idx, method="transformer_attribution")
                                              .detach().cpu().reshape(1, 1, 14,14)),
    "Partial-LRP": lambda image_norm, cls_idx: (lrp.generate_LRP(image_norm, cls_idx, method="last_layer")
                                              .detach().cpu().reshape(1, 1, 14,14)),
    "GradCAM": lambda image_norm, cls_idx: (baselines.generate_cam_attn(image_norm, cls_idx)
                                              .detach().cpu().reshape(1, 1, 14,14)),
    "Raw-Attention": lambda image_norm, cls_idx: (lrp.generate_LRP(image_norm, cls_idx, method="last_layer_attn")
                                              .detach().cpu().reshape(1, 1, 14,14)),
    "Rollout": lambda image_norm, cls_idx: (baselines.generate_rollout(image_norm, start_layer=1)
                                              .detach().cpu().reshape(1, 1, 14,14)),
}

output_dir: Path = base_dir / f"figures/artifacts/class_specific_visualization/"
output_dir.mkdir(exist_ok=True, parents=True)
plt.ioff()
# plt.ion()
plt.close("all")
for image_path, clsidx_lst in data:
    print(image_path)
    img_name = Path(image_path).stem
    image = Image.open(image_path)
    image = transform_without_normalize(image) # (3, 224, 224)
    path = output_dir / f'{img_name}_transformed.png'
    print(path)
    plt.imsave(path, image.permute(1, 2, 0).numpy())
    image_norm = normalize(image).unsqueeze(0).cuda()
    ################################
    model_new(image_norm)
    cls_pred_vec = model_new(image_norm)
    topk = torch.topk(cls_pred_vec, 5)
    print("Top predictions:")
    for ci in topk.indices.squeeze():
        print(f"\t {cls_pred_vec[0, ci].item(): >6.2f} \t {CLS2IDX[ci.item()]}, {ci.item()}, ")
    ################################
    for method_name, method_func in method_name2func.items():
        for cls_idx in clsidx_lst:
            xai_map = method_func(image_norm, cls_idx)

            vis = gen_visualization(image, xai_map)
            cls_name = CLS2IDX[cls_idx]
            plt.figure()
            plt.imshow(vis)
            plt.title(f'{img_name}_{cls_name}_{method_name}')
            # plt.imsave(f'{img_name}_{cls_name}_{method_name}')
            path = output_dir / f'{img_name}_{method_name}_{cls_name}.png'
            plt.imsave(path, vis)
        # break
    # break


# %%

# %%
