import os, json, random, argparse, uuid
from huggingface_hub import login
from multiprocessing import Process, Queue
from tqdm import tqdm
import pandas as pd
from engines import VllmModel
import time

# Dictionary mapping a short LLM name to its model id.
llm_ids = {
    "mistral02": "mistralai/Mistral-7B-Instruct-v0.2",
    "llama": "meta-llama/Llama-3.1-8B-Instruct",
}

def llm_worker_persistent(in_queue, out_queue, llm_name, save_folder, dp_rank, model_kwargs, sampling_kwargs, run_id):
    """
    A persistent worker that loads the model once and processes prompts from a queue.
    """
    os.environ["CUDA_VISIBLE_DEVICES"] = str(dp_rank)
    
    try:
        model_kwargs["seed"] = random.randint(0, 1000000)
        model = VllmModel(llm_ids[llm_name], llm_name, model_kwargs=model_kwargs, sampling_kwargs=sampling_kwargs)

        out_queue.put(f"DP rank {dp_rank} ready.")

        while True:
            data = in_queue.get()
            if data is None:
                break

            batch_id, prompts = data
            if not prompts:
                out_queue.put(f"DP rank {dp_rank} finished batch {batch_id} (no prompts).")
                continue

            print(f"DP rank {dp_rank} processing {len(prompts)} prompts for batch {batch_id}")

            outputs = model.inference(prompts)
            outputs = [model.postprocess(output) for output in outputs]
            
            filename = f"{run_id}_{dp_rank}_{batch_id}.json"
            with open(os.path.join(save_folder, filename), "w") as f:
                json.dump(outputs, f)

            out_queue.put(f"DP rank {dp_rank} finished batch {batch_id}.")
    except Exception as e:
        print(f"DP rank {dp_rank} failed with error: {e}")
        out_queue.put(f"DP rank {dp_rank} failed with error: {e}")
        raise e

def merge_and_save_json(save_folder, attribute, final_json_name, delete_temp=True, run_id=None):
    """
    Merge the json files generated by each worker and save the final json.
    """
    captions = []
    
    if run_id:
        temp_files = glob.glob(os.path.join(save_folder, f"{run_id}_*.json"))
    else:
        all_json_files = glob.glob(os.path.join(save_folder, "*.json"))
        final_path = os.path.join(save_folder, final_json_name)
        temp_files = []
        temp_file_pattern = re.compile(r"^\d+(?:_\d+)?\.json$")
        for f in all_json_files:
            if f == final_path:
                continue
            filename = os.path.basename(f)
            if temp_file_pattern.match(filename):
                temp_files.append(f)

    # A sort is needed to keep the order of the captions
    temp_files.sort()

    for temp_file in temp_files:
        try:
            with open(temp_file, "r") as f:
                captions.extend(json.load(f))
        except json.JSONDecodeError:
            print(f"Warning: Could not decode JSON from file {temp_file}. Skipping.")

    output_data = {
        "attribute": attribute,
        "captions": captions
    }
    
    dest = os.path.join(save_folder, final_json_name)
    print(f"Saving all metadata to {dest}...", end="")
    with open(dest, "w") as f:
        json.dump(output_data, f, indent=4)
    print(f"Done.")

    if delete_temp:
        print(f"Removing individual json files...")
        for temp_file in temp_files:
            os.remove(temp_file)

