import os
import fire
import json
from typing import Literal

from model.llm import LLMS, SamplingParams
from utils import load_jsonl, save_jsonl


def tag_prompts(
    model_name_or_path: str,
    data_dir: str,
    datasets: list[str],
    cache_dir: str | None = None,
    engine: Literal["huggingface", "vllm"] = "huggingface",
    chunk_size: int = -1,
    temperature: float = 0.8,
    top_p: float = 0.95,
    max_new_tokens: int = 4096,
    **engine_kwargs,
):
    # Load model and data
    llm = LLMS[engine](model_name_or_path, **engine_kwargs)
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens)

    for dataset in datasets:
        print(f"Processing dataset: {dataset}")
        data_path = os.path.join(data_dir, dataset, "instructions.jsonl")
        data = load_jsonl(data_path)
        prompts = [make_prompt(item["prompt"]) for item in data]
        responses = llm.generate(
            prompts,
            sampling_params,
            chunk_size=chunk_size,
            cache_dir=os.path.join(cache_dir, dataset) if cache_dir is not None else None,
        )
        results = [
            {
                **item,
                "tags": parse_output(response[0]),
            }
            for item, response in zip(data, responses)
        ]

        # Save responses
        save_jsonl(results, data_path.replace(".jsonl", "_tag.jsonl"))


def make_prompt(query: str) -> str:
    prompt = f'Please identify tags of user intentions in the following user query and provide an explanation for each tag. Please response in the JSON format {{"tag": str, "explanation": str}}.\n User query: {query}'
    messages = [("USER", prompt), ("ASSISTANT", None)]
    seps = [" ", "</s>"]
    ret = (
        "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
        + seps[0]
    )
    for i, (role, message) in enumerate(messages):
        if message:
            ret += role + ": " + message + seps[i % 2]
        else:
            ret += role + ":"
    return ret


def parse_output(output: str) -> list[str]:
    try:
        results = json.loads(output)
        tags = [item["tag"] for item in results]
        return tags
    except json.JSONDecodeError:
        if '"tag": "' in output:
            output = output.split('"tag": ')[1]
            if '"' in output:
                output = output.split('"')[0]
            return [output]
        else:
            return []
    except Exception as e:
        print(output)
        raise e


if __name__ == "__main__":
    fire.Fire(tag_prompts)
