import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

from vace.models.wan.configs import WAN_CONFIGS
from vace.models.wan.modules.model import VaceWanModel
import torch
import csv
import deepspeed
from torch.utils.data import DistributedSampler, DataLoader
import argparse
import json
import os
from wan.text2video import (
    FlowUniPCMultistepScheduler,
)
from datetime import timedelta
from vae import WanVAE

from PIL import Image
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from tqdm import tqdm
import math
import torch.nn as nn
import torch.distributed as dist

from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import export_to_video
from datetime import datetime
import random

from typing import Optional

from kaolin.render.camera import generate_rotate_translate_matrices
from einops import rearrange
from dataset_utils import (
    shift_dilate_mask,
    save_video,
    load_video,
    bbox_mask,
    center_and_resize,
    shift_mask,
    scale_video_mask_CFHW,
)
from torchvision.transforms import Compose
from torchvision import transforms
from PIL import Image
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Tuple, Optional, Union
from pathlib import Path


class ImageDataset(Dataset):
    def __init__(
        self,
        jsonl_path=None,
        resolution: Tuple[int, int] = (640, 960),
        iou_threshold=0.9,
    ):

        self.resolution = resolution
        self.transform = transforms.Compose(
            [
                transforms.Resize((resolution[0], resolution[1])),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )
        self.mask_transform = transforms.Compose(
            [
                transforms.Resize((resolution[0], resolution[1])),
            ]
        )
        self.pairs = []
        self.iou_threshold = iou_threshold

        with open(jsonl_path, "r") as f:
            for line in f:
                data = json.loads(line.strip())
                if (
                    data["left"]["iou"] > self.iou_threshold
                    and data["right"]["iou"] > self.iou_threshold
                ):
                    self.pairs.append(data)

        print(f"image: {len(self.pairs)}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        entry = self.pairs[idx]

        image_path = entry["image_path"]
        bg_path = entry["bg_image_pth"]
        src_view = [
            entry["left"]["yaw"],
            entry["left"]["pitch"],
            entry["left"]["dist"],
            entry["left"]["rx"],
            entry["left"]["ry"],
        ]
        tgt_view = [
            entry["right"]["yaw"],
            entry["right"]["pitch"],
            entry["right"]["dist"],
            entry["right"]["rx"],
            entry["right"]["ry"],
        ]

        src_view = torch.tensor(src_view, dtype=torch.float32)
        tgt_view = torch.tensor(tgt_view, dtype=torch.float32)

        result = self._load_aligned_quads(image_path,bg_path)

        result.update(
            {
                "src_view": src_view,
                "tgt_view": tgt_view,
            }
        )
        return result

    def _load_aligned_quads(self, image_path,bg_path) -> Dict[str, torch.Tensor]:

        image = Image.open(image_path)

        image = TF.to_tensor(image)  # C, H, W

        h, w = image.shape[1:]
        mid_h, mid_w = h // 2, w // 2

        src_img = image[:, :mid_h, :mid_w]
        tgt_img = image[:, :mid_h, mid_w:]

        src_mask = image[:, mid_h:, :mid_w]
        tgt_mask = image[:, mid_h:, mid_w:]

        bg_img = Tf.to_tensor(Image.open(bg_path))

        bg_img = self.transform(bg_img).unsqueeze(1)  # C, F, H, W
        src_img = self.transform(src_img).unsqueeze(1)  # C, F, H, W
        tgt_img = self.transform(tgt_img).unsqueeze(1)  # C, F, H, W



        src_mask = self.mask_transform(src_mask).unsqueeze(1)  # C, F, H, W
        tgt_mask = self.mask_transform(tgt_mask).unsqueeze(1)  # C, F, H, W

        src_mask = (src_mask.mean(0).unsqueeze(0) > 0.1).float()

        tgt_mask = (tgt_mask.mean(0).unsqueeze(0) > 0.1).float()

        src_ref_image = src_img * src_mask + 1.0 * (1 - src_mask)

        tgt_ref_image = tgt_img * tgt_mask + 1.0 * (1 - tgt_mask)

        src_mask_square, bbox_rendered = bbox_mask(src_mask)

        centered_src_ref_img, success_1 = center_and_resize(
            src_ref_image,
            bbox_rendered[0, 0],
            resolution=self.resolution,
        )

        tgt_mask_square, bbox_rendered = bbox_mask(tgt_mask)

        centered_tgt_ref_img, success_2 = center_and_resize(
            tgt_ref_image,
            bbox_rendered[0, 0],
            resolution=self.resolution,
        )

        return {
            "src_video": src_img,
            "tgt_video": tgt_img,
            "src_mask": src_mask,
            "tgt_mask": tgt_mask,
            "src_mask_square": src_mask_square,
            "tgt_mask_square": tgt_mask_square,
            "src_ref_img": centered_src_ref_img,
            "tgt_ref_img": centered_tgt_ref_img,
            "bg_video": bg_img,
            "loss_mask": (
                torch.tensor([1.0]) if success_1 & success_2 else torch.tensor([0.0])
            ),
        }


class VideoDataset(Dataset):
    def __init__(
        self,
        rendered_csv=None,
        camera_csv=None,
        transform: Optional[Compose] = None,
        video_length: int = 61,
        frame_stride: int = 1,1
        resolution: Tuple[int, int] = (640, 960),
        random_offset: bool = False,
        iou_threshold=0.9,
    ):
        self.transform = transform
        self.video_length = video_length
        self.frame_stride = frame_stride
        self.resolution = resolution
        self.random_offset = random_offset
        self.iou_threshold = iou_threshold

        self.transform = transforms.Compose(
            [
                transforms.Resize((resolution[0], resolution[1])),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )
        self.mask_transform = transforms.Compose(
            [
                transforms.Resize((resolution[0], resolution[1])),
            ]
        )
        self.pairs = []
        self.load_pairs(camera_csv, rendered_csv)  

    def load_pairs(self, camera_csv, render_csv):

        data = {}
        with open(camera_csv, "r") as f:
            reader = csv.reader(f)
            for line in reader:
                video_path, mask_path = line[0], line[1]
                yaw, pitch, distance, rx, ry, iou = (
                    float(line[4]),
                    float(line[5]),
                    float(line[6]),
                    float(line[7]),
                    float(line[8]),
                    float(line[9]),
                )
                data[mask_path] = {
                    "tgt_view": [yaw, pitch, distance, rx, ry],
                    "iou": iou,
                }

        with open(render_csv, "r") as f:
            reader = csv.reader(f)
            for line in reader:
                video_path, mask_path, rm_path, render_img_path, render_mask_path = (
                    line[0],
                    line[1],
                    line[3],
                    line[4],
                    line[5],
                )
                yaw, pitch, distance, rx, ry = (
                    float(line[6]),
                    float(line[7]),
                    float(line[8]),
                    float(line[9]),
                    float(line[10]),
                )
                if mask_path in data and data[mask_path]["iou"] > 0.8:
                    entry = {
                        "tgt_video_path": video_path,
                        "tgt_mask_path": mask_path,
                        "rm_video_path": rm_path,
                        "render_img_path": render_img_path,
                        "render_mask_path": render_mask_path,
                        "src_view": [yaw, pitch, distance, rx, ry],
                        "tgt_view": data[mask_path]["tgt_view"],
                        "src_iou": data[mask_path]["iou"],
                    }

                    self.pairs.append(entry)

        print(f"video: {len(self.pairs)}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        entry = self.pairs[idx]

        tgt_video_path = entry["tgt_video_path"]
        tgt_mask_path = entry["tgt_mask_path"]
        rm_video_path = entry["rm_video_path"]
        render_img_path = entry["render_img_path"]
        render_mask_path = entry["render_mask_path"]
        src_view = torch.tensor(entry["src_view"])
        tgt_view = torch.tensor(entry["tgt_view"])

        # Load all four videos (source, source_mask, target, target_mask) with aligned frames
        result = self._load_aligned_quads(
            tgt_video_path,
            tgt_mask_path,
            rm_video_path,
            render_img_path,
            render_mask_path,
            src_view,
            tgt_view,
        )

        result.update(
            {
                "src_view": src_view,
                "tgt_view": tgt_view,
                "target_view_video": tgt_video_path,
            }
        )

        return result

    def _load_aligned_quads(
        self,
        tgt_video_path,
        tgt_mask_path,
        rm_video_path,
        render_img_path,
        render_mask_path,
        src_view,
        tgt_view,
    ) -> Dict[str, torch.Tensor]:
        start_idx = 0
        frame_indices = range(
            start_idx,
            start_idx + self.video_length * self.frame_stride,
            self.frame_stride,
        )
        rm_frames = load_video(rm_video_path, frame_indices)
        tgt_frames = load_video(tgt_video_path, frame_indices)
        tgt_masks = load_video(tgt_mask_path, frame_indices)  # F, C, H, W

        rm_frames = torch.stack([self.transform(f) for f in rm_frames], dim=1)
        tgt_frames = torch.stack([self.transform(f) for f in tgt_frames], dim=1)

        tgt_masks = torch.stack([self.mask_transform(f) for f in tgt_masks], dim=1)
        tgt_masks = (tgt_masks > 0.5).float()[:1, :, :, :]

        rendered_img = Image.open(render_img_path)
        rendered_mask = Image.open(render_mask_path)

        rendered_img = TF.to_tensor(rendered_img)  # C, H, W

        rendered_mask = TF.to_tensor(rendered_mask)  # C, H, W

        rendered_mask = (rendered_mask > 0.5).float()
 
        rendered_img = self.transform(rendered_img).unsqueeze(1)  # C, F, H, W
        rendered_mask = self.mask_transform(rendered_mask).unsqueeze(1)  # C, F, H, W

        src_frames = rm_frames * (1 - rendered_mask) + rendered_img * rendered_mask

        rendered_img = rendered_img * rendered_mask + (1 - rendered_mask)

        tgt_img = tgt_frames[:, :1, :, :] * tgt_masks[:, :1, :, :] + (
            1 - tgt_masks[:, :1, :, :]
        )

        rendered_mask_square, bbox_rendered = bbox_mask(rendered_mask)

        rendered_mask_square = rendered_mask_square.repeat(1, self.video_length, 1, 1)

        centered_img_rendered, success_1 = center_and_resize(
            rendered_img,
            bbox_rendered[0, 0],
            resolution=self.resolution,
        )

        tgt_mask_square, bbox_rendered = bbox_mask(tgt_masks)

        centered_img_tgt, success_2 = center_and_resize(
            tgt_img,
            bbox_rendered[0, 0],
            resolution=self.resolution,
        )


        return {
            "src_video": src_frames,
            "bg_video": rm_frames,
            "tgt_video": tgt_frames,
            "src_mask": rendered_mask.repeat(1, self.video_length, 1, 1),
            "src_mask_square": rendered_mask_square,
            "tgt_mask": tgt_masks,
            "tgt_mask_square": tgt_mask_square,
            "tgt_ref_img": centered_img_tgt,
            "src_ref_img": centered_img_rendered,
            "loss_mask": (
                torch.tensor([1.0]) if success_1 & success_2 else torch.tensor([0.0])
            ),
        }


class CameraPoseEncoder(nn.Module):
    def __init__(
        self,
        num_bands: int = 64,
        token_dim: int = 4096,
        normalize: bool = True,
        num_tokens: int = 8,
    ):

        super().__init__()
        self.num_bands = num_bands
        self.token_dim = token_dim
        self.normalize = normalize
        self.num_tokens = num_tokens

        freqs = 2.0 ** torch.arange(num_bands).float() * torch.pi
        self.register_buffer("freqs", freqs)  # (num_bands,)
        self.null_context = torch.zeros(
            (1, self.num_tokens, token_dim), dtype=torch.bfloat16
        )

        in_dim = 2 * num_bands
        self.param_projs = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(in_dim, 4 * in_dim),
                    nn.SiLU(),
                    nn.Linear(4 * in_dim, 16 * in_dim),
                    nn.SiLU(),
                    nn.Linear(16 * in_dim, token_dim),
                )
                for _ in range(self.num_tokens)
            ]
        )

    def forward(
        self,
        source_camera,
        target_camera,
    ) -> torch.Tensor:

        # print("source camera", source_camera.shape)
        yaw1, pitch1, d1, rx1, ry1 = source_camera.unbind(-1)
        yaw2, pitch2, d2, rx2, ry2 = target_camera.unbind(-1)

        B = yaw1.shape[0]
        device = yaw1.device

        target = torch.zeros(1, 3, device=device).expand(B, -1)
        up = torch.tensor([0.0, 1.0, 0.0], device=device).unsqueeze(0).expand(B, -1)

        x1 = d1 * torch.cos(pitch1) * torch.cos(yaw1)
        y1 = d1 * torch.sin(pitch1)
        z1 = d1 * torch.cos(pitch1) * torch.sin(yaw1)
        eye1 = torch.stack([x1, y1, z1], dim=1)  # [M,3]

        R1, T1 = generate_rotate_translate_matrices(eye1, target, up)

        x2 = d2 * torch.cos(pitch2) * torch.cos(yaw2)
        y2 = d2 * torch.sin(pitch2)
        z2 = d2 * torch.cos(pitch2) * torch.sin(yaw2)
        eye2 = torch.stack([x2, y2, z2], dim=1)  # [M,3]

        R2, T2 = generate_rotate_translate_matrices(eye2, target, up)

        # 3) relative rotation & translation
        R_rel = R2 @ R1.transpose(1, 2)  # (B,3,3)

        T_rel = T2 - torch.einsum("bij,bj->bi", R_rel, T1)

        # theta = arccos((tr(R)-1)/2)
        cos_theta = (R_rel.diagonal(dim1=1, dim2=2).sum(-1) - 1) * 0.5
        theta = torch.acos(cos_theta.clamp(-1 + 1e-6, 1 - 1e-6))  # (B,)
        # axis = 1/(2 sinθ) * [R32−R23, R13−R31, R21−R12]
        axis = torch.stack(
            [
                R_rel[:, 2, 1] - R_rel[:, 1, 2],
                R_rel[:, 0, 2] - R_rel[:, 2, 0],
                R_rel[:, 1, 0] - R_rel[:, 0, 1],
            ],
            dim=-1,
        )  # (B,3)
        axis = axis / (2 * torch.sin(theta).unsqueeze(-1).clamp_min(1e-6))
        aa = axis * theta.unsqueeze(-1)  # (B,3)

        # 6) optional normalization
        if self.normalize:
            aa = aa / torch.pi  # ~[-1,1]
            T_rel = 2 * T_rel / (d1 + d2).unsqueeze(-1).clamp_min(1e-6)

        feat = torch.cat([aa, T_rel, rx2-rx1, ry2-ry1], dim=-1)  # (B,8)

        freqs = self.freqs.view(1, 1, -1)  # (1,1,Bands)
        x = feat.unsqueeze(-1) * freqs  # (B,8,Bands)
        sincos = torch.cat([x.sin(), x.cos()], dim=-1)  # (B,8,2*Bands)

        tokens = torch.stack(
            [self.param_projs[i](sincos[:, i]) for i in range(self.num_tokens)], dim=1
        )

        return tokens


class ModelWrapper(nn.Module):
    def __init__(
        self,
        config,
        checkpoint_dir,
    ):
        super().__init__()
        self.config = config

        self.num_train_timesteps = config.num_train_timesteps
        self.param_dtype = config.param_dtype
        # print(config)

        self.camera_encoder = CameraPoseEncoder(
            num_bands=64,
            token_dim=4096,
            num_tokens=8,
        )

        self.vae_stride = config.vae_stride  # 4,8,8
        self.patch_size = config.patch_size  # 1,2,2

        self.model = VaceWanModel.from_pretrained(checkpoint_dir)

        self.mask_projection = nn.Linear(128, 64)

        self.scheduler = UniPCMultistepScheduler(
            num_train_timesteps=1000,
            use_flow_sigmas=True,
            flow_shift=1.0,
        )
        # self.sample_neg_prompt = config.sample_neg_prompt

    def decode_latent(self, zs, ref_images=None):
        if ref_images is None:
            ref_images = [None] * len(zs)
        else:
            assert len(zs) == len(ref_images)

        trimed_zs = []
        for z, refs in zip(zs, ref_images):
            if refs is not None:
                z = z[:, len(refs) :, :, :]
            trimed_zs.append(z)

        return self.vae.decode(trimed_zs)

    def forward(
        self,
        z,
        x0,
        source_camera,
        target_camera,
        context_scale=1.0,
    ):
        context = self.camera_encoder(source_camera, target_camera)

        drop_mask = torch.rand(x0.shape[0], device=x0.device) < 0.1
        context[drop_mask] = self.camera_encoder.null_context.clone().cuda()

        noise = torch.randn_like(x0)

        seq_len = math.ceil(
            x0.shape[2]
            * x0.shape[3]
            * x0.shape[4]
            / self.patch_size[1]
            / self.patch_size[2]
        )

        with torch.autocast(device_type="cuda", dtype=self.param_dtype):

            timesteps = torch.randint(low=0, high=1000, size=(x0.shape[0],)).cuda()
            norm_t = (timesteps / 1000.0).view(-1, 1, 1, 1, 1)
            latents = norm_t * noise + (1 - norm_t) * x0

            v_gt = noise - x0

            arg_c = {"context": context, "seq_len": seq_len}

            v = self.model(
                latents,
                t=timesteps,
                vace_context=z,
                vace_context_scale=context_scale,
                **arg_c,
            )
            v = torch.stack(v, dim=0)

        return v_gt, v, latents, norm_t

    def validate(
        self,
        z,
        x0,
        source_camera,
        target_camera,
        input_frames,
        target_frames,
        input_ref_images,
        input_masks,
        context_scale=1.0,
        base_guide_scale=0.0,
        step=1,
        validate_dir=None,
        vae=None,
        swap=False,
        src_mask=None,
        tgt_mask=None,
    ):

        local_rank = int(os.environ["LOCAL_RANK"])
        guide_scale = base_guide_scale + local_rank
        batch_size = x0.shape[0]

        with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):

            context = self.camera_encoder(source_camera, target_camera)

            noise = torch.randn(
                x0.shape,
                dtype=torch.bfloat16,
                device=x0.device,
            )

            seq_len = math.ceil(
                x0.shape[2]
                * x0.shape[3]
                * x0.shape[4]
                / self.patch_size[1]
                / self.patch_size[2]
            )

            # sample videos
            latents = noise
            arg_c = {"context": context, "seq_len": seq_len}
            arg_null = {
                "context": self.camera_encoder.null_context.repeat(
                    batch_size, 1, 1
                ).cuda(),
                "seq_len": seq_len,
            }

            sample_scheduler = FlowUniPCMultistepScheduler(
                num_train_timesteps=1000,
                shift=1,
                use_dynamic_shifting=False,
            )
            sample_scheduler.set_timesteps(40, device=x0.device, shift=16.0)
            timesteps = sample_scheduler.timesteps

            for _, t in enumerate(tqdm(timesteps)):
                timestep = [t]

                timestep = (
                    torch.stack(timestep)
                    .cuda()
                    .repeat(
                        batch_size,
                    )
                )

                noise_pred_cond = self.model(
                    latents,
                    t=timestep,
                    vace_context=z,
                    vace_context_scale=context_scale,
                    **arg_c,
                )
                noise_pred_uncond = self.model(
                    latents,
                    t=timestep,
                    vace_context=z,
                    vace_context_scale=context_scale,
                    **arg_null,
                )

                noise_pred_cond = torch.stack(noise_pred_cond, dim=0)
                noise_pred_uncond = torch.stack(noise_pred_uncond, dim=0)
                # print(noise_pred_cond.shape)

                noise_pred = noise_pred_uncond + guide_scale * (
                    noise_pred_cond - noise_pred_uncond
                )

                temp_x0 = sample_scheduler.step(
                    noise_pred,
                    t,
                    latents,
                    return_dict=False,
                )[0]

                latents = temp_x0


        edited_videos = vae.decode(latents[:, :, 1:])
        original_videos = input_frames
        target_videos = target_frames

        rank = int(os.environ["RANK"])
        if not swap:
            save_videos_path = (
                f"{validate_dir}/step_{step}_rank_{rank}_guide_{guide_scale}.mp4"
            )
        else:
            save_videos_path = (
                f"{validate_dir}/step_{step}_rank_{rank}_guide_{guide_scale}_swap.mp4"
            )

        original_videos = rearrange(original_videos, "b c f h w -> c f (b h) w")
        target_videos = rearrange(target_videos, "b c f h w -> c f (b h) w")
        edited_videos = rearrange(edited_videos, "b c f h w -> c f (b h) w")

        input_ref_images = rearrange(
            input_ref_images, "b c f h w -> c f (b h) w"
        ).repeat(1, original_videos.shape[1], 1, 1)
        src_mask = rearrange(src_mask, "b c f h w -> c f (b h) w").repeat(3, 1, 1, 1)
        tgt_mask = rearrange(tgt_mask, "b c f h w -> c f (b h) w").repeat(3, 1, 1, 1)

        video = torch.cat(
            [
                input_ref_images,
                src_mask,
                tgt_mask,
                original_videos,
                target_videos,
                edited_videos,
            ],
            dim=3,
        )

        save_video(video, save_videos_path)
        return


