import re
import tiktoken


# Previously we had a bug with ```3``` where 3 was parsed as a language tag.
# Only permit language tags that start with a letter.
CODE_BLOCK_PATTERN = re.compile(r"```(?:[a-zA-Z][\w+-]*)?\s*\n?([\s\S]*?)```")

THINK_BLOCK_PATTERN = re.compile(r"<think\b[^>]*>([\s\S]*?)</think>", re.IGNORECASE)

OPEN_THINK_FALLBACK = re.compile(r"<think\b[^>]*>([\s\S]*)\Z", re.IGNORECASE)


def normalize_response(llm_result: str) -> str:
    """
    Extract the code (or the answer) from triple backticks.

    LLMs usually respond with ```python, ```sql, etc. This helps
    extract the responses from the last block.

    Also removes <think> and </think> tokens from the output.
    """
    s = llm_result

    code_matches = CODE_BLOCK_PATTERN.findall(s)
    if code_matches:
        return code_matches[-1].strip()

    think_matches = THINK_BLOCK_PATTERN.findall(s)
    if think_matches:
        return think_matches[-1].strip()

    open_think = OPEN_THINK_FALLBACK.search(s)
    if open_think:
        return open_think.group(1).strip()

    s = re.sub(r"<think\b[^>]*>[\s\S]*?</think>", "", s, flags=re.IGNORECASE)
    return s.strip()


def estimate_token_count(text: str) -> int:
    """
    Estimates the number of tokens in the text.

    THIS IS JUST AN ESTIMATE based on OPENAI GPT-4 tokenization.
    """
    encoding = tiktoken.encoding_for_model("gpt-4")
    tokens = encoding.encode(text)
    return len(tokens)
