from model.dual_unet_pipeline import DualUnetVideoDiffusionPipeline
from diffusers.utils import load_image
from utils import save_video, read_video_to_array, get_name
import torch
import os
import argparse


def create_dual_unet_pipeline(
    svd_pretrained_path: str,
    track_unet_path: str,
    torch_dtype: torch.dtype = torch.float32,
):
    """
    Create DualUnetVideoDiffusionPipeline with pretrained feature adapter.
    
    Args:
        svd_pretrained_path: Path to SVD pretrained model
        track_unet_path: Path to trained track unet
        torch_dtype: Data type for pipeline
        
    Returns:
        DualUnetVideoDiffusionPipeline instance
    """
    # Use the new convenient loading method with torch_dtype support
    dual_pipeline = DualUnetVideoDiffusionPipeline.from_pretrained_with_two_unets(
        pretrained_model_name_or_path=svd_pretrained_path,
        track_unet_path=track_unet_path,
        torch_dtype=torch_dtype,
    )
    
    dual_pipeline.enable_attention_slicing()
    dual_pipeline.enable_xformers_memory_efficient_attention()
    dual_pipeline.enable_model_cpu_offload()
    
    return dual_pipeline


def run(ref_video_path, image_path):
    print("🎬 Load Data...")

    ref_video = read_video_to_array(ref_video_path)

    condition_image = load_image(image_path)
    
    print(f"📹 Condition Video: {ref_video.shape}")
    print(f"🖼️  Condition Image: {condition_image.size}")

    infer_num_frames = ref_video.shape[0]
    
    import time
    print("🚀 Start Inference...")
    generator = torch.manual_seed(42)
    start_time = time.time()
    
    with torch.no_grad():
        result = pipeline(
            condition_video=ref_video,
            condition_image=condition_image,
            height=384,
            width=512,
            num_frames=infer_num_frames,
            num_inference_steps=25,
            fps=7,
            motion_bucket_id=127,
            noise_aug_strength=0.02,
            decode_chunk_size=8,
            generator=generator,
            # max_guidance_scale=1.0,
        ).frames[0]
    
    end_time = time.time()
    print(f"⏱️  Inference Time: {end_time - start_time:.2f}s")
    

    ref_name = get_name(ref_video_path)
    tar_name = get_name(image_path)

    os.makedirs(save_dir, exist_ok=True)
    output_path = f"{save_dir}/{ref_name}->{tar_name}.mp4"
    save_video(result, output_path, fps=7)
    print(f"✅ Inference Completed! Video saved to: {output_path}")



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run DualUnetVideoDiffusionPipeline inference")
    parser.add_argument("--ref_video_path", type=str, required=True, help="Path to reference video (.mp4)")
    parser.add_argument("--image_path", type=str, required=True, help="Path to condition image")
    parser.add_argument("--svd_pretrained_path", type=str, default="./checkpoints/stabilityai--stable-video-diffusion-img2vid-xt", help="Path to SVD pretrained model")
    parser.add_argument("--track_unet_path", type=str, default="./checkpoints/motion_perception/unet", help="Path to trained track unet")
    parser.add_argument("--save_dir", type=str, default="./results", help="Directory to save results")
    parser.add_argument("--dtype", type=str, choices=["fp32", "fp16"], default="fp16", help="Torch dtype")
    args = parser.parse_args()

    # map dtype string to torch dtype
    infer_torch_dtype = torch.float32 if args.dtype == "fp32" else torch.float16

    pipeline = create_dual_unet_pipeline(
        svd_pretrained_path=args.svd_pretrained_path,
        track_unet_path=args.track_unet_path,
        torch_dtype=infer_torch_dtype,
    )

    save_dir = args.save_dir
    run(args.ref_video_path, args.image_path)