def save_video(x, path):
    x = x.permute(1, 2, 3, 0).float().cpu().numpy().clip(-1, 1)
    x = (x + 1) / 2
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path), exist_ok=True)
    export_to_video(x, path, fps=8)


def encode_frames(
    vae,
    frames: torch.Tensor,
    ref_images: Optional[torch.Tensor] = None,
    masks: Optional[torch.Tensor] = None,
):
    # 1) split masks if provided
    if masks is not None:

        inactive_imgs = frames * (1 - masks)
        reactive_imgs = frames * masks
        latents = vae.encode(torch.cat([inactive_imgs, reactive_imgs], dim=0))

        latents = torch.cat(
            [latents[: len(latents) // 2], latents[len(latents) // 2 :]], dim=1
        )
    else:
        # simple pass-through
        latents = vae.encode(frames)

    # 2) optionally prepend reference encoding
    if ref_images is not None:
        if masks is None:
            ref_lat = vae.encode(ref_images)
        else:
            # same splitting logic for ref
            ref_enc = vae.encode(ref_images)
            # ref_enc = torch.stack(ref_enc,dim=0)
            zeros = torch.zeros_like(ref_enc)
            # pad to match the twice-channels of latents
            ref_lat = torch.cat([ref_enc, zeros], dim=1)
        latents = torch.cat([ref_lat, latents], dim=2)

    return latents, ref_enc


def pack_latent(z, m):
    return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]

def build_train_loader(
    world_size, rank, mix_ratio, image_batch_size=16, video_batch_size=3
):
    # --- build datasets ---
    dataset_img = ImageDataset(
        jsonl_path=None,
        resolution=(640, 960),
    )
    dataset_vid = VideoDataset(
        rendered_csv="",
        camera_csv="",
        resolution=(640, 960),
        video_length=61,
    )

    # --- DDP samplers (each dataset has its own sampler) ---
    sampler_img = DistributedSampler(
        dataset_img, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True
    )
    sampler_vid = DistributedSampler(
        dataset_vid, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True
    )

    loader_img = DataLoader(
        dataset_img,
        batch_size=image_batch_size,
        sampler=sampler_img,
        num_workers=4,
        pin_memory=False,
        drop_last=True,
    )
    loader_vid = DataLoader(
        dataset_vid,
        batch_size=video_batch_size,
        sampler=sampler_vid,
        num_workers=4,
        pin_memory=False,
        drop_last=True,
    )

    steps_per_epoch = int(
        mix_ratio * len(loader_img) + (1.0 - mix_ratio) * len(loader_vid)
    )
    steps_per_epoch = max(1, steps_per_epoch)

    return loader_img, loader_vid, steps_per_epoch


def _cycle(loader):
    """Endless iterator over a DataLoader (will recreate iterator on StopIteration)."""
    while True:
        for batch in loader:
            yield batch


def mixed_batch_iterator(loader_a, loader_b, prob_a, steps, epoch_seed=None):
    """
    Yields (src_id, batch) for exactly `steps` steps.
    src_id in {0,1} tells you which dataset it came from.
    """
    gen = random.Random(epoch_seed)
    it_a = _cycle(loader_a)
    it_b = _cycle(loader_b)
    for _ in range(steps):
        pick_a = gen.random() < prob_a
        if pick_a:
            yield next(it_a)
        else:
            yield next(it_b)


SCRIPT_STEM = Path(__file__).resolve().stem


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    parser.add_argument("--ds_config", type=str)
    parser.add_argument("--validation_step", type=int, default=300)
    parser.add_argument("--epoches", type=int, default=1000)
    parser.add_argument("--frames", type=int, default=61)
    parser.add_argument("--repeat", type=int, default=4)
    parser.add_argument("--task", type=str, default="test")
    parser.add_argument("--resume", type=str, default="None")

    args = parser.parse_args()
    with open(args.ds_config, "r") as f:
        ds_config = json.load(f)

    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    deepspeed.init_distributed(
        dist_backend="nccl", init_method=f"env://", timeout=timedelta(minutes=20)
    )

    now = datetime.now()
    time = now.strftime("%m_%d_%H")

    validate_dir = f"validation/{time}_{SCRIPT_STEM}_{args.task}"
    checkpoint_dir = f"checkpoint/{time}_{SCRIPT_STEM}_{args.task}"
    ds_config["tensorboard"][
        "job_name"
    ] = f"object_manipulation/{time}_{SCRIPT_STEM}_{args.task}"

    os.makedirs(validate_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    cfg = WAN_CONFIGS["vace-1.3B"]

    model = ModelWrapper(
        config=cfg,
        checkpoint_dir="./Wan2.1-1.3B-8blk",
    )

    base_step = 0
    if args.resume != "None":
        print(f"loading from {args.resume}")
        state_dict = torch.load(args.resume, map_location="cpu")
        m, u = model.load_state_dict(state_dict=state_dict, strict=True)
        base_step = int(args.resume.split("/")[-1].split(".pth")[0])

    model.train()

    vae = WanVAE(
        vae_pth=os.path.join(
            "./Wan2.1-T2V-1.3B", cfg.vae_checkpoint
        ),
        dtype=torch.bfloat16,
    )

    vae.to(torch.bfloat16)
    vae.to(local_rank)

    ### Stage 1
    for name, param in model.named_parameters():
        if (
            "vace" in name
            or "cross_attn" in name
            or "text_embedding" in name
        ):
            param.requires_grad = True
        else:
            param.requires_grad = False
    ### Stage 2
        for name, param in model.named_parameters():
        if (
            "vace" in name
        ):
            param.requires_grad = True
        else:
            param.requires_grad = False


    for param in model.camera_encoder.parameters():
        param.requires_grad = True

    for param in model.mask_projection.parameters():
        param.requires_grad = True

    model.camera_encoder.freqs.requires_grad = False
    total_param = sum([p.numel() for p in model.parameters()])
    trainable_param = sum([p.numel() for p in model.parameters() if p.requires_grad])

    print(f"Total params: {total_param}, trainable params: {trainable_param}")
    if rank == 0:
        for name, param in model.named_parameters():
            if param.requires_grad:
                print("require", name)
            else:
                print("not require", name)

    model_engine, _, _, _ = deepspeed.initialize(
        model=model,
        model_parameters=list(filter(lambda p: p.requires_grad, model.parameters())),
        config=ds_config,
    )

    loader_gem_val, loader_pex_val, steps_per_epoch = build_train_loader(
        world_size=world_size,
        rank=rank,
        mix_ratio=0.5,
        image_batch_size=1,
        video_batch_size=1,
    )
    mix_iter_val = mixed_batch_iterator(
        loader_gem_val, loader_pex_val, 0.5, steps_per_epoch, epoch_seed=42 + rank
    )

    # sanity check
    print("*" * 20, "sanity check", "*" * 20)
    run_validation(model_engine, mix_iter_val, vae, validate_dir, base_step)

    loader_gem, loader_pex, steps_per_epoch = build_train_loader(
        world_size=world_size,
        rank=rank,
        mix_ratio=0.5,
        image_batch_size=16,
        video_batch_size=3,
    )

    for epoch in range(args.epoches):

        epoch_seed = 42 + epoch + rank

        mix_iter = mixed_batch_iterator(
            loader_gem, loader_pex, 0.5, steps_per_epoch, epoch_seed=epoch_seed
        )

        model_engine.train()
        for batch in mix_iter:
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                if random.random() < 0.5:
                    z0, m, x0, source_camera, target_camera, _, _, loss_mask = (
                        vae_encode(vae, batch)
                    )
                else:
                    z0, m, x0, source_camera, target_camera, _, _, loss_mask = (
                        vae_encode(
                            vae,
                            batch,
                            swap=True,
                        )
                    )

            for _ in range(args.repeat):
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    m0 = model_engine.module.mask_projection(
                        m.permute(0, 2, 3, 4, 1)
                    ).permute(0, 4, 1, 2, 3)
                    z = pack_latent(z0, m0)
                    v_gt, v, xt, t = model_engine(z, x0, source_camera, target_camera)
                loss = ((v_gt - v) ** 2).mean(dim=(1, 2, 3, 4)) * loss_mask

                loss = loss.mean()

                if local_rank == 0:
                    print(f"loss: {loss.item():.6f}")
                    print(f"loss_mask: {loss_mask}")

                model_engine.backward(loss)
                model_engine.step()

                if model_engine.global_steps != 0 and (
                    model_engine.global_steps % args.validation_step == 0
                ):
                    if rank == 0:
                        torch.save(
                            model_engine.module.state_dict(),
                            f"{checkpoint_dir}/{base_step+model_engine.global_steps}.pth",
                        )

                    run_validation(
                        model_engine, mix_iter_val, vae, validate_dir, base_step
                    )


def encode_two_masks(src_masks, tgt_masks, ref_images=None):

    masks = src_masks
    if ref_images is None:
        ref_images = [None] * len(masks)
    else:
        assert len(masks) == len(ref_images)

    result_masks_1 = []
    for mask, refs in zip(masks, ref_images):
        c, depth, height, width = mask.shape
        new_depth = int((depth + 3) // 4)
        height = 2 * (int(height) // (8 * 2))
        width = 2 * (int(width) // (8 * 2))

        # reshape
        mask = mask[0, :, :, :]
        mask = mask.view(depth, height, 8, width, 8)  # depth, height, 8, width, 8
        mask = mask.permute(2, 4, 0, 1, 3)  # 8, 8, depth, height, width
        mask = mask.reshape(8 * 8, depth, height, width)  # 8*8, depth, height, width

        # interpolation
        mask = F.interpolate(
            mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact"
        ).squeeze(0)

        if refs is not None:
            mask_pad = torch.zeros_like(mask[:, :1, :, :])
            mask = torch.cat((mask_pad, mask), dim=1)
        result_masks_1.append(mask)

    masks = tgt_masks
    if ref_images is None:
        ref_images = [None] * len(masks)
    else:
        assert len(masks) == len(ref_images)

    result_masks_2 = []
    for mask, refs in zip(masks, ref_images):
        c, depth, height, width = mask.shape
        new_depth = int((depth + 3) // 4)
        height = 2 * (int(height) // (8 * 2))
        width = 2 * (int(width) // (8 * 2))

        # reshape
        mask = mask[0, :, :, :]
        mask = mask.view(depth, height, 8, width, 8)  # depth, height, 8, width, 8
        mask = mask.permute(2, 4, 0, 1, 3)  # 8, 8, depth, height, width
        mask = mask.reshape(8 * 8, depth, height, width)  # 8*8, depth, height, width

        # interpolation
        mask = F.interpolate(
            mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact"
        ).squeeze(0)

        if refs is not None:
            mask_pad = torch.zeros_like(mask[:, :1, :, :])
            mask = torch.cat((mask_pad, mask), dim=1)
        result_masks_2.append(mask)

    result_masks_1 = torch.stack(result_masks_1, dim=0)
    result_masks_2 = torch.stack(result_masks_2, dim=0)

    return torch.cat([result_masks_1, result_masks_2], dim=1)


def vae_encode(vae, batch, swap=False, mask_projection=None):

    if not swap:
        input_frames = batch["src_video"].cuda()
        src_masks = batch["src_mask"].cuda()
        tgt_masks_square = batch["tgt_mask_square"].cuda()
        bg_frames = batch["bg_video"].cuda()
        input_ref_images = batch["src_ref_img"].cuda()
        target_frames = batch["tgt_video"].cuda()
        source_camera = batch["src_view"].cuda()
        target_camera = batch["tgt_view"].cuda()

    else:
        input_frames = batch["tgt_video"].cuda()
        # input_masks = batch["tgt_mask"].cuda()
        src_masks = batch["tgt_mask"].cuda()
        tgt_masks_square = batch["src_mask_square"].cuda()
        bg_frames = batch["bg_video"].cuda()
        input_ref_images = batch["tgt_ref_img"].cuda()
        target_frames = batch["src_video"].cuda()
        source_camera = batch["src_view"].cuda()
        target_camera = batch["tgt_view"].cuda()

        # target_camera = random_camera()
    random_num = random.random()
    ## Main Task
    if random_num < 0.8:
        pass
    ## Referce Inpating with Camera Control
    elif random_num < 0.9:
        input_frames = bg_frames
        src_masks = torch.zeros_like(src_masks)
    ## Object Removal
    elif:
        tgt_mask = torch.zeros_like(tgt_masks)
        target_frames = bg_frmaes
        target_camera[:,4] = 2.0
        target_camera[:,5] = 2.0
        
    loss_mask = batch["loss_mask"].cuda()

    combined_mask = ((src_masks > 0.1) | (tgt_masks_square > 0.1)).float()

    with torch.no_grad():
        z0, x0_image = encode_frames(
            vae, input_frames, input_ref_images, masks=combined_mask
        )
        m0 = encode_two_masks(src_masks, tgt_masks_square, input_ref_images)
        x0 = vae.encode(target_frames)
        x0 = torch.cat((x0_image, x0), dim=2)

    return (
        z0,
        m0,
        x0,
        source_camera,
        target_camera,
        input_frames,
        target_frames,
        loss_mask,
    )


def get_random_views(
    base_view: torch.Tensor,
    num: int = 2,
    dyaw=0.4,
    dpitch=0.4,
    ddist=0.4,
    drx=0.2,
    dry=0.2,
) -> torch.Tensor:
    """base_view: (1,5) -> (num,5) [yaw,pitch,dist,rx,ry]."""
    assert base_view.shape == (1, 5)
    yaw0, pitch0, dist0, rx0, ry0 = base_view[0]

    device = yaw0.device

    yaws = torch.empty(num, device=device).uniform_(-dyaw, dyaw) + yaw0
    pitchs = torch.empty(num, device=device).uniform_(-dpitch, dpitch) + pitch0
    dists = torch.empty(num, device=device).uniform_(-ddist, ddist) + dist0
    rxs = torch.empty(num, device=device).uniform_(-drx, drx) + rx0
    rys = torch.empty(num, device=device).uniform_(-dry, dry) + ry0

    # clamp pitch to [0, pi], wrap yaw to [-pi, pi]
    pitchs.clamp_(0.0, torch.pi)
    yaws = ((yaws + torch.pi) % (2 * torch.pi)) - torch.pi
    return torch.stack([yaws, pitchs, dists, rxs, rys], dim=1)  # (num,5)


def _shift2d_trunc(x: torch.Tensor, sy: int, sx: int) -> torch.Tensor:
    """Zero-padded translation (no wrap). x:(C,H,W)->(C,H,W)."""
    C, H, W = x.shape
    out = x.new_zeros(C, H, W)

    y_src0 = max(0, -sy)
    x_src0 = max(0, -sx)
    y_src1 = min(H, H - sy)
    x_src1 = min(W, W - sx)

    h = y_src1 - y_src0
    w = x_src1 - x_src0
    if h > 0 and w > 0:
        y_dst0 = max(0, sy)
        x_dst0 = max(0, sx)
        out[:, y_dst0 : y_dst0 + h, x_dst0 : x_dst0 + w] = x[
            :, y_src0 : y_src0 + h, x_src0 : x_src0 + w
        ]
    return out


def shift_mask_trunc_CFHW(
    src_mask_CFHW: torch.Tensor,  # (C,F,H,W)
    rx_src: float,
    ry_src: float,
    rx_tgt: float,
    ry_tgt: float,
) -> torch.Tensor:
    """Translate mask by target-source shift in screen coords, truncating out-of-bounds."""
    C, F, H, W = src_mask_CFHW.shape
    sx = int(round((rx_tgt - rx_src) * W))
    sy = int(round((ry_tgt - ry_src) * H))
    return torch.stack(
        [_shift2d_trunc(src_mask_CFHW[:, f], sy, sx) for f in range(F)], dim=1
    )  # (C,F,H,W)


def bbox2d_from_CFHW(mask_CFHW: torch.Tensor):
    """Union bbox over C and F. Returns (ymin,xmin,ymax,xmax) or None if empty."""
    # union over channels and frames
    occ = (mask_CFHW > 0.5).any(dim=0).any(dim=0)  # (H,W) bool
    idx = occ.nonzero(as_tuple=False)
    if idx.numel() == 0:
        return None
    ymin = int(idx[:, 0].min().item())
    ymax = int(idx[:, 0].max().item())
    xmin = int(idx[:, 1].min().item())
    xmax = int(idx[:, 1].max().item())
    return ymin, xmin, ymax, xmax


def touches_edge(bbox, H: int, W: int, margin: int = 0):
    """True if bbox touches or crosses image border (optionally with margin)."""
    ymin, xmin, ymax, xmax = bbox
    return (
        (ymin <= 0 + margin)
        or (xmin <= 0 + margin)
        or (ymax >= H - 1 - margin)
        or (xmax >= W - 1 - margin)
    )


@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def vae_encode_val(
    vae,
    batch,
    swap: bool = False,
    num_rand_views: int = 5,
    dyaw=0.4,
    dpitch=0.4,
    ddist=0.4,
    drx=0.2,
    dry=0.2,
    try_scale_by_distance: bool = True,
    mask_projection=None,
):

    local_rank = os.environ.get("LOCAL_RANK")
    device = f"cuda:{local_rank}"

    # Unpack batch (all tensors are (C,F,H,W) except cameras (1,5))
    if not swap:
        src_vid = batch["src_video"][:1].to(device)
        src_mask = batch["src_mask"][:1].to(device)
        src_sq = batch["src_mask_square"][:1].to(device)
        tgt_sq = batch["tgt_mask_square"][:1].to(device)
        ref_img = batch["src_ref_img"][:1].to(device)  # (C,1,H,W)
        tgt_vid = batch["tgt_video"][:1].to(device)
        src_cam = batch["src_view"][:1].to(device)  # (1,5)
        tgt_cam = batch["tgt_view"][:1].to(device)  # (1,5)
    else:
        src_vid = batch["tgt_video"][:1].to(device)
        src_mask = batch["tgt_mask"][:1].to(device)
        src_sq = batch["tgt_mask_square"][:1].to(device)
        tgt_sq = batch["src_mask_square"][:1].to(device)
        ref_img = batch["tgt_ref_img"][:1].to(device)
        tgt_vid = batch["src_video"][:1].to(device)
        # swap cameras to match frames/masks
        src_cam = batch["tgt_view"][:1].to(device)
        tgt_cam = batch["src_view"][:1].to(device)

    # Prepare augmented target cameras & masks
    rand_views = get_random_views(
        tgt_cam,
        num=num_rand_views,
        dyaw=dyaw,
        dpitch=dpitch,
        ddist=ddist,
        drx=drx,
        dry=dry,
    )  # (K,5)

    # Always include the original first
    tgt_cams = [tgt_cam[0].clone()]  # list of (5,)
    tgt_masks = [tgt_sq[0].clone()]  # list of (C,F,H,W)

    rx_src, ry_src = float(src_cam[0, 3]), float(src_cam[0, 4])
    dist_src = float(src_cam[0, 2])

    C, F, H, W = src_sq.shape[-4:]

    for k in range(num_rand_views):
        yaw, pitch, dist_t, rx_t, ry_t = [float(v) for v in rand_views[k]]

        # 1) invalid distance -> fallback
        if dist_t <= 0.0:
            tgt_cams.append(tgt_cam[0].clone())
            tgt_masks.append(tgt_sq[0].clone())
            continue

        # 2) shift (truncate, no wrap)
        cand = shift_mask_trunc_CFHW(src_sq[0], rx_src, ry_src, rx_t, ry_t)

        # 3) optional scale by distance
        if try_scale_by_distance:
            try:
                # your util: scale_video_mask_CFHW(mask_CFHW, dist_src, dist_t)
                cand = scale_video_mask_CFHW(cand, dist_src, dist_t)
            except Exception:
                pass

        # 4) reject if empty or touches any border
        bbox = bbox2d_from_CFHW(cand)
        if (bbox is None) or touches_edge(bbox, H, W, margin=0):
            # fallback to original target
            tgt_cams.append(tgt_cam[0].clone())
            tgt_masks.append(tgt_sq[0].clone())
        else:
            tgt_cams.append(
                torch.tensor([yaw, pitch, dist_t, rx_t, ry_t], device=device)
            )
            tgt_masks.append(cand)

    # Stack to tensors: (B,5) and (B,C,F,H,W) with B=1+K
    all_tgt_cam = torch.stack(tgt_cams, dim=0)
    all_tgt_masks = torch.stack(tgt_masks, dim=0)

    B = all_tgt_cam.shape[0]

    # Tile "source side" to B
    def _tile(x):
        return x.repeat(B, 1, 1, 1, 1)

    src_vid_B = _tile(src_vid)
    src_mask_B = _tile(src_mask)
    ref_img_B = _tile(ref_img)
    tgt_vid_B = _tile(tgt_vid)
    all_src_cam = src_cam.repeat(B, 1)

    # Union mask to protect/compose while encoding
    combined_mask_B = ((src_mask_B > 0.1) | (all_tgt_masks > 0.1)).float()

    with torch.no_grad():
        z0, x0_image = encode_frames(
            vae, src_vid_B, ref_img_B, masks=combined_mask_B
        )
        m0 = encode_two_masks(src_mask_B, all_tgt_masks, ref_img_B)
        x0 = vae.encode(tgt_vid_B)
        x0 = torch.cat((x0_image, x0), dim=2)

        m0 = mask_projection(m0.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
        z = pack_latent(z0, m0)


    return (
        z,
        x0,
        all_src_cam,
        all_tgt_cam,
        src_vid_B,
        tgt_vid_B,
        ref_img_B,
        src_mask_B,
        all_tgt_masks,
        combined_mask_B,
    )


def run_validation(model_engine, mix_iter_val, vae, validate_dir, base_step):
    model_engine.eval()

    for batch in mix_iter_val:
        (
            z,
            x0,
            source_camera,
            target_camera,
            input_frames,
            target_frames,
            input_ref_images,
            src_mask,
            tgt_mask,
            input_masks,
        ) = vae_encode_val(
            vae, batch, mask_projection=model_engine.module.mask_projection
        )

        model_engine.module.validate(
            z,
            x0,
            source_camera,
            target_camera,
            input_frames,
            target_frames,
            input_ref_images,
            input_masks,
            step=model_engine.global_steps + base_step,
            validate_dir=validate_dir,
            vae=vae,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
        )

        (
            z,
            x0,
            source_camera,
            target_camera,
            input_frames,
            target_frames,
            input_ref_images,
            src_mask,
            tgt_mask,
            input_masks,
        ) = vae_encode_val(
            vae,
            batch,
            swap=True,
            mask_projection=model_engine.module.mask_projection,
        )

        model_engine.module.validate(
            z,
            x0,
            source_camera,
            target_camera,
            input_frames,
            target_frames,
            input_ref_images,
            input_masks,
            step=model_engine.global_steps + base_step,
            validate_dir=validate_dir,
            vae=vae,
            swap=True,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
        )
        break

    dist.barrier()
    torch.cuda.empty_cache()
    model_engine.train()


if __name__ == "__main__":
    main()