from typing import Dict, List, Optional, Union, Any, Literal, Tuple
import math
from pathlib import Path
from omegaconf import OmegaConf
from tqdm import tqdm
import random
import gc
import inspect
import numpy as np
import torch
import diffusion
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
from einops import rearrange
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXVideoTransformer3DModel
from transformers import T5EncoderModel, T5Tokenizer
import pytorch_lightning as pl
import diffusion
from diffusion.model.modular.layers import RMSNorm, MappingFeedForwardBlock
from torchvision.transforms.v2.functional import resize as tv_resize
from torch.nn.functional import pad
import torch.nn.functional as F
import av

from diffusion.model.video_model_finetuning.lora_modules import DataProvider, SimpleLoraLinear

from diffusion.utils import load_model_inference_direct

# from diffusion.utils import load_model_inference
from diffusion.model.motion_representation import MotionRepresentationLearner


def write_video(
    filename: str,
    video_array: torch.Tensor,
    fps: float,
    video_codec: str = "libx264",
    bitrate: int = 3500,
    options: Optional[Dict[str, Any]] = None,
    audio_array: Optional[torch.Tensor] = None,
    audio_fps: Optional[float] = None,
    audio_codec: Optional[str] = None,
    audio_options: Optional[Dict[str, Any]] = None,
) -> None:
    # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
    #     _log_api_usage_once(write_video)
    # _check_av_available()
    video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True)

    # PyAV does not support floating point numbers with decimal point
    # and will throw OverflowException in case this is not the case
    if isinstance(fps, float):
        fps = int(round(fps))

    with av.open(filename, mode="w") as container:
        stream = container.add_stream(video_codec, rate=fps)
        stream.codec_context.bit_rate = bitrate * 1e3
        stream.width = video_array.shape[2]
        stream.height = video_array.shape[1]
        stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
        stream.options = options or {}

        if audio_array is not None:
            audio_format_dtypes = {
                "dbl": "<f8",
                "dblp": "<f8",
                "flt": "<f4",
                "fltp": "<f4",
                "s16": "<i2",
                "s16p": "<i2",
                "s32": "<i4",
                "s32p": "<i4",
                "u8": "u1",
                "u8p": "u1",
            }
            a_stream = container.add_stream(audio_codec, rate=audio_fps)
            a_stream.options = audio_options or {}

            num_channels = audio_array.shape[0]
            audio_layout = "stereo" if num_channels > 1 else "mono"
            audio_sample_fmt = container.streams.audio[0].format.name

            format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
            audio_array = torch.as_tensor(audio_array).numpy(force=True).astype(format_dtype)

            frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)

            frame.sample_rate = audio_fps

            for packet in a_stream.encode(frame):
                container.mux(packet)

            for packet in a_stream.encode():
                container.mux(packet)

        for img in video_array:
            frame = av.VideoFrame.from_ndarray(img, format="rgb24")
            for packet in stream.encode(frame):
                container.mux(packet)

        # Flush stream
        for packet in stream.encode():
            container.mux(packet)


from torchvision.utils import make_grid


def make_video_grid(videos, *args, **kwargs):
    video_grid = torch.stack(
        [make_grid([video[fidx] for video in videos], *args, **kwargs) for fidx in range(videos[0].shape[0])]
    )
    return video_grid


def add_text_to_video(
    video, text, xy=(0, 0), size=25, color=(224, 224, 224), font="/usr/share/fonts/dejavu/DejaVuSans.ttf"
):
    from PIL import Image, ImageFont, ImageDraw

    font = ImageFont.truetype(font, size)
    is_pt = False
    if isinstance(video, torch.Tensor):
        is_pt = True
        device = video.device
        video = video.detach().cpu().numpy()
    video = video.transpose(0, 2, 3, 1)
    imgs = []
    for np_img in video:
        pil_img = Image.fromarray(np_img, mode="RGB")
        draw = ImageDraw.Draw(pil_img)
        draw.text(xy=xy, text=text, fill=color, font=font)
        img = np.asarray(pil_img)
        imgs.append(img)
    video = np.stack(imgs).transpose(0, 3, 1, 2)
    if is_pt:
        video = torch.from_numpy(video).to(device)
    return video


def expand_tensor_dims(tensor, ndim):
    while len(tensor.shape) < ndim:
        tensor = tensor.unsqueeze(-1)
    return tensor


def _get_t5_prompt_embeds(
    tokenizer: T5Tokenizer,
    text_encoder: T5EncoderModel,
    prompt: Union[str, List[str]] = None,
    num_videos_per_prompt: int = 1,
    max_sequence_length: int = 128,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    dtype = dtype or text_encoder.dtype

    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=max_sequence_length,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    prompt_attention_mask = text_inputs.attention_mask
    prompt_attention_mask = prompt_attention_mask.bool().to(device)

    untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
        removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
        # print(
        #     "The following part of your input was truncated because `max_sequence_length` is set to "
        #     f" {max_sequence_length} tokens: {removed_text}"
        # )

    prompt_embeds = text_encoder(text_input_ids.to(device))[0]
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    # duplicate text embeddings for each generation per prompt, using mps friendly method
    _, seq_len, _ = prompt_embeds.shape
    prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)

    prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
    prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)

    return prompt_embeds, prompt_attention_mask


