from fifo_utils.dir_utils import set_directory
from zeroscope_worker import get_pipeline, run_base, run_fifo
from argparse import ArgumentParser
import torch
import os
from zeroscope_fifo.models import UNet3DConditionModel


def load_prompts(prompt_file, num_processes, rank):
    f = open(prompt_file, 'r')
    prompt_list = []
    for idx, line in enumerate(f.readlines()):
        l = line.strip()
        if len(l) != 0:
            prompt_list.append(l)
        f.close()
    
    prompt_list = prompt_list[rank::num_processes]
    return prompt_list

if __name__ == "__main__":
    parser = ArgumentParser()

    # general arguments
    parser.add_argument("--seed", type=int, default=321)
    parser.add_argument("--model_dir", type=str, default="zeroscope_models", help="directory to save the model from huggingface")
    parser.add_argument("--video_length", type=int, default=24, help="f in paper")
    parser.add_argument("--num_partitions", type=int, default=4, help="n in paper")
    parser.add_argument("--num_inference_steps", type=int, default=24, help="number of inference steps, it will be f * n forcedly")
    parser.add_argument("--prompt_file", "-p", type=str, default="prompts/test_prompts.txt", help="path to the prompt file")
    parser.add_argument("--new_video_length", "-l", type=int, default=100, help="N in paper; desired length of the output video")
    parser.add_argument("--num_processes", type=int, default=1, help="number of processes if you want to run the prompts in multiple gpus")
    parser.add_argument("--rank", type=int, default=0, help="rank of the process(0~num_processes-1)")
    parser.add_argument("--height", type=int, default=320, help="height of the output video")
    parser.add_argument("--width", type=int, default=576, help="width of the output video")
    parser.add_argument("--save_frames", action="store_true", default=False, help="save generated frames for each step")
    parser.add_argument("--lookahead_denoising", "-ld", action="store_false", default=True, help="use lookahead denoising")
    parser.add_argument("--eta", "-e", type=float, default=0.5, help="ddim eta for sampling")

    args = parser.parse_args()
    
    prompts = load_prompts(args.prompt_file, args.num_processes, args.rank)

    args.num_inference_steps = args.video_length * args.num_partitions

    # get pipeline
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipe = get_pipeline(device, cache_dir=args.model_dir, use_device_map=False)
    
    # set random seed
    generator = torch.Generator().manual_seed(args.seed) if args.seed else None

    for prompt in prompts:
        # set output dir
        directories = set_directory(args, prompt)
        directories = {
            "base_dir":directories[0],
            "latents_dir":directories[1]
        }

        # generate first N frames to prepare the latents
        # if the latents are already generated, skip this step by setting args.skip_base = True
        is_run_base = not (os.path.exists(directories["latents_dir"]+f"/{args.num_inference_steps}.pt") and os.path.exists(directories["latents_dir"]+f"/0.pt"))
        
        if is_run_base:
            run_base(args, pipe, directories, generator, prompt)

        # generate longer video through fifo method    
        run_fifo(args, pipe, directories, generator, prompt)
    