# %%
from PIL import Image
import collections
from pathlib import Path
from matplotlib import ticker
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
import cv2
lt.monkey_patch()
lt.set_config(deeper_width=12)
torch.inf = float("Inf")
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from torchvision.datasets import ImageFolder
import torchvision
import tqdm

# %%

from samples.CLS2IDX import CLS2IDX
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_new

# IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_without_normalize = transforms.Compose([
    # transforms.Resize(256),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
transform = transforms.Compose([
    transform_without_normalize, 
    normalize
])

# %%


# initialize ViT pretrained
model = vit_new(pretrained=True).cuda()
model.eval()


# %%


# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + 1*np.float32(img)
    cam = cam / np.max(cam)
    return cam


# %%

def gen_visualization(image, cls_idx, layer):
    image_norm = normalize(image).unsqueeze(0).cuda() # (1, 3, 224, 224)
    predmap = model.predmap15_softmax_classes_batched_layer(image_norm, cls_idx, layer=layer) # (B, tokens-1)
    predmap = predmap.detach().cpu().reshape(1, 1, 14,14) # (1, 1, 14, 14)
    xai_map = predmap
    xai_map = (xai_map - xai_map.amin())/(xai_map.amax() - xai_map.amin())
    xai_map = torch.nn.functional.interpolate(xai_map, scale_factor=16, mode='bilinear') # (1, 1, 224, 224)
    xai_map = xai_map.squeeze().detach().cpu().numpy() # (224, 224)
    
    vis = show_cam_on_image(image.permute(1, 2, 0).numpy(), xai_map)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis




# %%
# cls_wind = None
# samples_dir = base_dir / "samples"
# sample_files = list(samples_dir.glob("*.png"))
# sample_files = list(samples_dir.glob("*.jpg"))

imagenet_dir = Path("~/ramdisk/datasets/imagenet/").expanduser()
# imagenet_val_dir = Path("~/ramdisk/datasets/imagenet/val").expanduser()
# cls_wind = "n07718472" # cucumber
# cls_wind = "n02099601" # golden retriever

# samples_dir = imagenet_val_dir / cls_wind
# sample_files = list(samples_dir.iterdir())
# sample_files = [f for f in sample_files if 
#                 "38781" in f.stem
# or "6981" in f.stem
# or "11373" in f.stem
# or "12100" in f.stem
# or "26887" in f.stem
# or "37455" in f.stem
#                 ]

import torchvision.datasets
imagenet_ds = torchvision.datasets.ImageNet(imagenet_dir, split='val', transform=transform_without_normalize)
torch.manual_seed(1)
rand_indices = torch.randperm(len(imagenet_ds))

data = []
for i in rand_indices[:8*10]:
    i=i.item()
    img_path, cls_idx = imagenet_ds.samples[i]
    img_path = Path(img_path)
    data.append((img_path, cls_idx))
# %%

# plt.ion()
plt.ioff()
output_dir: Path = base_dir / "figures/artifacts/supp/predicatt_explore"
output_dir.mkdir(exist_ok=True, parents=True)
# plt.rcParams.update({
#     "text.usetex": True,
#     "font.family": "Helvetica"
# })

plt.rcParams.update({
    "text.usetex": True,  # Enables latex equations
    "font.family": "cmu-serif", # Sets the correct font
    "mathtext.fontset": "cm",   # --"--
    "font.size": 20,            # Set the font according to what you need
    "text.latex.preamble": r"\usepackage{amsmath}"   # You can add this to enable complicated math stuff
})
plt.style.use('tableau-colorblind10')  # You can use this to get a colorblind color palette


plt.close("all")
plt.close()
def create_page(data, colgroups=1):
    import math
    
    nrows = math.ceil(len(data) / colgroups)
    cols = 3*colgroups
    fig, axes = plt.subplots(nrows, cols, figsize=(cols, nrows))
    
    for i, (img_path, cls_idx) in enumerate(data):
        r = i // colgroups
        c = (i % colgroups)*3
        # open image
        image = Image.open(img_path)
        image = transform_without_normalize(image) # (3, 224, 224)
        if image.shape[0] == 1:
            image = image.expand(3, -1, -1)
        if image.shape[0] == 4:
            # remove alpha channel
            image = image[:3] 

        vis10 = gen_visualization(image, cls_idx, 10) # (224, 224, 3)
        vis11 = gen_visualization(image, cls_idx, 11) # (224, 224, 3)
        axes[r, c+0].imshow(image.permute(1, 2, 0).numpy())
        axes[r, c+1].imshow(vis10)
        axes[r, c+2].imshow(vis11)
        def format_label(raw_label):
            l = raw_label.split(",")[0]
            return l
        cls_label = format_label(f"{CLS2IDX[cls_idx]}")
        axes[r, c+0].set_ylabel(cls_label, fontsize=10)
        if r==0:
            axes[r, c+0].set_title("Image", fontsize=13)
            axes[r, c+1].set_title(r"PredicAtt$_{11}$", fontsize=13)
            axes[r, c+2].set_title(r"PredicAtt$_{12}$", fontsize=13)
    for ax in axes.flatten():
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_frame_on(False)
        # ax.axis('off')
    
    # Ignore ylabels when doing tight_layout()
    ylabels = []
    for ax in axes.flat:
        ylabels.append(ax.get_ylabel())
        ax.set_ylabel('')
    fig.tight_layout(pad=0.1)
    for ax, yl in zip(axes.flat, ylabels):
        ax.set_ylabel(f'{yl}')
    return fig
    # plt.close(fig)


create_page(data[:8])

# %%

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

from svglib.svglib import svg2rlg
from reportlab.graphics import renderPDF

plt.ioff()
# plt.ion()
for i, chunk in enumerate(chunks(data, 8)):
    print(f"page {i}")
    # print(len(chunk))
    fig = create_page(chunk)
    path = output_dir / f"predicatt_explore_page{i}.svg"
    fig.savefig(path, bbox_inches='tight', pad_inches=0.05)

    plt.close(fig)


    drawing = svg2rlg(path)
    renderPDF.drawToFile(drawing, str(path.with_suffix(".pdf")))
    # break

# %%

