import torch
import numpy as np
import os
import  torchvision.transforms.functional as TF
from torch.nn import functional as F
from utils.process_mask import process_mask
import PIL
def create_background_frame(video: torch.Tensor) -> torch.Tensor:
    """
    Fills the entire frame with the 120x120 top-left corner of the first frame, repeated to fill 480x720.
    Ignores depth.
    Args:
        video: Video tensor of shape [49, 3, 480, 720] with values in [-1, 1]
    Returns:
        Video tensor of shape [49, 3, 480, 720] with values in [-1, 1]
    """
    first_frame = video[0]  # [3, 480, 720]
    patch = first_frame[:, :60, :60]  # [3, 120, 120]
    # Repeat the patch to fill 480x720
    repeated = patch.repeat(1, 8, 12)  # [3, 480, 720]
    # In case of any overrun, crop to exact size
    repeated = repeated[:, :480, :720]
    # Repeat for all frames
    out_video = repeated.unsqueeze(0).repeat(video.shape[0], 1, 1, 1)
    return out_video

def fill_background(video_tensor: torch.Tensor, mask_path: str,  device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    mask = torch.load(mask_path).to(device=device, dtype=dtype)
    mask_resized = F.interpolate(mask, size=(480, 720), mode='area') # 49 x 1 x 480 x 720
    mask_resized = torch.where(mask_resized < 1, 0., 1.)
    print('[processing mask...]')
    process_mask_path = f'{mask_path.split(".")[0]}_processed.pt'
    if os.path.exists(process_mask_path):
        mask_resized = torch.load(process_mask_path)
    else:
        mask_resized = process_mask(mask_resized)
        torch.save(mask_resized, process_mask_path)
    
    print('[done processing mask...]')
    print("[processing video...]")
    # Process the video
    print(type(video_tensor[0]))
    if type(video_tensor[0]) == PIL.Image.Image:
        video_tensor = [TF.pil_to_tensor(video_tensor[i]) for i in range(49)]
        video_tensor = torch.stack(video_tensor, dim=0).to(device=device) # 49 x 3 x 480 x 720
    
    filled_video_tensor = create_background_frame(video_tensor)
    video_tensor = torch.where(mask_resized.repeat(1, 3, 1, 1) == 0.0, filled_video_tensor, video_tensor)
    
    return video_tensor
