#%%
import time
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm

from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new

from baselines.ViT.ViT_explanation_generator import Baselines, LRP

import gc

from contextlib import contextmanager

@contextmanager
def measure_memory(title=""):
    print(title)
    try:
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        start_memory_allocated = torch.cuda.memory_allocated()
        print(f"Before: {start_memory_allocated/2**20:.4f} MiB")
        yield
    finally:
        end_memory_allocated = torch.cuda.memory_allocated()
        end_max_memory_allocated = torch.cuda.max_memory_allocated()
        print(f"After:  {end_memory_allocated/2**20:.4f} MiB \t Max:({end_max_memory_allocated/2**20:.4f} MiB)")
        print(f"diff: {(end_memory_allocated-start_memory_allocated)/2**20:.4f} MiB \t Max:({(end_max_memory_allocated-start_memory_allocated)/2**20:.4f} MiB)")
        print("")
class Runner:
    def __init__(self):
        self.model_LRP = vit_LRP(pretrained=True).cuda()
        self.model_LRP.eval()
        self.model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
        self.model_orig_LRP.eval()
        start_memory_allocated = torch.cuda.memory_allocated()
        self.model_new = vit_new(pretrained=True).cuda()
        end_memory_allocated = torch.cuda.memory_allocated()
        print(f"Memory allocated for model_new: {(end_memory_allocated-start_memory_allocated)/2**20:.4f} MiB")
        self.model_new.eval()
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.transform_without_normalize = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        self.transform = transforms.Compose([
            self.transform_without_normalize, 
            self.normalize
        ])
        self.attribution_generator = LRP(self.model_LRP)
        self.orig_lrp = LRP(self.model_orig_LRP)
        self.baselines = Baselines(self.model_new)
    
    def run(self, method_name, image, cls_idx):
        # image (C, H, W)
        image = image.unsqueeze(0).cuda() # (B, C, H, W)
        if method_name == "predmap15":
            def lambda_func():
                with torch.no_grad():
                    return self.model_new.predmap15_softmax_classes_batched_layer(image, cls_idx)
            method_func = lambda_func
        if method_name == "predmap15_slow":
            method_func = lambda: self.model_LRP.predmap15_softmax_classes_batched_layer(image, cls_idx)
        elif method_name == "transformer_attribution":
            with measure_memory("clear_model_LRP"):
                self.clear_model_LRP()
            method_func = lambda: self.attribution_generator.generate_LRP(image, start_layer=0, method="transformer_attribution", index=cls_idx)
        elif method_name == "rollout":
            method_func = lambda: self.baselines.generate_rollout(image, start_layer=1)
        elif method_name == "last_layer_attn":
            method_func = lambda: self.baselines.generate_raw_attn(image)
        elif method_name == "attn_gradcam":
            self.clear_model_new()
            # method_func = lambda: self.baselines.generate_cam_attn(image, index=cls_idx)
            method_func = lambda: generate_cam_attn(self.model_new, image, index=cls_idx)
        elif method_name == "lrp":
            with measure_memory("clear_model_LRP_orig"):
                self.clear_model_LRP_orig()
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="full")
            # method_func = lambda: orig_lrp_generate_LRP(self.model_orig_LRP, image, index=cls_idx)
        elif method_name == "lrp_last_layer":
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="last_layer")

        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()
        torch.cuda.reset_peak_memory_stats()
        start_memory_allocated = torch.cuda.memory_allocated()
        start_memory_reserved = torch.cuda.memory_reserved()
        start_max_memory_allocated = torch.cuda.max_memory_allocated()
        start_max_memory_reserved = torch.cuda.max_memory_reserved()
        start_stats = torch.cuda.memory_stats()
        start_time = time.time()
        with measure_memory("METHOD_FUNC"):
            xai_map = method_func()
        end_time = time.time()
        exec_time = end_time - start_time
        end_stats = torch.cuda.memory_stats()
        end_memory_allocated = torch.cuda.memory_allocated()
        end_memory_reserved = torch.cuda.memory_reserved()
        end_max_memory_allocated = torch.cuda.max_memory_allocated()
        end_max_memory_reserved = torch.cuda.max_memory_reserved()
        
        # print(f"{start_memory_allocated=} \t {end_memory_allocated=} \t diff = {end_memory_allocated-start_memory_allocated}")
        # print(f"{start_max_memory_allocated=} \t {end_max_memory_allocated=} \t diff = {end_max_memory_allocated-start_max_memory_allocated}")
        # print(f"{start_memory_reserved=} \t {end_memory_reserved=} \t diff = {end_memory_reserved-start_memory_reserved}")
        # print(f"{start_max_memory_reserved=} \t {end_max_memory_reserved=} \t diff = {end_max_memory_reserved-start_max_memory_reserved}")
        # print(f"max diff = {(end_max_memory_allocated - start_memory_allocated)/2**20:.2f} MB")
        mem_diff = (end_max_memory_allocated - start_memory_allocated)
        # print(f"{start_stats=}")
        # print(f"{end_stats=}")

        xai_map_cpu = xai_map.detach().cpu()
        xai_map = None
        del xai_map
        gc.collect()

        return xai_map_cpu, mem_diff
    
    def clear_relprop(self, obj):
        import modules.layers_ours
        import modules.layers_lrp
        # print(f"{obj=}")
        if isinstance(obj, (modules.layers_ours.RelProp, modules.layers_lrp.RelProp)):
            obj.X = None
            obj.Y = None
            obj.grad_input = None
            obj.grad_output = None
        mm = list(obj.modules())
        if len(mm) > 1:
            for m in mm[1:]:
                self.clear_relprop(m)
    
    def clear_model_LRP(self):
        model = self.model_LRP
        model.zero_grad(set_to_none=True)
        # from baselines.ViT.vit_LRP import VisionTransformer
        # VisionTransformer.
        print("clear")
        self.clear_relprop(model)
        for b in model.blocks:
            b.x = None
            b.mhsa = None
            b.attn.attn = None
            b.attn.attn_cam = None
            b.attn.v = None
            b.attn.v_cam = None
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()

    def clear_model_LRP_orig(self):
        model = self.model_orig_LRP
        # from baselines.ViT.vit_LRP import VisionTransformer
        # VisionTransformer.
        # print("clear")
        model.zero_grad(set_to_none=True)
        model.inp_grad = None
        self.clear_relprop(model)
        for b in model.blocks:
            # for m in b.modules():
                # m.X = None
                # m.Y = None
            # print(f"{b.attn.attn=}")
            b.attn.attn = None
            # print(f"{b.attn.attn_cam=}")
            b.attn.attn_cam = None
            # print(f"{b.attn.v=}")
            b.attn.v = None
            # print(f"{b.attn.v_cam=}")
            b.attn.v_cam = None
            # print(f"{b.attn.attn_gradients=}")
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()

    def clear_model_new(self):
        model = self.model_new
        # from baselines.ViT.vit_LRP import VisionTransformer
        # VisionTransformer.
        # print("zero_grad")
        model.zero_grad(set_to_none=True)
        for b in model.blocks:
            b.attn.attn = None
            b.attn.attn_cam = None
            b.attn.v = None
            b.attn.v_cam = None
            b.attn.attn_gradients = None
            
            # print(f"{b.attn.attention_map=}")
            b.attn.attention_map = None
            # print(f"{b.attn.attn_gradients=}")
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()


