import torch

def flatten_grid(x, grid_size=[2, 2]):
    '''
    x: B x C x H x W
    '''
    B, C, H, W = x.size()

    hs, ws = grid_size

    img_h = H // hs

    flattened = torch.cat(torch.split(x, img_h, dim=2), dim=-1)

    return flattened

def unflatten_grid(x, grid_size=[2,2]):
    ''' 
    x: B x C x H x W
    '''
    B, C, H, W = x.size()
    hs, ws = grid_size
    img_w = W // (ws)

    unflattened = torch.cat(torch.split(x, img_w, dim=3), dim=-2)
        
    return unflattened
    
def prepare_key_grid_latents(latents_video, latent_grid_size=[2,2], key_grid_size=[3,3], rand_indices=None):

    T = latents_video.size(0)
    img_h, img_w = latents_video.size(-2) // latent_grid_size[0], latents_video.size(-1) // latent_grid_size[1]
    list_of_flattens = [flatten_grid(el.unsqueeze(0), latent_grid_size) for el in latents_video]
    long_flatten = torch.cat(list_of_flattens, dim=-1)
    
    keyframe_grid = unflatten_grid(torch.cat([long_flatten[:,:,:,ind*(img_w):(ind+1)*(img_w)] for ind in rand_indices], dim=-1), key_grid_size)
    return keyframe_grid, rand_indices

    
def pil_grid_to_frames(pil_grid, grid_size=[2,2]):
    w,h = pil_grid.size
    img_w = w // grid_size[1]
    img_h = h // grid_size[0]
    # resize to same w,h with source videos
    if img_w < 512 or img_h < 512:
        resized_img_w = img_w
        resized_img_h = img_h
    else:
        resize_factor = float(img_w / 512) if img_w >= img_h else float(img_h / 512)
        resized_img_w = int(img_w / resize_factor // 8) * 8
        resized_img_h = int(img_h / resize_factor // 8) * 8
    print(f'Saved_resized_img {resized_img_w,resized_img_h}')
    list_of_pil = []
    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            list_of_pil.append(pil_grid.crop((j*img_w, i*img_h, (j+1)*img_w, (i+1)*img_h)).resize((resized_img_w,resized_img_h)))
    return list_of_pil
    

if __name__ == '__main__':
    a = torch.randint(0,5,(1,3), dtype=torch.float)

    
    
