#!/usr/bin/env python
# coding=utf-8

import os
import sys
import torch
import argparse
import numpy as np
import imageio
import decord
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
from typing import Dict, Any, List
import torch.nn.functional as F
from torchvision import transforms
from omegaconf import OmegaConf
import logging
import math
from transformers import AutoTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler
import pickle
import cv2
import gc

current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), 
                os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
    if project_root not in sys.path:
        sys.path.insert(0, project_root)

from MoRe4D.models import AutoencoderKLWan, WanT5EncoderModel, WanTransformer3DModel, CLIPModel, WanTransformer3DModelDINO
from MoRe4D.pipeline import WanFunControlPipeline, WanFunInpaintPipeline
from MoRe4D.utils.lora_utils import create_network, merge_lora, unmerge_lora
from MoRe4D.utils.utils import filter_kwargs, get_image_latent, get_video_to_video_latent, save_videos_grid, get_image_to_video_latent, get_image_to_flow_video_latent
from MoRe4D.models.cache_utils import get_teacache_coefficients
from MoRe4D.utils.fm_solvers import FlowDPMSolverMultistepScheduler
from MoRe4D.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from prompt_tuning.models.resnet import VAEEncoderadaptor, VAEDecoderadaptor
from unidepth.models import UniDepthV2old

from torch_scatter import scatter
from project_utils import project
from visualizer.gaussian_splatting import gs_render


def get_logger(name, log_level="INFO"):
    logger = logging.getLogger(name)
    logger.setLevel(getattr(logging, log_level))
    
    if not logger.handlers:
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
        logger.addHandler(handler)
    
    return logger


def load_prompts(file_path):
    with open(file_path, "r") as f:
        return [line.strip() for line in f.readlines() if line.strip()]


def load_videos(file_path):
    data_dir = file_path.parent
    with open(file_path, "r") as f:
        return [Path(str(data_dir / line.strip())) for line in f.readlines() if line.strip()]

class TwoStageDataset:
    def __init__(
        self, 
        data_root: str,
        caption_column: str,
        video_column: str,
        device: torch.device,
        max_num_frames: int, 
        height: int, 
        width: int,
        max_samples: int = None
    ):
        super().__init__()
        self.logger = get_logger("inference", "INFO")
        data_root = Path(data_root)

        import random
        self.prompts = load_prompts(data_root / caption_column)
        self.videos = load_videos(data_root / video_column)
        rand_idx = list(range(len(self.videos)))
        random.shuffle(rand_idx)
        self.prompts = [self.prompts[i] for i in rand_idx]
        self.videos = [self.videos[i] for i in rand_idx]
        import random
        # if max_samples is not None:
        #     self.prompts = self.prompts[:other_video_num] + additional_prompts
        #     self.videos = self.videos[:other_video_num] + additional_videos
        # print(self.videos)
        self.video_names = [video.stem for video in self.videos]
        sorted_indices = sorted(range(len(self.video_names)), key=lambda i: self.video_names[i])
        self.prompts = [self.prompts[i] for i in sorted_indices]
        # self.first_frames = [self.first_frames[i] for i in sorted_indices]
        self.video_names = [self.video_names[i] for i in sorted_indices]
        self.videos = [self.videos[i] for i in sorted_indices]
        
        self.max_num_frames = max_num_frames
        self.height = height
        self.width = width
        self.device = device
        
        self.__image_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])

        if not (len(self.videos) == len(self.prompts)):
            raise ValueError(
                f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset."
            )

        if any(not path.is_file() for path in self.videos):
            raise ValueError(
                f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
            )

    def __getitem__(self, index: int) -> Dict[str, Any]:
        if isinstance(index, list):
            return index

        prompt = self.prompts[index]
        video = self.videos[index]
        _, image, fps = self.load_first_frame(video)
        image = self.image_transform(image)

        return {
            "prompt": prompt,
            "image": image,
            "video_path": video,
            "video_fps": fps
        }
    
    def __len__(self):
        return len(self.videos)
    
    def load_first_frame(self, video: Path) -> torch.Tensor:
        cap = cv2.VideoCapture(str(video))
        success, first_frame = cap.read()
        cap.release()
        
        if not success:
            raise ValueError(f"Cannot read video file: {video}")
        
        # Convert BGR to RGB
        first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
        first_frame = Image.fromarray(first_frame)
        first_frame = first_frame.resize((self.width, self.height), Image.BILINEAR)
        first_frame = np.array(first_frame).astype(np.float32) 
        image = torch.from_numpy(first_frame).permute(2, 0, 1).unsqueeze(0)  # [1, C, H, W]
        return None, image, None
    
    def image_transform(self, image: torch.Tensor) -> torch.Tensor:
        return self.__image_transforms(image)



def back_project_coords(depth_map, H, W, device):
    depth_map = F.interpolate(depth_map.unsqueeze(0).unsqueeze(0), size=(H, W), mode='bilinear', align_corners=False)[0,0]
    
    H_ori, W_ori = 540, 960
    if W_ori / W > H_ori / H:
        fx = 1
        fy = W_ori / H_ori / (W / H)
    else:
        fy = 1
        fx = H_ori / W_ori / (H / W)
    
    intrinsic = torch.Tensor([
        [fx, 0, 0.5],
        [0, fy, 0.5],
        [0, 0, 1]
    ]).to(device)
    
    u = torch.linspace(0, 1, W, device=device)
    v = torch.linspace(0, 1, H, device=device)
    uu, vv = torch.meshgrid(u, v, indexing='xy')

    pixels = torch.stack([uu, vv, torch.ones_like(uu)], dim=-1)

    K_inv = torch.inverse(intrinsic.cpu()).to(device)
    rays = pixels @ K_inv.T

    points_3d = rays * depth_map.unsqueeze(-1)
    return points_3d


def inverse_flow_norm_transform_no_diff(rel_flow, first_frame_coords):

    rel_flow = rel_flow.permute((0, 2, 1, 3, 4))
    B, F, C, H, W = rel_flow.shape
    device = rel_flow.device
    
    first_frame_coords = first_frame_coords[:,:,0,:,:].permute((0, 2, 3, 1))
    if first_frame_coords.dim() == 4 and first_frame_coords.size(0) == 1:
        first_frame_coords = first_frame_coords.expand(B, H, W, 3)
    
    frame0_flat = first_frame_coords.permute(0, 3, 1, 2).reshape(B, 3, -1)
    max_vals = frame0_flat.max(dim=2).values
    min_vals = frame0_flat.min(dim=2).values
    diff = (max_vals - min_vals).max(dim=1)[0].repeat((1,3))

    diff = torch.where(diff == 0, torch.ones_like(diff), diff)
    
    frame0_normalized = first_frame_coords.permute(0, 3, 1, 2) / diff.view(B, 3, 1, 1)
    recovered_flow_normalized = rel_flow + frame0_normalized.unsqueeze(1)
    
    recovered_flow = recovered_flow_normalized * diff.view(B, 1, 3, 1, 1)
    recovered_flow = recovered_flow.permute((0, 2, 1, 3, 4))
    return recovered_flow, diff