cls_idx_tusker = 101
cls_idx_zebra = 340

runner = Runner()
image_PIL = Image.open('samples/el2.png')
image = runner.transform(image_PIL) # (C, H, W)


# %%
def orig_lrp_generate_LRP(model, input, index=None, method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs):
    output = model(input)
    kwargs = {"alpha": 1}
    if index == None:
        index = np.argmax(output.cpu().data.numpy(), axis=-1)

    one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot_vector = one_hot
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cuda() * output)

    model.zero_grad(set_to_none=True) # TODO:
    one_hot.backward(retain_graph=True) # TODO:

    res = model.relprop(torch.tensor(one_hot_vector).to(input.device), method=method, is_ablation=is_ablation,
                                start_layer=start_layer, **kwargs)
    res = torch.zeros(14,14)
    res_cpu = res.detach().cpu()
    res = None
    del res
    one_hot = None
    del one_hot
    output = None
    del output
    gc.collect()

    return res_cpu


def generate_cam_attn(model, input, index=None):
    print("A")

    with measure_memory("generate_cam_attn-FORWARD"):
        output = model(input.cuda(), register_hook=True) # (batch, classes)
    # end_max_memory_allocated = torch.cuda.max_memory_allocated()
    # mem_diff = (end_max_memory_allocated - start_memory_allocated)
    # print(f"diff-model: {mem_diff/2**20:.4} MiB")
    if index == None:
        index = np.argmax(output.cpu().data.numpy())

    one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) # (1, classes)
    one_hot[0][index] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cuda() * output) # (1, )

    model.zero_grad(set_to_none=True)
    start_memory_allocated = torch.cuda.memory_allocated()
    one_hot.backward(retain_graph=True)
    end_max_memory_allocated = torch.cuda.max_memory_allocated()
    mem_diff = (end_max_memory_allocated - start_memory_allocated)
    print(f"diff-backward: {mem_diff/2**20:.4} MiB")
    #################### attn
    grad = model.blocks[-1].attn.get_attn_gradients() # (batch, heads, tokens, tokens)
    cam = model.blocks[-1].attn.get_attention_map() # (batch, heads, tokens, tokens)
    cam = cam[0, :, 0, 1:].reshape(-1, 14, 14) # (heads, 14, 14)
    grad = grad[0, :, 0, 1:].reshape(-1, 14, 14) # (heads, 14, 14)
    grad = grad.mean(dim=[1, 2], keepdim=True) # (heads, 1, 1)
    cam = (cam * grad).mean(0).clamp(min=0) # (14, 14)
    cam = (cam - cam.min()) / (cam.max() - cam.min())

    cam_cpu = cam.detach().cpu()
    cam = None
    del cam
    one_hot = None
    del one_hot
    grad = None
    del grad
    output = None
    del output
    gc.collect()

    return cam_cpu
    #################### attn

