#%%
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

# %%

cache_dir = Path(__file__).resolve().parents[1] / 'cache'
cache_data = torch.load(cache_dir / f"cache.pt")
# %%
all_attnmap = cache_data['all_attnmap'] # (images, layers, heads, tokens-1)
xpredmap = cache_data['xpredmap'] # (images, tokens, classes)
cls_self_attend = cache_data['cls_self_attend'] # (images, layers, heads)
preds = (
        # predict using the CLS token
        xpredmap[:,0,:] # (images, classes)
        .argmax(axis=-1)
        ) # (images, )
predmap = xpredmap[range(xpredmap.shape[0]), 1:, preds] # (images, tokens-1)


# %%
# all_attnmap (images, layers, heads, tokens-1)
# predmap (images, tokens-1)

n_layers = all_attnmap.shape[1]
plt.close("all")

corr = (predmap.unsqueeze(1).unsqueeze(2) * all_attnmap).sum(-1) # (images, layers, heads)
corr_normalized = corr
corr_normalized = corr_normalized / predmap.unsqueeze(1).unsqueeze(2).norm(dim=-1) 
corr_normalized = corr_normalized / all_attnmap.norm(dim=-1)  # (images, layers, heads)

# %%
def plot_corr_hist(z):
    # z (images, layers, heads)
    n_layers = z.shape[1]
    n_heads = z.shape[2]

    # ALL
    fig, ax = plt.subplots()
    ax.hist(z.flatten().numpy(), bins="auto")
    ax.set_title(f"attention map - predmap correlation (all layers)")
    ax.set_xlim(z.min(), z.max())

    # PER LAYER
    fig, axes = plt.subplots(nrows=n_layers//4, ncols=4)
    for i, ax in enumerate(axes.flatten()):
        ax.hist(z[:,i].flatten().numpy(), bins="auto")
        ax.set_title(f"Layer {i}")
        ax.set_xlim(z.min(), z.max())
    fig.suptitle("attention map - predmap correlation")
    fig.tight_layout()

    # PER HEAD PER LAYER
    n_cols = n_heads
    n_rows = n_layers*n_heads//n_cols
    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols*2, n_rows*2))
    for i, ax in enumerate(axes.flatten()):
        layer = i // n_heads
        head = i % n_heads
        ax.hist(z[:, layer, head].flatten().numpy(), bins="auto")
        ax.set_title(f"L{layer} H{head}")
        ax.set_xlim(z.min(), z.max())
    fig.suptitle("attention map - predmap correlation")
    fig.tight_layout()    

# %%
# Unnormalized correlation
print("Unnormalized correlation")
plot_corr_hist(corr)

# %%
# Normalized correlation
print("Normalized correlation")
plot_corr_hist(corr_normalized)


# %%

def plot_norm_histogram(all_attnmap, p=2):
    # all_attnmap (images, layers, heads, tokens-1)
    n_layers = all_attnmap.shape[1]
    n_heads = all_attnmap.shape[2]
    z = all_attnmap.norm(dim=-1, p=p)  # (images, layers, heads)
    display(z.permute(1,2,0).deeper)

    # ALL
    fig, ax = plt.subplots()
    ax.hist(z.flatten().numpy(), bins="auto")
    ax.set_title(f"attention map norm (all layers)")
    ax.set_xlim(z.min(), z.max())
    # ax.set_yscale('log')

    # PER LAYER
    fig, axes = plt.subplots(nrows=n_layers//4, ncols=4)
    for i, ax in enumerate(axes.flatten()):
        ax.hist(z[:,i].flatten().numpy(), bins="auto")
        ax.set_title(f"Layer {i}")
        ax.set_xlim(z.min(), z.max())
    fig.suptitle("attention map norm (per layer)")
    fig.tight_layout()

    # PER HEAD PER LAYER
    n_cols = n_heads
    n_rows = n_layers*n_heads//n_cols
    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols*2, n_rows*2))
    for i, ax in enumerate(axes.flatten()):
        layer = i // n_heads
        head = i % n_heads
        ax.hist(z[:, layer, head].flatten().numpy(), bins="auto")
        ax.set_title(f"L{layer} H{head}")
        ax.set_xlim(z.min(), z.max())
    fig.suptitle("attention map norm (per head)")
    fig.tight_layout()   

