"""
demo
Input: yaml file
Output: a video with yaml name
"""
import torch
import argparse
import ast
import os
import glob
import json
import yaml
from datetime import datetime
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from models_2_1.pipeline_envcap import EnvCapWanPipeline
from models_2_1.transformer_envcap import WanEnvCapTransformer3DModel
from utils.visualization import visual_attention

def ddp_setup():
    rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    torch.cuda.set_device(rank)
    return rank, world_size

if __name__ == "__main__":
    # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    parser = argparse.ArgumentParser(description="Generate a video from a text prompt using Wanx")
    parser.add_argument("--model_type", type=str, default="Wan2.1-T2V-14B-Diffusers", help="The Baseline model type")
    parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
    parser.add_argument("--model_config", type=str, default="./configs/TS-Attn.yaml", help="The model for this experiment")
    parser.add_argument("--attn_timestep", type=int, default=5, help="The timestep for attention visualization")
    parser.add_argument("--attn_layer", type=int, default=10, help="The layer index for attention visualization")
    
    args = parser.parse_args()

    # Obtain rank / world_size
    rank, world_size = ddp_setup()
    is_master = (rank == 0)

    model_type = args.model_type

    device = torch.device(f"cuda:{rank}")
    
    if model_type == "Wan2.1-T2V-1.3B-Diffusers":
        model_id = os.path.join("/xxx/", model_type)
        vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
        flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
        scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)
        transformer = WanEnvCapTransformer3DModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, subfolder='transformer')
        pipe = EnvCapWanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
        pipe.transformer = transformer
        pipe.scheduler = scheduler
        pipe.to(device)

        height=480
        width=832
        num_frames=81
        guidance_scale=5.0
    
    elif model_type == "Wan2.1-T2V-14B-Diffusers":

        model_id = os.path.join("/xxx/", model_type)
        vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
        flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
        scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)
        transformer = WanEnvCapTransformer3DModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, subfolder='transformer')
        pipe = EnvCapWanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
        pipe.transformer = transformer
        pipe.scheduler = scheduler
        pipe.to(device)

        height=480
        width=832
        num_frames=81
        guidance_scale=5.0

    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

    model_config_path = args.model_config
    model_name = os.path.splitext(os.path.basename(model_config_path))[0]

    current_time = datetime.now().strftime("%Y%m%d-%H")
    
    if is_master:
        output_folder = f"./StoryEval_result/{model_type}/{model_name}_{current_time}"
        output_video_folder = os.path.join(output_folder, "videos")
        os.makedirs(output_video_folder, exist_ok=True)
        os.makedirs(output_attn_folder, exist_ok=True)
        os.system(f"cp -r ./models {output_folder}/models")
    else:
        output_folder = f"./StoryEval_result/{model_type}/{model_name}_{current_time}"
        output_video_folder = os.path.join(output_folder, "videos")


    with open(model_config_path, 'r') as file:
        model_configs = yaml.safe_load(file)
    
    json_path = "./all_prompts_out.json"
    with open(json_path, 'r', encoding='utf-8') as f:
        all_prompts = json.load(f)

    prompt_items = list(all_prompts.items())

    for idx, (name, data) in enumerate(prompt_items):
        if idx % world_size != rank:
            continue
    
        prompt = data["prompt"]
        event_list = data["motion"]
        event_range = ast.literal_eval(data["event_range"])
        subject = data["subject"]
        
        subject = [subject]
        
        output, _ = pipe(
            prompt=prompt,
            event_list=event_list,
            event_range=event_range,
            subject=subject,
            model_configs=model_configs,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=num_frames,
            guidance_scale=guidance_scale,
            generator=torch.Generator().manual_seed(args.seed),  # Set the seed for reproducibility
            )
        
        output = output.frames[0]   
            
        output_path = os.path.join(output_video_folder, name)
        export_to_video(output, output_path, fps=16)
    
    if is_master:
        print("=== All processes finished ===")