from copy import deepcopy
import numpy as np
import torch

from typing import List, Optional, Tuple, Union, Dict, Any
from dataclasses import dataclass

from diffusers import logging
from diffusers import DiffusionPipeline
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import BaseOutput

import sys
sys.path.append("/home/ubuntu/Open-Sora")
sys.path.append("/home/ubuntu/t2v-turbo")

from opensora.models.stdit import STDiT2
from opensora.models.vae import VideoAutoencoderKL
from opensora.schedulers.iddpm import IDDPM
from opensora.models.text_encoder import T5Encoder
from opensora.utils.inference_utils import prepare_multi_resolution_info

from utils.common_utils import guidance_scale_embedding, scalings_for_boundary_conditions

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
class T2VTurboOpenSoraPipelineOutput(BaseOutput):
    """
    Output class for the scheduler's `step` function output.
    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    denoised: Optional[torch.FloatTensor] = None


class T2VTurboOpenSoraPipeline(DiffusionPipeline):
    def __init__(
        self,
        dit: STDiT2,
        scheduler: IDDPM,
        vae: VideoAutoencoderKL,
        text_encoder: T5Encoder,
    ):
        super().__init__()

        self.register_modules(
            dit=dit,
            scheduler=scheduler,
            vae=vae,
            text_encoder=text_encoder,
        )

        self._device = next(dit.parameters()).device

    def prepare_latents(
        self,
        batch_size,
        frames,
        height,
        width,
        dtype,
        generator,
        latents=None,
    ):
        device = self._device
        input_size = (frames, height, width)
        latent_size = self.vae.get_latent_size(input_size)
        shape = (
            batch_size,
            self.vae.out_channels,
            *latent_size,
        )
        if latents is None:
            latents = randn_tensor(
                shape, generator=generator, device=device, dtype=dtype
            )
        else:
            latents = latents.to(device)
        # scale the initial noise by the standard deviation required by the scheduler, which is 1.0
        return latents

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        height: Optional[int] = 280,
        width: Optional[int] = 488,
        frames: int = 16,
        fps: int = 8,
        guidance_scale: float = 7.0,
        num_videos_per_prompt: Optional[int] = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        num_inference_steps: int = 4,
        lcm_origin_steps: int = 200,
        use_w_embedding_cond: bool = False,
        output_type: Optional[str] = "pil",
    ):
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
            prompt = [prompt]
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            raise ValueError("prompt must be a string or a list of strings")

        # Prepare timesteps
        # LCM Timesteps Setting:  # Linear Spacing
        c = 1000 // lcm_origin_steps
        lcm_origin_timesteps = (
            np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1
        )  # LCM Training  Steps Schedule
        skipping_step = len(lcm_origin_timesteps) // num_inference_steps
        timesteps = lcm_origin_timesteps[::-skipping_step][
            :num_inference_steps
        ]  # LCM Inference Steps Schedule
        self.timesteps = torch.from_numpy(timesteps.copy()).to(self.device)

        # Prepare latent variable
        latents = self.prepare_latents(
            batch_size * num_videos_per_prompt,
            frames,
            height,
            width,
            self.dtype,
            generator,
            latents,
        )

        bs = batch_size * num_videos_per_prompt
        additional_kwargs = prepare_multi_resolution_info(
            "STDiT2",
            bs,
            (height, width),
            frames,
            fps,
            self._device,
            self.dtype,
        )
        # Encode input prompt
        # Repeat interleave prompt for each video
        prompt_list = [p for p in prompt for _ in range(num_videos_per_prompt)]
        model_kwargs = self.text_encoder.encode(prompt_list)
        model_kwargs["y"] = model_kwargs["y"].to(self.device, self.dtype)
        model_kwargs.update(additional_kwargs)

        # 6. Get Guidance Scale Embedding
        w = torch.tensor(guidance_scale).repeat(bs)
        dit_model_kwargs = {}
        if use_w_embedding_cond:
            w_embedding = guidance_scale_embedding(
                w, embedding_dim=256, dtype=self.dtype
            ).to(self._device)
            dit_model_kwargs["w_embedding"] = w_embedding

        # 7. LCM MultiStep Sampling Loop:
        C = latents.shape[1]

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                ts = torch.full((bs,), t, device=self._device, dtype=torch.long)
                # model prediction (v-prediction, eps, x)
                model_output = self.dit(
                    latents, ts, **model_kwargs, **dit_model_kwargs
                )
                model_pred, _ = torch.split(model_output, C, dim=1)

                # compute the previous noisy sample x_t -> x_t-1
                latents, denoised = self.step(
                    model_pred, i, t, latents, return_dict=False
                )
                latents = latents.to(self.dtype)
                progress_bar.update()

        if not output_type == "latent":
            videos = self.vae.decode(denoised.to(self.dtype), num_frames=frames)
        else:
            videos = denoised

        return videos

    def step(
        self,
        model_output: torch.FloatTensor,
        timeindex: int,
        timestep: int,
        sample: torch.FloatTensor,
        return_dict: bool = True,
    ) -> Union[T2VTurboOpenSoraPipelineOutput, Tuple]:
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).
        Args:
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor`):
                A current instance of a sample created by the diffusion process.
            eta (`float`):
                The weight of noise for added noise in diffusion step.
            use_clipped_model_output (`bool`, defaults to `False`):
                If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
                because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
                clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
                `use_clipped_model_output` has no effect.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            variance_noise (`torch.FloatTensor`):
                Alternative to generating noise with `generator` by directly providing the noise for the variance
                itself. Useful for methods such as [`CycleDiffusion`].
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
        Returns:
            [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
        """
        timesteps = self.timesteps
        # 1. get previous step value
        prev_timeindex = timeindex + 1
        if prev_timeindex < len(timesteps):
            prev_timestep = timesteps[prev_timeindex]
        else:
            prev_timestep = timestep

        # 2. compute alphas, betas
        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
        alpha_prod_t_prev = (
            self.scheduler.alphas_cumprod[prev_timestep]
            if prev_timestep >= 0
            else torch.tensor(1.0)
        )

        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        # 3. Get scalings for boundary conditions
        c_skip, c_out = scalings_for_boundary_conditions(timestep)
        pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()

        # 4. Denoise model output using boundary conditions
        denoised = c_out * pred_x0 + c_skip * sample

        # 5. Sample z ~ N(0, I), For MultiStep Inference
        # Noise is not used for one-step sampling.
        if len(timesteps) > 1:
            noise = torch.randn(model_output.shape).to(model_output.device)
            prev_sample = (
                alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
            )
        else:
            prev_sample = denoised

        if not return_dict:
            return (prev_sample, denoised)

        return T2VTurboOpenSoraPipelineOutput(
            prev_sample=prev_sample, denoised=denoised
        )


