#%%
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)
from PIL import Image
from data.Imagenet import Imagenet_Segmentation

# %%

cache_dir = Path(__file__).resolve().parents[1] / 'cache'
cache_data = torch.load(cache_dir / f"cache.pt")
# %%
all_attnmap = cache_data['all_attnmap'] # (images, layers, heads, tokens-1)
xpredmap = cache_data['xpredmap'] # (images, tokens, classes)
cls_self_attend = cache_data['cls_self_attend'] # (images, layers, heads)
preds = (
        # predict using the CLS token
        xpredmap[:,0,:] # (images, classes)
        .argmax(axis=-1)
        ) # (images, )
predmap = xpredmap[range(xpredmap.shape[0]), 1:, preds] # (images, tokens-1)

n_layers = all_attnmap.shape[1]
n_heads = all_attnmap.shape[2]


# %%
corr = (predmap.unsqueeze(1).unsqueeze(2) * all_attnmap).sum(-1) # (images, layers, heads)
corr_normalized = corr
corr_normalized = corr_normalized / predmap.unsqueeze(1).unsqueeze(2).norm(dim=-1) 
corr_normalized = corr_normalized / all_attnmap.norm(dim=-1)  # (images, layers, heads)

# %%
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])
test_lbl_trans = transforms.Compose([
    transforms.Resize((224, 224), Image.NEAREST),
])
imagenet_seg_path = Path(os.environ("IMAGENET_SEGMENTATION_PATH"))
ds = Imagenet_Segmentation(imagenet_seg_path,
                           transform=test_img_trans, target_transform=test_lbl_trans)
# %%
from tqdm import tqdm
batch_size = 16
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=20, drop_last=False)
seg_size_lst = []
for i, (img, full_seg_gt) in enumerate(tqdm(dl)):
    # img (B, 3, H, W)
    # seg_gt (B, H, W)
    full_seg_gt = full_seg_gt.to(torch.float32)
    # calc segmentation size ratio
    batch_seg_size = full_seg_gt.sum(dim=(1,2))
    batch_seg_size_ratio = batch_seg_size / full_seg_gt[0].numel()
    seg_size_lst.append(batch_seg_size_ratio)
# %%
seg_size_ratio = torch.cat(seg_size_lst, dim=0) # (N,)
plt.hist(seg_size_ratio.numpy(), bins="auto")
plt.xlabel("Segmentation size ratio")
plt.ylabel("Number of samples")
plt.title("Histogram of segmentation size ratio")
display(seg_size_ratio)


# %%
most_correlated_layer = (
    corr # (images, layers, heads)
        .amax(dim=2) # (images, layers)
        .argmax(dim=1) # (images,)
)

# %%
# plt.bar(
#     range(n_layers),
#     corr # (images, layers, heads)
#         .amax(dim=2) # (images, layers)
#         .argmax(dim=1) # (images,)
#         .bincount(minlength=n_layers) # (layers,)
# )
# %%
plt.scatter(
    most_correlated_layer.numpy(),
    seg_size_ratio.numpy(),
    s=1,
    alpha=0.1,
)
plt.xlabel("Layer with most correlated head")
plt.ylabel("Segmentation size ratio")
# %%
# PER LAYER
fig, axes = plt.subplots(nrows=n_layers//4, ncols=4)
for layer, ax in enumerate(axes.flatten()):
    seg_size_ratio_layer = seg_size_ratio[most_correlated_layer == layer]
    display(f"layer {layer} {seg_size_ratio_layer}")
    ax.hist(
        seg_size_ratio_layer.numpy(),
        bins="auto",
        # alpha=0.5,
        label=f"Layer {layer}",
    )
    cnt = (most_correlated_layer == layer).sum()
    ax.set_title(f"Layer {layer}\n({cnt})")
    ax.set_xlim(0, 1)
    # ax.set_ylim(0, 200)
fig.suptitle("Histogram of segmentation size ratio per most correlated layer")
fig.tight_layout()


# %%
def predmap15_shortcut(all_attnmap, predmap):
    # all_attnmap (images, layers, heads, tokens-1)
    # predmap (images, tokens-1)

    all_attnmap = all_attnmap.flatten(1,2) # (images, blocks*heads, tokens-1)

    projections = (all_attnmap * predmap.unsqueeze(1)).sum(-1) # (images, blocks*heads)
    projections = projections.softmax(dim=1) # (images, blocks*heads)

    weighted_attnmap = (all_attnmap * projections.unsqueeze(-1)).sum(-2) # (images, tokens-1)
    x = weighted_attnmap * predmap # (images, tokens-1)
    return x

pred_xmap = predmap15_shortcut(all_attnmap, predmap)


# %%
# draw first 100 images using make_grid
import torchvision
grid = torchvision.utils.make_grid(
    pred_xmap.reshape(-1, 1, 14, 14)[:100],
    nrow=10,
    normalize=True,
    scale_each=True,
)
plt.imshow(grid.permute(1,2,0))
# %%

def calc_pred_seg(pred_xai):
    # Based on imagenet_seg_eval.py
    # pred_xai (images, tokens-1)
    res = pred_xai.reshape(-1, 1, 14, 14)
    res = torch.nn.functional.interpolate(res, scale_factor=16, mode='bilinear', align_corners=False) # (images, 1, 224, 224)
    res = res.flatten(1) # (images, 224*224)

    res = (res - res.amin(dim=-1, keepdims=True)) / (res.amax(dim=-1, keepdims=True) - res.amin(dim=-1, keepdims=True)) # (images, 224*224)
    threshold = res.mean(dim=-1) # (images,)
    res = res > threshold.unsqueeze(-1) # (images, 224*224)
    return res.to(pred_xai.dtype)

pred_seg = calc_pred_seg(pred_xmap) # (images, 224*224)
pred_seg


# draw first 100 images using make_grid
import torchvision
grid = torchvision.utils.make_grid(
    pred_seg.reshape(-1, 1, 14*16, 14*16)[:100],
    nrow=10,
    normalize=True,
    scale_each=True,
)
plt.imshow(grid.permute(1,2,0))
# %%
pred_obj_size = pred_seg.sum(dim=-1)/(224*224)
plt.hist(pred_obj_size.numpy(), bins="auto")
plt.xlim(0, 1)
plt.xlabel("Predicted object size ratio")
plt.ylabel("Number of samples")
plt.title("Histogram of predicted object size ratio")
display(pred_obj_size)

# %%
# PER LAYER
fig, axes = plt.subplots(nrows=n_layers//4, ncols=4)
for layer, ax in enumerate(axes.flatten()):
    ax.hist(
        pred_obj_size[most_correlated_layer == layer].numpy(),
        bins="auto",
        # alpha=0.5,
        label=f"Layer {layer}",
    )
    cnt = (most_correlated_layer == layer).sum()
    ax.set_title(f"Layer {layer}\n({cnt})")
    ax.set_xlim(0, 1)
    # ax.set_ylim(0, 200)
fig.suptitle("Histogram of segmentation size ratio per most correlated layer")
fig.tight_layout()
# %%
temp = []
for i, (img, full_seg_gt) in enumerate(tqdm(dl)):
    # img (B, 3, H, W)
    # seg_gt (B, H, W)
    full_seg_gt = full_seg_gt.to(torch.bool)
    temp.append(full_seg_gt)
full_seg_gt = torch.cat(temp, dim=0)
# %%
pixAcc = (pred_seg == full_seg_gt.flatten(1)).sum() / pred_seg.numel()*100 # (Should be 79.46%)
pixAcc
# %%
