#%%
from pathlib import Path
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

import torchvision
import tqdm
# %%
import baselines.ViT.ViT_LRP
import baselines.ViT.ViT_orig_LRP
import baselines.ViT.ViT_new

model_LRP = baselines.ViT.ViT_LRP.vit_base_patch16_224(pretrained=True).cuda()
model_orig_LRP = baselines.ViT.ViT_orig_LRP.vit_base_patch16_224(pretrained=True).cuda()
model_new = baselines.ViT.ViT_new.vit_base_patch16_224(pretrained=True).cuda()

def vit_base_patch16_224(pretrained=False, **kwargs):
    model = baselines.ViT.ViT_new.VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=baselines.ViT.ViT_new.partial(baselines.ViT.ViT_new.nn.LayerNorm), **kwargs)
    model.default_cfg = baselines.ViT.ViT_new.default_cfgs['vit_base_patch16_224']
    if pretrained:
        baselines.ViT.ViT_new.load_pretrained(
            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=baselines.ViT.ViT_new._conv_filter)
    return model

model_new2 = vit_base_patch16_224(pretrained=True).cuda()


for model in [model_LRP, model_orig_LRP, model_new, model_new2]:
    model.eval()


# %%
x = torch.randn(1, 3, 224, 224).cuda()
res = []
for model in [model_LRP, model_orig_LRP, model_new, model_new2]:
    res.append(model(x))

# %%
torch.allclose(res[0], res[1]), torch.allclose(res[0], res[2]), torch.allclose(res[1], res[2])
# torch.allclose(res[0], res[3])
# %%
torch.dist(res[0], res[1]), torch.dist(res[0], res[2]), torch.dist(res[1], res[2])


# %%

(res[0] == res[3]).sum()
# %%
plt.plot(res[2].detach().cpu().numpy().flatten())
# plt.plot(res[3].detach().cpu().numpy().flatten())
# res[2].detach().cpu().numpy().flatten(), "o")
# %%
# plt.plot((res[2]-res[3]).detach().cpu().numpy().flatten())
# plt.plot(res[0].detach().cpu().numpy().flatten())


display(res[0])
display(res[2])
display((res[2]-res[0]).abs())
# %%