def encode_prompt(
    tokenizer: T5Tokenizer,
    text_encoder: T5EncoderModel,
    prompt: Union[str, List[str]],
    negative_prompt: Optional[Union[str, List[str]]] = None,
    do_classifier_free_guidance: bool = True,
    num_videos_per_prompt: int = 1,
    prompt_embeds: Optional[torch.Tensor] = None,
    negative_prompt_embeds: Optional[torch.Tensor] = None,
    prompt_attention_mask: Optional[torch.Tensor] = None,
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
    max_sequence_length: int = 128,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    # device = device or self._execution_device

    prompt = [prompt] if isinstance(prompt, str) else prompt
    if prompt is not None:
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    if prompt_embeds is None:
        prompt_embeds, prompt_attention_mask = _get_t5_prompt_embeds(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            prompt=prompt,
            num_videos_per_prompt=num_videos_per_prompt,
            max_sequence_length=max_sequence_length,
            device=device,
            dtype=dtype,
        )

    if do_classifier_free_guidance and negative_prompt_embeds is None:
        negative_prompt = negative_prompt or ""
        negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt

        if prompt is not None and type(prompt) is not type(negative_prompt):
            raise TypeError(
                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                f" {type(prompt)}."
            )
        elif batch_size != len(negative_prompt):
            raise ValueError(
                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                " the batch size of `prompt`."
            )

        negative_prompt_embeds, negative_prompt_attention_mask = _get_t5_prompt_embeds(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            prompt=negative_prompt,
            num_videos_per_prompt=num_videos_per_prompt,
            max_sequence_length=max_sequence_length,
            device=device,
            dtype=dtype,
        )

    return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask


def _normalize_latents(
    latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
    # Normalize latents across the channel dimension [B, C, F, H, W]
    latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
    latents = (latents - latents_mean) * scaling_factor / latents_std
    return latents


def _denormalize_latents(
    latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
    # Denormalize latents across the channel dimension [B, C, F, H, W]
    latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents * latents_std / scaling_factor + latents_mean
    return latents


def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
    # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
    # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
    # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
    # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
    batch_size, num_channels, num_frames, height, width = latents.shape
    post_patch_num_frames = num_frames // patch_size_t
    post_patch_height = height // patch_size
    post_patch_width = width // patch_size
    latents = latents.reshape(
        batch_size,
        -1,
        post_patch_num_frames,
        patch_size_t,
        post_patch_height,
        patch_size,
        post_patch_width,
        patch_size,
    )
    latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
    return latents


def _unpack_latents(
    latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
    # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
    # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
    # what happens in the `_pack_latents` method.
    batch_size = latents.size(0)
    latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
    latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
    return latents


def calculate_shift(
    image_seq_len,
    base_seq_len: int = 256,
    max_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.16,
):
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    mu = image_seq_len * m + b
    return mu


def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps


def getattr_recursive(obj: Any, path: str) -> Any:
    parts = path.split(".")
    for part in parts:
        if part.isnumeric():
            obj = obj[int(part)]
        else:
            obj = getattr(obj, part)
    return obj

class SimpleMapper(nn.Module):
    def __init__(self, n_layers, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.in_proj = nn.Linear(d_model, d_model)
        self.in_norm = RMSNorm(d_model)
        self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)])
        self.out_norm = RMSNorm(d_model)

    def forward(self, x):
        x = self.in_proj(x)
        x = self.in_norm(x)
        for block in self.blocks:
            x = block(x)
        x = self.out_norm(x)
        return x


class AggregateMapper(nn.Module):
    def __init__(self, motion_dim, depth=2):
        super().__init__()
        self.in_proj = nn.Linear(motion_dim, motion_dim)
        self.transformer = Transformer(
            in_features=None,
            out_features=None,
            width=motion_dim,
            depth=depth,
            layer_params={"pos_enc": "none"},
            # patch_params = {'in_proj': False, 'out_proj': False}
        )
        self.cls_token = nn.Embedding(1, motion_dim)

    def forward(self, x, pts):
        # bs, fs, ds = x.shape
        x = self.in_proj(x)
        x = x + get_1d_sincos_embed(x.shape[-1], pts, pos_multiplier=1000)
        x = self.transformer(self.cls_token.weight.unsqueeze(0).repeat(x.shape[0], 1, 1), extra_tokens=x)  # bs, 1, ds
        return x

class LTXVideoMotionAdapter(pl.LightningModule):
    def __init__(
        self,
        dismo_config_dir,
        dismo_ckpt_path,
        aggregate_mapping=False,
        lora_config=None,
        # motion_adapter_config = None,
        train_with_source_frame=True,
        train_with_text_prompts=False,
        cond_lora_dropout_p=0.05,
        # cond_cross_dropout_p = 0.05,
        caption_dropout_p=1,
        caption_dropout_technique="empty",
        compile=False,
        hf_path="Lightricks/LTX-Video",
    ):
        super().__init__()
        # Load required pre-trained modules, disable gradient computations, and set to eval mode
        self.vae: AutoencoderKLLTXVideo = AutoencoderKLLTXVideo.from_pretrained(
            hf_path, subfolder="vae", torch_dtype=torch.bfloat16
        )
        self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(hf_path, subfolder="tokenizer")
        self.text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(
            hf_path, subfolder="text_encoder", torch_dtype=torch.bfloat16
        )
        self.transformer: LTXVideoTransformer3DModel = LTXVideoTransformer3DModel.from_pretrained(
            hf_path, subfolder="transformer", torch_dtype=torch.bfloat16
        )
        scheduler = FlowMatchEulerDiscreteScheduler()
        self.num_train_timesteps = scheduler.config.num_train_timesteps
        self.sigmas = scheduler.sigmas.clone()

        self.vae.requires_grad_(False)
        self.vae.eval()
        self.vae.train = lambda x: x
        self.text_encoder.requires_grad_(False)
        self.text_encoder.eval()
        self.text_encoder.train = lambda x: x
        self.transformer.requires_grad_(False)
        self.transformer.train()

        compile = False
        if compile:
            self.transformer = torch.compile(self.transformer, fullgraph=True, mode="default")
        self.dismo: MotionRepresentationLearner = None

        def cfg_cb(cfg):
            cfg.model.reconstructor.checkpoint = None

        self.dismo, _ = load_model_inference_direct(
            dismo_config_dir, dismo_ckpt_path, config_callback=cfg_cb, initialize_deepspeed=False, strict=False
        )
        self.dismo.requires_grad_(False)
        self.dismo.eval()
        self.dismo.train = lambda x: x

        # Mapping network
        motion_dim = self.dismo.sequence_embedder.mid_split.proj.out_features * 2
        self.data_provider = DataProvider()

        self.trainable_targets = set()

        # Create and insert lora modules
        self.has_conditional_lora = False
        self.aggregate_mapping = aggregate_mapping
        if lora_config is not None:
            if not self.aggregate_mapping:
                self.cond_lora_mapper = SimpleMapper(n_layers=2, d_model=motion_dim, d_ff=3 * motion_dim, dropout=0)
            else:
                self.cond_lora_mapper = AggregateMapper(motion_dim, depth=4)
            self.init_lora_modules(c_dim=motion_dim, **lora_config)
            for k in self.cond_lora_mapper.state_dict().keys():
                self.trainable_targets.add("cond_lora_mapper." + k)
            self.cond_lora_dropout_p = cond_lora_dropout_p
            if cond_lora_dropout_p > 0:
                self.cond_lora_cfg_token = torch.nn.Embedding(1, motion_dim)
                self.trainable_targets.add("cond_lora_cfg_token.weight")

            self.no_motion_emb_token = torch.nn.Embedding(1, motion_dim // 2)

        # Misc
        self.train_with_source_frame = train_with_source_frame
        self.train_with_text_prompts = train_with_text_prompts
        self.caption_dropout_p = caption_dropout_p
        self.caption_dropout_technique = caption_dropout_technique

    def init_lora_modules(
        self,
        c_dim: int,
        conditional_targets: list = [],
        unconditional_targets: list = [],
        ignore_targets: list = [],
        **kwargs,
    ):
        has_conditional_loras = False
        sd = self.transformer.state_dict()
        for path, w in sd.items():
            if any([t in path for t in ignore_targets]):
                continue

            if not any([t in path for t in (conditional_targets + unconditional_targets)]):
                continue

            if path.endswith("bias"):
                # we handle that during the 'weight' part
                continue

            parent_path = ".".join(path.split(".")[:-2])
            target_path = ".".join(path.split(".")[:-1])
            target_name = path.split(".")[-2]
            parent_module = getattr_recursive(self.transformer, parent_path)
            target_module = getattr_recursive(self.transformer, target_path)

            has_bias_term = f"{target_path}.bias" in sd
            conditional = any([t in path for t in conditional_targets])
            conditional = conditional and not any([t in path for t in unconditional_targets])
            has_conditional_loras = has_conditional_loras or conditional

            lora_module = SimpleLoraLinear(
                out_features=target_module.out_features,
                in_features=target_module.in_features,
                c_dim=c_dim,
                data_provider=self.data_provider,
                with_conditioning=conditional,
                base_bias=has_bias_term,
                lora_bias=False,
                frozen_weights_dtype=w.dtype,
                target_path=target_path,
                depth=None, 
                **kwargs,
            )
            W_sd = {"weight": w}
            if has_bias_term:
                W_sd["bias"] = sd[f"{target_path}.bias"]
            lora_module.W.load_state_dict(W_sd)

            setattr(parent_module, target_name, lora_module)

            lora_keys = set(lora_module.state_dict().keys())
            lora_keys.remove("W.weight")
            if has_bias_term:
                lora_keys.remove("W.bias")

            for lora_key in lora_keys:
                self.trainable_targets.add("transformer." + target_path + "." + lora_key)

        self.has_conditional_lora = has_conditional_loras

    def get_conditioning(self, videos=None, motion_embs=None, video_fps=24, motion_fps=6):
        if motion_embs is None:
            # videos: b, f, c, h, w
            bs = videos.shape[0]
            n_frames = videos.shape[1]  # should be 29, when motion extractor expects 8 frames
            n_motion_frames = 8 
            latent_fps = video_fps / 8

            # compute motion frame indices
            motion_frame_idcs = torch.arange(n_motion_frames).mul(video_fps / motion_fps).round().long()
            assert (
                motion_frame_idcs[-1] < n_frames
            ), f"not enough video frames for extracting {n_motion_frames} motion frames"

            # gather, resize and normalize motion frames
            motion_frames = videos[:, motion_frame_idcs].float().div(255)
            motion_frames = F.interpolate(
                motion_frames.flatten(0, 1), size=(256, 256), mode="bilinear", align_corners=False
            ).view(bs, n_motion_frames, 3, 256, 256)
            motion_frames = motion_frames.sub(0.5).mul(2)
            motion_frames = rearrange(motion_frames, "b f c h w -> b c f h w")

            # extract motion embeddings
            dtype = motion_frames.dtype
            dismo_dtype = self.dismo.sequence_embedder.mid_split.proj.weight.dtype
            with torch.no_grad():
                motion_embs = self.dismo.get_motion_embeddings(motion_frames.to(dismo_dtype), training_enabled=False)
            motion_embs = motion_embs.detach().permute(0, 2, 1).to(dtype)
            print(motion_embs.shape)

        # compute lora conditionings from motion embeddings
        if videos is None:
            bs = 1
        print(bs, motion_embs.shape)
        motion_embs = motion_embs[:, :-1]
        motion_embs = torch.cat([self.no_motion_emb_token.weight[None, :, :].expand(bs, 1, -1), motion_embs], dim=1)
        motion_embs = motion_embs.reshape(bs, 2, motion_embs.shape[1] // 2, -1)
        motion_embs = motion_embs.permute(0, 2, 1, 3).flatten(-2, -1)
        cond_lora = self.cond_lora_mapper(motion_embs)

        return None, cond_lora

    def training_step(self, batch):
        videos = torch.stack(batch["frames"]["data"])
        # videos = batch["x"].permute(0, 2, 1, 3, 4)  # ---> b, f, c, h, w

        # 1. Embed motion embeddings
        cond_cross, cond_lora = self.get_conditioning(videos=videos)

        # 1.1 Trim video
        n_frames = (((videos.shape[1] - 1) // 8) * 8) + 1
        videos = videos[:, :n_frames]

        if cond_cross is not None and self.cond_cross_dropout_p > 0:
            bs, fs, _ = cond_cross.shape
            drop_idcs = torch.rand(bs * fs) <= self.cond_cross_dropout_p
            cond_cross = rearrange(cond_cross, "b f d -> (b f) d")
            cond_cross[drop_idcs] = self.cond_cross_cfg_token.weight.to(cond_cross.dtype)
            cond_cross = rearrange(cond_cross, "(b f) d -> b f d", b=bs, f=fs)

        if cond_lora is not None and self.cond_lora_dropout_p > 0:
            bs, fs, _ = cond_lora.shape
            drop_idcs = torch.rand(bs) <= self.cond_lora_dropout_p
            # cond_lora = rearrange(cond_lora, "b f d -> (b f) d")
            cond_lora[drop_idcs] = self.cond_lora_cfg_token.weight.unsqueeze(1).to(cond_lora.dtype)
            # cond_lora = rearrange(cond_lora, "(b f) d -> b f d", b=bs, f=fs)

        # 2. Embed text prompts
        txt_key = None
        for k in ["txt", "caption"]:
            if k in batch:
                txt_key = k
                break

        if self.train_with_text_prompts and txt_key is not None:
            # prompts = [("" if random.random() <= self.caption_dropout_p else txt) for txt in batch[txt_key]]
            prompts = []
            for txt in batch[txt_key]:
                if random.random() <= self.caption_dropout_p:
                    prompts.append("")
                else:
                    if isinstance(txt, (tuple, list)):
                        txt = bytes(txt)
                    prompts.append(txt.decode())
        else:
            prompts = ["" for _ in range(videos.shape[0])]
        (
            prompt_embeds,
            prompt_attention_mask,
            negative_prompt_embeds,
            negative_prompt_attention_mask,
        ) = encode_prompt(
            tokenizer=self.tokenizer,
            text_encoder=self.text_encoder,
            prompt=prompts,
            negative_prompt=None,
            do_classifier_free_guidance=False,
            num_videos_per_prompt=1,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            prompt_attention_mask=None,
            negative_prompt_attention_mask=None,
            max_sequence_length=128,
            device=videos.device,
        )

        # 3. Compute video latents
        latents = self.vae.encode(
            videos.permute(0, 2, 1, 3, 4).float().div(127.5).sub(1).to(self.vae.dtype)
        ).latent_dist.sample()
        latents = latents.to(dtype=self.vae.dtype)
        _, _, latent_num_frames, latent_height, latent_width = latents.shape
        latents = _normalize_latents(latents, self.vae.latents_mean, self.vae.latents_std)

        # 4. Compute sigmas
        u = torch.rand(size=(latents.shape[0],), device=latents.device)
        indices = (u * self.num_train_timesteps).long()
        sigmas = self.sigmas.clone().to(device=latents.device, dtype=torch.float32)[indices]
        timesteps = (sigmas * 1000.0).long()

        # 5. Compute noisy latents
        noise = torch.randn(
            latents.shape,
            device=latents.device,
            dtype=self.transformer.dtype,
        )
        sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)
        noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
        noisy_latents = noisy_latents.to(latents.dtype)

        # 6. (optional) Keep clean first frame as conditioning
        if self.train_with_source_frame:
            conditioning_mask = torch.zeros(
                (latents.shape[0], 1, latent_num_frames, latent_height, latent_width),
                device=latents.device,
                dtype=latents.dtype,
            )
            conditioning_mask[:, :, 0] = 1.0
            noisy_latents = latents * conditioning_mask + noisy_latents * (1 - conditioning_mask)
            conditioning_mask = _pack_latents(
                conditioning_mask, self.transformer.config.patch_size, self.transformer.config.patch_size_t
            ).squeeze(-1)
            timesteps = timesteps.unsqueeze(-1) * (1 - conditioning_mask)

        noisy_latents = _pack_latents(
            noisy_latents,
            patch_size=self.transformer.config.patch_size,
            patch_size_t=self.transformer.config.patch_size_t,
        ).squeeze(-1)

        # 7. Forward model
        # video_fps = 1 / (pts[:, 1:] - pts[:, :-1]).mean(dim=1)
        # avg_video_fps = video_fps.mean().item()
        video_fps = 24
        avg_video_fps = video_fps
        avg_latent_fps = avg_video_fps / 8
        spatial_compression_ratio = 32
        rope_interpolation_scale = [1 / avg_latent_fps, spatial_compression_ratio, spatial_compression_ratio]

        self.data_provider.set(
            cond_cross=cond_cross,
            cond_cross_fps=video_fps,
            cond_lora=cond_lora,
            latent_num_frames=latent_num_frames,
            latent_height=latent_height,
            latent_width=latent_width,
            latent_fps=video_fps / 8,
        )

        pred = self.transformer(
            hidden_states=noisy_latents,
            encoder_hidden_states=prompt_embeds,
            timestep=timesteps,
            encoder_attention_mask=prompt_attention_mask,
            num_frames=latent_num_frames,
            height=latent_height,
            width=latent_width,
            rope_interpolation_scale=rope_interpolation_scale,
            return_dict=False,
        )[0]

        self.data_provider.reset()

        pred = _unpack_latents(
            pred,
            latent_num_frames,
            latent_height,
            latent_width,
            self.transformer.config.patch_size,
            self.transformer.config.patch_size_t,
        )

        # 8. Compute loss
        target = noise - latents
        # weights = compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme='none')
        # weights = expand_tensor_dims(weights, noise.ndim)
        if self.train_with_source_frame:
            pred = pred[:, :, 1:]
            target = target[:, :, 1:]
        loss = (pred.float() - target.float()).pow(2)
        # Average loss across all but batch dimension
        loss = loss.mean(list(range(1, loss.ndim)))
        # Average loss across batch dimension
        loss = loss.mean()

        self.log("loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=False)
        return loss

    @torch.no_grad()
    def sample(
        self,
        num_samples=None,
        driving_video=None,
        motion_embeddings=None,
        # driving_pts=None,
        src_images=None,
        width=None,
        height=None,
        text_prompts=None,
        negative_text_prompt=None,
        same_noise_per_sample=True,
        n_steps=50,
        motion_cfg_scale=1.0,
        text_cfg_scale=1.0,
        generator=None,
        progress_bar=True,
        motion_callback=None,
    ) -> torch.Tensor:
        # 0. Init stuff
        if src_images is not None and len(src_images.shape) == 3:
            src_images = src_images[None, :]

        if isinstance(text_prompts, str):
            text_prompts = [text_prompts]

        if text_prompts is not None and src_images is not None:
            assert len(text_prompts) == 1 or src_images.shape[0] == 1 or len(text_prompts) == src_images.shape[0]

        bs = num_samples or 1
        if src_images is not None and src_images.shape[0] > 1:
            bs = src_images.shape[0]
        elif text_prompts is not None and len(text_prompts) > 1:
            bs = len(text_prompts)

        width = (src_images.shape[-1] if src_images is not None else width) or 704
        height = (src_images.shape[-2] if src_images is not None else height) or 512

        assert (
            motion_embeddings is None or driving_video is None
        ), "driving_video and motion_embeddings cannot be set at the same time"

        if driving_video is not None:
            driving_video = driving_video.unsqueeze(0)
            video_fps = 24
            avg_video_fps = video_fps
            num_frames = (((driving_video.shape[1] - 1) // 8) * 8) + 1
        elif motion_embeddings is not None:
            motion_embeddings = motion_embeddings.unsqueeze(0)
            video_fps = 24
            avg_video_fps = video_fps
            num_frames = ((((motion_embeddings.shape[1] * 4) - 1) // 8) * 8) + 1
        else:
            video_fps = 24
            avg_video_fps = 24
            # num_frames = 81

        latent_height = height // 32
        latent_width = width // 32
        latent_num_frames = (num_frames - 1) // 8 + 1

        avg_latent_fps = avg_video_fps / 8
        spatial_compression_ratio = 32
        rope_interpolation_scale = [1 / avg_latent_fps, spatial_compression_ratio, spatial_compression_ratio]

        # 1. Embed motion
        if driving_video is not None or motion_embeddings is not None:
            cond_cross, cond_lora = self.get_conditioning(videos=driving_video, motion_embs=motion_embeddings)
        else:
            cond_lora = (
                self.cond_lora_cfg_token.weight.unsqueeze(0).expand(bs, latent_num_frames, -1)
                if self.has_conditional_lora
                else None
            )
            cond_cross = (
                self.cond_cross_cfg_token.weight.unsqueeze(0).expand(bs, num_frames, -1)
                if self.has_motion_adapter
                else None
            )

        # # 1.1 Trim video
        # if driving_video is not None:
        #     driving_video = driving_video[:, :num_frames]

        # 2. Embed prompt
        (
            prompt_embeds,
            prompt_attention_mask,
            negative_prompt_embeds,
            negative_prompt_attention_mask,
        ) = encode_prompt(
            tokenizer=self.tokenizer,
            text_encoder=self.text_encoder,
            prompt=text_prompts or [""],
            negative_prompt=negative_text_prompt,
            do_classifier_free_guidance=True,
            num_videos_per_prompt=1,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            prompt_attention_mask=None,
            negative_prompt_attention_mask=None,
            max_sequence_length=128,
            device=self.device,
        )
        if prompt_embeds.shape[0] == 1 and bs > 1:
            prompt_embeds = prompt_embeds.repeat(bs, 1, 1)
            prompt_attention_mask = prompt_attention_mask.repeat(bs, 1, 1)
        # if len(negative_prompt_attention_mask.shape) == 2:
        #     negative_prompt_attention_mask = negative_prompt_attention_mask[None, :]
        if negative_prompt_embeds.shape[0] == 1 and bs > 1:
            negative_prompt_embeds = negative_prompt_embeds.repeat(bs, 1, 1)
            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(bs, 1, 1)

        # 3. Compute latents of all videos including their source frame
        init_latents = torch.randn(
            (1 if same_noise_per_sample else bs, 128, latent_num_frames, latent_height, latent_width),
            device=self.device,
            dtype=self.transformer.dtype,
            generator=generator,
        )
        if same_noise_per_sample:
            init_latents = init_latents.repeat(bs, 1, 1, 1, 1)

        conditioning_mask = None

        if src_images is not None:
            src_frame_latents = self.vae.encode(
                src_images.unsqueeze(2).float().div(127.5).sub(1).to(self.vae.dtype)
            ).latent_dist.sample()
            src_frame_latents = src_frame_latents.to(dtype=self.vae.dtype)
            src_frame_latents = _normalize_latents(src_frame_latents, self.vae.latents_mean, self.vae.latents_std)
            src_frame_latents = src_frame_latents.repeat(1, 1, latent_num_frames, 1, 1)
            conditioning_mask = torch.zeros(
                (bs, 1, latent_num_frames, latent_height, latent_width), device=self.device, dtype=self.vae.dtype
            )
            conditioning_mask[:, :, 0] = 1.0
            init_latents = src_frame_latents * conditioning_mask + init_latents * (1 - conditioning_mask)
            conditioning_mask = _pack_latents(
                conditioning_mask, self.transformer.config.patch_size, self.transformer.config.patch_size_t
            ).squeeze(-1)

        init_latents = _pack_latents(
            init_latents, self.transformer.config.patch_size, self.transformer.config.patch_size_t
        )

        # 4. Sampling loop
        apply_cfg = motion_cfg_scale != 1 or text_cfg_scale != 1
        scheduler = FlowMatchEulerDiscreteScheduler()
        num_inference_steps = n_steps
        video_sequence_length = latent_num_frames * latent_height * latent_width
        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
        mu = calculate_shift(
            video_sequence_length,
            scheduler.config.get("base_image_seq_len", 256),
            scheduler.config.get("max_image_seq_len", 4096),
            scheduler.config.get("base_shift", 0.5),
            scheduler.config.get("max_shift", 1.16),
        )
        # timesteps = [1000, 980, 960, ...]
        timesteps, num_inference_steps = retrieve_timesteps(
            scheduler,
            num_inference_steps,
            init_latents.device,
            None,
            sigmas=sigmas,
            mu=mu,
        )

        if cond_cross is not None:
            cond_cross_ = cond_cross.repeat(init_latents.shape[0], 1, 1)
        if cond_lora is not None:
            cond_lora_ = cond_lora.repeat(init_latents.shape[0], 1, 1)

        if apply_cfg:
            # [motion, text], [motion, -], [-, text]
            conditioning_mask_ = (
                torch.cat([conditioning_mask, conditioning_mask, conditioning_mask])
                if conditioning_mask is not None
                else None
            )
            prompt_embeds_ = torch.cat([prompt_embeds, negative_prompt_embeds, prompt_embeds], dim=0)
            prompt_attention_mask_ = torch.cat(
                [prompt_attention_mask, negative_prompt_attention_mask, prompt_attention_mask], dim=0
            )
            if cond_cross is not None:
                neg_cond_cross = self.cond_cross_cfg_token.weight.unsqueeze(0).expand(*cond_cross_.shape)
                cond_cross_ = torch.cat([cond_cross_, cond_cross_, neg_cond_cross])
            if cond_lora is not None:
                neg_cond_lora = self.cond_lora_cfg_token.weight.unsqueeze(0).expand(*cond_lora_.shape)
                cond_lora_ = torch.cat([cond_lora_, cond_lora_, neg_cond_lora])
        else:
            conditioning_mask_ = conditioning_mask.clone() if conditioning_mask is not None else None
            prompt_embeds_ = prompt_embeds.clone()
            prompt_attention_mask_ = prompt_attention_mask.clone()

        latents = init_latents.clone()

        self.data_provider.set(
            cond_cross=cond_cross_ if cond_cross is not None else None,
            cond_cross_fps=video_fps,
            cond_lora=cond_lora_ if cond_lora is not None else None,
            latent_num_frames=latent_num_frames,
            latent_height=latent_height,
            latent_width=latent_width,
            latent_fps=video_fps / 8,
        )

        for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=not progress_bar):
            latent_model_input = torch.cat([latents] * 3) if apply_cfg else latents
            latent_model_input = latent_model_input.to(prompt_embeds_.dtype)

            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            timestep = t.expand(latent_model_input.shape[0])
            if conditioning_mask_ is not None:
                timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask_)

            noise_pred = self.transformer(
                hidden_states=latent_model_input,
                encoder_hidden_states=prompt_embeds_,
                timestep=timestep,
                encoder_attention_mask=prompt_attention_mask_,
                num_frames=latent_num_frames,
                height=latent_height,
                width=latent_width,
                rope_interpolation_scale=rope_interpolation_scale,
                # attention_kwargs=attention_kwargs,
                return_dict=False,
            )[0]
            noise_pred = noise_pred.float()

            if apply_cfg:
                noise_pred_motion_text, noise_pred_motion_uncond, noise_pred_uncond_text = noise_pred.chunk(3)
                if motion_cfg_scale == 1:
                    noise_pred = noise_pred_motion_uncond + text_cfg_scale * (
                        noise_pred_motion_text - noise_pred_motion_uncond
                    )
                elif text_cfg_scale == 1:
                    noise_pred = noise_pred_uncond_text + motion_cfg_scale * (
                        noise_pred_motion_text - noise_pred_uncond_text
                    )
                else:
                    noise_pred = (
                        ((1 + motion_cfg_scale + text_cfg_scale) * noise_pred_motion_text)
                        - (text_cfg_scale * noise_pred_motion_uncond)
                        - (motion_cfg_scale * noise_pred_uncond_text)
                    )
                timestep, _, _ = timestep.chunk(3)

            # compute the previous noisy sample x_t -> x_t-1
            noise_pred = _unpack_latents(
                noise_pred,
                latent_num_frames,
                latent_height,
                latent_width,
                self.transformer.config.patch_size,
                self.transformer.config.patch_size_t,
            )
            latents = _unpack_latents(
                latents,
                latent_num_frames,
                latent_height,
                latent_width,
                self.transformer.config.patch_size,
                self.transformer.config.patch_size_t,
            )

            if conditioning_mask_ is not None:
                noise_pred = noise_pred[:, :, 1:]
                noise_latents = latents[:, :, 1:]
                pred_latents = scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
                latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
            else:
                latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

            latents = _pack_latents(latents, self.transformer.config.patch_size, self.transformer.config.patch_size_t)

        self.data_provider.reset()

        # decode generated latents to pixels
        latents = _unpack_latents(
            latents,
            latent_num_frames,
            latent_height,
            latent_width,
            self.transformer.config.patch_size,
            self.transformer.config.patch_size_t,
        )
        latents = _denormalize_latents(
            latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
        )
        latents = latents.to(prompt_embeds.dtype)

        gen_videos = self.vae.decode(latents, None, return_dict=False)[0]
        gen_videos = gen_videos.float().add(1).mul(127.5).round().clamp(0, 255).byte()
        gen_videos = gen_videos.permute(0, 2, 1, 3, 4)  # b c f h w -> b f c h w
        return gen_videos

    def validation_step(self, batch, batch_idx):
        videos = torch.stack(batch["frames"]["data"])
        bs = videos.shape[0]
        # bs = batch["x"].shape[0]
        # videos = batch["x"].permute(0, 2, 1, 3, 4)  # ---> b, f, c, h, w
        src_images = videos[:, 0]
        motion_cfg_scales = [1, 4]

        txt_key = None
        for k in ["txt", "caption"]:
            if k in batch:
                txt_key = k
                break

        if self.train_with_text_prompts and txt_key is not None:
            prompts = []
            for txt in batch[txt_key]:
                if isinstance(txt, (tuple, list)):
                    txt = bytes(txt)
                prompts.append(txt.decode())
        else:
            prompts = None

        rows = []
        for driving_idx in tqdm(range(bs), total=bs, desc="Transferring motion"):
            entries = []
            for motion_cfg_scale in motion_cfg_scales:
                cfg_entries = (
                    self.sample(
                        driving_video=videos[driving_idx],
                        src_images=src_images,
                        text_prompts=prompts,
                        n_steps=50,
                        motion_cfg_scale=motion_cfg_scale,
                        generator=torch.Generator(src_images.device).manual_seed(driving_idx),
                        progress_bar=False,
                    )
                    .detach()
                    .cpu()
                )  # b f c h w
                try:
                    cfg_entries = [
                        add_text_to_video(
                            entry, text=f"cfg={motion_cfg_scale:.1f}", xy=(20, 20), size=50, color=(255, 0, 0)
                        )
                        for entry in cfg_entries
                    ]
                except:
                    cfg_entries = [entry for entry in cfg_entries]
                entries.append(cfg_entries)

            cfg_grids = []
            for k in range(len(entries[0])):
                cfg_grids.append(
                    make_video_grid([entries[j][k] for j in range(len(entries))], nrow=len(entries), padding=5)
                )
            row = make_video_grid(cfg_grids, nrow=len(cfg_grids), padding=15)
            driving_vid = videos[driving_idx].detach().cpu()
            height_diff = row.shape[-2] - driving_vid.shape[-2]
            driving_vid = pad(driving_vid, (0, 10, math.floor(height_diff / 2), math.ceil(height_diff / 2)))
            driving_vid = driving_vid[: row.shape[0]]
            row = torch.cat([driving_vid, row], dim=-1)
            rows.append(row)

        video_out = make_video_grid(rows, nrow=1, padding=10)
        
        # resize for easier handling afterwards
        ratio = 1920 / max(video_out.shape[2], video_out.shape[3])
        if ratio < 1:
            video_out = tv_resize(
                video_out, size=(int(video_out.shape[2] * ratio), int(video_out.shape[3] * ratio)), antialias=True
            )
        if video_out.shape[-1] % 2 == 1:
            video_out = video_out[:, :, :, :-1]
        if video_out.shape[-2] % 2 == 1:
            video_out = video_out[:, :, :-1]

        video_out = video_out.detach().cpu()

        # store video
        is_rank_zero = self.trainer.global_rank == 0
        if is_rank_zero:
            save_dir = Path(self.logger.save_dir)
            save_dir = save_dir.joinpath("videos")
            save_dir.mkdir(parents=True, exist_ok=True)
            write_video(
                filename=str(save_dir.joinpath(f"{self.trainer.global_step}_{batch_idx}.mp4")),
                video_array=video_out.permute(0, 2, 3, 1),
                fps=24,
                bitrate=5000,
            )
            print(f"Saved video to {save_dir}")

    def on_validation_start(self):
        gc.collect()
        torch.cuda.empty_cache()

    def on_validation_end(self):
        gc.collect()
        torch.cuda.empty_cache()

    def on_save_checkpoint(self, checkpoint):
        sd = checkpoint["state_dict"]
        sd_keys = list(sd.keys())
        for k in sd_keys:
            if k not in self.trainable_targets:
                del sd[k]
    
    def configure_optimizers(self):
        trainable_params = []
        for k, p in self.named_parameters():
            if k in self.trainable_targets:
                trainable_params.append(p)

        opt = torch.optim.AdamW(
            trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.99), amsgrad=False
        )

        lr_schedulers = []
        milestones = []
        if self.lr_schedule_params is not None:
            if "warmup_steps" in self.lr_schedule_params and self.lr_schedule_params["warmup_steps"] > 0:
                warmup_steps = self.lr_schedule_params["warmup_steps"]
                lr_schedulers.append(
                    lr_scheduler.LinearLR(opt, start_factor=0.0001, end_factor=1.0, total_iters=warmup_steps)
                )
                milestones = [warmup_steps]

            decay = "decay_start" in self.lr_schedule_params and self.lr_schedule_params["decay_start"] > 0

            if decay:
                n_steps = self.lr_schedule_params["decay_start"]
                milestones.append(n_steps)
            else:
                n_steps = self.max_training_steps

            if len(milestones) > 0:
                n_steps -= milestones[-1]

            if n_steps > 0:
                lr_schedulers.append(lr_scheduler.ConstantLR(opt, factor=1.0, total_iters=n_steps))

            if decay:
                assert "decay_scheme" in self.lr_schedule_params
                scheme = self.lr_schedule_params["decay_scheme"]
                n_steps = self.max_training_steps - milestones[-1]
                if scheme == "cosine":
                    lr_schedulers.append(
                        lr_scheduler.CosineAnnealingLR(
                            opt, T_max=n_steps, eta_min=self.learning_rate * self.lr_schedule_params["lr_end_factor"]
                        )
                    )
                elif scheme == "linear":
                    lr_schedulers.append(
                        lr_scheduler.LinearLR(
                            opt,
                            start_factor=1.0,
                            end_factor=self.lr_schedule_params["lr_end_factor"],
                            total_iters=n_steps,
                        )
                    )
                elif scheme == "steplr":
                    lr_schedulers.append(
                        lr_scheduler.StepLR(
                            opt, step_size=self.lr_schedule_params["step_size"], gamma=self.lr_schedule_params["gamma"]
                        )
                    )

        return [opt], [
            {
                "scheduler": lr_scheduler.SequentialLR(opt, lr_schedulers, milestones),
                "interval": "step",
                "name": "learning_rate",
            }
        ]


from ltx_train import instantiate_from_config


def load_ltx(ltx_dir, ltx_ckpt, dismo_config_dir=None, freeze=True):
    ltx_dir = Path(ltx_dir)
    ckpt_path = ltx_dir.joinpath("checkpoints").joinpath(ltx_ckpt + ".ckpt")
    cfg_path = ltx_dir.joinpath("configs", "config.yaml")
    assert ckpt_path.exists(), f"Checkpoint not found: {ckpt_path}"
    assert cfg_path.exists(), f"Config not found: {cfg_path}"
    cfg = OmegaConf.load(cfg_path)
    if dismo_config_dir is not None:
        dismo_rel_ckpt_path = Path(cfg.model.params.dismo_ckpt_path).relative_to(cfg.model.params.dismo_config_dir)
        cfg.model.params.dismo_config_dir = str(Path(dismo_config_dir))
        cfg.model.params.dismo_ckpt_path = str(Path(dismo_config_dir).joinpath(dismo_rel_ckpt_path))
    ltx_video: LTXVideoMotionAdapter = instantiate_from_config(cfg.model)
    ckpt = torch.load(str(ckpt_path), weights_only=False, map_location="cpu")
    ltx_video.on_load_checkpoint(ckpt)
    missing_keys, unexpected_keys = ltx_video.load_state_dict(ckpt["state_dict"], strict=False)
    assert len(unexpected_keys) == 0, f"Unexpected keys in checkpoint: {missing_keys}"
    ltx_video = ltx_video.eval()
    if freeze:
        ltx_video.train = lambda x: x
        ltx_video.requires_grad_(False)
    return ltx_video
