#!/usr/bin/env python3
import os
# First, set environment variables to disable tokenizers warning messages
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import torch.distributed as dist
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import time
import json
import contextlib
from pathlib import Path
from loguru import logger
from datetime import datetime

from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.config import parse_args
from scaling_cache.adapter.hyvideo.inference import HunyuanVideoSampler
from scaling_cache.utils.common import save_alpha_dict

def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)  # 关键修复：每个进程绑定到不同GPU
    # 验证设备绑定
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    print(f"Rank {os.environ['RANK']} -> CUDA device: {torch.cuda.current_device()}")
    args = parse_args()
    print(args)
    models_root_path = Path(args.model_base)
    if not models_root_path.exists():
        raise ValueError(f"`models_root` not exists: {models_root_path}")
    
    # Load models
    hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
    
    # Get the updated args
    args = hunyuan_video_sampler.args

    dimension_list = ["Human_Anatomy"]
    sample_num = 40

    for dimension in dimension_list:
        with open(f"./assets/prompts/prompt_aug/{dimension}.txt", 'r') as f:
            aug_prompt_list = f.readlines()
        with open(f"./assets/prompts/prompt/{dimension}.txt", 'r') as f:
            origin_prompt_list = f.readlines()

        origin_prompt_list = [prompt.strip() for prompt in origin_prompt_list][:sample_num]
        aug_prompt_list = [prompt.strip() for prompt in aug_prompt_list][:sample_num]

        for idx, origin_prompt in enumerate(origin_prompt_list):
            # sample 5 videos for each prompt
            aug_prompt = aug_prompt_list[idx]
            for index in range(args.num_videos_per_prompt):
                # perform sampling
                base_seed = 42 + index

                outputs = hunyuan_video_sampler.predict(
                    prompt=aug_prompt, 
                    height=args.video_size[0],
                    width=args.video_size[1],
                    video_length=args.video_length,
                    seed=base_seed,
                    negative_prompt=args.neg_prompt,
                    infer_steps=args.infer_steps,
                    guidance_scale=args.cfg_scale,
                    num_videos_per_prompt=args.num_videos,
                    flow_shift=args.flow_shift,
                    batch_size=args.batch_size,
                    embedded_guidance_scale=args.embedded_cfg_scale
                )
                if dist.get_rank() == 0:
                    samples = outputs['samples']
                    save_dir = f'vbench_video_path/{args.task}/{dimension}/{args.mode}/'
                    os.makedirs(save_dir, exist_ok=True)
                    cur_save_path = f'{save_dir}/{origin_prompt}-{index}.mp4'

                    for i, sample in enumerate(samples):
                        sample = samples[i].unsqueeze(0)
                        save_videos_grid(sample, cur_save_path, fps=24)
                        logger.info(f"Sample saved to: {cur_save_path}")
                
                dist.barrier()

                #重新加载
                hunyuan_video_sampler.pipeline.transformer.to(device)
                hunyuan_video_sampler.pipeline.text_encoder.to(device)
                hunyuan_video_sampler.pipeline.text_encoder_2.to(device)

            if hunyuan_video_sampler.pipeline.transformer.cache_dic['update_alpha']:
                save_alpha_dict(hunyuan_video_sampler.pipeline.transformer.cache_dic, args.task)
    
if __name__ == "__main__":
    main()
