from concurrent.futures import ThreadPoolExecutor, as_completed
from logging import getLogger
from typing import Dict, List, Optional, Tuple

import litellm
from tqdm import tqdm
from transformers import AutoTokenizer

from pcot.parser import parse_random_digit_sequence, parse_random_string

logger = getLogger(__name__)

WAIT_TOKEN = "Wait"
THINKING_END = "<|im_start|>answer"


def run_budget_forced_query(
    base_messages: List[Dict[str, str]],
    model: str,
    api_base: Optional[str],
    api_key: Optional[str],
    temperature: float,
    max_think_tokens: int,
    num_wait_insertion: int,
    tokenizer_name: str,
) -> Tuple[str, str, int, int, int]:
    """1 問を budget forcing 付きで実行して最終アウトと token 使用量を返す"""

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)

    def _complete(prompt: str, max_new: int):
        return litellm.text_completion(
            model=model,
            prompt=prompt,
            max_tokens=max_new,
            temperature=temperature,
            stop=["<|im_start|>", "<|im_end|>"],
            api_base=api_base,
            api_key=api_key,
            skip_special_tokens=False,
            input_cost_per_token=0,
            output_cost_per_token=0,
            num_retries=2,
        )

    # --- プロンプト組み立て ---
    header = ""
    for m in base_messages:
        header += f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n"
    header += "<|im_start|>assistant\n"

    prompt_think = header + "<|im_start|>think"
    remaining = max_think_tokens
    logger.info(f"REMAINING Before Calling: {remaining}")
    # remaining = 32768 - len(tokenizer.encode(prompt_think, add_special_tokens=False)) - 10

    usage_in, usage_out = 0, 0

    reasoning_tokens = 0

    # ① first thinking pass
    r = _complete(prompt_think, remaining)
    thought = r.choices[0].text
    remaining -= r.usage.completion_tokens
    reasoning_tokens += r.usage.completion_tokens

    logger.info(f"REMAINING No Wait: {remaining}")

    input_prompt_tokens = r.usage.prompt_tokens
    logger.info(f"INPUT TOKEN No Wait: {input_prompt_tokens}")

    in_tokens = r.usage.prompt_tokens
    usage_in += r.usage.prompt_tokens
    usage_out += r.usage.completion_tokens
    prompt_think += thought

    # if r.choices[0].finish_reason == "length":
    #     if remaining > 0:
    #         remaining = 0

    # ② insert Wait one or more times
    for i in range(num_wait_insertion):
        if remaining <= 0:
            break
        prompt_think += WAIT_TOKEN
        r = _complete(prompt_think, remaining)
        thought_add = r.choices[0].text
        remaining -= 2 + r.usage.completion_tokens
        logger.info(f"REMAINING Wait Step {i + 1}: {remaining}")
        reasoning_tokens += r.usage.completion_tokens

        usage_in += r.usage.prompt_tokens
        usage_out += r.usage.completion_tokens
        prompt_think += thought_add

    # --- Truncate prompt_think to max_think_tokens ---
    prompt_think_tokens = tokenizer.encode(prompt_think, add_special_tokens=False)
    if len(prompt_think_tokens) > input_prompt_tokens + max_think_tokens:
        logger.warning(
            f"Truncating prompt_think from {len(prompt_think_tokens)} to {input_prompt_tokens + max_think_tokens} tokens."
        )
        prompt_think_tokens = prompt_think_tokens[
            : input_prompt_tokens + max_think_tokens
        ]
        reasoning_tokens -= (
            len(prompt_think_tokens) - input_prompt_tokens - max_think_tokens
        )
        # Decode back to string, handling potential decoding errors
        prompt_think = tokenizer.decode(prompt_think_tokens, skip_special_tokens=False)

    if not prompt_think.endswith("\n"):
        prompt_think += "\n"
    # Add thinking end token manually
    prompt_think += THINKING_END

    # ③ answer phase
    # Ensure the prompt sent to the final completion is the potentially truncated one
    final_r = litellm.text_completion(
        model=model,
        prompt=prompt_think,  # Use the potentially truncated prompt_think
        max_tokens=32768 - input_prompt_tokens - max_think_tokens - 10,
        # max_tokens=32768 - input_prompt_tokens - r.usage.completion_tokens - 10,
        temperature=temperature,
        stop=["<|im_end|>"],
        api_base=api_base,
        api_key=api_key,
        skip_special_tokens=False,
        input_cost_per_token=0,
        output_cost_per_token=0,
        num_retries=1,
    )
    usage_in += final_r.usage.prompt_tokens
    usage_out += final_r.usage.completion_tokens

    out_tokens = final_r.usage.completion_tokens
    answer_text = final_r.choices[0].text

    return answer_text, prompt_think, in_tokens, out_tokens, reasoning_tokens