method_name_lst = [
    # "lrp",
    # "lrp_last_layer",
    "transformer_attribution",
    # "predmap15",
    # "attn_gradcam",
    # "rollout",
    # "predmap15_slow",
    # "last_layer_attn",
]


repeats = 4
method2xai_map = {}
print(f"Repeats: {repeats}")
for method_name in method_name_lst:
    print(f"Method: {method_name}")
    mem_diff_lst = []
    xai_map_lst = []
    for _ in range(repeats):
        print(f"BEFORE RUN: alloc: {torch.cuda.memory_allocated()/2**20:.4f} MiB")
        print(f"reserved: {torch.cuda.memory_reserved()/2**20:.4f} MiB")
        xai_map, mem_diff = runner.run(method_name, image, cls_idx_tusker)
        print(f"AFTER RUN: {torch.cuda.memory_allocated()/2**20:.4f} MiB")
        # xai_map, exec_time = runner.run(method_name, image, cls_idx_zebra)
        mem_diff_lst.append(mem_diff)
        # print(f"Execution time: {exec_time:.4f} seconds")
        xai_map = xai_map.detach()
        xai_map_lst.append(xai_map)
        with measure_memory("cleanup"):
            runner.clear_model_LRP_orig()
            # runner.clear_model_new()
        print("=====================")
    mem_diff_pt = torch.tensor(mem_diff_lst, dtype=torch.float32)
    if not all(torch.allclose(xai_map_lst[0], x) for x in xai_map_lst):
        print("WARNING: XAI maps are not equal between repeats")

    print(f"first samples:\t", ", ".join(f"{v/2**20:.4} MiB" for v in mem_diff_pt[:5]))
    print(f"Mean:\t {mem_diff_pt[1:].mean() / 2**20:.4f} MiB")
    print(f"Std:\t {mem_diff_pt[1:].std() / 2**20:.4f} MiB")
    print()
    method2xai_map[method_name] = xai_map_lst[0]
# %%
# Verify raw attention maps are equal
# fast_raw_attn, fast_exec_time = runner.run("last_layer_attn", image, cls_idx_tusker)
# start_time = time.time()
# slow_raw_attn = runner.orig_lrp.generate_LRP(image.unsqueeze(0).cuda(), method="last_layer_attn")
# end_time = time.time()
# slow_exec_time = end_time - start_time
# torch.allclose(fast_raw_attn, slow_raw_attn)
# print(f"Fast execution time: {fast_exec_time:.4f} seconds")
# print(f"Slow execution time: {slow_exec_time:.4f} seconds")
# %%