def render_with_project(world_points, extrinsic, intrinsic, colors, H, W, device):

    predicted_2D, depth_2D = project(world_points, extrinsic, intrinsic)
    
    mask = (predicted_2D[..., 0] >= 0) * (predicted_2D[..., 0] <= 1) * \
           (predicted_2D[..., 1] >= 0) * (predicted_2D[..., 1] <= 1) * (depth_2D >= 0)
    
    if mask.sum() > 0:
        color_pc = colors[mask, :]
        depth_2D = depth_2D[mask]
        idx_pc = predicted_2D[mask, :]
        idx_xy = (idx_pc[:, 0]*W).floor().clamp(0, W-1) * H + (idx_pc[:, 1]*H).floor().clamp(0, H-1)
        
        unique_indices, inverse_indices = torch.unique(idx_xy, return_inverse=True)
        min_depth = torch.ones_like(unique_indices, dtype=depth_2D.dtype)*depth_2D.max()
        min_depth.index_reduce_(0, inverse_indices, depth_2D, 'amin')
        mask_depth = (depth_2D == min_depth[inverse_indices])
        
        color_pc = color_pc[mask_depth, :]
        idx_xy = idx_xy[mask_depth]
        
        color_image = scatter(color_pc, idx_xy.long(), dim=0, reduce="mean")
        if len(color_image) < H*W:
            color_image = torch.cat([color_image, torch.zeros((H*W-len(color_image), 3), device=device)], dim=0)
        
        color_image = color_image.reshape(W, H, 3).transpose(0, 1)
    else:
        color_image = torch.zeros((H, W, 3), device=device)
    
    image_proj = color_image.cpu().numpy().astype(np.uint8)
    mask = (image_proj.sum(-1) == 0)
    
    return image_proj, mask


def render_with_gs(world_points, extrinsic, intrinsic, colors, H, W, device, scale=0.0001):

    scale_tensor = torch.Tensor([scale, scale, scale]).to(device)
    rotation = torch.Tensor([0.0, 0.0, 0.0, 1.0]).to(device)
    
    rendered_images = gs_render(
        intrinsic,
        extrinsic,
        [H, W],
        world_points,
        scale_tensor,
        rotation,
        colors.float()/255.0 if colors.max() > 1.0 else colors.float(),
        torch.ones((H*W,)).to(device)
    )
    
    color_image = rendered_images[0].permute(1, 2, 0).detach().cpu() * 255
    image_gs = color_image.numpy().astype(np.uint8)
    
    return image_gs

def generate_x_moving_trajectory(center, n_frames, radius_base=0.3, z_progress=True):
    extrinsics = []
    for i in range(n_frames):
        cam_y = cam_z = 0.0
        if i < n_frames // 4:
            cam_x =  radius_base * i / n_frames if z_progress else 0.0
        elif i < 3 * n_frames // 4:
            cam_x = 0.5 * radius_base - radius_base * (i) / n_frames if z_progress else 0.0
        else:
            cam_x = -radius_base + radius_base * (i) / n_frames if z_progress else 0.0

        cam_pos = np.array([cam_x, cam_y, cam_z])
   
        extrinsic = np.eye(4)
        extrinsic[:3, 3] = cam_pos
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics

def generate_mix2_trajectory(center, n_frames, radius_base=0.3, z_progress=True):
    extrinsics = []
    for i in range(n_frames):

        cam_z = radius_base * i / n_frames if z_progress else 0.0
        # cam_y = radius_base + 4 * radius_base * (n_frames/2 - i) / n_frames if z_progress else 0.0
        cam_x = -cam_z * 0.2
        cam_y = -cam_z * 1.5 * math.cos( math.pi * i / n_frames/2)

        cam_pos = np.array([cam_x, cam_y, cam_z])
   
        extrinsic = np.eye(4)
        extrinsic[:3, 3] = cam_pos
        # cam_pos = np.array([cam_x, cam_y, cam_z])
        target = np.array([center[0], center[1], center[2]]) 
        up = np.array([0, 1, 0]) 

        forward = target - cam_pos
        forward = forward / np.linalg.norm(forward)
        
        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)
        
        up_corrected = np.cross(right, forward)
        up_corrected = up_corrected / np.linalg.norm(up_corrected)
        
        R = np.array([-right, up_corrected, forward]).T
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics

def generate_mix1_trajectory(center, n_frames, radius_base=0.3, z_progress=True):
    extrinsics = []
    for i in range(n_frames):
        cam_z = 0.15*radius_base * i / n_frames if z_progress else 0.0
        # cam_y = radius_base + 4 * radius_base * (n_frames/2 - i) / n_frames if z_progress else 0.0
        cam_x = math.cos( math.pi * i / n_frames * 2)*radius_base * 0.15
        cam_y = -cam_z * 0.1

        cam_pos = np.array([cam_x, cam_y, cam_z])
   
        extrinsic = np.eye(4)
        extrinsic[:3, 3] = cam_pos
        # cam_pos = np.array([cam_x, cam_y, cam_z])
        target = np.array([center[0], center[1], center[2]]) 
        up = np.array([0, 1, 0])  

        forward = target - cam_pos
        forward = forward / np.linalg.norm(forward)
        
        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)
        
        up_corrected = np.cross(right, forward)
        up_corrected = up_corrected / np.linalg.norm(up_corrected)
        
        R = np.array([-right, up_corrected, forward]).T
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics

def generate_left_right_trajectory(center, n_frames, radius_base=0.3, z_progress=True):

    extrinsics = []
    for i in range(n_frames):
        cam_y = cam_z = 0.0
        if i < n_frames // 4:
            cam_x =  radius_base * i / n_frames if z_progress else 0.0
        elif i < 3 * n_frames // 4:
            cam_x = 0.5 * radius_base - radius_base * (i) / n_frames if z_progress else 0.0
        else:
            cam_x = -radius_base + radius_base * (i) / n_frames if z_progress else 0.0

        cam_pos = np.array([cam_x, cam_y, cam_z])
   
        extrinsic = np.eye(4)
        extrinsic[:3, 3] = cam_pos
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics

def generate_y_moving_trajectory(center, n_frames, radius_base=0.3, z_progress=True):
    extrinsics = []
    for i in range(n_frames):
        cam_x = cam_z = 0.0
        if i < n_frames // 4:
            cam_y =  radius_base * i / n_frames if z_progress else 0.0
        elif i < 3 * n_frames // 4:
            cam_y = 0.5 * radius_base - radius_base * (i) / n_frames if z_progress else 0.0
        else:
            cam_y = -radius_base + radius_base * (i) / n_frames if z_progress else 0.0
            # cam_y = radius_base + 4 * radius_base * (n_frames/2 - i) / n_frames if z_progress else 0.0

        cam_pos = np.array([cam_x, cam_y, cam_z])
   
        extrinsic = np.eye(4)
        extrinsic[:3, 3] = cam_pos
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics


