#%%
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm
# %%
from samples.CLS2IDX import CLS2IDX
# from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new
import types

# IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_without_normalize = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
transform = transforms.Compose([
    transform_without_normalize, 
    normalize
])

# %%


# initialize ViT pretrained
# model = vit_LRP(pretrained=True).cuda()
model = vit_new(pretrained=True).cuda()
model.eval()


# %%
image = Image.open('samples/el2.png')
tusker_zebra_image = transform_without_normalize(image)

fig, axs = plt.subplots()
axs.imshow(tusker_zebra_image.permute(1, 2, 0))

def show_patch_indices(ax):
    # Show patch indices
    i=0
    for r in range(14):
        for c in range(14):
            ax.text(16*c+7, 16*r+7, f'{i}', color='black', fontsize=6, ha='center', va='center')
            i+=1

def show_grid(ax):
    ax.grid(which='major', axis='both', linestyle='-', color='k', linewidth=1)
    ax.set_xticks(np.arange(0.5+15, 224, 16))
    ax.set_yticks(np.arange(0.5+15, 224, 16))
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params(length=0)
    
show_grid(axs)
# show_patch_indices(axs)

# %%
# Draw image with patches using make_grid()
patches = tusker_zebra_image.unflatten(-2, (14, 16)).unflatten(-1, (14, 16)) # (3, 14, 16, 14, 16)
patches = patches.permute(1,3,0,2,4) # (14, 14, 3, 16, 16)
patches = patches.flatten(0,1) # (14*14, 3, 16, 16)

image_with_patches = torchvision.utils.make_grid(patches, nrow=14,
                                padding=1,
                                # pad_value=1, # White
                                pad_value=0, # Black
                                ) # (3, H, W)
image_with_patches = image_with_patches.permute(1,2,0).cpu() # (H, W, C)
image_with_patches = image_with_patches.numpy()

fig, ax = plt.subplots()
ax.imshow(image_with_patches)
ax.axis('off')

predmap_overview_dir = base_dir / Path(f'figures/artifacts/predmap_overview/')
predmap_overview_dir.mkdir(exist_ok=True, parents=True)
path = predmap_overview_dir / f'overview_input.png'
plt.imsave(path, image_with_patches)



# %%
inp = tusker_zebra_image.unsqueeze(0).cuda()
cls_idx_tusker = 101,
cls_idx_zebra = 340,
_, extras = model.predmap15_batched_layer(inp, 
                                          cls_idx_tusker,
                                        #   cls_idx_zebra,
                                        # 150,
                                            return_extras=True,
                                            layer=9,
                                            apply_softmax_classes=True,
                                           )
predmap = extras['predmap'] # (B, tokens-1, classes)
# plt.matshow((predmap[0][:,[101,340]]).cpu().detach().numpy(),
        #    cmap='jet')
# predmap[0]
# plt.matshow(_.reshape(14,14).cpu().detach().numpy(),)
# %%
# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + 1*np.float32(img)
    cam = cam / np.max(cam)
    return cam

def vis_predmap_cls(predmap, cls_idx, orig_image, show_patch_indices=False):
    fig, ax = plt.subplots()
    predmap_img = predmap[..., :, cls_idx].reshape(14,14).detach().cpu() # (14,14)
    # predmap_img = predmap[..., :, cls_idx_tusker].reshape(14,14).detach().cpu() # (14,14)
    # predmap_img = predmap[..., :, cls_idx_zebra].reshape(14,14).detach().cpu() # (14,14)
    # predmap_img = _.reshape(14,14).detach().cpu() # (14,14)
    predmap_img = (predmap_img - predmap_img.amin())/(predmap_img.amax() - predmap_img.amin())


    predmap_img = torch.nn.functional.interpolate(
        predmap_img[None,None,...],
        scale_factor=16, mode='nearest',
        ) # (1,1,224,224)
    predmap_img = predmap_img.squeeze()
    predmap_img = predmap_img.numpy()
    #ax.imshow(
    #    # predmap[..., :, cls_idx_zebra].reshape(14,14).detach().cpu().numpy(),
    #    predmap_img,
    #    cmap='jet'
    #)
    # ax.axis('off')

    vis = show_cam_on_image(orig_image.permute(1, 2, 0).numpy(), predmap_img)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    ax.imshow(vis)
    show_grid(ax)
    if show_patch_indices:
        show_patch_indices(ax)
    fig.tight_layout()
    return fig, ax

fig, ax = vis_predmap_cls(predmap, cls_idx_tusker, tusker_zebra_image)
path = predmap_overview_dir / f'overview_predmap1.png'
fig.savefig(path, transparent=True, bbox_inches='tight', pad_inches=0.0)

