"""
Generates the hard negative captions given the balanced_sampling.py output of curated captions, concepts and attributes.
Stores the csv ready for training.
Columns: image_path, caption, hn_caption. (concept and attribute are not needed for training).
the image_path is created arbitrarily and will be the relative path to the image during training.
"""
import os, json, random, argparse, uuid, glob
from huggingface_hub import login
from multiprocessing import Process, Queue
from tqdm import tqdm
import pandas as pd
from engines import VllmModel
import time

# Dictionary mapping a short LLM name to its model id.
llm_ids = {
    "mistral02": "mistralai/Mistral-7B-Instruct-v0.2",
    "llama": "meta-llama/Llama-3.1-8B-Instruct",
}

def merge_and_save_csv(save_folder, final_csv_path, delete_temp, run_id):
    """
    Merge the csv files generated by each worker and save the final csv.
    """
    df = pd.DataFrame()
    temp_files = glob.glob(os.path.join(save_folder, f"{run_id}_*.csv"))

    for temp_file in tqdm(temp_files):
        df = pd.concat([df, pd.read_csv(temp_file, sep="\t")])

    total_samples = len(df)
    print(f"Total samples: {total_samples}")

    print(f"Saving complete csv dataset to {final_csv_path}...", end="")
    df.to_csv(final_csv_path, sep="\t", index=False)
    print(f"Done.")

    if delete_temp:
        # remove the individual csv files
        print(f"Removing individual csv files...")
        for temp_file in temp_files:
            os.remove(temp_file)

def llm_worker_persistent(in_queue, out_queue, llm_name, save_folder, dp_rank, model_kwargs, sampling_kwargs, run_id):
    """
    A persistent worker that loads the model once and processes prompts from a queue.
    """
    os.environ["CUDA_VISIBLE_DEVICES"] = str(dp_rank)
    
    try:
        model_kwargs["seed"] = random.randint(0, 1000000)
        model = VllmModel(llm_ids[llm_name], llm_name, model_kwargs=model_kwargs, sampling_kwargs=sampling_kwargs)

        out_queue.put(f"DP rank {dp_rank} ready.")

        while True:
            data = in_queue.get()
            if data is None:
                break

            batch_id, args_list = data
            if not args_list:
                out_queue.put(f"DP rank {dp_rank} finished batch {batch_id} (no prompts).")
                continue

            print(f"DP rank {dp_rank} processing {len(args_list)} prompts for batch {batch_id}")
            prompts = [model.hard_negative_prompt(concept, attribute, caption) for caption, concept, attribute in args_list]
            outputs = model.inference(prompts)
            outputs = [model.postprocess(output) for output in outputs]
            
            filename = f"{run_id}_{dp_rank}_{batch_id}.csv"
            image_paths = [f"{dp_rank}_{batch_id}/{i}.jpg" for i in range(len(args_list))]
            captions = [caption for caption, _, _ in args_list]
            df = pd.DataFrame({
                "image_path": image_paths,
                "caption": captions,
                "hn_caption": outputs
            })
            df.to_csv(os.path.join(save_folder, filename), index=False, sep="\t")
            out_queue.put(f"DP rank {dp_rank} finished batch {batch_id}.")

    except Exception as e:
        print(f"DP rank {dp_rank} failed with error: {e}")
        out_queue.put(f"DP rank {dp_rank} failed with error: {e}")
        raise e

def load_json(filepath):
    with open(filepath, "r") as f:
        return json.load(f)

