#%%
import time
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm

from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_LRP import vit_large_patch16_224 as vit_large_LRP
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
from baselines.ViT.ViT_orig_LRP import vit_large_patch16_224 as vit_large_orig_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
from baselines.ViT.ViT_new import vit_large_patch16_224 as vit_large_new

from baselines.ViT.ViT_explanation_generator import Baselines, LRP

class Runner:
    def __init__(self, model_type):
        if model_type == "vit-b":
            self.model_LRP = vit_LRP(pretrained=True).cuda()
            self.model_LRP.eval()
            self.model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
            self.model_orig_LRP.eval()
            self.model_new = vit_new(pretrained=True).cuda()
            self.model_new.eval()
        elif model_type == "vit-l":
            self.model_LRP = vit_large_LRP(pretrained=True).cuda()
            self.model_LRP.eval()
            self.model_orig_LRP = vit_large_orig_LRP(pretrained=True).cuda()
            self.model_orig_LRP.eval()
            self.model_new = vit_large_new(pretrained=True).cuda()
            self.model_new.eval()
        else:
            raise ValueError(f"Invalid model type: {model_type}")
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.transform_without_normalize = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        self.transform = transforms.Compose([
            self.transform_without_normalize, 
            self.normalize
        ])
        self.attribution_generator = LRP(self.model_LRP)
        self.orig_lrp = LRP(self.model_orig_LRP)
        self.baselines = Baselines(self.model_new)
    
    def run(self, method_name, image, cls_idx):
        # image (C, H, W)
        image = image.unsqueeze(0).cuda() # (B, C, H, W)
        if method_name == "predmap15":
            def lambda_func():
                with torch.no_grad():
                    return self.model_new.predmap15_softmax_classes_batched_layer(image, cls_idx)
            method_func = lambda_func
        if method_name == "predmap15_slow":
            method_func = lambda: self.model_LRP.predmap15_softmax_classes_batched_layer(image, cls_idx)
        elif method_name == "transformer_attribution":
            method_func = lambda: self.attribution_generator.generate_LRP(image, start_layer=0, method="transformer_attribution", index=cls_idx)
        elif method_name == "rollout":
            method_func = lambda: self.baselines.generate_rollout(image, start_layer=1)
        elif method_name == "last_layer_attn":
            method_func = lambda: self.baselines.generate_raw_attn(image)
        elif method_name == "attn_gradcam":
            method_func = lambda: self.baselines.generate_cam_attn(image, index=cls_idx)
        elif method_name == "lrp":
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="full")
        elif method_name == "lrp_last_layer":
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="last_layer")
        
        start_time = time.time()
        xai_map = method_func()
        end_time = time.time()
        exec_time = end_time - start_time

        xai_map = xai_map.detach()

        return xai_map, exec_time
    


cls_idx_tusker = 101
cls_idx_zebra = 340

model_type = "vit-b"
model_type = "vit-l"

runner = Runner(model_type)
image_PIL = Image.open('samples/el2.png')
image = runner.transform(image_PIL) # (C, H, W)


method_name_lst = [
    "lrp",
    "lrp_last_layer",
    "transformer_attribution",
    "attn_gradcam",
    "rollout",
    "predmap15",
    # "predmap15_slow",
    "last_layer_attn",
]
# method_name = "transformer_attribution"
# method_name = "attn_gradcam"
# method_name = "lrp"
# method_name = "rollout"
# method_name = "predmap15"
# method_name = "last_layer_attn"


repeats = 10
method2xai_map = {}
print(f"Repeats: {repeats}")
for method_name in method_name_lst:
    print(f"Method: {method_name}")
    exec_time_lst = []
    xai_map_lst = []
    for _ in range(repeats):
        xai_map, exec_time = runner.run(method_name, image, cls_idx_tusker)
        # xai_map, exec_time = runner.run(method_name, image, cls_idx_zebra)
        exec_time_lst.append(exec_time)
        # print(f"Execution time: {exec_time:.4f} seconds")
        xai_map = xai_map.detach()
        xai_map_lst.append(xai_map)
    exec_time_pt = torch.tensor(exec_time_lst)
    if not all(torch.allclose(xai_map_lst[0], x) for x in xai_map_lst):
        print("WARNING: XAI maps are not equal between repeats")

    print(f"first samples:\t", ", ".join(f"{v*1000:.4} ms" for v in exec_time_pt[:5]))
    print(f"Mean:\t {exec_time_pt.mean()*1000:.4f} ms")
    print(f"Std:\t {exec_time_pt.std()*1000:.4f} ms")
    print()
    method2xai_map[method_name] = xai_map_lst[0]
# %%
# Verify raw attention maps are equal
# fast_raw_attn, fast_exec_time = runner.run("last_layer_attn", image, cls_idx_tusker)
# start_time = time.time()
# slow_raw_attn = runner.orig_lrp.generate_LRP(image.unsqueeze(0).cuda(), method="last_layer_attn")
# end_time = time.time()
# slow_exec_time = end_time - start_time
# torch.allclose(fast_raw_attn, slow_raw_attn)
# print(f"Fast execution time: {fast_exec_time:.4f} seconds")
# print(f"Slow execution time: {slow_exec_time:.4f} seconds")
# %%

# fast_predmap15, fast_exec_time = runner.run("predmap15", image, cls_idx_tusker)
# slow_predmap15, slow_exec_time = runner.run("predmap15_slow", image, cls_idx_tusker)
# torch.allclose(slow_predmap15,fast_predmap15)
# %%

# %%
import math
ncols = 3
nrows = math.ceil(len(method2xai_map) / ncols)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
for ax, (method_name, xai_map_raw) in zip(axes.flat, method2xai_map.items()):
    if method_name == "lrp":
        xai_map = xai_map_raw.reshape(224, 224)
    else:
        xai_map = xai_map_raw.reshape(14, 14)

    ax.imshow(xai_map.cpu().numpy(), cmap="jet")
    ax.set_title(method_name)
    ax.axis("off")
# %%
image_PIL
# %%
