import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

import json
import argparse
from collections import defaultdict
from copy import deepcopy

from vllm import LLM, SamplingParams
from tqdm import tqdm
from transformers import AutoTokenizer
from datasets import load_dataset

from models import LlamaDraftForCausalLM, LlamaForCausalLM
from utils import Timer
from preprocess import get_tokenizer


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-id",
        type=str,
        default="meta-llama/Meta-Llama-3-8B-Instruct",
        help="The model ID",
    )
    parser.add_argument(
        "--max-length",
        type=int,
        default=2048,
        help="The maximum length of the input sequence",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Debug mode",
    )
    parser.add_argument(
        "--tp",
        type=int,
        default=1,
        help="tensor parallel size",
    )
    parser.add_argument(
        "--max-num-seqs",
        type=int,
        default=64,
        help="maximum number of sequences",
    )

    args = parser.parse_args()

    with Timer("Loading dataset..."):
        dataset = load_dataset(
            "Aeala/ShareGPT_Vicuna_unfiltered",
            data_files={
                "valid": "ShareGPT_V4.3_unfiltered_cleaned_split.json",
            },
            split="valid" if not args.debug else "valid[99%:]",
        )

    conversations = dataset["conversations"]
    role_mapper = {
        "human": "user",
        "gpt": "assistant",
    }

    with Timer("Loading tokenizer..."):
        tokenizer = get_tokenizer(args.model_id)

    generated_data = defaultdict(list)
    generated_msgs = dict()

    for ith_conv, (conversation, id) in tqdm(enumerate(zip(conversations, dataset["id"])), total=len(conversations)):
        messages = [{
            "role": "system",
            "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
        }]
        for idx, turn in enumerate(conversation):
            role = role_mapper.get(turn["from"])
            if idx == 0 and role == "assistant":
                continue
            content = turn["value"]
            messages.append({
                "role": role,
                "content": content,
            })

        gen_messages = deepcopy(messages)

        generated_msgs[id] = gen_messages

        for idx, msg in tqdm(enumerate(messages), leave=False):
            if msg["role"] != "assistant":
                continue

            generated_data[idx].append({
                "id": id,
                "idx": idx,
            })

    llm = LLM(model=args.model_id, tensor_parallel_size=args.tp, max_num_seqs=args.max_num_seqs)

    sampling_params = SamplingParams(
        temperature=0,
        max_tokens=args.max_length,
    )

    lens =  sorted(generated_data.keys())

    for len in lens:
        input_texts = []
        for idx, input_data in enumerate(generated_data[len]):
            gen_id = input_data["id"]
            gen_idx = input_data["idx"]

            msgs = generated_msgs[gen_id][:gen_idx]

            input_text = tokenizer.apply_chat_template(
                msgs,
                tokenize=False,
                add_generation_prompt=True,
            )

            inputs = tokenizer(
                input_text,
                return_tensors="pt",
                add_special_tokens=False,
            )

            if inputs["input_ids"].shape[1] >= args.max_length:
                continue

            input_texts.append(input_text)

        outputs = llm.generate(input_texts, sampling_params)
        output_texts = [output.outputs[0].text for output in outputs]

        for idx, output in enumerate(output_texts):
            gen_id = generated_data[len][idx]["id"]
            gen_idx = generated_data[len][idx]["idx"]
            generated_msgs[gen_id][gen_idx]["content"] = output

    gen_messages = list(generated_msgs.values())

    generated_json = []
    # rollback role
    for ith_msg, gen_msg in enumerate(gen_messages):
        for msg in gen_msg:
            if msg["role"] == "user":
                msg["from"] = "human"
            elif msg["role"] == "assistant":
                msg["from"] = "gpt"
            msg["value"] = msg["content"]
        generated_json.append({
            "id": ith_msg,
            "conversations": gen_msg[1:],
        })

    with Timer("Saving generated data..."):
        json.dump(generated_json, open("generated_data.json", "w"), indent=2)
