#%%
import time
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm

from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_LRP import vit_large_patch16_224 as vit_large_LRP
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
from baselines.ViT.ViT_orig_LRP import vit_large_patch16_224 as vit_large_orig_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
from baselines.ViT.ViT_new import vit_large_patch16_224 as vit_large_new

from baselines.ViT.ViT_explanation_generator import Baselines, LRP
import gc


from contextlib import contextmanager

@contextmanager
def measure_memory(title=""):
    print(title)
    try:
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        start_memory_allocated = torch.cuda.memory_allocated()
        print(f"Before: {start_memory_allocated/2**20:.4f} MiB")
        yield
    finally:
        end_memory_allocated = torch.cuda.memory_allocated()
        end_max_memory_allocated = torch.cuda.max_memory_allocated()
        print(f"After:  {end_memory_allocated/2**20:.4f} MiB \t Max:({end_max_memory_allocated/2**20:.4f} MiB)")
        print(f"diff: {(end_memory_allocated-start_memory_allocated)/2**20:.4f} MiB \t Max:({(end_max_memory_allocated-start_memory_allocated)/2**20:.4f} MiB)")
        print("")

class Runner:
    def __init__(self, model_type):
        if model_type == "vit-b":
            self.model_LRP = vit_LRP(pretrained=True).cuda()
            self.model_LRP.eval()
            self.model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
            self.model_orig_LRP.eval()
            self.model_new = vit_new(pretrained=True).cuda()
            self.model_new.eval()
        elif model_type == "vit-l":
            self.model_LRP = vit_large_LRP(pretrained=True).cuda()
            self.model_LRP.eval()
            self.model_orig_LRP = vit_large_orig_LRP(pretrained=True).cuda()
            self.model_orig_LRP.eval()
            self.model_new = vit_large_new(pretrained=True).cuda()
            self.model_new.eval()
        else:
            raise ValueError(f"Invalid model type: {model_type}")
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.transform_without_normalize = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        self.transform = transforms.Compose([
            self.transform_without_normalize, 
            self.normalize
        ])
        self.attribution_generator = LRP(self.model_LRP)
        self.orig_lrp = LRP(self.model_orig_LRP)
        self.baselines = Baselines(self.model_new)
    
    def run(self, method_name, image, cls_idx):
        # image (C, H, W)
        image = image.unsqueeze(0).cuda() # (B, C, H, W)
        if method_name == "predmap15":
            self.clear_model_new()
            def lambda_func():
                with torch.no_grad():
                    return self.model_new.predmap15_softmax_classes_batched_layer(image, cls_idx)
            method_func = lambda_func
        if method_name == "predmap15_slow":
            self.clear_model_LRP()
            method_func = lambda: self.model_LRP.predmap15_softmax_classes_batched_layer(image, cls_idx)
        elif method_name == "transformer_attribution":
            self.clear_model_LRP()
            method_func = lambda: self.attribution_generator.generate_LRP(image, start_layer=0, method="transformer_attribution", index=cls_idx)
        elif method_name == "rollout":
            self.clear_model_new()
            method_func = lambda: self.baselines.generate_rollout(image, start_layer=1)
        elif method_name == "last_layer_attn":
            self.clear_model_new()
            method_func = lambda: self.baselines.generate_raw_attn(image)
        elif method_name == "attn_gradcam":
            self.clear_model_new()
            method_func = lambda: self.baselines.generate_cam_attn(image, index=cls_idx)
        elif method_name == "lrp":
            self.clear_model_LRP_orig()
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="full")
        elif method_name == "lrp_last_layer":
            self.clear_model_LRP_orig()
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="last_layer")

        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        start_memory_allocated = torch.cuda.memory_allocated()
        start_memory_reserved = torch.cuda.memory_reserved()
        start_max_memory_allocated = torch.cuda.max_memory_allocated()
        start_max_memory_reserved = torch.cuda.max_memory_reserved()
        start_stats = torch.cuda.memory_stats()
        # with measure_memory():
        xai_map = method_func()
        end_stats = torch.cuda.memory_stats()
        end_memory_allocated = torch.cuda.memory_allocated()
        end_memory_reserved = torch.cuda.memory_reserved()
        end_max_memory_allocated = torch.cuda.max_memory_allocated()
        end_max_memory_reserved = torch.cuda.max_memory_reserved()
        
        # print(f"{start_memory_allocated=} \t {end_memory_allocated=} \t diff = {end_memory_allocated-start_memory_allocated}")
        # print(f"{start_max_memory_allocated=} \t {end_max_memory_allocated=} \t diff = {end_max_memory_allocated-start_max_memory_allocated}")
        # print(f"{start_memory_reserved=} \t {end_memory_reserved=} \t diff = {end_memory_reserved-start_memory_reserved}")
        # print(f"{start_max_memory_reserved=} \t {end_max_memory_reserved=} \t diff = {end_max_memory_reserved-start_max_memory_reserved}")
        # print(f"max diff = {(end_max_memory_allocated - start_memory_allocated)/2**20:.2f} MB")
        mem_diff = (end_max_memory_allocated - start_memory_allocated)
        # print(f"{start_stats=}")
        # print(f"{end_stats=}")

        xai_map_cpu = xai_map.detach().cpu()
        xai_map = None
        del xai_map
        gc.collect()

        return xai_map_cpu, mem_diff
        
    def clear_relprop(self, obj):
        import modules.layers_ours
        import modules.layers_lrp
        # print(f"{obj=}")
        if isinstance(obj, (modules.layers_ours.RelProp, modules.layers_lrp.RelProp)):
            obj.X = None
            obj.Y = None
            obj.grad_input = None
            obj.grad_output = None
        mm = list(obj.modules())
        if len(mm) > 1:
            for m in mm[1:]:
                self.clear_relprop(m)

    def clear_model_new(self):
        model = self.model_new
        model.zero_grad(set_to_none=True)
        # from baselines.ViT.vit_LRP import VisionTransformer
        # VisionTransformer.
        # print("clear")
        for b in model.blocks:
            b.attn.attention_map = None
            b.attn.attn_gradients = None

            b.attn.attn = None
            b.attn.attn_cam = None
            b.attn.v = None
            b.attn.v_cam = None
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()
    
    def clear_model_LRP(self):
        model = self.model_LRP
        model.zero_grad(set_to_none=True)
        # from baselines.ViT.vit_LRP import VisionTransformer
        # VisionTransformer.
        # print("clear")
        self.clear_relprop(model)
        for b in model.blocks:
            b.x = None
            b.mhsa = None
            b.attn.attn = None
            b.attn.attn_cam = None
            b.attn.v = None
            b.attn.v_cam = None
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()
    
    def clear_model_LRP_orig(self):
        model = self.model_orig_LRP
        model.zero_grad(set_to_none=True)
        # from baselines.ViT.vit_LRP import VisionTransformer
        # VisionTransformer.
        # print("clear")
        self.clear_relprop(model)
        model.inp_grad = None
        for b in model.blocks:
            b.attn.attn = None
            b.attn.attn_cam = None
            b.attn.v = None
            b.attn.v_cam = None
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()
    


