#%%
import sys
import os
from pathlib import Path
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
from samples.CLS2IDX import CLS2IDX
import types

import lovely_tensors as lt
lt.monkey_patch()

from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP

from torchvision.datasets import ImageNet

# %%

model = vit_LRP(pretrained=True).cuda().eval()

# %%
batch_size = 2

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])

imagenet_validation_path = os.environ("IMAGENET_PATH")

imagenet_ds = ImageNet(imagenet_validation_path, split='val', transform=transform)
sample_loader = torch.utils.data.DataLoader(
    imagenet_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

# %%

n_layers = len(model.blocks)
counters = torch.zeros(n_layers).cuda()
for i, (img, target) in enumerate(tqdm(sample_loader)):
    _, extras = model.predmap15_batched_layer(img.cuda(), layer=11, idx=None, return_extras=True)
    xpredmap = extras['xpredmap'].detach() # (batch, tokens, classes)
    preds = xpredmap[:,0,:].argmax(dim=-1) # (batch, )
    projections = extras['projections'].detach() # (batch, layers, heads, classes)
    B = projections.shape[0]
    projections = projections[range(B), ..., preds] # (batch, layers, heads)

    curr_counters = (
        projections # (batch, layers, heads)
        .amax(dim=2) # (batch, layers)
        .argmax(dim=1) # (batch,)
        .bincount(minlength=n_layers) # (layers,)
    )
    counters += curr_counters

# %%

plt.bar(
    range(n_layers),
    counters.cpu().numpy(),
)
plt.title("layer with most correlated head")
plt.xlabel("layer index")
plt.ylabel("samples count")

# %%