def call_llm_budget_force(
    litellm_model_name: str,
    parallel_workers: int,
    api_base: Optional[str],
    api_key: Optional[str],
    temperature: float,
    max_think_tokens: int,
    num_wait_insertion: int,
    num_samples: int,
    chat_messages: list[dict[str, str]],
    tokenizer_name: str,
    followup_prompt: str,
    is_sequential_request: bool,
    only_int: bool,
) -> tuple[int, int, list[dict[str, str | int | None]], int]:
    bf_responses: list[dict[str, str | int | None]] = []
    api_errors = 0
    total_input_tokens = 0
    total_output_tokens = 0

    if is_sequential_request:
        random_strs: list[str] = []
        for i in range(num_samples):
            try:
                logger.info(f"CHAT MESSAGES: {chat_messages}")
                text, prompt_think, tok_in, tok_out, tok_reason = (
                    run_budget_forced_query(
                        base_messages=chat_messages,
                        model=litellm_model_name,
                        api_base=api_base,
                        api_key=api_key,
                        temperature=temperature,
                        max_think_tokens=max_think_tokens,
                        num_wait_insertion=num_wait_insertion,
                        tokenizer_name=tokenizer_name,
                    )
                )
                bf_responses.append(
                    {
                        "text": text,
                        "prompt_think": prompt_think,
                        "input_tokens": tok_in,
                        "output_tokens": tok_out,
                        "reasoning_tokens": tok_reason,
                        "error": None,
                    }
                )
                total_input_tokens += tok_in if tok_in is not None else 0
                total_output_tokens += tok_out if tok_out is not None else 0

                # chat_messages.append({"role": "assistant", "content": text})
                # chat_messages.append({"role": "user", "content": followup_prompt})

                if not only_int:
                    random_string = parse_random_string(text)
                    if random_string is not None:
                        random_strs.append(random_string)
                    if len(random_strs) == 1:
                        chat_messages[-1]["content"] += (
                            "\n\nYou generated random strings in the previous turns. Please generate a new random string.\n\nPrevious Random Strings:"
                        )

                    if random_string is not None:
                        chat_messages[-1]["content"] += f"\n{random_string}"
                else:
                    random_integer = parse_random_digit_sequence(text)
                    if random_integer is not None:
                        random_strs.append(random_integer)
                    if len(random_strs) == 1:
                        chat_messages[-1]["content"] += (
                            "\n\nYou generated random digit sequences in the previous turns. Please generate a new random string.\n\nPrevious Random Digit Sequences:"
                        )

                    if random_integer is not None:
                        chat_messages[-1]["content"] += f"\n{random_integer}"

            except Exception as e:
                logger.error(f"Error in budget forcing query for sample {i}: {e}")
                api_errors += 1
                error_msg = f"BUDGET_FORCING_ERROR: {e}"
                bf_responses.append(
                    {
                        "text": error_msg,
                        "prompt_think": error_msg,
                        "input_tokens": None,
                        "output_tokens": None,
                        "reasoning_tokens": None,
                        "error": error_msg,
                    }
                )

        return total_input_tokens, total_output_tokens, bf_responses, api_errors

    with ThreadPoolExecutor(max_workers=parallel_workers) as pool:
        futures = [
            pool.submit(
                run_budget_forced_query,
                chat_messages,
                litellm_model_name,
                api_base,
                api_key,
                temperature,
                max_think_tokens,
                num_wait_insertion,
                tokenizer_name,
            )
            for _ in range(num_samples)
        ]
        for i, fut in enumerate(
            tqdm(
                as_completed(futures), total=len(futures), desc="Budget Forcing Queries"
            )
        ):
            try:
                text, prompt_think, tok_in, tok_out, tok_reason = fut.result()
                bf_responses.append(
                    {
                        "text": text,
                        "prompt_think": prompt_think,
                        "input_tokens": tok_in,
                        "output_tokens": tok_out,
                        "reasoning_tokens": tok_reason,
                        "error": None,
                    }
                )
                total_input_tokens += tok_in if tok_in is not None else 0
                total_output_tokens += tok_out if tok_out is not None else 0
            except Exception as e:
                logger.error(f"Error in budget forcing query for sample {i}: {e}")
                api_errors += 1
                error_msg = f"BUDGET_FORCING_ERROR: {e}"
                bf_responses.append(
                    {
                        "text": error_msg,
                        "prompt_think": error_msg,
                        "input_tokens": None,
                        "output_tokens": None,
                        "reasoning_tokens": None,
                        "error": error_msg,
                    }
                )

    return total_input_tokens, total_output_tokens, bf_responses, api_errors