fig, ax = vis_predmap_cls(predmap, cls_idx_zebra, tusker_zebra_image)
path = predmap_overview_dir / f'overview_predmap2.png'
fig.savefig(path, transparent=True, bbox_inches='tight', pad_inches=0.0)


# %%
cls_indices = [
    517, # crane
    101, # tusker
    574, # golf ball
    352, # impala
    353, # gazelle
    340, # zebra
    399, # abaya
    579, # grand piano
]
patch_indices = [
    15, 16, 17, 30, 31, 45,
    75, 
    104, 105, 106, 107, 108, 121, 122
    ]

# Image patches ######################################################
patches = tusker_zebra_image.unflatten(-2, (14, 16)).unflatten(-1, (14, 16)) # (3, 14, 16, 14, 16)
patches = patches.permute(1,3,0,2,4) # (14, 14, 3, 16, 16)
patches = patches.flatten(0,1) # (14*14, 3, 16, 16)

fig, ax = plt.subplots()
img_patches = torchvision.utils.make_grid(patches[patch_indices], nrow=1,
                                padding=4,
                                pad_value=1,
                                ).permute(1,2,0)
img_patches = torch.cat([img_patches, torch.ones_like(img_patches[:,:,[0]])], dim=-1)
img_patches = img_patches.cpu()
# make border transparent
border = (img_patches == torch.Tensor([1,1,1,1])).all(dim=-1) # (h,w)
img_patches[border] = 0

img_patches = img_patches.numpy()
ax.imshow(img_patches)
ax.axis('off')

path = predmap_overview_dir / f'overview_patches.png'
plt.imsave(path, img_patches)

# %%

# fig, ax = plt.subplots()
# # predmap (B, tokens-1, classes)
# predmap_mat = (predmap[0, :,cls_indices][patch_indices]).cpu().detach() # (len(patch_indices), len(cls_indices))

# cax = ax.imshow(predmap_mat.numpy(),
#            cmap='jet')
# ax.grid(which='major', axis='both', linestyle='-', color='k', linewidth=1)
# ax.grid(which='major', axis='y', linestyle='-', color='w', linewidth=5)
# ax.set_xticks(np.arange(0.5, len(cls_indices), 1))
# ax.set_yticks(np.arange(0.5, len(patch_indices), 1))
# ax.set_xticklabels([])
# ax.set_yticklabels([])
# # ax.yaxis.set_major_locator(ticker.NullLocator())
# ax.tick_params(length=0)
# # fig.colorbar(cax, ax=ax)
# # , width=2, colors='r',
#             #    grid_color='r', grid_alpha=0.5)

# fig.tight_layout()
# fig.savefig('figures/patch_heatmap.png', transparent=True)

# %%
######################################################
# Normalize each column (class) of predmap and visualize
# predmap (B, tokens-1, classes)
predmap_norm = predmap[0] # (tokens-1, classes)
predmap_norm = (predmap_norm - predmap_norm.amin(dim=0, keepdim=True))/(predmap_norm.amax(dim=0, keepdim=True) - predmap_norm.amin(dim=0, keepdim=True))

predmap_mat = (predmap_norm[:,cls_indices][patch_indices]).cpu().detach() # (len(patch_indices), len(cls_indices))
predmap_mat = predmap_mat[..., None, None] # (len(patch_indices), len(cls_indices), 1, 1)
predmap_mat = predmap_mat.expand(-1,-1,16,16) # (len(patch_indices), len(cls_indices), 16, 16)
# https://stackoverflow.com/questions/56977768/is-there-a-way-to-convert-scalar-values-in-an-array-to-matplotlib-colormap-indic

predmap_mat = plt.Normalize(predmap_mat.min(), predmap_mat.max())(predmap_mat) # (len(patch_indices), len(cls_indices), 16, 16)
predmap_mat = plt.cm.jet(predmap_mat) # (len(patch_indices), len(cls_indices), 16, 16, C)
predmap_mat = torch.from_numpy(predmap_mat)



