import os
import argparse
import glob
import multiprocessing as mp
import subprocess
import time
import pdb
# import ray

## template 
# CUDA_VISIBLE_DEVICES=7 python3 cli_demo.py \
# --prompt "pick up the alphabet soup and place it in the basket" \
# --image_or_video_path /path to the finetune_dataset/libero_object_original/first_frames/pick_up_the_alphabet_soup_and_place_it_in_the_basket_trj_10.png \
# --model_path ../finetune/test_output/libero_90_original/cogvideox1.5-5b-i2v \
# --generate_type i2v \
# --output_path ./output.mp4 \
# --num_frames 81


def run_task(gpu_id, tasks, args):
    for task in tasks:
        prompt, img_path = task
        save_name = os.path.basename(img_path).split('.')[0] + "_out.mp4"
        save_path = os.path.join(args.save_path, save_name)

        cmd = [
            "python3", "cli_demo.py",
            "--prompt", f"\"{prompt}\"",
            "--image_or_video_path", img_path,
            "--model_path", args.model_path,
            "--generate_type", args.generate_type,
            "--output_path", save_path,
            "--num_frames", str(args.num_frames),
        ]

        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

        print(f"[GPU {gpu_id}] Running: {' '.join(cmd)}")
        subprocess.run(cmd, env=env)
        log_file = os.path.join(args.save_path, "log.txt")
        time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        with open(log_file, 'a') as f:
            f.write(f"[GPU {gpu_id}] Running: {' '.join(cmd)} finished at {time_now}\n")




def main(args):
    prompts_path = os.path.join(args.data_folder, "prompts.txt")
    with open(prompts_path, 'r') as f:
        prompts = [line.strip() for line in f.readlines()]

    first_frames_path = os.path.join(args.data_folder, "images.txt")
    with open(first_frames_path, 'r') as f:
        first_frames = [line.strip() for line in f.readlines()]
    
    assert len(prompts) == len(first_frames)
    first_frames = [os.path.join(args.data_folder, f) for f in first_frames]

    finished_task = os.listdir(args.save_path)
    finished_task = [x for x in finished_task if x.endswith(".mp4")]
    finished_task_name = [os.path.basename(x).split('.')[0] for x in finished_task]
    finished_task_name = [x.replace("_out", "") for x in finished_task_name]
    print(f"Finished tasks: {len(finished_task_name)}")
    
    tasks = list(zip(prompts, first_frames))
    tasks = [task for task in tasks if task[1].split('/')[-1].split('.')[0] not in finished_task_name]
    tasks = [task for task in tasks if int(task[1].split('/')[-1].split('.')[0].split("_")[-1]) < 20]
    # tasks = tasks[:5]
    print(f"Total tasks: {len(tasks)}")
    # print(tasks[:3])
    gpu_list = [int(d) for d in args.device.split(",")]
    num_gpus = len(gpu_list)

    task_splits = [tasks[i::num_gpus] for i in range(num_gpus)]

    log_file = os.path.join(args.save_path, "log.txt")
    time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    with open(log_file, 'a') as f:
        f.write(f"Start time: {time_now}\n")
        f.write(f"Total tasks: {len(tasks)}\n")
        f.write(f"Tasks per GPU: {len(tasks) // num_gpus}\n")

    ctx = mp.get_context("spawn")
    procs = []
    for i, gpu_id in enumerate(gpu_list):
        p = ctx.Process(target=run_task, args=(gpu_id, task_splits[i], args))
        p.start()
        procs.append(p)

    for p in procs:
        p.join()


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--data_folder', type=str, required=True)
    argparser.add_argument('--model_path', type=str, required=True)
    argparser.add_argument('--generate_type', type=str, default='i2v')
    argparser.add_argument('--save_path', type=str, required=True)
    argparser.add_argument('--device', type=str, default='0,1,2,3')
    argparser.add_argument('--num_frames', type=int, default=17)
    args = argparser.parse_args()

    os.makedirs(args.save_path, exist_ok=True)
    main(args)

## run 1 task on 1 device at a time