if __name__ == "__main__":
    from mmengine.config import Config
    from opensora.registry import MODELS, SCHEDULERS, build_module

    device = torch.device("cuda:0")
    weight_dtype = torch.float16
    cfg = Config.fromfile("/home/ubuntu/t2v-turbo/configs/opensora-v1-1.py")

    scheduler = build_module(cfg.scheduler, SCHEDULERS)
    vae = build_module(cfg.vae, MODELS).eval()
    vae.to(device, weight_dtype)

    input_size = (16, 280, 488)
    latent_size = vae.get_latent_size(input_size)

    text_encoder = build_module(cfg.text_encoder, MODELS)
    text_encoder.t5.model.to(device, torch.float16)

    teacher_dit = build_module(
        cfg.model,
        MODELS,
        input_size=latent_size,
        in_channels=vae.out_channels,
        caption_channels=text_encoder.output_dim,
        model_max_length=text_encoder.model_max_length,
        enable_sequence_parallelism=False,
    ).eval()

    time_cond_proj_dim = 256
    dit_config = deepcopy(cfg.model)
    dit_config.time_cond_proj_dim = time_cond_proj_dim
    dit_config.from_pretrained = None

    dit = build_module(
        dit_config,
        MODELS,
        input_size=latent_size,
        in_channels=vae.out_channels,
        caption_channels=text_encoder.output_dim,
        model_max_length=text_encoder.model_max_length,
        enable_sequence_parallelism=False,
    ).eval()
    dit.load_state_dict(teacher_dit.state_dict(), strict=False)
    del teacher_dit

    dit.to(device, weight_dtype)
    text_encoder.y_embedder = dit.y_embedder

    pipeline = T2VTurboOpenSoraPipeline(
        dit=dit,
        scheduler=scheduler,
        vae=vae,
        text_encoder=text_encoder,
    )
    pipeline = pipeline.to(device, weight_dtype)
    generator = torch.Generator(device=device).manual_seed(42)

    validation_prompts = [
        "An astronaut riding a horse.",
        "Darth vader surfing in waves.",
        "Robot dancing in times square.",
        "Clown fish swimming through the coral reef.",
        "A child excitedly swings on a rusty swing set, laughter filling the air.",
        "With the style of van gogh, A young couple dances under the moonlight by the lake.",
        "A young woman with glasses is jogging in the park wearing a pink headband.",
        "Impressionist style, a yellow rubber duck floating on the wave on the sunset",
    ]

    video_logs = []

    for _, prompt in enumerate(validation_prompts):
        with torch.autocast("cuda"):
            videos = pipeline(
                prompt=prompt,
                frames=16,
                num_inference_steps=8,
                num_videos_per_prompt=1,
                generator=generator,
            )
            videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0
            videos = (videos * 255).to(torch.uint8).permute(0, 2, 1, 3, 4).cpu().numpy()
        video_logs.append({"validation_prompt": prompt, "videos": videos})