def main(args):
    attributes = ["background", "color", "concept", "lighting", "material", "perspective", "position", "style"]
    assert args.attribute in attributes, f"Invalid attribute: {args.attribute}. Must be one of {attributes}."
    start_time = time.time()
    run_id = uuid.uuid4().hex

    dp_size = args.dp_size

    login(os.environ["HF_TOKEN"])
    with open(args.concept_path, "r") as f:
        concept_bank = json.load(f)

    model = VllmModel(model_id=llm_ids[args.llm_name], model_type=args.llm_name, load_model=False)
    os.makedirs(args.save_folder, exist_ok=True)

    model_kwargs = {
        "max_model_len": 4096,
        "dtype": "float16",
        "gpu_memory_utilization": 0.90
    }
    sampling_kwargs = {
        "temperature": 0.7,
        "top_p": 0.95,
        "max_tokens": 256,
        "presence_penalty": 1.0,
        "frequency_penalty": 1.0,
    }

    batch_size = args.batch_size_per_gpu * args.dp_size
    num_batches = (args.num_captions + batch_size - 1) // batch_size
    print(f"Total number of prompts: {args.num_captions}, processing in {num_batches} batches of size {batch_size}")

    in_queues = [Queue() for _ in range(args.dp_size)]
    out_queue = Queue()

    procs = []
    for dp_rank in range(args.dp_size):
        proc = Process(target=llm_worker_persistent,
                       args=(in_queues[dp_rank], out_queue, args.llm_name, args.save_folder, dp_rank, model_kwargs, sampling_kwargs, run_id))
        proc.start()
        procs.append(proc)

    for _ in range(args.dp_size):
        msg = out_queue.get()
        print(msg)
        if "failed" in msg:
            for p in procs:
                if p.is_alive():
                    p.terminate()
            exit(1)

    exit_code = 0
    for batch_id in tqdm(range(num_batches), desc="Processing batches"):
        batch_start = batch_id * batch_size
        batch_end = min((batch_id + 1) * batch_size, args.num_captions)
        
        prompts = []
        for i in range(batch_start, batch_end):
            concept = concept_bank[i % len(concept_bank)]
            prompts.append(model.caption_prompt(concept, args.attribute))

        chunk_size = len(prompts) // args.dp_size
        for dp_rank in range(args.dp_size):
            start = dp_rank * chunk_size
            end = start + chunk_size if dp_rank < args.dp_size - 1 else len(prompts)
            in_queues[dp_rank].put((batch_id, prompts[start:end]))

        for _ in range(args.dp_size):
            msg = out_queue.get()
            if "failed" in msg:
                print(msg)
                exit_code = 1
        
        if exit_code:
            break
    
    for dp_rank in range(args.dp_size):
        in_queues[dp_rank].put(None)

    for proc in procs:
        proc.join()
        if proc.exitcode:
            exit_code = proc.exitcode

    if exit_code:
        print("A worker process failed. Terminating.")
        for p in procs:
            if p.is_alive():
                p.terminate()
        exit(exit_code)

    total_samples = args.num_captions
    if total_samples > 1000000: suffix = f"{total_samples // 1000000}M"
    elif total_samples > 1000: suffix = f"{total_samples // 1000}k"
    else: suffix = str(total_samples)
    final_json_name = f"metadata_{args.attribute}_{suffix}.json"

    merge_and_save_json(args.save_folder, args.attribute, final_json_name, args.delete_temp, run_id)
    
    total_time = time.time() - start_time
    print("Total time taken:",total_time,"seconds")
    exit(exit_code)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Data Parallel Inference. Generate captions using vllm.")
    parser.add_argument("--llm-name", type=str, default="mistral02",
                        help="Name of the LLM to use")
    parser.add_argument("--save-folder", type=str, required=True,
                        help="Path to the folder where the json with the generated captions will be saved.")
    parser.add_argument("--dp-size", type=int, default=8,
                        help="Data parallel size, number of GPUs on a single node to load one model istance each and split the prompts between GPUs.")
    parser.add_argument("--concept-path", type=str, required=True,
                        help="Path to the concept bank.")
    parser.add_argument("--num-captions", type=int, default=12500000,
                        help="Number of captions to generate.")
    parser.add_argument("--delete-temp", type=bool, default=True,
                        help="Whether to delete the temporary json files after merging them into one.")
    parser.add_argument("--attribute", type=str, required=True,
                        help="Attribute to use for all prompts.")
    parser.add_argument("--batch-size-per-gpu", type=int, default=156250,
                        help="Number of prompts to process in a batch on each GPU.")
    args = parser.parse_args()
    main(args)