import time
from typing import Dict, List, Optional

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from fair_gpt.generation.utils import (
    CHAT_FEW_SHOT_PRE_PROMPT,
    CHAT_ZERO_SHOT,
    PRETRAINED_FEW_SHOT,
    SYSTEM_PROMPT,
)


def build_prompt_chat(
    few_shots: List[Dict[str, str]],
) -> List[dict]:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
    ]
    if not few_shots:
        user_prompt = CHAT_ZERO_SHOT
    else:
        user_prompt = CHAT_FEW_SHOT_PRE_PROMPT
        for x in few_shots:
            user_prompt += f"\n\nBiography:\n{x['content']}"
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": "Biography:\n"},
    ]
    return messages


def build_prompt_pretrained(few_shots: List[Dict[str, str]]) -> str:
    prompt = PRETRAINED_FEW_SHOT
    for x in few_shots:
        prompt += f"\n\nBiography:\n{x['content']}"
    return prompt


def build_prompt(
    few_shots: List[Dict[str, str]],
    chat: bool = True,
    tokenizer: Optional[AutoTokenizer] = None,
) -> str:
    if chat:
        messages = build_prompt_chat(few_shots)
        prompt = tokenizer.apply_chat_template(
            messages, continue_final_message=True, tokenize=False
        )
        return prompt
    else:
        prompt = build_prompt_pretrained(few_shots)
        return prompt


# ----------------------------
# vLLM runner
# ----------------------------
def generate_with_vllm(
    llm: LLM,
    prompt_token_ids: List[List[int]],
    sampling_params: SamplingParams,
    chat: bool = True,
    lora=None,
):
    # Prepare JSON schema for guided decoding

    print("Len of prompt_token_ids:", len(prompt_token_ids))
    print(lora)
    outputs = llm.generate(
        prompts=prompt_token_ids, sampling_params=sampling_params, lora_request=lora
    )
    text_list = [out.outputs[0].text for out in outputs]

    return text_list


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(f"models_dir/{args.model_name}")

    few_shot_lists = [
        {
            "content": "John Doe (born January 1, 1970) is an American software engineer and entrepreneur. He is best known for founding Tech Innovations, a leading technology company specializing in AI and machine learning solutions. Doe graduated from MIT with a degree in Computer Science and has been recognized for his contributions to the tech industry with several awards."
        },
        {
            "content": "Jane Smith (born February 2, 1980) is a British author and journalist. She has written several best-selling novels, including 'The Silent Echo' and 'Whispers in the Wind.' Smith's work often explores themes of identity and human connection. In addition to her novels, she has contributed to various newspapers and magazines."
        },
    ]

    prompt = build_prompt(few_shot_lists, tokenizer=tokenizer, chat=True)

    print("Prompt:")

    prompt_lists = [prompt] * 50
    sampling_params = SamplingParams(
        temperature=0.7, max_tokens=150, top_p=0.9, stop=["\n\n"]
    )
    llm = LLM(f"models_dir/{args.model_name}")

    start_time = time.time()
    for p in prompt_lists:
        text_list = generate_with_vllm(llm, [p], sampling_params, chat=True)
    end_time = time.time()
    print(text_list)
    print(f"Time taken for generation: {end_time - start_time} seconds")
