#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.

import argparse
import os

import torch
import torch.distributed as dist
from accelerate.logging import get_logger

from fastvideo.utils.load import load_text_encoder

logger = get_logger(__name__)


def main(args):
    local_rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    print("world_size", world_size, "local rank", local_rank)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(local_rank)
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl",
                                init_method="env://",
                                world_size=world_size,
                                rank=local_rank)

    text_encoder = load_text_encoder(args.model_type,
                                     args.model_path,
                                     device=device)
    autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16
    # output_dir/validation/prompt_attention_mask
    # output_dir/validation/prompt_embed
    os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True)
    os.makedirs(
        os.path.join(args.output_dir, "validation", "prompt_attention_mask"),
        exist_ok=True,
    )
    os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"),
                exist_ok=True)

    with open(args.validation_prompt_txt, "r", encoding="utf-8") as file:
        lines = file.readlines()
    prompts = [line.strip() for line in lines]
    for prompt in prompts:
        with torch.inference_mode():
            with torch.autocast("cuda", dtype=autocast_type):
                prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt(
                    prompt)
                file_name = prompt.split(".")[0]
                prompt_embed_path = os.path.join(args.output_dir, "validation",
                                                 "prompt_embed",
                                                 f"{file_name}.pt")
                prompt_attention_mask_path = os.path.join(
                    args.output_dir,
                    "validation",
                    "prompt_attention_mask",
                    f"{file_name}.pt",
                )
                torch.save(prompt_embeds[0], prompt_embed_path)
                torch.save(prompt_attention_mask[0],
                           prompt_attention_mask_path)
                print(f"sample {file_name} saved")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # dataset & dataloader
    parser.add_argument("--model_path", type=str, default="data/mochi")
    parser.add_argument("--model_type", type=str, default="mochi")
    parser.add_argument("--validation_prompt_txt", type=str)
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help=
        "The output directory where the model predictions and checkpoints will be written.",
    )
    args = parser.parse_args()
    main(args)
