# %%
import argparse
import gc
import os
import sys
from pathlib import Path

import lovely_numpy
import torch
import torchvision.transforms as transforms
import tqdm
from lovely_numpy import lo
from PIL import Image

lovely_numpy.set_config(deeper_width=12)

import lovely_tensors as lt

lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")

base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

import modules.layers_lrp
import modules.layers_ours
from baselines.ViT.ViT_explanation_generator import LRP, Baselines
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_LRP import vit_large_patch16_224 as vit_large_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
from baselines.ViT.ViT_new import vit_large_patch16_224 as vit_large_new
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
from baselines.ViT.ViT_orig_LRP import vit_large_patch16_224 as vit_large_orig_LRP


class Runner:
    def __init__(self, model_type):
        if model_type == "vit-b":
            self.model_LRP = vit_LRP(pretrained=True).cuda()
            self.model_LRP.eval()
            self.model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
            self.model_orig_LRP.eval()
            self.model_new = vit_new(pretrained=True).cuda()
            self.model_new.eval()
        elif model_type == "vit-l":
            self.model_LRP = vit_large_LRP(pretrained=True).cuda()
            self.model_LRP.eval()
            self.model_orig_LRP = vit_large_orig_LRP(pretrained=True).cuda()
            self.model_orig_LRP.eval()
            self.model_new = vit_large_new(pretrained=True).cuda()
            self.model_new.eval()
        else:
            raise ValueError(f"Invalid model type: {model_type}")
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.transform_without_normalize = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]
        )
        self.transform = transforms.Compose(
            [self.transform_without_normalize, self.normalize]
        )
        self.attribution_generator = LRP(self.model_LRP)
        self.orig_lrp = LRP(self.model_orig_LRP)
        self.baselines = Baselines(self.model_new)

    def run(self, method_name, image, cls_idx):
        # image (C, H, W)
        image = image.unsqueeze(0).cuda()  # (B, C, H, W)
        if method_name == "predicatt":
            self.clear_model_new()

            def lambda_func():
                with torch.no_grad():
                    return self.model_new.predmap15_softmax_classes_batched_layer(
                        image, cls_idx
                    )

            method_func = lambda_func
        elif method_name == "transformer_attribution":
            self.clear_model_LRP()
            method_func = lambda: self.attribution_generator.generate_LRP(
                image, start_layer=0, method="transformer_attribution", index=cls_idx
            )
        elif method_name == "rollout":
            self.clear_model_new()
            method_func = lambda: self.baselines.generate_rollout(image, start_layer=1)
        elif method_name == "attn_last_layer":
            self.clear_model_new()
            method_func = lambda: self.baselines.generate_raw_attn(image)
        elif method_name == "attn_gradcam":
            self.clear_model_new()
            method_func = lambda: self.baselines.generate_cam_attn(image, index=cls_idx)
        elif method_name == "full_lrp":
            self.clear_model_LRP_orig()
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="full")
        elif method_name == "lrp_last_layer":
            self.clear_model_LRP_orig()
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="last_layer")
        else:
            raise ValueError(f"Invalid method: {method_name}")

        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        start_memory_allocated = torch.cuda.memory_allocated()
        xai_map = method_func()
        end_max_memory_allocated = torch.cuda.max_memory_allocated()

        mem_diff = end_max_memory_allocated - start_memory_allocated

        xai_map_cpu = xai_map.detach().cpu()
        xai_map = None
        del xai_map
        gc.collect()

        return xai_map_cpu, mem_diff

    def clear_relprop(self, obj):
        if isinstance(obj, (modules.layers_ours.RelProp, modules.layers_lrp.RelProp)):
            obj.X = None
            obj.Y = None
            obj.grad_input = None
            obj.grad_output = None
        mm = list(obj.modules())
        if len(mm) > 1:
            for m in mm[1:]:
                self.clear_relprop(m)

    def clear_model_new(self):
        model = self.model_new
        model.zero_grad(set_to_none=True)
        for b in model.blocks:
            b.attn.attention_map = None
            b.attn.attn_gradients = None

            b.attn.attn = None
            b.attn.attn_cam = None
            b.attn.v = None
            b.attn.v_cam = None
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()

    def clear_model_LRP(self):
        model = self.model_LRP
        model.zero_grad(set_to_none=True)
        self.clear_relprop(model)
        for b in model.blocks:
            b.x = None
            b.mhsa = None
            b.attn.attn = None
            b.attn.attn_cam = None
            b.attn.v = None
            b.attn.v_cam = None
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()

    def clear_model_LRP_orig(self):
        model = self.model_orig_LRP
        model.zero_grad(set_to_none=True)
        self.clear_relprop(model)
        model.inp_grad = None
        for b in model.blocks:
            b.attn.attn = None
            b.attn.attn_cam = None
            b.attn.v = None
            b.attn.v_cam = None
            b.attn.attn_gradients = None
        torch.cuda.empty_cache()
        gc.collect()


def main():
    parser = argparse.ArgumentParser(description="Benchmark memory")
    parser.add_argument(
        "--method",
        type=str,
        default="predicatt",
    )
    parser.add_argument(
        "--model_type",
        type=str,
        default="vit-b",
        choices=["vit-b", "vit-l"],
    )
    parser.add_argument(
        "--sample_image",
        type=str,
        default="samples/el2.png",
    )
    parser.add_argument(
        "--cls_idx",
        type=int,
        default=101,  # tusker
    )
    parser.add_argument(
        "--repeats",
        type=int,
        default=3,
    )
    args = parser.parse_args()

    repeats = args.repeats
    method_name = args.method
    model_type = args.model_type
    sample_path = args.sample_image
    cls_idx = args.cls_idx

    runner = Runner(model_type)
    image_PIL = Image.open(sample_path)
    image = runner.transform(image_PIL)  # (C, H, W)

    print(f"Repeats: {repeats}")
    print(f"Method: {method_name}")
    mem_diff_lst = []
    xai_map_lst = []
    for _ in range(repeats):
        xai_map, mem_diff = runner.run(method_name, image, cls_idx)
        mem_diff_lst.append(mem_diff)
        xai_map = xai_map.detach()
        xai_map_lst.append(xai_map)

        mem_diff_pt = torch.tensor(mem_diff_lst, dtype=torch.float32)
        if not all(torch.allclose(xai_map_lst[0], x) for x in xai_map_lst):
            print("WARNING: XAI maps are not equal between repeats")

    print(
        f"first samples:\t",
        ", ".join(f"{v/2**20:.4f} MiB" for v in mem_diff_pt[:5]),
    )
    print(f"Mean:\t {mem_diff_pt.mean() / 2**20:.4f} MiB")
    print(f"Std:\t {mem_diff_pt.std() / 2**20:.4f} MiB")


if __name__ == "__main__":
    main()