# plot_norm_histogram(all_attnmap, p=2)
plot_norm_histogram(all_attnmap, p=1)


# %%
# Plot correlation vs attnmap norm
def plot_attnmap_norm_vs_corr(corr, all_attnmap):
    # corr (images, layers, heads)
    # all_attnmap (images, layers, heads, tokens-1)
    assert corr.shape == all_attnmap.shape[:3]
    n_layers = all_attnmap.shape[1]
    n_heads = all_attnmap.shape[2]
    all_attnmap_norm = all_attnmap.norm(dim=-1, p=2)  # (images, layers, heads)

    # ALL
    fig, ax = plt.subplots()
    xx = corr.flatten().numpy()
    yy = all_attnmap_norm.flatten().numpy()
    # ax.scatter(xx, yy, s=0.1, alpha=0.01)
    # ax.hexbin(xx,yy)#, bins="auto")
    ax.hist2d(xx,yy, bins=100)
    ax.set_xlabel("correlation")
    ax.set_ylabel("attnmap norm")
    ax.set_title(f"attnmap norm vs correlation (all layers)")

    # PER LAYER
    fig, axes = plt.subplots(nrows=n_layers//4, ncols=4)
    for i, ax in enumerate(axes.flatten()):
        xx = corr[:,i].flatten().numpy()
        yy = all_attnmap_norm[:,i].flatten().numpy()
        # ax.scatter(xx, yy, s=0.1, alpha=0.01)
        ax.hist2d(xx,yy, bins=100)
        ax.set_title(f"Layer {i}")
        # ax.set_xlim(corr.min(), corr.max())
        # ax.set_ylim(all_attnmap_norm.min(), all_attnmap_norm.max())
    fig.suptitle("attnmap norm vs correlation (per layer)")
    fig.tight_layout()

    # PER HEAD PER LAYER
    n_cols = n_heads
    n_rows = n_layers*n_heads//n_cols
    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols*2, n_rows*2))
    for i, ax in enumerate(axes.flatten()):
        layer = i // n_heads
        head = i % n_heads
        xx = corr[:, layer, head].flatten().numpy()
        yy = all_attnmap_norm[:, layer, head].flatten().numpy()
        # ax.scatter(xx, yy, s=0.1, alpha=0.01)
        ax.hist2d(xx,yy, bins=100)
        ax.set_title(f"L{layer} H{head}")
        ax.set_xlim(corr.min(), corr.max())
        ax.set_ylim(all_attnmap_norm.min(), all_attnmap_norm.max())
    fig.suptitle("attention map norm (per head)")
    fig.tight_layout()

# %%
print("Unnormalized correlation")
plot_attnmap_norm_vs_corr(corr, all_attnmap)

# %%
print("Normalized correlation")
plot_attnmap_norm_vs_corr(corr_normalized, all_attnmap)

# %%
# Plot correlation vs attnmap norm
z = corr_normalized
xx = z.reshape(-1,12,12)[:,0,:].flatten().numpy()
yy = all_attnmap[:,0,:].norm(dim=-1).flatten().numpy()
plt.scatter(xx, yy, s=1)
plt.xlabel("correlation")
plt.ylabel("attnmap norm")


