"""Prompt building and post-processing utilities.

Extracted from the monolithic script to separate concerns.
"""

from __future__ import annotations

from typing import Any, Dict


def build_chat(tokenizer, prompt: str, model_name: str) -> str:
    """Build a model-specific chat prompt.

    Args:
        tokenizer: HF tokenizer instance with templates.
        prompt: Raw user prompt.
        model_name: Model identifier string.

    Returns:
        A prompt string compatible with the given model.

    Time/Space: O(|prompt|) time to format the string, O(|prompt|) space.
    """
    if "chatglm3" in model_name:
        prompt = tokenizer.build_chat_input(prompt)
    elif "chatglm" in model_name:
        prompt = tokenizer.build_prompt(prompt)
    elif "longchat" in model_name or "vicuna" in model_name:
        from fastchat.model import get_conversation_template  # type: ignore

        conv = get_conversation_template("vicuna")
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
    elif "llama" in model_name:
        if "3" in model_name:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt},
            ]
            prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            prompt = f"[INST]{prompt}[/INST]"
    elif "mistral" in model_name:
        prompt = f"[INST]{prompt}[/INST]"
    elif "xgen" in model_name:
        header = (
            "A chat between a curious human and an artificial intelligence assistant. "
            "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
        )
        prompt = header + f" ### Human: {prompt}\n###"
    elif "internlm" in model_name:
        prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
    return prompt


def post_process(response: str, model_name: str) -> str:
    """Lightweight post-process of model output.

    Args:
        response: Decoded model output string.
        model_name: Model identifier string.

    Returns:
        Cleaned text.

    Time/Space: O(|response|).
    """
    if "xgen" in model_name:
        response = response.strip().replace("Assistant:", "")
    elif "internlm" in model_name:
        response = response.split("<eoa>")[0]
    return response


def build_inputs(model, tokenizer, prompt: str):
    """Construct input tensors for generation.

    Supports both chat-template and base tokenization flows.

    Args:
        model: HF model to determine device move.
        tokenizer: HF tokenizer.
        prompt: Input prompt string.

    Returns:
        Dict[str, tensor]: Input tensors moved to model.device.
    """
    has_chat = bool(getattr(tokenizer, "chat_template", None))
    if has_chat:
        msgs = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ]
        enc = tokenizer.apply_chat_template(
            msgs, add_generation_prompt=True, tokenize=True, return_tensors="pt"
        )
    else:
        enc = tokenizer(prompt, add_special_tokens=True, return_tensors="pt")
    return {k: v.to(model.device) for k, v in enc.items()}


def setup_eos_pad(model, tokenizer) -> None:
    """Set eos and pad token IDs on model.generation_config.

    Adjusts for multiple eos tokens if present.
    """
    eos = tokenizer.eos_token_id
    eos_ids = [eos] if isinstance(eos, int) else list(eos or [])
    for tok in ("<|im_end|>", "<|endoftext|>", "<|eot_id|>"):
        try:
            if tok in tokenizer.get_vocab():
                tid = tokenizer.convert_tokens_to_ids(tok)
                if tid is not None:
                    eos_ids.append(int(tid))
        except Exception:
            pass
    eos_ids = sorted(set([i for i in eos_ids if i is not None]))
    pad_id = (
        tokenizer.pad_token_id if tokenizer.pad_token_id is not None else (eos_ids[0] if eos_ids else None)
    )
    if eos_ids:
        model.generation_config.eos_token_id = eos_ids
    if pad_id is not None:
        model.generation_config.pad_token_id = pad_id