def generate_circle_rotating_trajectory(center, n_frames, radius_base=0.3, z_progress=True):

    extrinsics = []
    for i in range(n_frames):
        angle = 2 * math.pi * i / n_frames
        
        cam_x = radius_base * math.cos(angle)
        cam_y = radius_base * math.sin(angle)
        cam_z = 3 * radius_base * i / n_frames if z_progress else 0.0
        
        cam_pos = np.array([cam_x, cam_y, cam_z])
        target = np.array([center[0], center[1], center[2]])
        up = np.array([0, 1, 0])  

        forward = target - cam_pos
        forward = forward / np.linalg.norm(forward)
        
        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)
        
        up_corrected = np.cross(right, forward)
        up_corrected = up_corrected / np.linalg.norm(up_corrected)
        
        R = np.array([-right, up_corrected, forward]).T
        
        extrinsic = np.eye(4)
        extrinsic[:3, :3] = R
        extrinsic[:3, 3] = cam_pos
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics



def generate_camera_rotate_trajectory(center, n_frames, rotate_max_degree=30, z_progress=True):

    extrinsics = []
    for i in range(n_frames):
        cam_x = cam_y = cam_z = 0.0
        cam_pos = np.array([cam_x, cam_y, cam_z])
   
        extrinsic = np.eye(4)
        extrinsic[:3, 3] = cam_pos
        R = np.eye(3)
        angle = math.radians(rotate_max_degree) * math.sin( math.pi * i / n_frames)

        R_z = np.array([
            [math.cos(angle), -math.sin(angle), 0],
            [math.sin(angle), math.cos(angle), 0],
            [0, 0, 1]
        ])
        R = R_z @ R
        extrinsic[:3, :3] = R
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics

def generate_forward_backward_trajectory(center, n_frames, radius_base=0.3, z_progress=True):

    extrinsics = []
    for i in range(n_frames):
        cam_x = cam_y = 0.0
        if i < n_frames // 4:
            cam_z =  radius_base * i / n_frames if z_progress else 0.0
        elif i < 3 * n_frames // 4:
            cam_z = 0.5 * radius_base - radius_base * (i) / n_frames if z_progress else 0.0
        else:
            cam_z = -radius_base + radius_base * (i) / n_frames if z_progress else 0.0

        cam_pos = np.array([cam_x, cam_y, cam_z])

        extrinsic = np.eye(4)
        extrinsic[:3, 3] = cam_pos
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics


def generate_surrounding_trajectory(center, n_frames):
    extrinsics = []
    for i in range(n_frames):
        angle = -math.pi * i / n_frames / 4 - math.atan2(center[2], center[0])
        radius = math.sqrt(center[0]**2 + center[2]**2) 
        cam_x = center[0] + radius * math.cos(angle)
        cam_y = 0
        cam_z = center[2] + radius * math.sin(angle)
        cam_pos = np.array([cam_x, cam_y, cam_z])
        target = center
        up = np.array([0, 1, 0])
        
        forward = target - cam_pos
        forward = forward / np.linalg.norm(forward)
        
        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)
        
        up_corrected = np.cross(right, forward)
        up_corrected = up_corrected / np.linalg.norm(up_corrected)
        
        R = np.array([-right, up_corrected, forward]).T
        
        extrinsic = np.eye(4)
        extrinsic[:3, :3] = R
        extrinsic[:3, 3] = cam_pos
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics


def generate_anti_surrounding_trajectory(center, n_frames):
    extrinsics = []
    for i in range(n_frames):
        angle = +math.pi * i / n_frames / 4 - math.atan2(center[2], center[0])
        radius = math.sqrt(center[0]**2 + center[2]**2)
        cam_x = center[0] + radius * math.cos(angle)
        cam_y = 0
        cam_z = center[2] + radius * math.sin(angle)
        cam_pos = np.array([cam_x, cam_y, cam_z])
        target = center
        up = np.array([0, 1, 0])
        
        forward = target - cam_pos
        forward = forward / np.linalg.norm(forward)
        
        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)
        
        up_corrected = np.cross(right, forward)
        up_corrected = up_corrected / np.linalg.norm(up_corrected)
        
        R = np.array([-right, up_corrected, forward]).T
        
        extrinsic = np.eye(4)
        extrinsic[:3, :3] = R
        extrinsic[:3, 3] = cam_pos
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics


def generate_circular_trajectory(center, n_frames, radius_base=0.3, pitch_angle=45):
    extrinsics = []
    
    for i in range(n_frames):
        # angle = 2 * math.pi * i / n_frames
        pitch_rad = i / n_frames * math.radians(pitch_angle)  
        
        cam_x = 0 
        cam_y = -radius_base * math.sin(pitch_rad) 
        cam_z = radius_base - radius_base * math.cos(pitch_rad)
        
        cam_pos = np.array([cam_x, cam_y, cam_z])
        target = np.array([center[0], center[1], center[2]])
        up = np.array([0, 1, 0])  

        forward = target - cam_pos
        forward = forward / np.linalg.norm(forward)
        
        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)
        
        up_corrected = np.cross(right, forward)
        up_corrected = up_corrected / np.linalg.norm(up_corrected)
        
        R = np.array([-right, up_corrected, forward]).T
        
        extrinsic = np.eye(4)
        extrinsic[:3, :3] = R
        extrinsic[:3, 3] = cam_pos
        
        extrinsics.append(torch.from_numpy(extrinsic).float())
    
    return extrinsics


def generate_static_trajectory(n_frames):
    extrinsic = torch.eye(4).float()
    return [extrinsic for _ in range(n_frames)]


