#%%
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

# exp_dir_path = Path("run/imagenet/all_layers/predmap9_all_layers/experiment_3")
# exp_dir_path = Path("run/imagenet/all_layers/predmap9_all_layers/experiment_5")
exp_dir_path = Path("run/imagenet/all_layers/attn_all_layers/experiment_4")
pix_acc_path = exp_dir_path / "pix_acc.npy"
iou_path = exp_dir_path / "iou.npy"
ap_path = exp_dir_path / "ap.npy"
inter_path = exp_dir_path / "inter.npy"
union_path = exp_dir_path / "union.npy"
heatmap_path = exp_dir_path / "heatmap.npy"

pix_acc = np.load(pix_acc_path) # (images, blocks)
# iou_alternative is iou per sample. a slight different metric is reported in the paper
iou_alternative = np.load(iou_path) # (images, blocks)
ap = np.load(ap_path) # (images, blocks)
inter = np.load(inter_path) # (images, blocks, 2)
union = np.load(union_path) # (images, blocks, 2)
heatmap = np.load(heatmap_path) # (images, blocks, 196)
iou = np.float64(1.0) * inter.sum(axis=0) / (np.spacing(1, dtype=np.float64) + union.sum(axis=0)) # (blocks, 2)
miou = iou.mean(axis=-1) # (blocks, )

#%%
print("pix_acc")
display(lo(pix_acc.transpose()).deeper)

print("mIoU")
display(lo(iou_alternative.transpose()).deeper)

print("map")
display(lo(ap.transpose()).deeper)
# %%
for data, title in zip([pix_acc, iou_alternative, ap], ["pix_acc", "iou", "ap"]):
    plt.figure()
    block_idx, counts = np.unique(data.argmax(axis=-1), return_counts=True)
    plt.bar(block_idx, counts)
    plt.xlabel("block index")
    plt.ylabel("count")
    plt.title(title)
    
# %%
plt.figure()
plt.bar(range(len(miou)), miou)
plt.xlabel("block index")
plt.title("mIoU")

# %%
# Sanity check. The following should be the same as predmap9
lyr_idx = 4
print(f"pix_acc: {pix_acc[:, lyr_idx].mean()*100:.2f}%")
print(f"mAP: {ap[:, lyr_idx].mean()*100:.2f}%")
print(f"mIoU: {miou[lyr_idx]*100:.2f}%")
# print(f"mIoU: {iou_alternative[:, -2].mean()*100:.2f}%") # doesn't match predmap9
# %%

print(f"max pix_acc: {pix_acc.max(axis=1).mean()*100:.2f}%")
print(f"max mAP: {ap.max(axis=1).mean()*100:.2f}%")
# print(f"max mIoU: {iou_alternative.max(axis=1).mean()*100:.2f}%") # Wrong calculation
iou_max_idx = iou_alternative.argmax(axis=1) # (images, )
inter_max = inter[np.arange(inter.shape[0]), iou_max_idx] # (images, 2)
union_max = union[np.arange(union.shape[0]), iou_max_idx] # (images, 2)
max_iou = np.float64(1.0) * inter_max.sum(axis=0) / (np.spacing(1, dtype=np.float64) + union_max.sum(axis=0)) # (2, )
max_miou = max_iou.mean()
print(f"max mIoU: {max_miou*100:.2f}%")

# %%

def calc_miou(inter, union, indices):
    n_samples = inter.shape[0]
    iou = (
            np.float64(1.0) * inter[np.arange(n_samples), indices].sum(axis=0) / # (2, )
            (np.spacing(1, dtype=np.float64) + union[np.arange(n_samples), indices].sum(axis=0)) # (2, )
        )
    mIoU = iou.mean()
    return  mIoU

# Optimize on pix_acc
max_pix_acc_idx = pix_acc.argmax(axis=1)
n_samples = max_pix_acc_idx.shape[0]
max_pix_acc = pix_acc[np.arange(n_samples), max_pix_acc_idx].mean()
print(f"MAX pix_acc: {pix_acc[np.arange(n_samples), max_pix_acc_idx].mean()*100:.2f}%")
print(f"mAP: {ap[np.arange(n_samples), max_pix_acc_idx].mean()*100:.2f}%")
print(f"mIoU: {calc_miou(inter, union, max_pix_acc_idx)*100:.2f}%")
print()

# Optimize on mAP
max_ap_idx = ap.argmax(axis=1) # (images, )
n_samples = max_ap_idx.shape[0]
max_pix_acc = pix_acc[np.arange(n_samples), max_ap_idx].mean()
print(f"pix_acc: {pix_acc[np.arange(n_samples), max_ap_idx].mean()*100:.2f}%")
print(f"MAX mAP: {ap[np.arange(n_samples), max_ap_idx].mean()*100:.2f}%")
print(f"mIoU: {calc_miou(inter, union, max_ap_idx)*100:.2f}%")
print()

# Optimize on mIoU
max_iou_idx = iou_alternative.argmax(axis=1) # (images, )
n_samples = max_iou_idx.shape[0]
max_pix_acc = pix_acc[np.arange(n_samples), max_iou_idx].mean()
print(f"pix_acc: {pix_acc[np.arange(n_samples), max_iou_idx].mean()*100:.2f}%")
print(f"mAP: {ap[np.arange(n_samples), max_iou_idx].mean()*100:.2f}%")
print(f"MAX mIoU: {calc_miou(inter, union, max_iou_idx)*100:.2f}%")