def main(args):
    num_final_captions = args.num_final_captions
    curated_captions_folder = args.final_captions_folder
    save_folder = args.save_folder
    dp_size = args.dp_size
    llm_name = args.llm_name
    delete_temp = args.delete_temp

    login(os.environ["HF_TOKEN"])
    os.makedirs(save_folder, exist_ok=True)
    curated_captions_path = os.path.join(curated_captions_folder, f"captions_{num_final_captions}.json")
    curated_concepts_path = os.path.join(curated_captions_folder, f"concepts_{num_final_captions}.json")
    curated_attribute_indices_path = os.path.join(curated_captions_folder, f"attribute_indices_{num_final_captions}.json")

    assert os.path.exists(curated_captions_path), f"Curated captions file not found at {curated_captions_path}"
    assert os.path.exists(curated_concepts_path), f"Curated concepts file not found at {curated_concepts_path}"
    assert os.path.exists(curated_attribute_indices_path), f"Curated attribute indices file not found at {curated_attribute_indices_path}"

    curated_captions = load_json(curated_captions_path)
    curated_concepts = load_json(curated_concepts_path)
    curated_attribute_indices = {int(k): v for k, v in load_json(curated_attribute_indices_path).items()}

    print(f"Loaded {len(curated_captions)} curated captions.")
    print(f"Loaded {len(curated_concepts)} curated concepts.")
    print(f"Loaded {len(curated_attribute_indices)} curated attribute indices.")
    assert len(curated_captions) == len(curated_concepts), "Curated captions and concepts must have the same length."

    start_time = time.time()
    run_id = uuid.uuid4().hex

    model = VllmModel(model_id=llm_ids[args.llm_name], model_type=args.llm_name, load_model=False)
    model_kwargs = {
        "max_model_len": 4096,
        "dtype": "float16",
        "gpu_memory_utilization": 0.90
    }
    sampling_kwargs = {
        "temperature": 0.7,
        "top_p": 0.95,
        "max_tokens": 256
    }

    batch_size = args.batch_size_per_gpu * args.dp_size
    num_batches = (len(curated_captions) + batch_size - 1) // batch_size
    print(f"Total number of prompts: {len(curated_captions)}, processing in {num_batches} batches of size {batch_size}")

    in_queues = [Queue() for _ in range(args.dp_size)]
    out_queue = Queue()

    procs = []
    for dp_rank in range(args.dp_size):
        proc = Process(target=llm_worker_persistent,
                       args=(in_queues[dp_rank], out_queue, args.llm_name, args.save_folder, dp_rank, model_kwargs, sampling_kwargs, run_id))
        proc.start()
        procs.append(proc)

    for _ in range(args.dp_size):
        msg = out_queue.get()
        print(msg)
        if "failed" in msg:
            for p in procs:
                if p.is_alive():
                    p.terminate()
            exit(1)
    
    exit_code = 0
    attribute = None
    for batch_id in tqdm(range(num_batches), desc="Processing batches"):
        batch_start = batch_id * batch_size
        batch_end = min((batch_id + 1) * batch_size, len(curated_captions))
        
        args_list = []
        for i in range(batch_start, batch_end):
            if i in curated_attribute_indices:
                attribute = curated_attribute_indices[i] # update attribute when we reach a new starting index
            concept = curated_concepts[i]
            captions = curated_captions[i]
            args_list.append((captions, concept, attribute))

        chunk_size = len(args_list) // args.dp_size
        for dp_rank in range(args.dp_size):
            start = dp_rank * chunk_size
            end = start + chunk_size if dp_rank < args.dp_size - 1 else len(args_list)
            in_queues[dp_rank].put((batch_id, args_list[start:end]))

        for _ in range(args.dp_size):
            msg = out_queue.get()
            if "failed" in msg:
                print(msg)
                exit_code = 1
        
        if exit_code:
            break
    
    for dp_rank in range(args.dp_size):
        in_queues[dp_rank].put(None)

    for proc in procs:
        proc.join()
        if proc.exitcode:
            exit_code = proc.exitcode

    if exit_code:
        print("A worker process failed. Terminating.")
        for p in procs:
            if p.is_alive():
                p.terminate()
        exit(exit_code)

    final_csv_path = os.path.join(save_folder, f"captions_and_hn_{num_final_captions}.csv")
    print("Merging results...")
    merge_and_save_csv(save_folder, final_csv_path, delete_temp, run_id)
    pass


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Data Parallel Inference. Generate captions using vllm.")
    parser.add_argument("--llm-name", type=str, default="llama",
                        help="Name of the LLM to use")
    parser.add_argument("--final-captions-folder", type=str, required=True,
                        help="Path to the folder where the json with the final curated captions, concepts and attributes are stored.")
    parser.add_argument("--save-folder", type=str, required=True,
                        help="Path to the folder where the json with the generated csv will be saved.")
    parser.add_argument("--dp-size", type=int, default=8,
                        help="Data parallel size, number of GPUs on a single node to load one model istance each and split the prompts between GPUs.")
    parser.add_argument("--delete-temp", type=bool, default=True,
                        help="Whether to delete the temporary csv files after merging them into one.")
    parser.add_argument("--batch-size-per-gpu", type=int, default=150000,
                        help="Number of prompts to process in a batch on each GPU.")
    parser.add_argument("--num-final-captions", type=int, default=12000000,
                        help="number of final captions. Used to load the correct caption json files.")
    args = parser.parse_args()
    main(args)