def render_trajectory(coords_data, colors, trajectory_type, n_frames, H, W, device):

    H_ori, W_ori = 540, 960
    if W_ori / W > H_ori / H:
        fx = 1
        fy = W_ori / H_ori / (W / H)
    else:
        fy = 1
        fx = H_ori / W_ori / (H / W)
    
    intrinsic = torch.Tensor([
        [fx, 0, 0.5],
        [0, fy, 0.5],
        [0, 0, 1]
    ]).to(device)
    
    first_frame_coords = coords_data[0, :, 0].permute(1, 2, 0).reshape(-1, 3)
    center = first_frame_coords.mean(dim=0).cpu().numpy()
    print(trajectory_type)
    if trajectory_type == "surrounding":
        extrinsics = generate_surrounding_trajectory(center, n_frames)
    elif trajectory_type == "anti-surrounding":
        extrinsics = generate_anti_surrounding_trajectory(center, n_frames)
    elif trajectory_type == "mix1":
        radius_scaled = 1 * abs(center[2]) 
        extrinsics = generate_mix1_trajectory(center,n_frames, radius_scaled, z_progress=True)
    elif trajectory_type == "mix2":
        radius_scaled = 0.2 * abs(center[2])  
        extrinsics = generate_mix2_trajectory(center,n_frames, radius_scaled, z_progress=True)
    elif trajectory_type == "circular":
        radius_scaled = 1.0 * abs(center[2]) 
        extrinsics = generate_circular_trajectory(center,n_frames, radius_scaled)
    elif trajectory_type == "static":
        extrinsics = generate_static_trajectory(n_frames)
    elif trajectory_type == "forward_backward":
        radius_scaled = 0.4 * abs(center[2]) 
        extrinsics = generate_forward_backward_trajectory(center,n_frames, radius_scaled, z_progress=True)
    elif trajectory_type == "y_moving":
        radius_scaled = 0.3 * abs(center[2]) 
        extrinsics = generate_y_moving_trajectory(center,n_frames, radius_scaled, z_progress=True)
    elif trajectory_type == "x_moving":
        radius_scaled = 0.4 * abs(center[2]) 
        extrinsics = generate_x_moving_trajectory(center,n_frames, radius_scaled, z_progress=True)
    elif trajectory_type == "circle_rotating":
        radius_scaled = 0.05 * abs(center[2]) 
        extrinsics = generate_circle_rotating_trajectory(center,n_frames, radius_scaled, z_progress=True)
    elif trajectory_type == "camera_rotate":
        extrinsics = generate_camera_rotate_trajectory(center,n_frames, rotate_max_degree=30, z_progress=True)
    else:
        raise ValueError(f"Unknown trajectory type: {trajectory_type}")
    
    gs_frames = []
    project_frames = []
    project_masks = []
    
    for frame_idx in range(min(n_frames, coords_data.shape[2])):
        world_points = coords_data[0, :, frame_idx].permute(1, 2, 0).reshape(-1, 3)
        extrinsic = extrinsics[frame_idx].to(device)
        
        try:
            gs_image = render_with_gs(world_points, extrinsic, intrinsic, colors[0], H, W, device, 0.0001)
            gs_frames.append(gs_image)
        except Exception as e:
            print(f"GS rendering failed for frame {frame_idx}: {e}")
            gs_frames.append(np.zeros((H, W, 3), dtype=np.uint8))
        
        try:
            project_image, project_mask = render_with_project(world_points, extrinsic, intrinsic, colors[0], H, W, device)
            project_frames.append(project_image)
            project_masks.append(project_mask)
        except Exception as e:
            print(f"Project rendering failed for frame {frame_idx}: {e}")
            project_frames.append(np.zeros((H, W, 3), dtype=np.uint8))
            project_masks.append(np.ones((H, W), dtype=bool))
    
    return gs_frames, project_frames, project_masks


