from typing import Any, Dict, List, Optional, Union

import torch
from diffusers import (
    AutoencoderKLCogVideoX,
    CogVideoXDDIMScheduler,
    CogVideoXPipeline,
    CogVideoXTransformer3DModel,
)
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer

from .utils import prepare_rotary_positional_embeddings


def load_condition_models(
    model_id: str = "THUDM/CogVideoX-5b",
    text_encoder_dtype: torch.dtype = torch.bfloat16,
    revision: Optional[str] = None,
    cache_dir: Optional[str] = None,
    **kwargs,
):
    tokenizer = T5Tokenizer.from_pretrained(
        model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir
    )
    text_encoder = T5EncoderModel.from_pretrained(
        model_id,
        subfolder="text_encoder",
        torch_dtype=text_encoder_dtype,
        revision=revision,
        cache_dir=cache_dir,
    )
    return {"tokenizer": tokenizer, "text_encoder": text_encoder}


def load_latent_models(
    model_id: str = "THUDM/CogVideoX-5b",
    vae_dtype: torch.dtype = torch.bfloat16,
    revision: Optional[str] = None,
    cache_dir: Optional[str] = None,
    **kwargs,
):
    vae = AutoencoderKLCogVideoX.from_pretrained(
        model_id,
        subfolder="vae",
        torch_dtype=vae_dtype,
        revision=revision,
        cache_dir=cache_dir,
    )
    return {"vae": vae}


def load_diffusion_models(
    model_id: str = "THUDM/CogVideoX-5b",
    transformer_dtype: torch.dtype = torch.bfloat16,
    revision: Optional[str] = None,
    cache_dir: Optional[str] = None,
    **kwargs,
):
    transformer = CogVideoXTransformer3DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        torch_dtype=transformer_dtype,
        revision=revision,
        cache_dir=cache_dir,
    )
    scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
    return {"transformer": transformer, "scheduler": scheduler}


def initialize_pipeline(
    model_id: str = "THUDM/CogVideoX-5b",
    text_encoder_dtype: torch.dtype = torch.bfloat16,
    transformer_dtype: torch.dtype = torch.bfloat16,
    vae_dtype: torch.dtype = torch.bfloat16,
    tokenizer: Optional[T5Tokenizer] = None,
    text_encoder: Optional[T5EncoderModel] = None,
    transformer: Optional[CogVideoXTransformer3DModel] = None,
    vae: Optional[AutoencoderKLCogVideoX] = None,
    scheduler: Optional[CogVideoXDDIMScheduler] = None,
    device: Optional[torch.device] = None,
    revision: Optional[str] = None,
    cache_dir: Optional[str] = None,
    enable_slicing: bool = False,
    enable_tiling: bool = False,
    enable_model_cpu_offload: bool = False,
    is_training: bool = False,
    **kwargs,
) -> CogVideoXPipeline:
    component_name_pairs = [
        ("tokenizer", tokenizer),
        ("text_encoder", text_encoder),
        ("transformer", transformer),
        ("vae", vae),
        ("scheduler", scheduler),
    ]
    components = {}
    for name, component in component_name_pairs:
        if component is not None:
            components[name] = component

    pipe = CogVideoXPipeline.from_pretrained(
        model_id, **components, revision=revision, cache_dir=cache_dir
    )
    pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
    pipe.vae = pipe.vae.to(dtype=vae_dtype)

    # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
    # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
    # DDP optimizer step.
    if not is_training:
        pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)

    if enable_slicing:
        pipe.vae.enable_slicing()
    if enable_tiling:
        pipe.vae.enable_tiling()

    if enable_model_cpu_offload:
        pipe.enable_model_cpu_offload(device=device)
    else:
        pipe.to(device=device)

    return pipe


def prepare_conditions(
    tokenizer,
    text_encoder,
    prompt: Union[str, List[str]],
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    max_sequence_length: int = 226,  # TODO: this should be configurable
    **kwargs,
):
    device = device or text_encoder.device
    dtype = dtype or text_encoder.dtype
    return _get_t5_prompt_embeds(
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        prompt=prompt,
        max_sequence_length=max_sequence_length,
        device=device,
        dtype=dtype,
    )


def prepare_latents(
    vae: AutoencoderKLCogVideoX,
    image_or_video: torch.Tensor,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    generator: Optional[torch.Generator] = None,
    precompute: bool = False,
    **kwargs,
) -> torch.Tensor:
    device = device or vae.device
    dtype = dtype or vae.dtype

    if image_or_video.ndim == 4:
        image_or_video = image_or_video.unsqueeze(2)
    assert (
        image_or_video.ndim == 5
    ), f"Expected 5D tensor, got {image_or_video.ndim}D tensor"

    image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
    image_or_video = image_or_video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]
    if not precompute:
        latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
        if not vae.config.invert_scale_latents:
            latents = latents * vae.config.scaling_factor
        # For training Cog 1.5, we don't need to handle the scaling factor here.
        # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
        # is probably only needed for image-to-video training.
        #  investigate this
        # else:
        #     latents = 1 / vae.config.scaling_factor * latents
        latents = latents.to(dtype=dtype)
        return {
            "latents": latents.permute(0, 2, 1, 3, 4)
        }  # TODO: fix for CogVideoX [B, F, C, H, W]
    else:
        # handle vae scaling in the `train()` method directly.
        if vae.use_slicing and image_or_video.shape[0] > 1:
            encoded_slices = [
                vae._encode(x_slice) for x_slice in image_or_video.split(1)
            ]
            h = torch.cat(encoded_slices)
        else:
            h = vae._encode(image_or_video)
        return {"latents": h}


