#%%
from pathlib import Path
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm
# %%

from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
import types

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

# %%


# initialize ViT pretrained
# model = vit_LRP(pretrained=True).cuda()
model = vit_new(pretrained=True).cuda()
model.eval()

def my_forward(self, x):
    B = x.shape[0]
    x = self.patch_embed(x) # (B, tokens-1, embed_dim)

    cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    x = torch.cat((cls_tokens, x), dim=1) # (B, tokens, embed_dim)
    x = x + self.pos_embed # (B, tokens, embed_dim)

    # x.register_hook(self.save_inp_grad)

    layers_out = []
    for blk in self.blocks:
        x = blk(x) # (B, tokens, embed_dim)
        layers_out.append(x)
    layers_out = torch.stack(layers_out, dim=1) # (B, layers, tokens, embed_dim)
    x = layers_out.flatten(0,1) # (B*layers, tokens, embed_dim)
    x = self.norm(x) # (B*layers, tokens, embed_dim)
    # Select CLS token
    x = x[..., 0, :] # (B*layers, embed_dim)
    x = self.head(x) # (B*layers, classes)
    n_layers = len(self.blocks)
    x = x.unflatten(0, (B, n_layers)) # (B, layers, classes)
    return x

model.my_forward = types.MethodType(my_forward, model)

# batch_size = 16
# imagenet_path = Path("~/imagenet")
# imagenet_ds = torchvision.datasets.ImageNet(imagenet_path, split='val', transform=transform)
# sample_loader = torch.utils.data.DataLoader(
#     imagenet_ds,
#     batch_size=batch_size,
#     shuffle=False,
#     num_workers=7
# )

# %%
from PIL import Image
import collections
from label_str_to_imagenet_classes import label_str_to_imagenet_classes
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
class SiScoreDataset(ImageFolder):
    def __init__(self, dataset_path, transform):
        self._transform = transform
        self._dataset_path = Path(dataset_path)
        self._tag_list = list(t.name for t in self._dataset_path.iterdir() if t.is_dir())
        self._all_images = []
        for tag in self._tag_list:
            base_dir = self._dataset_path / tag
            for i, file in enumerate(base_dir.iterdir()):
                self._all_images.append(ImageItem(file, tag))


    def __getitem__(self, item):
        image_item = self._all_images[item]
        image_path = os.path.join(self._dataset_path, image_item.tag, image_item.image_name)
        image = Image.open(image_path)
        image = image.convert('RGB')
        image = self._transform(image)

        class_name = int(label_str_to_imagenet_classes[image_item.tag])

        return image, class_name

    def __len__(self):
        return len(self._all_images)
    
path = "/path/to/si-score/dataset"
ds = SiScoreDataset(path, transform=transform)
batch_size=16
sample_loader = torch.utils.data.DataLoader(
    ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=7
)



n_layers = len(model.blocks)
correct_preds_per_layer = torch.zeros(n_layers).cuda()

for i, (img, lbl) in enumerate(tqdm.tqdm(sample_loader)):
    img = img.cuda()
    labels = lbl.cuda() # (batch, )
    with torch.no_grad():
        preds = model.my_forward(img) # (batch, layer, classes)
    preds_clxidx = preds.argmax(dim=-1) # (batch, layer)

    temp = preds_clxidx == labels.unsqueeze(-1) # (batch, layer)
    temp = temp.sum(dim=0) # (layer, )
    correct_preds_per_layer += temp.detach()
    # if i == 10:
        # break
# %% 
plt.bar(range(n_layers), correct_preds_per_layer.cpu().numpy())
plt.xlabel("Layer")
plt.ylabel("Correct Predictions")
plt.title("Correct Predictions per Layer")
# %%

ds[-1]
plt.imshow(ds[-1][0].permute(1,2,0).numpy())
# %%