# %%
# Plot attnmap l1 norm vs attnmap l2 norm
def plot_attnmap_norm_l1_vs_l2(all_attnmap):
    # corr (images, layers, heads)
    # all_attnmap (images, layers, heads, tokens-1)
    n_layers = all_attnmap.shape[1]
    n_heads = all_attnmap.shape[2]
    all_attnmap_l2_norm = all_attnmap.norm(dim=-1, p=2)  # (images, layers, heads)
    all_attnmap_l1_norm = all_attnmap.norm(dim=-1, p=1)  # (images, layers, heads)

    # ALL
    fig, ax = plt.subplots()
    xx = all_attnmap_l2_norm.flatten().numpy()
    yy = all_attnmap_l1_norm.flatten().numpy()
    ax.scatter(xx, yy, s=1)
    ax.set_xlabel("attnmap l2 norm")
    ax.set_ylabel("attnmap l1 norm")
    ax.set_title(f"attnmap l1 norm vs attnmap l2 norm (all layers)")
    
    # PER LAYER
    fig, axes = plt.subplots(nrows=n_layers//4, ncols=4)
    for i, ax in enumerate(axes.flatten()):
        xx = all_attnmap_l2_norm[:,i].flatten().numpy()
        yy = all_attnmap_l1_norm[:,i].flatten().numpy()
        ax.scatter(xx, yy, s=1)
        ax.set_title(f"Layer {i}")
        ax.set_xlim(all_attnmap_l2_norm.min(), all_attnmap_l2_norm.max())
        ax.set_ylim(all_attnmap_l1_norm.min(), all_attnmap_l1_norm.max())
    fig.tight_layout()
    
    # PER HEAD PER LAYER
    n_cols = n_heads
    n_rows = n_layers*n_heads//n_cols
    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols*2, n_rows*2))
    for i, ax in enumerate(axes.flatten()):
        layer = i // n_heads
        head = i % n_heads
        xx = all_attnmap_l2_norm[:, layer, head].flatten().numpy()
        yy = all_attnmap_l1_norm[:, layer, head].flatten().numpy()
        ax.scatter(xx, yy, s=1)
        ax.set_title(f"L{layer} H{head}")
        ax.set_xlim(all_attnmap_l2_norm.min(), all_attnmap_l2_norm.max())
        ax.set_ylim(all_attnmap_l1_norm.min(), all_attnmap_l1_norm.max()+0.1)
    fig.suptitle("attention map l1 norm vs l2 norm (per head)")
    fig.tight_layout()

plot_attnmap_norm_l1_vs_l2(all_attnmap)


# %%
# Plot predmap norms

def plot_predmap_norm_histogram(predmap, p=2):
    # predmap (images, tokens-1)
    predmap_norm = predmap.norm(dim=-1, p=p)  # (images, )
    display(predmap_norm)
    fig, ax = plt.subplots()
    ax.hist(predmap_norm.flatten().numpy(), bins="auto")
    ax.set_xlabel("predmap norm")
    ax.set_ylabel("Number of samples")
    ax.set_title(f"Histogram of predmap L{p} norm")

plot_predmap_norm_histogram(predmap, p=2)

# %%
# CLS_SELF_ATTN
# cls_self_attend (images, layers, heads)
plt.close("all")
display(cls_self_attend.permute(1,2,0).deeper)