def post_latent_preparation(
    vae_config: Dict[str, Any],
    latents: torch.Tensor,
    patch_size_t: Optional[int] = None,
    **kwargs,
) -> torch.Tensor:
    if not vae_config.invert_scale_latents:
        latents = latents * vae_config.scaling_factor
    # For training Cog 1.5, we don't need to handle the scaling factor here.
    # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
    # is probably only needed for image-to-video training.
    #  investigate this
    # else:
    #     latents = 1 / vae_config.scaling_factor * latents
    latents = _pad_frames(latents, patch_size_t)
    latents = latents.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]
    return {"latents": latents}


def collate_fn_t2v(
    batch: List[List[Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
    return {
        "prompts": [x["prompt"] for x in batch[0]],
        "videos": torch.stack([x["video"] for x in batch[0]]),
    }


def calculate_noisy_latents(
    scheduler: CogVideoXDDIMScheduler,
    noise: torch.Tensor,
    latents: torch.Tensor,
    timesteps: torch.LongTensor,
) -> torch.Tensor:
    noisy_latents = scheduler.add_noise(latents, noise, timesteps)
    return noisy_latents


def forward_pass(
    transformer: CogVideoXTransformer3DModel,
    scheduler: CogVideoXDDIMScheduler,
    prompt_embeds: torch.Tensor,
    latents: torch.Tensor,
    noisy_latents: torch.Tensor,
    timesteps: torch.LongTensor,
    ofs_emb: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
    VAE_SPATIAL_SCALE_FACTOR = 8
    transformer_config = (
        transformer.module.config
        if hasattr(transformer, "module")
        else transformer.config
    )
    batch_size, num_frames, num_channels, height, width = noisy_latents.shape
    rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
    rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR

    image_rotary_emb = (
        prepare_rotary_positional_embeddings(
            height=height * VAE_SPATIAL_SCALE_FACTOR,
            width=width * VAE_SPATIAL_SCALE_FACTOR,
            num_frames=num_frames,
            vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
            patch_size=transformer_config.patch_size,
            patch_size_t=transformer_config.patch_size_t
            if hasattr(transformer_config, "patch_size_t")
            else None,
            attention_head_dim=transformer_config.attention_head_dim,
            device=transformer.device,
            base_height=rope_base_height,
            base_width=rope_base_width,
        )
        if transformer_config.use_rotary_positional_embeddings
        else None
    )
    ofs_emb = (
        None
        if transformer_config.ofs_embed_dim is None
        else latents.new_full((batch_size,), fill_value=2.0)
    )

    velocity = transformer(
        hidden_states=noisy_latents,
        timestep=timesteps,
        encoder_hidden_states=prompt_embeds,
        ofs=ofs_emb,
        image_rotary_emb=image_rotary_emb,
        return_dict=False,
    )[0]
    # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
    # code paths as scheduler.get_velocity(), which can be confusing to understand.
    denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps)

    return {"latents": denoised_latents}


def validation(
    pipeline: CogVideoXPipeline,
    prompt: str,
    image: Optional[Image.Image] = None,
    video: Optional[List[Image.Image]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_frames: Optional[int] = None,
    num_videos_per_prompt: int = 1,
    generator: Optional[torch.Generator] = None,
    **kwargs,
):
    generation_kwargs = {
        "prompt": prompt,
        "height": height,
        "width": width,
        "num_frames": num_frames,
        "num_videos_per_prompt": num_videos_per_prompt,
        "generator": generator,
        "return_dict": True,
        "output_type": "pil",
    }
    generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
    output = pipeline(**generation_kwargs).frames[0]
    return [("video", output)]


def _get_t5_prompt_embeds(
    tokenizer: T5Tokenizer,
    text_encoder: T5EncoderModel,
    prompt: Union[str, List[str]] = None,
    max_sequence_length: int = 226,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt

    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=max_sequence_length,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids

    prompt_embeds = text_encoder(text_input_ids.to(device))[0]
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    return {"prompt_embeds": prompt_embeds}


def _pad_frames(latents: torch.Tensor, patch_size_t: int):
    if patch_size_t is None or patch_size_t == 1:
        return latents

    # `latents` should be of the following format: [B, C, F, H, W].
    # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
    latent_num_frames = latents.shape[2]
    additional_frames = patch_size_t - latent_num_frames % patch_size_t

    if additional_frames > 0:
        last_frame = latents[:, :, -1:, :, :]
        padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1)
        latents = torch.cat([latents, padding_frames], dim=2)

    return latents


#  refactor into model specs for better re-use
COGVIDEOX_T2V_LORA_CONFIG = {
    "pipeline_cls": CogVideoXPipeline,
    "load_condition_models": load_condition_models,
    "load_latent_models": load_latent_models,
    "load_diffusion_models": load_diffusion_models,
    "initialize_pipeline": initialize_pipeline,
    "prepare_conditions": prepare_conditions,
    "prepare_latents": prepare_latents,
    "post_latent_preparation": post_latent_preparation,
    "collate_fn": collate_fn_t2v,
    "calculate_noisy_latents": calculate_noisy_latents,
    "forward_pass": forward_pass,
    "validation": validation,
}
