"""
This script is for generating the dataset of images given the csv file with the captions (or hard negative captions) and the image paths.
It uses data parallelism to distribute the generation across multiple GPUs on a single node.
Each GPU is managed by a separate process.

Example usage:
python image_gen/scaled_img_generation.py --dp-size 8 [other args]
"""
import os, sys, time, argparse, pandas as pd
import numpy as np
from multiprocessing import Process, Queue, set_start_method, get_context
from tqdm import tqdm
from pytorch_lightning import seed_everything
from full_pipeline import FullPipeline

final_t2i_hyperparams = {
    "SD15": {
        "size": 512,
        "gs": 2.0,
        "steps": 50,
        "seed": 42
    },
    "SD2": {
        "size": 768,
        "gs": 2.0,
        "steps": 50,
        "seed": 420
    },
    "SANA": {
        "size": 512,
        "gs": 2.0,
        "steps": 20,
        "seed": 4200
    },
    "SDXLT_16b": {
        "size": 512,
        "gs": 0.0,
        "steps": 2,
        "seed": 42000
    }
}

def image_gen_worker(in_queue, out_queue, dp_rank, tti_name, tti_dtype, batch_size, save_size, save_folder, local_model_path, task, seed=None):
    """
    A worker that generates images for prompts from a queue.
    """
    #os.environ["CUDA_VISIBLE_DEVICES"] = str(dp_rank)
    size = final_t2i_hyperparams[tti_name]["size"]
    gs = final_t2i_hyperparams[tti_name]["gs"]
    steps = final_t2i_hyperparams[tti_name]["steps"]
    if seed is not None:
        print(f"Overriding default seed for {tti_name} with {seed}")
        seed_everything(seed + dp_rank)
    elif task == "hn": seed_everything(final_t2i_hyperparams[tti_name]["seed"] * 2 + dp_rank)
    else: seed_everything(final_t2i_hyperparams[tti_name]["seed"] + dp_rank)
    
    try:
        # Each worker has its own pipeline
        pipeline = FullPipeline(tti_name, parallelize=False, device=f"cuda:{dp_rank}", torch_dtype=tti_dtype, local_path=local_model_path)
        out_queue.put(f"DP rank {dp_rank} ready.")

        while True:
            data = in_queue.get()
            if data is None:
                break

            batch_id, captions, filenames = data
            if not captions:
                out_queue.put(f"DP rank {dp_rank} finished batch {batch_id} (no prompts).")
                continue

            print(f"DP rank {dp_rank} processing {len(captions)} prompts for batch {batch_id}")

            pipeline.generate_images(
                captions, 
                filenames, 
                batch_size, 
                steps, 
                size, 
                gs, 
                dest_folder=save_folder,
                makedir=False,
                abs_path=False, 
                resize=(save_size, save_size) if save_size else None
            )
            
            out_queue.put(f"DP rank {dp_rank} finished batch {batch_id}.")
    except Exception as e:
        print(f"DP rank {dp_rank} failed with error: {e}")
        out_queue.put(f"DP rank {dp_rank} failed with error: {e}")
        raise e