n_layers = cls_self_attend.shape[1]
n_cols = 3
fig, axes = plt.subplots(nrows=n_layers//n_cols, ncols=n_cols)

for i, ax in enumerate(axes.flatten()):
    z = cls_self_attend[:,i]
    ax.hist(z.flatten().numpy(), bins="auto")
    ax.set_title(f"Layer {i}")
    # ax.set_yscale('log')

    # ax.set_xlim(0,1)
fig.suptitle("CLS self-attention score")
fig.tight_layout()

fig, ax = plt.subplots()
ax.hist(cls_self_attend.flatten().numpy(), bins="auto")
ax.set_title(f"CLS self-attention score (all layers)")
ax.set_yscale('log')

# %%
fig, axes = plt.subplots(1,2)
ax = axes[0]
cax = ax.matshow(cls_self_attend.mean(0)[-4:, :]) # (layers, heads)
ax.set_title("Average CLS self-attend")
ax.set_xlabel("heads")
ax.set_ylabel("layers")
fig.colorbar(cax, ax=ax, orientation='horizontal')

ax = axes[1]
cax = ax.matshow(cls_self_attend.std(0)[-4:, :]) # (layers, heads)
ax.set_title("Std CLS self-attend")
ax.set_xlabel("heads")
ax.set_ylabel("layers")
fig.colorbar(cax, ax=ax, orientation='horizontal')
fig.tight_layout()

# %%
import torchvision
plt.matshow(
    torchvision.utils.make_grid(cls_self_attend[torch.randperm(len(cls_self_attend))][:25].unsqueeze(1), nrow=5).permute(1,2,0)[...,0],
)
# cax=plt.matshow(cls_self_attend[1000])
# plt.colorbar(cax)

# %%
plt.plot(
    (1-cls_self_attend) # (images, layers, heads)
        .amax(dim=2) # (images, layers)
        .argmax(dim=1) # (images,)
        .bincount(minlength=12) # (layers,)
)
plt.yscale('log')
plt.title("winning layer count")
plt.xlabel("layer index")
plt.ylabel("samples count")



# %%
# predmap (images, tokens-1)
# all_attnmap (images, layers, heads, tokens-1)
# corr = (predmap.unsqueeze(1).unsqueeze(2) * all_attnmap).sum(-1) # (images, layers, heads)
# corr = corr.flatten(1).softmax(dim=-1).unflatten(1, (12, 12)) # (images, layers, heads)

                          
fig, axes = plt.subplots(1,2)
ax = axes[0]
avg_corr = corr.mean(0) # (layers, heads)
# avg_corr = corr_normalized.mean(0) # (layers, heads)
cax = ax.matshow(avg_corr)
ax.set_title("Average correlation")
ax.set_xlabel("heads")
ax.set_ylabel("layers")
fig.colorbar(cax, ax=ax, orientation='horizontal')

ax = axes[1]
std_corr = corr.std(0) # (layers, heads)
# std_corr = corr_normalized.std(0) # (layers, heads)
cax = ax.matshow(std_corr)
ax.set_title("Std correlation")
ax.set_xlabel("heads")
ax.set_ylabel("layers")
fig.colorbar(cax, ax=ax, orientation='horizontal')
fig.tight_layout()

# %%
# plt.plot(
#     corr # (images, layers, heads)
#         .amax(dim=2) # (images, layers)
#         .mean(dim=0) # (layers,)
# )

# plt.plot(
#     corr # (images, layers, heads)
#         .amax(dim=2) # (images, layers)
#         .argmax(dim=1) # (images,)
#         .bincount(minlength=12) # (layers,)
# )
n_layers = corr.shape[1]
n_heads = corr.shape[2]
plt.bar(
    range(n_layers),
    corr # (images, layers, heads)
        .amax(dim=2) # (images, layers)
        .argmax(dim=1) # (images,)
        .bincount(minlength=n_layers) # (layers,)
)
plt.title("layer with most correlated head")
plt.xlabel("layer index")
plt.ylabel("samples count")


# %%

plt.bar(
    range(n_layers),
    corr_normalized
        .amax(dim=2) # (images, layers)
        .argmax(dim=1) # (images,)
        .bincount(minlength=n_layers) # (layers,)
)
plt.title("layer with most correlated head (normalized)")
plt.xlabel("layer index")
plt.ylabel("samples count")

# %%

# Sort unique values by their counts
unique, counts = torch.unique(preds, return_counts=True)
sorted_idx = counts.argsort(descending=True)
unique = unique[sorted_idx]
counts = counts[sorted_idx]
unique[:10], counts[:10]

# mask = preds == 895
# predmap = predmap[mask]
# all_attnmap = all_attnmap[mask]

# for clsidx, cnt in zip(unique, counts):
#     print(f"class {clsidx} count: {cnt}")
#     mask = preds == clsidx
#     curr_predmap = predmap[mask]
#     curr_all_attnmap = all_attnmap[mask]
#     corr = (curr_predmap.unsqueeze(1).unsqueeze(2) * curr_all_attnmap).sum(-1) # (images, layers, heads)
#     winning_layer = corr.amax(dim=-1).argmax(dim=-1)
#     print(f"winning layer: {winning_layer.mode().values}")
#     print("-------------------")
#     # break

# %%
n = 10
for clsidx, cnt in zip(unique[:n], counts[:n]):
    print(f"class {clsidx} count: {cnt}")
    mask = preds == clsidx
    curr_predmap = predmap[mask]
    curr_all_attnmap = all_attnmap[mask]
    corr = (curr_predmap.unsqueeze(1).unsqueeze(2) * curr_all_attnmap).sum(-1) # (images, layers, heads)
    fig, axes = plt.subplots(1,2)
    ax = axes[0]
    avg_corr = corr.mean(0) # (layers, heads)
    cax = ax.matshow(avg_corr)
    ax.set_title("Average correlation")
    ax.set_xlabel("heads")
    ax.set_ylabel("layers")
    fig.colorbar(cax, ax=ax, orientation='horizontal')

    ax = axes[1]
    std_corr = corr.std(0) # (layers, heads)
    cax = ax.matshow(std_corr)
    ax.set_title("Std correlation")
    ax.set_xlabel("heads")
    ax.set_ylabel("layers")
    fig.colorbar(cax, ax=ax, orientation='horizontal')
    fig.tight_layout()
    # winning_layer = corr.amax(dim=-1).argmax(dim=-1)
    # print(f"winning layer: {winning_layer.mode().values}")
    # print("-------------------")


# %%
