import torch
import argparse
import os
from datetime import datetime
import glob
import yaml
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from models_wan22.pipeline_compattn import WanCompAttnPipeline
from models_wan22.transformer_compattn import WanTransformerCompAttn3DModel

def ddp_setup():
    rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    torch.cuda.set_device(rank)
    return rank, world_size


if __name__ == "__main__":
    # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    parser = argparse.ArgumentParser(description="Generate a video from a text prompt using Wanx")
    parser.add_argument("--model_folder", type=str, default="/Your_path_to/pretrained_models/", help="The parent directory of your weight folder.")
    parser.add_argument("--model_type", type=str, default="Wan2.1-T2V-1.3B-Diffusers", help="The Baseline model type")
    parser.add_argument("--yaml_path", type=str, default="./example/1.yaml", help="The directory of the story")
    parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
    parser.add_argument("--model_config", type=str, default="./configs/CompAttn.yaml", help="The model for this experiment")
    
    # Obtain rank / world_size
    rank, world_size = ddp_setup()
    is_master = (rank == 0)
    
    args = parser.parse_args()

    model_type = args.model_type
    model_folder = args.model_folder
    
    model_config_path = args.model_config
    model_name = os.path.splitext(os.path.basename(model_config_path))[0]
    current_time = datetime.now().strftime("%Y%m%d-%H")
    output_video_folder = f"./results/{model_type}/{model_name}_{current_time}"

    vae = AutoencoderKLWan.from_pretrained("/Your_path_to/pretrained_models/Wan2.2-T2V-A14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
    transformer = WanTransformerCompAttn3DModel.from_pretrained("/Your_path_to/pretrained_models/Wan2.2-T2V-A14B-Diffusers", torch_dtype=torch.bfloat16, subfolder='transformer')
    transformer_2 = WanTransformerCompAttn3DModel.from_pretrained("/Your_path_to/pretrained_models/Wan2.2-T2V-A14B-Diffusers", torch_dtype=torch.bfloat16, subfolder='transformer_2')
    pipe = WanCompAttnPipeline.from_pretrained("/Your_path_to/pretrained_models/Wan2.2-T2V-A14B-Diffusers", vae=vae, torch_dtype=torch.bfloat16)
    pipe.transformer = transformer
    pipe.transformer_2 = transformer_2
    pipe.enable_model_cpu_offload(gpu_id=rank)

    negative_prompt = "色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走"
    
    if is_master:
        os.makedirs(output_video_folder, exist_ok=True)

    yaml_list = glob.glob("./input/*.yaml")
    for idx, yaml_path in enumerate(yaml_list):
        if idx % world_size != rank:
            continue
            
        with open(yaml_path, 'r') as file:
            configs = yaml.safe_load(file)
        yaml_name = os.path.splitext(os.path.basename(yaml_path))[0]
        output_path = os.path.join(output_video_folder, f"{yaml_name}_{args.seed}.mp4")
    
        prompt = configs['args']["video_caption"]
        subjects = configs['args']["protagonist"]
        frames_layout = configs['args']["frames"]
    
        with open(model_config_path, 'r') as file:
            model_configs = yaml.safe_load(file)
    
        output, attn_weights_all = pipe(
            prompt=prompt,
            subject_list=subjects,
            frames_layout=frames_layout,
            model_configs=model_configs,
            negative_prompt=negative_prompt,
            height=480,
            width=832,
            num_frames=81,
            guidance_scale=4.0,
            guidance_scale_2=3.0,
            num_inference_steps=40,
            generator=torch.Generator().manual_seed(args.seed),  # Set the seed for reproducibility
        )

        del attn_weights_all

        torch.cuda.empty_cache()

        output = output.frames[0]
    
        export_to_video(output, output_path, fps=16)