# %%
import matplotlib.pyplot as plt
# heatmap (images, blocks, tokens-1)
# plt.matshow(heatmap[:200, -1].T)
import torch
import torchvision

# heatmap # (images, blocks, tokens-1)
win_block = 4
win_block = 0
mask = max_pix_acc_idx == win_block
heatmap_show = heatmap[mask, win_block] # (n_samples--, tokens-1)
heatmap_show = torch.tensor(heatmap_show)
heatmap_show = heatmap_show > heatmap_show.mean(dim=-1, keepdim=True) # (n_samples--, tokens-1)
heatmap_show = heatmap_show.float() # (n_samples--, tokens-1)

# heatmap_show = heatmap[(mask), 4] # (n_samples--, tokens-1)

heatmap_show = torch.tensor(heatmap_show).reshape(-1, 1, 14, 14) # (images, 1, 14, 14)
heatmap_show = heatmap_show[:50]
# heatmap_show = heatmap_show[45:46]
heatmap_show = torch.nn.functional.interpolate(heatmap_show, scale_factor=16, mode='bilinear', align_corners=False) # (blocks, 1, 224, 224)

grid = torchvision.utils.make_grid(heatmap_show, scale_each=True, normalize=True)
plt.imshow(grid[0], cmap="jet")
plt.title(f"Heatmap, layer_idx={win_block}")
# %%
max_pix_acc_idx.shape
# %%
# mask_idx = np.where(mask)[0][40]
# mask_idx = np.where(mask)[0][41]
mask_idx = np.where(mask)[0][5]
temp = heatmap[mask_idx] # (blocks, tokens-1)
temp = torch.tensor(temp)
temp = temp.reshape(-1, 1, 14, 14) # (blocks, 1, 14, 14)
temp = torch.nn.functional.interpolate(temp, scale_factor=16, mode='bilinear', align_corners=False) # (blocks, 1, 224, 224)
# temp = temp > temp.mean(dim=(-2,-1), keepdim=True) # (n_samples--, tokens-1)
temp = temp.float() # (n_samples--, tokens-1)

grid = torchvision.utils.make_grid(temp, scale_each=True, normalize=True)
plt.imshow(grid[0], cmap="jet")
# %%
from PIL import Image
import torchvision.transforms as transforms
from data.Imagenet import Imagenet_Segmentation

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # normalize,
])
test_lbl_trans = transforms.Compose([
    transforms.Resize((224, 224), Image.NEAREST),
])

imagenet_seg_path = os.environ.get("IMAGENET_SEGMENTATION_PATH")
ds = Imagenet_Segmentation(imagenet_seg_path,
                           transform=test_img_trans, target_transform=test_lbl_trans)
print(mask_idx)
img, gt = ds[mask_idx]
plt.imshow(img.permute(1, 2, 0))

# %%
plt.matshow(gt)
# %%
ratios_gt = torch.zeros(len(ds))
for i, (img, gt) in enumerate(ds):
    if i % 100 == 0:
        print(i+1)
    ratio = gt.sum() / (gt.shape[0] * gt.shape[1])
    ratios_gt[i] = ratio
# %%
ratios_gt
# %%
import lovely_tensors as lt
lt.monkey_patch()

# %%
fig, axes = plt.subplots(3, 4, figsize=(20, 15))

for i,ax in enumerate(axes.flatten()):
    win_block = i
    # Draw histogram of the ratio
    temp = ratios_gt[max_pix_acc_idx == win_block]
    counts, bins = np.histogram(temp)
    ax.hist(bins[:-1], bins, weights=counts)
    ax.set_xlim([0,1])
    ax.set_ylim([0,350]) # comment for auto
    ax.set_title(f"{i}: mean={temp.mean():.2f}, std={temp.std():.2f}")
    # plt.stairs(counts, bins)
    # plt.figure()
    # plt.hist(bins[:-1], bins, weights=counts)
    # plt.xlim([0,1])
    # plt.title(f"{i}-{temp.mean():.2f}, {temp.std():.2f}")

# %%
win_block = i
# Draw histogram of the ratio
temp = ratios_gt
counts, bins = np.histogram(temp)
# plt.stairs(counts, bins)
plt.figure()
plt.hist(bins[:-1], bins, weights=counts)
plt.xlim([0,1])
plt.title(f"mean={temp.mean():.2f}, std={temp.std():.2f}")
plt.xlabel("segementation size ratio")

# %%
# FOCUS/NORMALIZED ENTROPY
temp2 = temp
temp2 = temp2 / temp2.sum(dim=(-1,-2), keepdim=True)
temp2.shape # (12,1,224,224)
focus = 1-(-temp2 * temp2.log2()).sum(dim=(-1,-2))/torch.tensor([224.0**2]).log2()
plt.bar(range(12), focus.squeeze())
# %%
##############################

import torch
import lovely_tensors as lt
lt.monkey_patch()

n=10
x = torch.rand(n, 1, 14,14)
y = torch.nn.functional.interpolate(x, scale_factor=16, mode='bilinear', align_corners=False) # (blocks, 1, 224, 224)