def main(args):
    start_time = time.time()

    df = pd.read_csv(args.csv_path, sep="\t")

    if args.task == "cap":
        captions = df["caption"].tolist()
    elif args.task == "hn":
        captions = df["hn_caption"].tolist()
    filenames = df["image_path"].tolist()
    del df

    # create save folder
    os.makedirs(args.save_folder, exist_ok=True)
    save_folder = os.path.join(args.save_folder, args.tti_name)
    os.makedirs(save_folder, exist_ok=True)
    if args.seed is not None:
        save_folder = os.path.join(save_folder, f"{args.task}_{args.seed}")
    else: save_folder = os.path.join(save_folder, args.task)
    os.makedirs(save_folder, exist_ok=True)

    print(f"Generating all subfolders of {save_folder} listed in image_path values...")
    subfolders = list(set([f.split("/")[0] for f in filenames]))
    for subfolder in subfolders:
        os.makedirs(os.path.join(save_folder, subfolder), exist_ok=True)
    print("Done.")
    
    if args.exist_ok:
        print("Filtering out captions which image already exist...")
        idxs = [i for i, f in tqdm(enumerate(filenames)) if not os.path.exists(os.path.join(save_folder, f))]
        print(f"Keeping {len(idxs)} out of {len(filenames)} images.")
        captions = [captions[i] for i in idxs]
        filenames = [filenames[i] for i in idxs]

    if len(captions) == 0:
        print("No captions to generate images for. Images already generated. Exiting.")
        return
    print(f"Number of captions to generate images for: {len(captions)}")
    print(f"Number of filenames of images to generate: {len(filenames)}")

    # Split data for workers
    num_prompts = len(captions)
    prompts_per_worker = num_prompts // args.dp_size
    data_chunks = []
    for i in range(args.dp_size):
        start = i * prompts_per_worker
        end = start + prompts_per_worker if i < args.dp_size - 1 else num_prompts
        chunk_captions = captions[start:end]
        chunk_filenames = filenames[start:end]
        data_chunks.append((chunk_captions, chunk_filenames))

    ctx = get_context("spawn")
    in_queues = [ctx.Queue() for _ in range(args.dp_size)]
    out_queue = ctx.Queue()

    
    procs = []
    for dp_rank in range(args.dp_size):
        proc = ctx.Process(target=image_gen_worker,
                       args=(in_queues[dp_rank], out_queue, dp_rank, args.tti_name, args.tti_dtype, args.batch_size, args.save_size, save_folder, args.local_model_path, args.task, args.seed))
        proc.start()
        procs.append(proc)

    for _ in range(args.dp_size):
        msg = out_queue.get()
        print(msg)
        if "failed" in msg:
            for p in procs:
                if p.is_alive():
                    p.terminate()
            exit(1)

    exit_code = 0
    for dp_rank in range(args.dp_size):
        captions_chunk, filenames_chunk = data_chunks[dp_rank]
        in_queues[dp_rank].put((dp_rank, captions_chunk, filenames_chunk))

    for _ in range(args.dp_size):
        msg = out_queue.get()
        if "failed" in msg:
            print(msg)
            exit_code = 1
    
    if exit_code:
        for p in procs:
            if p.is_alive():
                p.terminate()
        exit(exit_code)
    
    for dp_rank in range(args.dp_size):
        in_queues[dp_rank].put(None)

    for proc in procs:
        proc.join()
        if proc.exitcode:
            exit_code = proc.exitcode

    if exit_code:
        print("A worker process failed. Terminating.")
        for p in procs:
            if p.is_alive():
                p.terminate()
        exit(exit_code)

    total_time = time.time() - start_time
    print("Total time taken:", total_time, "seconds")
    if num_prompts > 0:
        print(f"Average time per image: {total_time/num_prompts}s")
    exit(exit_code)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate images from a given csv file with data parallelism.")
    parser.add_argument("--save-folder", type=str, required=True,
                        help="Absolute path to the folder where the images will be saved.")
    parser.add_argument("--csv-path", type=str, required=True,
                        help="Path where the csv with the captions is stored.")
    parser.add_argument("--batch-size", type=int, required=True,
                        help="Batch size for the image generation for each worker.")
    parser.add_argument("--tti-name", choices=["SD15", "SD2", "SANA", "SDXLT_16b"], required=True,
                        help="Name of the Text to image model to use. Generates a subfolder with this name in the save folder.")
    parser.add_argument("--tti-dtype", choices=["float16", "bfloat16"], required=True,
                        help="Data type of the Text to image model to use. Can be either float16 or bfloat16.")
    parser.add_argument("--dp-size", type=int, required=True,
                        help="Data parallel size, number of GPUs on a single node to use.")
    parser.add_argument("--save-size", type=int, default=256,
                        help="Resolution (size x size) of the images to save.")
    parser.add_argument("--task", choices=["cap", "hn"], default="cap",
                        help="'cap' for generating images based on original captions, 'hn' for generating images based on hard negative captions. Generates a subfolder with this name in the save_folder/tti_name.")
    parser.add_argument("--local-model-path", default=None, type=str,
                        help="If set, loads the diffusion model from local checkpoints, from HF Hub otherwise.")
    parser.add_argument("--exist-ok", default=False, action="store_true",
                        help="If set, filters out captions which images are already present in the correct folder")
    parser.add_argument("--seed", default=None, type=int,
                        help="Overrides the default base seed for the t2i model")
    args = parser.parse_args()

    assert args.task in ["cap", "hn"], f"task {args.task} must be either 'cap' or 'hn'"
    assert os.path.exists(args.csv_path), f"csv file {args.csv_path} does not exist"
    assert args.csv_path.endswith(".csv"), f"csv file {args.csv_path} must be a .csv file"
    set_start_method("spawn", force=True)
    main(args)