def process_stage1_all_samples(args, dataset, stage1_pipeline, depth_model, decoder_prompt, logger, only_render=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weight_dtype = torch.bfloat16
    H, W = args.video_height, args.video_width
    
    stage1_dir = os.path.join(args.output_dir, "stage1_render_results", f"seed_{args.seed}")
    os.makedirs(stage1_dir, exist_ok=True)

    trajectory_types = ["mix1", "mix2", "surrounding", "anti-surrounding", "circular", "forward_backward", "y_moving", "x_moving", "circle_rotating", "static","camera_rotate"]

    
    for traj_type in trajectory_types:
        for render_type in ["gs", "project"]:
            os.makedirs(os.path.join(stage1_dir, f"{traj_type}_{render_type}"), exist_ok=True)
        os.makedirs(os.path.join(stage1_dir, f"{traj_type}_masks"), exist_ok=True)
    
    
    with torch.no_grad():
        for i in tqdm(range(len(dataset)), desc="Stage1 Processing"):
            try:
                if not only_render:
                    generator = torch.Generator(device=device).manual_seed(args.seed)
                    sample = dataset[i]
                    prompt = sample["prompt"]
                    image = sample["image"].to(device).to(weight_dtype)
                    video_path = sample["video_path"]  
                    video_name =  video_path.stem
                    parent_folder_name = video_path.parent.name
                    video_name = f"{parent_folder_name}_{video_name}"
                    
                    video_length = args.video_num_frames
                    video_length = int((video_length - 1) // stage1_pipeline.vae.config.temporal_compression_ratio * stage1_pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
                    latent_frames = (video_length - 1) // stage1_pipeline.vae.config.temporal_compression_ratio + 1
                    
                    if hasattr(args, 'enable_riflex') and args.enable_riflex:
                        stage1_pipeline.transformer.enable_riflex(k=args.riflex_k, L_test=latent_frames)
                    
                    input_image = (image + 1.0) / 2.0  
                    input_image = input_image.squeeze(0)  # [C, H, W]
                    
                    temp_image_path = os.path.join(args.output_dir, f"temp_{0}.png")
                    temp_image = (input_image.float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    Image.fromarray(temp_image).save(temp_image_path)
                    
                    ref_image = get_image_latent(temp_image_path, sample_size=[H, W])
                    
                    control_video = create_control_video_from_image(temp_image_path, video_length, [H, W], args.fps)
                    control_video = control_video.repeat(1, 1, video_length, 1, 1)
                    
                    clip_image = Image.new("RGB", (W, H), (127, 127, 127))

                    os.remove(temp_image_path)
                    
                    depth_pred = depth_model.infer(image.to(torch.float32))["depth"].to(device)
                    
                    first_frame_coords = back_project_coords(depth_pred.squeeze(), H, W, device)
                    first_frame_coords = first_frame_coords.permute(2, 0, 1).unsqueeze(0).unsqueeze(2)  # [B, C, F, H, W]
                    print("first_frame_coords",first_frame_coords.shape)
                    depth_pixel_values = first_frame_coords[:,2,:,:].unsqueeze(1).repeat(1,3,1,1,1)
                    depth_pixel_values = torch.clamp(depth_pixel_values, min=0.0, max=10000.0)
                    depth_pixel_values[depth_pixel_values == float('inf')] = 1
                    depth_pixel_values[depth_pixel_values == float('-inf')] = 1
                    depth_pixel_values[depth_pixel_values < 1e-5] = 1
                    depth_min = torch.min(depth_pixel_values)
                    depth_max = torch.max(depth_pixel_values)
                    depth_pixel_values = 2 * (depth_pixel_values - depth_min) / (depth_max - depth_min + 1e-8) - 1
                    print(depth_min, depth_max)
                    sample_result = stage1_pipeline(
                        prompt, 
                        num_frames=video_length,
                        negative_prompt=args.negative_prompt if hasattr(args, 'negative_prompt') else "",
                        height=H,
                        width=W,
                        generator=generator,
                        guidance_scale=args.guidance_scale,
                        num_inference_steps=args.num_inference_steps,
                        control_video=control_video,
                        control_camera_video=None,
                        ref_image=ref_image,
                        start_image=None,
                        clip_image=clip_image,
                        shift=getattr(args, 'shift', 3.0),
                        output_type="no_normalize",
                        depth_image = depth_pixel_values,
                    ).videos
                    decoder_prompt = decoder_prompt.to(device)
                    recon_video = decoder_prompt(sample_result.to(torch.bfloat16).to(device)).float()  # [B, C, F, H, W]
                    recon_video_vis = (recon_video.squeeze(0) + 1)/2  # [C, F, H, W]
                    recon_video_vis = recon_video_vis.permute(1, 2, 3, 0).cpu().numpy()  # [F, H, W, C]
                    recon_dir = os.path.join(args.output_dir,"recon")
                    os.makedirs(recon_dir, exist_ok=True)

                    imageio.mimwrite(os.path.join(recon_dir, f"{video_name}_recon.mp4"), recon_video_vis, fps=8)

                    if args.normalize_track_z:
                        recon_video = recon_video.permute((0, 2, 1, 3, 4))[0].float()
                        F, C, H, W = recon_video.shape
                        
                        H, W = 368, 512
                        H_ori, W_ori = 720, 960
                        
                        if W_ori / W > H_ori / H: 
                            tmp_fx = 1
                            tmp_fy = W_ori / H_ori / (W / H)
                        else:
                            tmp_fy = 1
                            tmp_fx = H_ori / W_ori / (H / W)
                        frame0 = first_frame_coords[0,:,0]  # [3, H, W]
                        frame0[2,:,:][torch.isnan(frame0[2,:,:])] = 1.0
                        frame0[2,:,:][frame0[2,:,:]==0] = 1.0
                        frame0[2,:,:][torch.isinf(frame0[2,:,:])] = 1.0
                        
                        current_x_norm = frame0[2,:,:] / tmp_fx
                        current_y_norm = frame0[2,:,:] / tmp_fy
                        recon_video=recon_video.cpu()
                        recon_video[:,0, :, :] = recon_video[:,0, :, :] * current_x_norm.float().cpu().numpy()
                        recon_video[:,1, :, :] = recon_video[:,1, :, :] * current_y_norm.float().cpu().numpy()
                        recon_video[:,2, :, :] = recon_video[:,2, :, :] * frame0[2:3, :, :].float().cpu().numpy()
                        # reconstructions += frame0.permute(1, 2, 0).unsqueeze(0).float().cpu().numpy()
                        # targets += frame0.permute(1, 2, 0).unsqueeze(0).float().cpu().numpy()
                        first_frame_expanded = frame0.unsqueeze(1).unsqueeze(0)
                        # print(first_frame_expanded.shape)
                        recon_flow = (recon_video.cuda().unsqueeze(0).permute((0, 2, 1, 3, 4)) + first_frame_expanded)
                    else:
                        recon_flow, diff = inverse_flow_norm_transform_no_diff(
                            recon_video,
                            first_frame_coords,
                        )
                    
                    color = (image + 1) / 2  
                    color = color.reshape(color.shape[0], 3, -1).permute(0, 2, 1)  # [B, H*W, 3]
                    colors = (color * 255).clamp(0, 255).to(torch.uint8)
                    
                    coords_data = torch.cat([first_frame_coords, recon_flow[:,:,1:]], dim=2)  # [B, C, F, H, W]
                    
                    save_pointcloud_data(coords_data,colors,video_name,args.output_dir, args.seed)
                else:
                    video_path = dataset[i]["video_path"]
                    video_name =  video_path.stem
                    parent_folder_name = video_path.parent.name
                    video_name = f"{parent_folder_name}_{video_name}"
                    coords_data, colors = load_pointcloud_data(video_name, args.output_dir)
                    video_length = coords_data.shape[2]
                       
                for traj_type in trajectory_types:
                    
                    gs_frames, project_frames, project_masks = render_trajectory(
                        coords_data, colors, traj_type, video_length, H, W, device
                    )
                    
                    gs_video_path = os.path.join(stage1_dir, f"{traj_type}_gs", f"{video_name}_render.mp4")
                    try:
                        imageio.mimwrite(gs_video_path, gs_frames, fps=8)
                    except Exception as e:
                        continue
                    
                    project_video_path = os.path.join(stage1_dir, f"{traj_type}_project", f"{video_name}_render.mp4")
                    try:
                        imageio.mimwrite(project_video_path, project_frames, fps=8)
                    except Exception as e:
                        continue
                    
                    mask_video_path = os.path.join(stage1_dir, f"{traj_type}_masks", f"{video_name}_mask.mp4")
                    try:
                        mask_frames_uint8 = [(mask * 255).astype(np.uint8) for mask in project_masks]
                        imageio.mimwrite(mask_video_path, mask_frames_uint8, fps=8)
                    except Exception as e:
                        continue
                
                    
            except Exception as e:
                import traceback
                traceback.print_exc()
                continue
    


def cleanup_stage1_models(stage1_pipeline, depth_model, decoder_prompt, logger):
    
    if hasattr(stage1_pipeline, '_lora_applied') and stage1_pipeline._lora_applied:
        pass
    
    # 将模型移到CPU并删除引用
    if stage1_pipeline is not None:
        stage1_pipeline.to("cpu")
        del stage1_pipeline
    
    if depth_model is not None:
        depth_model.to("cpu")
        del depth_model
    
    if decoder_prompt is not None:
        decoder_prompt.to("cpu")
        del decoder_prompt
    
    gc.collect()
    torch.cuda.empty_cache()
    
    logger.info("Stage1模型清理完成!")


def process_stage2_all_samples(args, dataset, stage2_pipeline, logger):
    """处理所有样本的Stage2"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weight_dtype = torch.bfloat16
    H, W = args.video_height, args.video_width
    
    stage1_dir = os.path.join(args.output_dir, "stage1_render_results", f"seed_{args.seed}")
    stage2_dir = os.path.join(args.output_dir, "stage2_completion_results", f"seed_{args.seed}")
    os.makedirs(stage2_dir, exist_ok=True)
    trajectory_types = ["mix1", "mix2", "surrounding", "anti-surrounding", "circular", "forward_backward", "y_moving", "x_moving", "circle_rotating", "static","camera_rotate"]
    
    for traj_type in trajectory_types:
        os.makedirs(os.path.join(stage2_dir, traj_type), exist_ok=True)
    
    for i in tqdm(range(len(dataset)), desc="Stage2 Processing"):
        try:
            sample = dataset[i]
            prompt = sample["prompt"]
            video_path = sample["video_path"]
            video_name =  video_path.stem
            parent_folder_name = video_path.parent.name
            video_name = f"{parent_folder_name}_{video_name}"
            
            for traj_type in trajectory_types:
                generator = torch.Generator(device=device).manual_seed(args.seed+1)
                output_video_path = os.path.join(stage2_dir, traj_type, f"{video_name}.mp4")
                if os.path.exists(output_video_path):
                    continue
                
                gs_video_path = os.path.join(stage1_dir, f"{traj_type}_gs", f"{video_name}_render.mp4")
                mask_video_path = os.path.join(stage1_dir, f"{traj_type}_masks", f"{video_name}_mask.mp4")
                
                if not os.path.exists(gs_video_path) or not os.path.exists(mask_video_path):
                    continue
                
                with torch.no_grad():
                    video_length = args.video_num_frames
                    calc_video_length = int((video_length - 1) // stage2_pipeline.vae.config.temporal_compression_ratio * stage2_pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
                    
                    input_video, input_video_mask, _, _ = get_video_to_video_latent(
                        gs_video_path, 
                        video_length=calc_video_length, 
                        sample_size=[H, W], 
                        fps=8,
                        validation_video_mask=mask_video_path, 
                        ref_image=None,
                        video_mask=True
                    )
                    
                    cap = cv2.VideoCapture(str(video_path))
                    ret, frame = cap.read()
                    cap.release()
                    
                    sample = stage2_pipeline(
                        prompt, 
                        num_frames=calc_video_length,
                        negative_prompt=args.stage2_negative_prompt,
                        height=H,
                        width=W,
                        generator=generator,
                        guidance_scale=args.stage2_guidance_scale,
                        num_inference_steps=args.stage2_num_inference_steps,
                        video=input_video,
                        mask_video=input_video_mask,
                        shift=3,
                    ).videos
                    
                    save_videos_grid(sample, output_video_path, fps=8)

                    
        except Exception as e:
            continue
    


from safetensors.torch import load_file
def merge_safetensors(model_dir):
    state_dict = {}
    for filename in os.listdir(model_dir):
        if filename.endswith(".safetensors"):
            filepath = os.path.join(model_dir, filename)
            # print(filepath)
            partial_state_dict = load_file(filepath)
            state_dict.update(partial_state_dict)
            print(f"Loaded {filename} with {len(partial_state_dict)} keys.")
    return state_dict
def load_stage1_models(args, logger):
    weight_dtype = torch.float32
    if args.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif args.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    config = OmegaConf.load(args.config_path)
    
    transformer = WanTransformer3DModelDINO.from_pretrained(
        os.path.join(args.pretrained_model_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
        low_cpu_mem_usage=True,
        transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
        torch_dtype=weight_dtype,
    )
    if args.use_depth:
        old_conv = transformer.patch_embedding
        old_w = old_conv.weight.data.clone()   # [out_c, in_c, d,h,w]
        old_b = old_conv.bias.data.clone() if old_conv.bias is not None else None

        if old_w.shape[1] == 48:
            out_c = old_w.shape[0]
            kernel_size = old_conv.kernel_size
            stride = old_conv.stride
            device = old_w.device
            dtype = old_w.dtype
            from torch import nn
            new_conv = nn.Conv3d(
                in_channels=64,
                out_channels=out_c,
                kernel_size=kernel_size,
                stride=stride,
                bias=(old_b is not None),
                dtype=weight_dtype,
                device=device
            )

            new_w = torch.zeros((out_c, 64, *old_w.shape[2:]), device=device, dtype=dtype)
            new_w[:, :48, :, :, :] = old_w

            std = old_w.std().item()
            new_w[:, 48:, :, :, :].zero_()

            new_conv.weight.data.copy_(new_w)
            if old_b is not None:
                new_conv.bias.data.copy_(old_b)

            new_conv.to(device=device, dtype=dtype)

            transformer.patch_embedding = new_conv
    
    if hasattr(args, 'transformer_path') and args.transformer_path is not None:
        if os.path.isdir(args.transformer_path):
            state_dict = merge_safetensors(args.transformer_path)
        elif args.transformer_path.endswith("safetensors"):
            from safetensors.torch import load_file
            state_dict = load_file(args.transformer_path)
        else:
            state_dict = torch.load(args.transformer_path, map_location="cpu")
        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
        m, u = transformer.load_state_dict(state_dict, strict=False)
        logger.info(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
    
    vae = AutoencoderKLWan.from_pretrained(
        os.path.join(args.pretrained_model_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
        additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
    ).to(weight_dtype)
    
    if hasattr(args, 'vae_path') and args.vae_path is not None:
        if args.vae_path.endswith("safetensors"):
            from safetensors.torch import load_file
            state_dict = load_file(args.vae_path)
        else:
            state_dict = torch.load(args.vae_path, map_location="cpu")
        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
        m, u = vae.load_state_dict(state_dict, strict=False)
        logger.info(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
    
    tokenizer = AutoTokenizer.from_pretrained(
        os.path.join(args.pretrained_model_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
    )
    text_encoder = WanT5EncoderModel.from_pretrained(
        os.path.join(args.pretrained_model_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
        additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
        low_cpu_mem_usage=True,
        torch_dtype=weight_dtype,
    )
    text_encoder = text_encoder.eval()
    
    clip_image_encoder = CLIPModel.from_pretrained(
        os.path.join(args.pretrained_model_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
    ).to(weight_dtype)
    clip_image_encoder = clip_image_encoder.eval()
    
    scheduler_dict = {
        "Flow": FlowMatchEulerDiscreteScheduler,
        "Flow_Unipc": FlowUniPCMultistepScheduler,
        "Flow_DPM++": FlowDPMSolverMultistepScheduler,
    }
    sampler_name = getattr(args, 'sampler_name', 'Flow')
    Choosen_Scheduler = scheduler_dict[sampler_name]
    
    if sampler_name == "Flow_Unipc" or sampler_name == "Flow_DPM++":
        config['scheduler_kwargs']['shift'] = 1
    
    scheduler = Choosen_Scheduler(
        **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
    )
    pipeline = WanFunControlPipeline(
        transformer=transformer,
        vae=vae,
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        scheduler=scheduler,
        clip_image_encoder=clip_image_encoder
    )
    if hasattr(args, 'compile_dit') and args.compile_dit:
        for i in range(len(pipeline.transformer.blocks)):
            pipeline.transformer.blocks[i] = torch.compile(pipeline.transformer.blocks[i])
        logger.info("DIT编译完成")
    
    gpu_memory_mode = getattr(args, 'gpu_memory_mode', 'model_full_load')
    if gpu_memory_mode == "sequential_cpu_offload":
        from MoRe4D.utils.fp8_optimization import replace_parameters_by_name
        replace_parameters_by_name(transformer, ["modulation",], device=device)
        transformer.freqs = transformer.freqs.to(device=device)
        pipeline.enable_sequential_cpu_offload(device=device)
    elif gpu_memory_mode == "model_cpu_offload_and_qfloat8":
        from MoRe4D.utils.fp8_optimization import convert_model_weight_to_float8, convert_weight_dtype_wrapper
        convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
        convert_weight_dtype_wrapper(transformer, weight_dtype)
        pipeline.enable_model_cpu_offload(device=device)
    elif gpu_memory_mode == "model_cpu_offload":
        pipeline.enable_model_cpu_offload(device=device)
    elif gpu_memory_mode == "model_full_load_and_qfloat8":
        from MoRe4D.utils.fp8_optimization import convert_model_weight_to_float8, convert_weight_dtype_wrapper
        convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
        convert_weight_dtype_wrapper(transformer, weight_dtype)
        pipeline.to(device=device)
    else:
        pipeline.to(device=device)
    
    if hasattr(args, 'enable_teacache') and args.enable_teacache:
        coefficients = get_teacache_coefficients(args.pretrained_model_path)
        if coefficients is not None:
            teacache_threshold = getattr(args, 'teacache_threshold', 0.10)
            num_skip_start_steps = getattr(args, 'num_skip_start_steps', 5)
            teacache_offload = getattr(args, 'teacache_offload', False)
            pipeline.transformer.enable_teacache(
                coefficients, args.num_inference_steps, teacache_threshold, 
                num_skip_start_steps=num_skip_start_steps, offload=teacache_offload
            )
    
    if hasattr(args, 'cfg_skip_ratio') and args.cfg_skip_ratio > 0:
        pipeline.transformer.enable_cfg_skip(args.cfg_skip_ratio, args.num_inference_steps)
    
    if args.lora_path is not None:
        pipeline = merge_lora(pipeline, args.lora_path, args.lora_weight, device=device)
    
    depth_model = UniDepthV2old.from_pretrained(f"xxx")
    depth_model.to(device)
    depth_model.eval()
    
    if args.vae_ckpt_dir and os.path.exists(os.path.join(args.vae_ckpt_dir, "decoder_prompt", "pytorch_model.bin")):
        decoder_prompt = VAEDecoderadaptor().to(device).to(weight_dtype)
        decoder_state_dict = torch.load(os.path.join(args.vae_ckpt_dir, "decoder_prompt", "pytorch_model.bin"), map_location="cpu")
        decoder_prompt.load_state_dict(decoder_state_dict, strict=True)
        decoder_prompt.eval()
        del decoder_state_dict
    else:
        decoder_prompt = None
    
    return pipeline, depth_model, decoder_prompt


def load_stage2_models(args, logger):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weight_dtype = torch.bfloat16
    
    config_path = args.stage2_config_path
    model_name = args.stage2_model_path
    
    config = OmegaConf.load(config_path)
    
    transformer = WanTransformer3DModel.from_pretrained(
        os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
        transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
        low_cpu_mem_usage=True,
        torch_dtype=weight_dtype,
    )
    
    vae = AutoencoderKLWan.from_pretrained(
        os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
        additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
    ).to(weight_dtype)
    
    tokenizer = AutoTokenizer.from_pretrained(
        os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
    )
    
    text_encoder = WanT5EncoderModel.from_pretrained(
        os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
        additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
        low_cpu_mem_usage=True,
        torch_dtype=weight_dtype,
    )
    text_encoder = text_encoder.eval()
    
    clip_image_encoder = CLIPModel.from_pretrained(
        os.path.join(model_name, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
    ).to(weight_dtype)
    clip_image_encoder = clip_image_encoder.eval()
    
    scheduler = FlowMatchEulerDiscreteScheduler(
        **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs']))
    )
    
    pipeline = WanFunInpaintPipeline(
        vae=vae,
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        transformer=transformer,
        scheduler=scheduler,
        clip_image_encoder=clip_image_encoder,
    )
    
    if args.gpu_memory_mode == "model_cpu_offload":
        pipeline.enable_model_cpu_offload(device=device)
    elif args.gpu_memory_mode == "sequential_cpu_offload":
        from MoRe4D.utils.fp8_optimization import replace_parameters_by_name
        replace_parameters_by_name(transformer, ["modulation",], device=device)
        transformer.freqs = transformer.freqs.to(device=device)
        pipeline.enable_sequential_cpu_offload(device=device)
    else:
        pipeline.to(device=device)
    if args.stage2_lora_path is not None:
        pipeline = merge_lora(pipeline, args.stage2_lora_path, args.stage2_lora_weight, device=device)
    
    return pipeline


def parse_args():
    parser = argparse.ArgumentParser(description="")
    
    # Stage 1 参数
    parser.add_argument("--pretrained_model_path", type=str, 
                        default="models/Wan2.1-Fun-V1.1-14B-Control",
                        help="")
    parser.add_argument("--use_depth", default=True,
                        help="")
    parser.add_argument("--only_render", action='store_true',
                        help="")
    parser.add_argument("--vae_ckpt_dir", type=str,
                        default="xxx",
                        help="")
    parser.add_argument("--lora_path", type=str,
                        default=None,
                        help="")
    parser.add_argument("--transformer_path", type=str,
                        default=None,
                        help="")
    parser.add_argument("--vae_path", type=str,
                        default=None,
                        help="")
    parser.add_argument("--config_path", type=str,
                        default="config/wan2.1/wan_civitai.yaml",
                        help="")
    parser.add_argument("--normalize_track_z", action='store_true',
                        help="")
    parser.add_argument("--sampler_name", type=str, default="Flow",
                        choices=["Flow", "Flow_Unipc", "Flow_DPM++"],
                        help="")
    parser.add_argument("--shift", type=float, default=3.0,
                        help="")
    parser.add_argument("--enable_teacache", action='store_true', default=False,
                        help="")
    parser.add_argument("--teacache_threshold", type=float, default=0.10,
                        help="")
    parser.add_argument("--num_skip_start_steps", type=int, default=5,
                        help="")
    parser.add_argument("--teacache_offload", action='store_true',
                        help="")
    parser.add_argument("--cfg_skip_ratio", type=float, default=0.0,
                        help="")
    parser.add_argument("--compile_dit", action='store_true',
                        help="")
    parser.add_argument("--enable_riflex", action='store_true',
                        help="")
    parser.add_argument("--riflex_k", type=int, default=6,
                        help="")
    parser.add_argument("--negative_prompt", type=str, 
                        default="",
                        help="")
    parser.add_argument("--gpu_memory_mode", type=str, default="model_full_load",
                        choices=["model_full_load", "model_full_load_and_qfloat8", "model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"],
                        help="")
    parser.add_argument("--fps", type=int, default=8,
                        help="")
    
    # Stage 2 参数
    parser.add_argument("--stage2_model_path", type=str,
                        default="models/Wan2.1-Fun-V1.1-14B-InP",
                        help="")
    parser.add_argument("--stage2_config_path", type=str,
                        default="config/wan2.1/wan_civitai.yaml",
                        help="")
    parser.add_argument("--stage2_lora_path", type=str,
                        default="xxx.safetensors",
                        help="")
    parser.add_argument("--stage2_lora_weight", type=float, default=0.55,
                        help="")
    parser.add_argument("--stage2_negative_prompt", type=str,
                        default="",
                        help="")
    parser.add_argument("--stage2_guidance_scale", type=float, default=6.0,
                        help="")
    parser.add_argument("--stage2_num_inference_steps", type=int, default=50,
                        help="")
    
    # 通用参数
    parser.add_argument("--data_path", type=str,
                        default="/xxx",
                        help="")
    parser.add_argument("--prompt_file_name", type=str,
                        default="prompts_demo.txt",
                        help="")
    parser.add_argument("--video_file_name", type=str,
                        default="videos_demo.txt",
                        help="")
    parser.add_argument("--output_dir", type=str,
                        default="output_dir/infer",
                        help="")
    parser.add_argument("--video_height", type=int, default=368,
                        help="")
    parser.add_argument("--video_width", type=int, default=512,
                        help="")
    parser.add_argument("--video_num_frames", type=int, default=49,
                        help="")
    parser.add_argument("--mixed_precision", type=str, default="bf16",
                        choices=["no", "fp16", "bf16"],
                        help="")
    parser.add_argument("--lora_weight", type=float, default=0.55,
                        help="")
    parser.add_argument("--guidance_scale", type=float, default=6.0,
                        help="")
    parser.add_argument("--num_inference_steps", type=int, default=50,
                        help="")
    parser.add_argument("--seed", type=int, default=43,
                        help="")
    parser.add_argument("--original_clip", action='store_true',
                        help="")
    parser.add_argument("--all_mask", action='store_true',
                        help="")
    parser.add_argument("--run_stage1", action='store_true', default=False,
                        help="")
    parser.add_argument("--run_stage2_render", action='store_true', default=True,
                        help="")
    parser.add_argument("--run_stage2_complete", action='store_true', default=False,
                        help="")
    parser.add_argument("--max_samples", type=int, default=800,
                        help="")
    
    args = parser.parse_args()
    return args


def create_control_video_from_image(image_path, video_length, sample_size, fps=16):

    input_video, input_video_mask, _, _ = get_video_to_video_latent(
        image_path, 
        video_length=video_length, 
        sample_size=sample_size, 
        fps=fps, 
        ref_image=None
    )
    return input_video
def load_pointcloud_data(video_name, input_dir, device='cpu'):
    pts_dir = os.path.join(input_dir, "pts")
    
    frame_files = []
    for file in os.listdir(pts_dir):
        if file.startswith(f"{video_name}_frame_") and file.endswith(".txt"):
            frame_files.append(file)
    
    frame_files.sort()
    
    if not frame_files:
        raise ValueError(f"No point cloud files found for video {video_name} in directory {pts_dir}")
    
    first_frame_path = os.path.join(pts_dir, frame_files[0])
    first_frame_data = np.loadtxt(first_frame_path)
    
    num_points = first_frame_data.shape[0]
    H = 368  
    W = 512
    F = len(frame_files)
    
    recon_flow = torch.zeros(1, 3, F, H, W, device=device)
    colors = torch.zeros(1, 3, H, W, device=device)
    
    for frame_idx, frame_file in enumerate(frame_files):
        frame_path = os.path.join(pts_dir, frame_file)
        frame_data = np.loadtxt(frame_path)
        
        frame_coords = frame_data[:, :3]  # [H*W, 3]
        frame_colors = frame_data[:, 3:6]  # [H*W, 3]
        
        coords_reshaped = frame_coords.reshape(H, W, 3)
        colors_reshaped = frame_colors.reshape(H, W, 3)
        
        coords_tensor = torch.from_numpy(coords_reshaped).permute(2, 0, 1).to(device)
        colors_tensor = torch.from_numpy(colors_reshaped).permute(2, 0, 1).to(device)
        
        recon_flow[0, :, frame_idx, :, :] = coords_tensor
        
        if frame_idx == 0:
            colors[0, :, :, :] = colors_tensor
    
    return recon_flow, colors

def save_pointcloud_data(recon_flow, colors, video_name, output_dir, seed):

    print("saving pointcloud data...")
    pts_dir = os.path.join(output_dir, "pts", f"seed_{seed}") 
    # if rank == 0:
    print(pts_dir)
    os.makedirs(pts_dir, exist_ok=True)
    
    
    B, C, F, H, W = recon_flow.shape
    for frame_idx in range(F):
        frame_coords = recon_flow[0, :, frame_idx].permute(1, 2, 0).reshape(-1, 3)  # [H*W, 3]
        frame_colors = colors[0].reshape(-1, 3)  # [H*W, 3]
        
        pointcloud_data = torch.cat([frame_coords.cpu(), frame_colors.cpu().float()], dim=1)
    
        pts_file = os.path.join(pts_dir, f"{video_name}_frame_{frame_idx:04d}.txt")
        np.savetxt(pts_file, pointcloud_data.numpy())
        
def main():
    args = parse_args()
    logger = get_logger("two_stage_pipeline")
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = TwoStageDataset(
        data_root=args.data_path,
        caption_column=args.prompt_file_name,
        video_column=args.video_file_name,
        device=device,
        max_num_frames=args.video_num_frames,
        height=args.video_height,
        width=args.video_width,
        max_samples=args.max_samples,
    )
    
    if args.run_stage1:
        logger.info("=" * 50)
        logger.info("Stage1:")
        logger.info("=" * 50)
                    
        stage1_pipeline, depth_model, decoder_prompt = load_stage1_models(args, logger)
        
        process_stage1_all_samples(args, dataset, stage1_pipeline, depth_model, decoder_prompt, logger, only_render=args.only_render)
        
        cleanup_stage1_models(stage1_pipeline, depth_model, decoder_prompt, logger)
        
        logger.info("Stage1 Done!")
    
    if args.run_stage2_complete:
        logger.info("=" * 50)
        logger.info("Stage2:")
        logger.info("=" * 50)
        
        stage2_pipeline = load_stage2_models(args, logger)
        
        process_stage2_all_samples(args, dataset, stage2_pipeline, logger)
        
        if args.stage2_lora_path is not None:
            stage2_pipeline = unmerge_lora(stage2_pipeline, args.stage2_lora_path, args.stage2_lora_weight, device=device)
        
        stage2_pipeline.to("cpu")
        del stage2_pipeline
        gc.collect()
        torch.cuda.empty_cache()
        
    
    logger.info("=" * 50)
    logger.info(f"Pipeline done! Save in {args.output_dir}")
    logger.info("=" * 50)


if __name__ == "__main__":
    main()