# fast_predmap15, fast_exec_time = runner.run("predmap15", image, cls_idx_tusker)
# slow_predmap15, slow_exec_time = runner.run("predmap15_slow", image, cls_idx_tusker)
# torch.allclose(slow_predmap15,fast_predmap15)
# %%

# %%
import math
ncols = 3
nrows = math.ceil(len(method2xai_map) / ncols)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
for ax, (method_name, xai_map_raw) in zip(axes.flat, method2xai_map.items()):
    if method_name == "lrp":
        xai_map = xai_map_raw.reshape(224, 224)
    else:
        xai_map = xai_map_raw.reshape(14, 14)

    ax.imshow(xai_map.cpu().numpy(), cmap="jet")
    ax.set_title(method_name)
    ax.axis("off")
# %%

from baselines.ViT.ViT_new import VisionTransformer
import gc
def clear_attnmap(model: VisionTransformer):
    for b in model.blocks:
        # print(f"{b.attn.attention_map=}")
        b.attn.attention_map = None
        # print(f"{b.attn.attn_gradients=}")
        b.attn.attn_gradients = None
    torch.cuda.empty_cache()
    gc.collect()


lst=[]
# %%

################ MODEL TIMM
# import timm
# import huggingface_hub.utils
# huggingface_hub.utils.disable_progress_bars()
# timm.list_models("*vit*", pretrained=True)
# # model_timm = timm.create_model("vit_base_patch16_224.orig_in21k", pretrained=True) # classifier not valid

# # start_memory_allocated = torch.cuda.memory_allocated()
# # print(f"Before: {start_memory_allocated:_}")
# # model_new = vit_new(pretrained=True).cuda()
# model_timm = timm.create_model("vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=True)
# model_timm.eval()
# model_timm.cuda()

# with torch.no_grad():
#     with measure_memory("forward"):
#         res = model_timm(image.unsqueeze(0).cuda())

# lst.append(model_timm)
# end_memory_allocated = torch.cuda.memory_allocated()
# print(f"After:  {end_memory_allocated:_}")
# print(f"Memory allocated for model_new: {end_memory_allocated-start_memory_allocated:_}")

# %%

# start_memory_allocated = torch.cuda.memory_allocated()
# print(f"Before: {start_memory_allocated:_}")
gc.collect()
with measure_memory("Model loading"):
    model_new = vit_new(pretrained=True).cuda()
    # model_lrp = vit_LRP(pretrained=True).cuda()
lst.append(model_new)
# lst.append(model_lrp)
# end_memory_allocated = torch.cuda.memory_allocated()
# print(f"After:  {end_memory_allocated:_}")
# print(f"Memory allocated for model_new: {end_memory_allocated-start_memory_allocated:_}")
# %%

def clear_model_LRP(model):
    # from baselines.ViT.vit_LRP import VisionTransformer
    # VisionTransformer.
    print("clear")
    model.inp_grad = None
    for b in model.blocks:
        b.attn.attn = None
        b.attn.attn_cam = None
        b.attn.v = None
        b.attn.v_cam = None
        b.attn.attn_gradients = None
    torch.cuda.empty_cache()
    gc.collect()

with measure_memory("cleanup"):
    model_new.zero_grad(set_to_none=True)
    clear_attnmap(model_new)
# clear_model_LRP(model_lrp)
# with torch.no_grad():
with measure_memory("forward"):
    # with torch.no_grad():
    res = model_new(image.unsqueeze(0).cuda())

with measure_memory("backward"):
    # res.sum().backward(retain_graph=True)
    res.sum().backward(retain_graph=True)

    # res = model_lrp(image.unsqueeze(0).cuda())
    # res = model_timm(image.unsqueeze(0).cuda())
    # res = None
# lst.append(res)

with measure_memory("cleanup-res"):
    res = None
    gc.collect()

with measure_memory("cleanup-zero_grad"):
    model_new.zero_grad(set_to_none=True)

with measure_memory("cleanup-attnmap"):
    clear_attnmap(model_new)

# %%
