#%%
import sys
import os
from pathlib import Path
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
from samples.CLS2IDX import CLS2IDX
import types



import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")



from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
from baselines.ViT.ViT_explanation_generator import LRP
# %%

model = vit_LRP(pretrained=True).cuda()
model.eval()

# %%
heads = 12
head_dim = 64
lin = model.blocks[-1].attn.proj.cpu()
w = lin.weight.detach() # (heads*head_dim, heads*head_dim)
# w (out_features, in_features)
w = w.T # (in_features, out_features)


# plt.matshow(lin.bias.detach().numpy().reshape(32, -1))
split = w.reshape(heads, head_dim, -1) # (heads, head_dim, out_features)
split.flatten(1).deeper
# split.flatten(1).sum(-1).v
# split.flatten(1).abs().mean(-1).v
# split.flatten(1).clamp(0).mean(-1).v

# .abs().sum(dim=1).sum(dim=1).v

# %%
w = torch.arange(36).reshape(6,6)
w.reshape(3,2,6).flatten(1).sum(-1)
# (w/36).deeper(1)
# %%

# %%
