from torchvision.transforms.functional import to_pil_image
import numpy as np
from PIL import Image

def inverse_normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """
    :param img: numpy array. shape (height, width, channel). [-1~1]
    :return: numpy array. shape (height, width, channel). [0~1]
    """
    img[:,:,0] = ((img[:,:,0]) * std[0]) + mean[0]
    img[:,:,1] = ((img[:,:,1]) * std[1]) + mean[1]
    img[:,:,2] = ((img[:,:,2]) * std[2]) + mean[2]
    return img
    
def debug(data_loader, name):
    batch = next(iter(data_loader))
    print("batch keys : ", batch.keys())
    print("batch video shape : ", batch['video'].shape)
    print("sid in batch : ", len(batch['sid']))

    sample_index = 0
    print("sample index : {}".format(sample_index))
    if "vid" in batch.keys():
        print(batch['vid'][sample_index])
    print(batch['sid'][sample_index], batch['video'][sample_index].shape)
    if "sparse_idx" in batch.keys():
        print(batch['sparse_idx'][sample_index], batch['dense_idx'][sample_index])
    keyframes = batch['video'][sample_index].detach().numpy()
    keyframes = np.transpose(keyframes, (0, 1, 3, 4, 2))
    s, k, h, w, c = keyframes.shape
    shots = []
    for i in range(s):
        imgs = []
        for j in range(k):
            # print(type(keyframes[i,j,:,:,:]), keyframes[i,j,:,:,:].shape)
            img = Image.fromarray((inverse_normalize(keyframes[i,j,:,:,:])* 255).astype(np.uint8))
            imgs.append(img)
        width = w
        height = h * k
        new_im = Image.new('RGB', (width, height))

        x_offset = 0
        for im in imgs:
            new_im.paste(im, (0,x_offset))
            x_offset += im.size[1]

        shots.append(new_im)
    width = w * s
    height = h * k
    new_im = Image.new('RGB', (width, height))

    y_offset = 0
    for shot in shots:
        new_im.paste(shot, (y_offset,0))
        y_offset += im.size[0]

    new_im.save("{}_{}.png".format(name, batch['sid'][sample_index]))