import torch, torchvision
import math
from PIL import Image, ImageSequence

def prepare_latents(args, latents_dir, scheduler=None):
    latents_list = []

    if args.lookahead_denoising:
        video = torch.load(latents_dir+f"/{args.num_inference_steps}.pt")
        for i in range(args.video_length // 2):
            t = scheduler.timesteps[-1]
            alpha = scheduler.alphas_cumprod[t]
            beta = 1 - alpha
            x_0 = video[:,:,[0]]
            latents = alpha**(0.5) * x_0 + beta**(0.5) * torch.randn_like(x_0)
            latents_list.append(latents)
        for i in range(args.num_inference_steps):
            t = scheduler.timesteps[args.num_inference_steps-i-1]
            alpha = scheduler.alphas_cumprod[t]
            beta = 1 - alpha
            frame_idx = max(0, i-(args.num_inference_steps - args.video_length))
            x_0 = video[:,:,[frame_idx]]
            
            latents = alpha**(0.5) * x_0 + beta**(0.5) * torch.randn_like(x_0)
            latents_list.append(latents)
    else:
        video = torch.load(latents_dir+f"/{args.num_inference_steps}.pt")
        for i in range(args.num_inference_steps):
            t = scheduler.timesteps[args.num_inference_steps-i-1]
            alpha = scheduler.alphas_cumprod[t]
            beta = 1 - alpha

            frame_idx = max(0, i-(args.num_inference_steps - args.video_length))
            x_0 = video[:,:,[frame_idx]]
            
            latents = alpha**(0.5) * x_0 + beta**(0.5) * torch.randn_like(x_0)
            latents_list.append(latents)

    latents = torch.cat(latents_list, dim=2)

    return latents

def shift_latents(latents, i, args, scheduler=None):
    # shift latents
    latents[:,:,:-1] = latents[:,:,1:].clone()

    # add new noise to the last frame
    latents[:,:,-1] = torch.randn_like(latents[:,:,-1]) * scheduler.init_noise_sigma

    return latents