###########################################################################################################################
###########################################################################################################################
# Add the patch pixels behind the prediction matrix.
# patches # (14*14, 3, 16, 16)
# patches2 = patches[patch_indices] # (len(patch_indices), 3, 16, 16)
# patches2 = patches2.unsqueeze(1) # (len(patch_indices), 1, 3, 16, 16)
# patches2 = patches2.permute(0,1,3,4,2) # (len(patch_indices), 1, 16, 16, 3)
# temp = torch.ones_like(patches2[...,[0]]) # (len(patch_indices), 1, 16, 16, 1)
# patches2 = torch.cat([patches2, temp], dim=-1) # (len(patch_indices), 1, 16, 16, 4)
# w = 0.5
# predmap_mat = w*predmap_mat + (1-w)*patches2
###########################################################################################################################
###########################################################################################################################
col_padding = 2
row_padding = 2
row_lst = []
for i in range(len(patch_indices)):
    predmap_mat_row = predmap_mat[i] # (len(cls_indices), 16, 16, C)
    predmap_mat_row = predmap_mat_row.permute(0,3,1,2) # (len(cls_indices), C, 16, 16)
    row = torchvision.utils.make_grid(predmap_mat_row, nrow=len(cls_indices), padding=col_padding, pad_value=1) # (C, 16, 16*len(cls_indices)+len(cls_indices)-1)

    # Strip top&bottom padding
    row = row[:,col_padding:,:]
    row = row[:,:-col_padding,:]
    
    # Strip left&right padding
    row = row[:,:,col_padding:]
    row = row[:,:,:-col_padding]
    row_lst.append(row)

# row
rows = torch.stack(row_lst, dim=0) # (len(patch_indices), C, 16, 16*len(cls_indices))
img_predmat = torchvision.utils.make_grid(rows, nrow=1, padding=row_padding, pad_value=1)
img_predmat = img_predmat.permute(1,2,0) # (h,w,c)
img_predmat = img_predmat.contiguous().cpu()
# make border transparent
border = (img_predmat == torch.Tensor([1,1,1,1])).all(dim=-1) # (h,w)
img_predmat[border] = 0

img_predmat = img_predmat.numpy()
plt.matshow(img_predmat)
path = predmap_overview_dir / f'overview_pred_mat.png'
plt.imsave(path, img_predmat)

# %%
######################################################
# Visualize classification vector
xpredmap = extras["xpredmap"] # (B, tokens, classes)
pred_cls_vec = xpredmap[:,0] # (B, classes)
# Normalize
pred_cls_vec = pred_cls_vec / 10 # Set temperature
pred_cls_vec = pred_cls_vec.softmax(dim=-1)


pred_cls_vec = pred_cls_vec.detach().cpu()
pred_cls_vec = pred_cls_vec[0, cls_indices] # (classes)
pred_cls_vec = pred_cls_vec[..., None, None] # (len(cls_indices), 1, 1)
pred_cls_vec = pred_cls_vec.expand(-1,16,16) # (len(cls_indices), 16, 16)

pred_cls_vec = plt.Normalize(pred_cls_vec.min(), pred_cls_vec.max())(pred_cls_vec) # (len(cls_indices), 16, 16)
pred_cls_vec = plt.cm.jet(pred_cls_vec) # (len(cls_indices), 16, 16, C)
pred_cls_vec = torch.from_numpy(pred_cls_vec)

img_pred_cls_vec = pred_cls_vec.permute(0,3,1,2) # (len(cls_indices), C, 16, 16)
img_pred_cls_vec = torchvision.utils.make_grid(img_pred_cls_vec, nrow=len(cls_indices), padding=col_padding, pad_value=1) # (C, 16, 16*len(cls_indices)+len(cls_indices)-1)
img_pred_cls_vec = img_pred_cls_vec.permute(1,2,0) # (h,w,c)
img_pred_cls_vec = img_pred_cls_vec.contiguous().cpu()
# make border transparent
border = (img_pred_cls_vec == torch.Tensor([1,1,1,1])).all(dim=-1) # (h,w)
img_pred_cls_vec[border] = 0

img_pred_cls_vec = img_pred_cls_vec.numpy()
plt.matshow(img_pred_cls_vec)
path = predmap_overview_dir / f'overview_pred_cls_vec.png'
plt.imsave(path, img_pred_cls_vec)


# %%
# Print classes ######################################################
cls_names = [CLS2IDX[cls_idx] for cls_idx in cls_indices]
xpredmap = extras["xpredmap"] # (B, tokens, classes)
pred_cls_vec = xpredmap[0,0] # (classes)
# pred_cls_vec = pred_cls_vec / 2 
# pred_cls_vec = pred_cls_vec.softmax(dim=-1)
pred_scores = [f'{pred_cls_vec[cls_idx]: >6.2f}:\t {cls_name} ' for cls_name, cls_idx in zip(cls_names, cls_indices)]
res = "\n".join(pred_scores)
print(res)
path = predmap_overview_dir / f'overview_pred_cls_scores.txt'
path.write_text(res)

# %%


# %%