cls_idx_tusker = 101
cls_idx_zebra = 340

model_type = "vit-b"
model_type = "vit-l"

runner = Runner(model_type)
image_PIL = Image.open('samples/el2.png')
image = runner.transform(image_PIL) # (C, H, W)


# %%


method_name_lst = [
    "lrp",
    "lrp_last_layer",
    "transformer_attribution", 
    "attn_gradcam",
    "rollout",
    "predmap15",
    # "predmap15_slow",
    "last_layer_attn",
]


repeats = 3
method2xai_map = {}
print(f"Repeats: {repeats}")
for method_name in method_name_lst:
    print(f"Method: {method_name}")
    mem_diff_lst = []
    xai_map_lst = []
    for _ in range(repeats):
        xai_map, mem_diff = runner.run(method_name, image, cls_idx_tusker)
        # xai_map, exec_time = runner.run(method_name, image, cls_idx_zebra)
        mem_diff_lst.append(mem_diff)
        # print(f"Execution time: {exec_time:.4f} seconds")
        xai_map = xai_map.detach()
        xai_map_lst.append(xai_map)
        
    mem_diff_pt = torch.tensor(mem_diff_lst, dtype=torch.float32)
    if not all(torch.allclose(xai_map_lst[0], x) for x in xai_map_lst):
        print("WARNING: XAI maps are not equal between repeats")

    print(f"first samples:\t", ", ".join(f"{v/2**20:.4f} MiB" for v in mem_diff_pt[:5]))
    print(f"Mean:\t {mem_diff_pt.mean() / 2**20:.4f} MiB")
    print(f"Std:\t {mem_diff_pt.std() / 2**20:.4f} MiB")
    print()
    method2xai_map[method_name] = xai_map_lst[0]
# %%
# Verify raw attention maps are equal
# fast_raw_attn, fast_exec_time = runner.run("last_layer_attn", image, cls_idx_tusker)
# start_time = time.time()
# slow_raw_attn = runner.orig_lrp.generate_LRP(image.unsqueeze(0).cuda(), method="last_layer_attn")
# end_time = time.time()
# slow_exec_time = end_time - start_time
# torch.allclose(fast_raw_attn, slow_raw_attn)
# print(f"Fast execution time: {fast_exec_time:.4f} seconds")
# print(f"Slow execution time: {slow_exec_time:.4f} seconds")
# %%

# fast_predmap15, fast_exec_time = runner.run("predmap15", image, cls_idx_tusker)
# slow_predmap15, slow_exec_time = runner.run("predmap15_slow", image, cls_idx_tusker)
# torch.allclose(slow_predmap15,fast_predmap15)
# %%

# %%
import math
ncols = 3
nrows = math.ceil(len(method2xai_map) / ncols)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
for ax, (method_name, xai_map_raw) in zip(axes.flat, method2xai_map.items()):
    if method_name == "lrp":
        xai_map = xai_map_raw.reshape(224, 224)
    else:
        xai_map = xai_map_raw.reshape(14, 14)

    ax.imshow(xai_map.cpu().numpy(), cmap="jet")
    ax.set_title(method_name)
    ax.axis("off")
# %%
image_PIL
# %%
