# %%
import argparse
import collections
import time
from pathlib import Path

import lovely_numpy
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from lovely_numpy import lo
from matplotlib import ticker
from PIL import Image

lovely_numpy.set_config(deeper_width=12)
import cv2
import lovely_tensors as lt

lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import os
import sys

base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

import torchvision
import tqdm
from torchvision.datasets import ImageFolder

from baselines.ViT.ViT_explanation_generator import LRP, Baselines
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_LRP import vit_large_patch16_224 as vit_large_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
from baselines.ViT.ViT_new import vit_large_patch16_224 as vit_large_new
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
from baselines.ViT.ViT_orig_LRP import vit_large_patch16_224 as vit_large_orig_LRP


class Runner:
    def __init__(self, model_type):
        if model_type == "vit-b":
            self.model_LRP = vit_LRP(pretrained=True).cuda()
            self.model_LRP.eval()
            self.model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
            self.model_orig_LRP.eval()
            self.model_new = vit_new(pretrained=True).cuda()
            self.model_new.eval()
        elif model_type == "vit-l":
            self.model_LRP = vit_large_LRP(pretrained=True).cuda()
            self.model_LRP.eval()
            self.model_orig_LRP = vit_large_orig_LRP(pretrained=True).cuda()
            self.model_orig_LRP.eval()
            self.model_new = vit_large_new(pretrained=True).cuda()
            self.model_new.eval()
        else:
            raise ValueError(f"Invalid model type: {model_type}")
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.transform_without_normalize = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]
        )
        self.transform = transforms.Compose(
            [self.transform_without_normalize, self.normalize]
        )
        self.attribution_generator = LRP(self.model_LRP)
        self.orig_lrp = LRP(self.model_orig_LRP)
        self.baselines = Baselines(self.model_new)

    def run(self, method_name, image, cls_idx):
        # image (C, H, W)
        image = image.unsqueeze(0).cuda()  # (B, C, H, W)
        if method_name == "predicatt":

            def lambda_func():
                with torch.no_grad():
                    return self.model_new.predmap15_softmax_classes_batched_layer(
                        image, cls_idx
                    )

            method_func = lambda_func
        elif method_name == "transformer_attribution":
            method_func = lambda: self.attribution_generator.generate_LRP(
                image, start_layer=0, method="transformer_attribution", index=cls_idx
            )
        elif method_name == "rollout":
            method_func = lambda: self.baselines.generate_rollout(image, start_layer=1)
        elif method_name == "attn_last_layer":
            method_func = lambda: self.baselines.generate_raw_attn(image)
        elif method_name == "attn_gradcam":
            method_func = lambda: self.baselines.generate_cam_attn(image, index=cls_idx)
        elif method_name == "full_lrp":
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="full")
        elif method_name == "lrp_last_layer":
            method_func = lambda: self.orig_lrp.generate_LRP(image, method="last_layer")

        start_time = time.time()
        xai_map = method_func()
        end_time = time.time()
        exec_time = end_time - start_time

        xai_map = xai_map.detach()

        return xai_map, exec_time


def main():
    parser = argparse.ArgumentParser(description="Benchmark runtime")
    parser.add_argument(
        "--method",
        type=str,
        default="predicatt",
    )
    parser.add_argument(
        "--model_type",
        type=str,
        default="vit-b",
        choices=["vit-b", "vit-l"],
    )
    parser.add_argument(
        "--sample_image",
        type=str,
        default="samples/el2.png",
    )
    parser.add_argument(
        "--cls_idx",
        type=int,
        default=101,  # tusker
    )
    parser.add_argument(
        "--repeats",
        type=int,
        default=3,
    )
    args = parser.parse_args()

    repeats = args.repeats
    method_name = args.method
    model_type = args.model_type
    sample_path = args.sample_image
    cls_idx = args.cls_idx

    runner = Runner(model_type)
    image_PIL = Image.open(sample_path)
    image = runner.transform(image_PIL)  # (C, H, W)

    print(f"Repeats: {repeats}")
    print(f"Method: {method_name}")
    exec_time_lst = []
    xai_map_lst = []
    for _ in range(repeats):
        xai_map, exec_time = runner.run(method_name, image, cls_idx)
        exec_time_lst.append(exec_time)
        xai_map = xai_map.detach()
        xai_map_lst.append(xai_map)
    exec_time_pt = torch.tensor(exec_time_lst)
    if not all(torch.allclose(xai_map_lst[0], x) for x in xai_map_lst):
        print("WARNING: XAI maps are not equal between repeats")

    print(f"first samples:\t", ", ".join(f"{v*1000:.4} ms" for v in exec_time_pt[:5]))
    print(f"Mean:\t {exec_time_pt.mean()*1000:.4f} ms")
    print(f"Std:\t {exec_time_pt.std()*1000:.4f} ms")
    print()

if __name__ == "__main__":
    main()
