#%%
from pathlib import Path
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

import torchvision
import tqdm
# %%
import baselines.ViT.ViT_LRP
import baselines.ViT.ViT_orig_LRP
import baselines.ViT.ViT_new

model_LRP = baselines.ViT.ViT_LRP.vit_base_patch16_224(pretrained=True).cuda()
# model_orig_LRP = baselines.ViT.ViT_orig_LRP.vit_base_patch16_224(pretrained=True).cuda()
model_new = baselines.ViT.ViT_new.vit_base_patch16_224(pretrained=True).cuda()

for model in [model_LRP, model_new]:
    model.eval()

# %%
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

image = Image.open('samples/el2.png')
tusker_zebra_image = transform(image) # (3, 224, 224)
idx = 340 # zebra


xmap_lrp = model_LRP.predmap31_batched_layer(tusker_zebra_image.unsqueeze(0).cuda(), idx=idx).reshape(14,14)
xmap_new = model_new.predmap31_batched_layer(tusker_zebra_image.unsqueeze(0).cuda(), idx=idx).reshape(14,14)
plt.matshow(xmap_new.detach().cpu().numpy())
plt.matshow(xmap_lrp.detach().cpu().numpy())

print(f"{torch.allclose(xmap_new,xmap_lrp)=}")


# %%
import timm
import huggingface_hub.utils
huggingface_hub.utils.disable_progress_bars()
timm.list_models("*vit*", pretrained=True)
# model_timm = timm.create_model("vit_base_patch16_224.orig_in21k", pretrained=True) # classifier not valid
model_timm = timm.create_model("vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=True)
model_timm.eval()
model_timm.cuda()
# %%
# def predmap31_timm(model, x):
x = tusker_zebra_image.unsqueeze(0).cuda()
x = model_timm.forward_features(x) # (1, tokens, embed_dim)
x = model_timm.head(x) # (1, tokens, 1000)
x = x[..., 1:, :] # (1, tokens-1, 1000) # strip CLS token
x = x[..., idx] # (1, tokens-1)
xmap_timm = x.reshape(14,14)
plt.matshow(xmap_timm.detach().cpu().numpy())

print(f"{torch.allclose(xmap_timm.cuda(),xmap_lrp)=}")

# %%
