#!/usr/bin/env python3
import os
# First, set environment variables to disable tokenizers warning messages
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import torch.distributed as dist
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import time
import json
import contextlib
from pathlib import Path
from loguru import logger
from datetime import datetime

from hyvideo.utils.file_utils import save_videos_grid
from scaling_cache.adapter.hyvideo.config import parse_args
from scaling_cache.adapter.hyvideo.inference import HunyuanVideoSampler
from scaling_cache.utils import save_alpha_dict

def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    print(f"Rank {os.environ['RANK']} -> CUDA device: {torch.cuda.current_device()}")

    args = parse_args()
    print(args)
    models_root_path = Path(args.model_base)
    if not models_root_path.exists():
        raise ValueError(f"`models_root` not exists: {models_root_path}")
    
    # Load models
    hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
    
    # Get the updated args
    args = hunyuan_video_sampler.args

    # dimension_list = ["all_dimension"]
    dimension_list = ["contrast_methods"]

    for dimension in dimension_list:
        with open(f"../assets/prompts/prompt_aug/{dimension}.txt", 'r') as f:
            aug_prompt_list = f.readlines()
        with open(f"../assets/prompts/prompt/{dimension}.txt", 'r') as f:
            origin_prompt_list = f.readlines()

        origin_prompt_list = [prompt.strip() for prompt in origin_prompt_list]
        aug_prompt_list = [prompt.strip() for prompt in aug_prompt_list]

        for idx, origin_prompt in enumerate(origin_prompt_list):
            # sample 5 videos for each prompt
            aug_prompt = aug_prompt_list[idx]
            for index in range(args.num_videos_per_prompt):
                # perform sampling
                base_seed = 42 + index
                outputs = hunyuan_video_sampler.predict(
                    prompt=aug_prompt, 
                    height=args.video_size[0],
                    width=args.video_size[1],
                    video_length=args.video_length,
                    seed=base_seed,
                    negative_prompt=args.neg_prompt,
                    infer_steps=args.infer_steps,
                    guidance_scale=args.cfg_scale,
                    num_videos_per_prompt=args.num_videos,
                    flow_shift=args.flow_shift,
                    batch_size=args.batch_size,
                    embedded_guidance_scale=args.embedded_cfg_scale
                )
                samples = outputs['samples']

                # Save samples
                if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
                    save_dir = f'vbench_video_path/{args.task}/{dimension}/{args.mode}-{args.first_enhance}/'
                    os.makedirs(save_dir, exist_ok=True)
                    cur_save_path = f'{save_dir}/{origin_prompt}-{index}.mp4'

                    for i, sample in enumerate(samples):
                        sample = samples[i].unsqueeze(0)
                        save_videos_grid(sample, cur_save_path, fps=24)
                        logger.info(f"Sample saved to: {cur_save_path}")

if __name__ == "__main__":
    main()
