from copy import deepcopy
import numpy as np
import torch
import torchvision

from typing import List, Optional, Tuple, Union, Dict, Any
from dataclasses import dataclass

from diffusers import logging
from diffusers import DiffusionPipeline
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import BaseOutput

import sys

sys.path.append("/home/ubuntu/Open-Sora")
sys.path.append("/home/ubuntu/t2v-turbo")

from ode_solver.ddim_solver import DDIMSolver
from opensora.models.stdit import STDiT2
from opensora.models.vae import VideoAutoencoderKL
from opensora.schedulers.iddpm import IDDPM
from opensora.models.text_encoder import T5Encoder
from opensora.utils.inference_utils import prepare_multi_resolution_info

from utils.common_utils import (
    get_predicted_noise,
    get_predicted_original_sample,
)

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
class DDIMOpenSoraPipelineOutput(BaseOutput):
    """
    Output class for the scheduler's `step` function output.
    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    denoised: Optional[torch.FloatTensor] = None


class DDIMOpenSoraPipeline(DiffusionPipeline):
    def __init__(
        self,
        dit: STDiT2,
        scheduler: IDDPM,
        vae: VideoAutoencoderKL,
        text_encoder: T5Encoder,
    ):
        super().__init__()

        self.register_modules(
            dit=dit,
            scheduler=scheduler,
            vae=vae,
            text_encoder=text_encoder,
        )

        self._device = next(dit.parameters()).device

    def prepare_latents(
        self,
        batch_size,
        frames,
        height,
        width,
        dtype,
        generator,
        latents=None,
    ):
        device = self._device
        input_size = (frames, height, width)
        latent_size = self.vae.get_latent_size(input_size)
        shape = (
            batch_size,
            self.vae.out_channels,
            *latent_size,
        )
        if latents is None:
            latents = randn_tensor(
                shape, generator=generator, device=device, dtype=dtype
            )
        else:
            latents = latents.to(device)
        # scale the initial noise by the standard deviation required by the scheduler, which is 1.0
        return latents

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        height: Optional[int] = 280,
        width: Optional[int] = 488,
        frames: int = 16,
        fps: int = 8,
        guidance_scale: float = 7,
        num_videos_per_prompt: Optional[int] = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        num_inference_steps: int = 4,
        output_type: Optional[str] = "pil",
    ):
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
            prompt = [prompt]
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            raise ValueError("prompt must be a string or a list of strings")

        ddim_solver = DDIMSolver(
            scheduler.alphas_cumprod.cpu().numpy(),
            ddim_timesteps=num_inference_steps,
            use_scale=False,
        )
        ddim_solver.to(device, self.dtype)
        alpha_schedule = torch.sqrt(scheduler.alphas_cumprod).to(device)
        sigma_schedule = torch.sqrt(1 - scheduler.alphas_cumprod).to(device)

        # Prepare latent variable
        latents = self.prepare_latents(
            batch_size * num_videos_per_prompt,
            frames,
            height,
            width,
            self.dtype,
            generator,
            latents,
        )

        bs = batch_size * num_videos_per_prompt
        additional_kwargs = prepare_multi_resolution_info(
            "STDiT2",
            bs,
            (height, width),
            frames,
            fps,
            self._device,
            self.dtype,
        )
        # Encode input prompt
        # Repeat interleave prompt for each video
        prompt_list = [p for p in prompt for _ in range(num_videos_per_prompt)]
        model_kwargs = self.text_encoder.encode(prompt_list)
        model_kwargs["y"] = model_kwargs["y"].to(self.device, self.dtype)
        model_kwargs.update(additional_kwargs)

        uncond_model_kwargs = {
            "y": text_encoder.null(bs).to(device, weight_dtype),
            "mask": None,
        }
        uncond_model_kwargs.update(additional_kwargs)

        # 7. LCM MultiStep Sampling Loop:
        C = latents.shape[1]

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i in range(num_inference_steps):
                index = len(ddim_solver.ddim_timesteps) - i - 1
                ts = torch.full(
                    (bs,),
                    ddim_solver.ddim_timesteps[index],
                    device=self._device,
                    dtype=torch.long,
                )
                index = torch.full((bs,), index, device=self._device, dtype=torch.long)
                # model prediction (v-prediction, eps, x)
                model_output = self.dit(latents, ts, **model_kwargs)
                cond_teacher_output, _ = torch.split(model_output, C, dim=1)
                cond_pred_x0 = get_predicted_original_sample(
                    cond_teacher_output,
                    ts,
                    latents,
                    "epsilon",
                    alpha_schedule,
                    sigma_schedule,
                )
                cond_pred_noise = get_predicted_noise(
                    cond_teacher_output,
                    ts,
                    latents,
                    "epsilon",
                    alpha_schedule,
                    sigma_schedule,
                )

                model_output = self.dit(latents, ts, **uncond_model_kwargs)
                uncond_teacher_output, _ = torch.split(model_output, C, dim=1)
                uncond_pred_x0 = get_predicted_original_sample(
                    uncond_teacher_output,
                    ts,
                    latents,
                    "epsilon",
                    alpha_schedule,
                    sigma_schedule,
                )
                uncond_pred_noise = get_predicted_noise(
                    uncond_teacher_output,
                    ts,
                    latents,
                    "epsilon",
                    alpha_schedule,
                    sigma_schedule,
                )

                pred_x0 = cond_pred_x0 + guidance_scale * (
                    cond_pred_x0 - uncond_pred_x0
                )
                pred_noise = cond_pred_noise + guidance_scale * (
                    cond_pred_noise - uncond_pred_noise
                )

                latents = ddim_solver.ddim_step(pred_x0, pred_noise, index).to(
                    self.dtype
                )
                progress_bar.update()

        if not output_type == "latent":
            videos = self.vae.decode(latents.to(self.dtype), num_frames=frames)
        else:
            videos = latents

        return videos


if __name__ == "__main__":
    from mmengine.config import Config
    from opensora.registry import MODELS, SCHEDULERS, build_module

    device = torch.device("cuda:0")
    weight_dtype = torch.float16
    cfg = Config.fromfile("/home/ubuntu/t2v-turbo/configs/opensora-v1-1.py")

    scheduler = build_module(cfg.scheduler, SCHEDULERS)
    vae = build_module(cfg.vae, MODELS).eval()
    vae.to(device, weight_dtype)

    input_size = (16, 280, 488)
    latent_size = vae.get_latent_size(input_size)

    text_encoder = build_module(cfg.text_encoder, MODELS)
    text_encoder.t5.model.to(device, torch.float16)

    dit = (
        build_module(
            cfg.model,
            MODELS,
            input_size=latent_size,
            in_channels=vae.out_channels,
            caption_channels=text_encoder.output_dim,
            model_max_length=text_encoder.model_max_length,
            enable_sequence_parallelism=False,
        )
        .eval()
        .to(device, weight_dtype)
    )
    text_encoder.y_embedder = dit.y_embedder

    pipeline = DDIMOpenSoraPipeline(
        dit=dit,
        scheduler=scheduler,
        vae=vae,
        text_encoder=text_encoder,
    )
    pipeline = pipeline.to(device, weight_dtype)
    generator = torch.Generator(device=device).manual_seed(42)

    validation_prompts = [
        "A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.",
        "A majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.",
        "A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene.  Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene.  In the foreground, a few cars can be seen driving along a winding road that cuts through the mountains. The cars are small compared to the vastness of the landscape, emphasizing the grandeur of the surroundings.  The overall style of the video is a mix of adventure and tranquility, with the hot air balloons adding a touch of whimsy to the otherwise serene mountain landscape. The video is likely shot during the day, as the lighting is bright and even, casting soft shadows on the snow-covered mountains.",
        "The vibrant beauty of a sunflower field. The sunflowers, with their bright yellow petals and dark brown centers, are in full bloom, creating a stunning contrast against the green leaves and stems. The sunflowers are arranged in neat rows, creating a sense of order and symmetry. The sun is shining brightly, casting a warm glow on the flowers and highlighting their intricate details. The video is shot from a low angle, looking up at the sunflowers, which adds a sense of grandeur and awe to the scene. The sunflowers are the main focus of the video, with no other objects or people present. The video is a celebration of nature's beauty and the simple joy of a sunny day in the countryside.",
        "A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene. The video is shot from a slightly elevated angle, providing a comprehensive view of the turtle's surroundings. The overall style of the video is calm and peaceful, capturing the beauty and tranquility of the underwater world.",
        "A vibrant underwater scene. A group of blue fish, with yellow fins, are swimming around a coral reef. The coral reef is a mix of brown and green, providing a natural habitat for the fish. The water is a deep blue, indicating a depth of around 30 feet. The fish are swimming in a circular pattern around the coral reef, indicating a sense of motion and activity. The overall scene is a beautiful representation of marine life.",
        "A bustling city street at night, filled with the glow of car headlights and the ambient light of streetlights. The scene is a blur of motion, with cars speeding by and pedestrians navigating the crosswalks. The cityscape is a mix of towering buildings and illuminated signs, creating a vibrant and dynamic atmosphere. The perspective of the video is from a high angle, providing a bird's eye view of the street and its surroundings. The overall style of the video is dynamic and energetic, capturing the essence of urban life at night.",
        "A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road.",
        "The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements.",
        "A serene night scene in a forested area. The first frame shows a tranquil lake reflecting the star-filled sky above. The second frame reveals a beautiful sunset, casting a warm glow over the landscape. The third frame showcases the night sky, filled with stars and a vibrant Milky Way galaxy. The video is a time-lapse, capturing the transition from day to night, with the lake and forest serving as a constant backdrop. The style of the video is naturalistic, emphasizing the beauty of the night sky and the peacefulness of the forest.",

        "An astronaut riding a horse.",
        "Darth vader surfing in waves.",
        "Robot dancing in times square.",
        "Clown fish swimming through the coral reef.",
        "A child excitedly swings on a rusty swing set, laughter filling the air.",
        "With the style of van gogh, A young couple dances under the moonlight by the lake.",
        "A young woman with glasses is jogging in the park wearing a pink headband.",
        "Impressionist style, a yellow rubber duck floating on the wave on the sunset",
    ]

    video_logs = []

    for i, prompt in enumerate(validation_prompts):
        with torch.autocast("cuda"):
            videos = pipeline(
                prompt=prompt,
                frames=16,
                num_inference_steps=100,
                num_videos_per_prompt=1,
                generator=generator,
            )
            videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0
            videos = (videos * 255).to(torch.uint8).permute(0, 2, 1, 3, 4).cpu().numpy()

            torchvision.io.write_video(
                f"video_{i}.mp4",
                torch.from_numpy(videos)[0].permute(0, 2, 3, 1),
                fps=8,
                video_codec="h264",
                options={"crf": "10"},
            )
