import concurrent
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Callable, Tuple

import torch
from google import genai
from google.genai import types
from openai import OpenAI
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizer

from src.api_bank_prompt_functions import (
    api_bank_get_prompt_for_strict_api,
    api_bank_get_prompt_for_template_summarize,
    api_bank_parse_openai_tool_calling,
)
from src.tool_ace_prompt_functions import (
    tool_ace_get_prompt_for_strict_api,
    tool_ace_get_prompt_for_template_summarize,
)
from src.utility import (
    Record,
    dataset_api_bank,
    batch_iterator,
    chat_format_general,
    chat_format_template,
    DataMode,
    ModelResponse,
    dataset_tool_ace,
    dataset_when2call,
    max_new_tokens,
)
from src.when2call_prompt_functions import (
    when2call_get_prompt_for_template_summarize,
    when2call_get_prompt_for_strict_api,
)

logger = logging.getLogger(__name__)


def get_prompt_for_strict_api(
    record: Record, prepare_for_chat_template: bool
) -> list[dict]:
    """
    Given a Record, generate schema based strict API string.
    :param record: Input Record.
    :param prepare_for_chat_template: Process the prompts so it can be used by tokenizer.apply_chat_template.
    :return: Prompts.
    """
    if record.data_set == dataset_api_bank:
        return api_bank_get_prompt_for_strict_api(record, prepare_for_chat_template)
    elif record.data_set == dataset_tool_ace:
        return tool_ace_get_prompt_for_strict_api(record, prepare_for_chat_template)
    elif record.data_set == dataset_when2call:
        return when2call_get_prompt_for_strict_api(record, prepare_for_chat_template)
    else:
        raise RuntimeError(f"Unknown data set:{record.data_set}")


def get_prompt_for_template_summarize(
    record: Record, prepare_for_chat_template: bool
) -> list[dict]:
    """
    Given a Record, generate template based API string.
    :param record: Input Record.
    :param prepare_for_chat_template: Process the prompts so it can be used by tokenizer.apply_chat_template.
    :return: Prompts.
    """
    if record.data_set == dataset_api_bank:
        return api_bank_get_prompt_for_template_summarize(
            record, prepare_for_chat_template
        )
    elif record.data_set == dataset_tool_ace:
        return tool_ace_get_prompt_for_template_summarize(
            record, prepare_for_chat_template
        )
    elif record.data_set == dataset_when2call:
        return when2call_get_prompt_for_template_summarize(
            record, prepare_for_chat_template
        )
    else:
        raise RuntimeError(f"Unknown data set:{record.data_set}")


def extract_text(response, messages):
    texts = []
    for output in response.output:
        if output.type == "message":
            for item in output.content:
                if item.type == "output_text":
                    texts.append(item.text)
    if len(texts) > 1:
        logger.info(f"Multiple output for {messages}: {response}")
    return texts[0]


def query_openai(
    id: str,
    data_set: str,
    prompt: list[dict],
    model_name: str,
    openai_client: OpenAI,
    temperature: float,
    top_p: float,
) -> ModelResponse:
    """
    Query OpenAI API.
    :param id: ID for the test record.
    :param data_set: The data set for the test record.
    :param prompt: Prompt to use when calling OpenAI.
    :param model_name: OpenAI model name.
    :param openai_client: OpenAI client to use.
    :param temperature: Temperature used when calling OpenAI.
    :param top_p: top_p used when calling OpenAI.
    :return: ModelResponse containing results from OpenAI.
    """
    messages = []
    tool_calling_idx = 0
    for message in prompt:
        messages.append(message)

    if model_name.startswith("gpt-4"):
        response = openai_client.chat.completions.create(
            model=model_name, messages=messages, temperature=temperature, top_p=top_p
        )
        output_text = response.choices[0].message.content
    else:
        reason = os.getenv("reason", "minimal")
        if len(reason) == 0:
            reason = "minimal"
        if reason.startswith("."):
            reason = reason.removeprefix(".")
        if reason not in ["minimal", "low", "medium", "high"]:
            raise RuntimeError(f"Unexpected reasoning level:{reason}")
        response = openai_client.responses.create(
            model=model_name,
            input=messages,
            reasoning={"effort": reason},
            text={"verbosity": "low"},
        )

        output_text = extract_text(response, messages)
    result = ModelResponse(id=id, data_set=data_set, response=output_text)
    log_obj = {
        "id": id,
        "data_set": data_set,
        "model_name": model_name,
        "prompt": messages,
        "temperature": temperature,
        "top_p": top_p,
        "response": output_text,
    }
    logger.info(json.dumps(log_obj))
    return result


def get_openai_result(
    *,
    output_file: str,
    record_keys: list,
    records_with_key: dict[str, Record],
    model_name: str,
    prompt_func: Callable[[Record, bool], list[dict]],
    openai_client: OpenAI,
    batch_size: int,
    temperature: float,
    top_p: float,
):
    """
    Query OpenAI API and write results to file.
    :param output_file: Output file.
    :param record_keys: The key of records will be used to call OpenAI.
    :param records_with_key: Record ID and record object mapping.
    :param model_name: OpenAI model name.
    :param prompt_func: Function used to create the prompt.
    :param openai_client: OpenAI client to use.
    :param batch_size: Batch size when calling OpenAI.
    :param temperature: Temperature used when calling OpenAI.
    :param top_p: top_p used when calling OpenAI.
    """
    with open(output_file, "a") as fout:
        for keys in tqdm(
            batch_iterator(record_keys, batch_size=batch_size),
            total=len(record_keys) // batch_size,
        ):
            params = []
            for key in keys:
                record = records_with_key[key]
                prompts = prompt_func(record, True)
                params.append(
                    {
                        "id": record.id,
                        "data_set": record.data_set,
                        "prompt": prompts,
                        "model_name": model_name,
                        "openai_client": openai_client,
                        "temperature": temperature,
                        "top_p": top_p,
                    }
                )

            with ThreadPoolExecutor() as executor:
                futures = [executor.submit(query_openai, **param) for param in params]
                for future in concurrent.futures.as_completed(futures):
                    result = future.result()
                    fout.write(f"{json.dumps(asdict(result), ensure_ascii=False)}\n")


def query_gemini(
    id: str,
    data_set: str,
    prompt: list[dict],
    model_name: str,
    gemini_client: genai.Client,
    temperature: float,
    top_p: float,
) -> ModelResponse:
    """
    Query Gemini API.
    :param id: ID for the test record.
    :param data_set: The data set for the test record.
    :param prompt: Prompt to use when calling Gemini.
    :param model_name: Gemini model name.
    :param gemini_client: Gemini client to use.
    :param temperature: Temperature used when calling Gemini.
    :param top_p: top_p used when calling Gemini.
    :return: ModelResponse containing results from Gemini.
    """
    lines = [x["content"] for x in prompt]
    contents = "\n".join([x for x in lines if x is not None])

    response = gemini_client.models.generate_content(
        model=model_name,
        contents=contents,
        config=types.GenerateContentConfig(
            top_p=top_p,
            temperature=temperature,
        ),
    )
    output_text = response.text
    result = ModelResponse(id=id, data_set=data_set, response=output_text)
    log_obj = {
        "id": id,
        "data_set": data_set,
        "model_name": model_name,
        "prompt": contents,
        "temperature": temperature,
        "top_p": top_p,
        "response": output_text,
    }
    logger.info(json.dumps(log_obj))
    return result


def get_gemini_result(
    *,
    output_file: str,
    record_keys: list,
    records_with_key: dict[str, Record],
    model_name: str,
    prompt_func: Callable[[Record, bool], list[dict]],
    gemini_client: genai.Client,
    batch_size: int,
    temperature: float,
    top_p: float,
):
    """
    Query Gemini API and write results to file.
    :param output_file: Output file.
    :param record_keys: The key of records will be used to call Gemini.
    :param records_with_key: Record ID and record object mapping.
    :param model_name: Gemini model name.
    :param prompt_func: Function used to create the prompt.
    :param gemini_client: Gemini client to use.
    :param batch_size: Batch size when calling Gemini.
    :param temperature: Temperature used when calling Gemini.
    :param top_p: top_p used when calling Gemini.
    """
    with open(output_file, "a") as fout:
        for keys in tqdm(
            batch_iterator(record_keys, batch_size=batch_size),
            total=len(record_keys) // batch_size,
        ):
            params = []
            for key in keys:
                record = records_with_key[key]
                prompts = prompt_func(record, False)
                if len(prompts) == 0:
                    logger.info(f"Empty prompt for record:{record}")
                    continue
                params.append(
                    {
                        "id": record.id,
                        "data_set": record.data_set,
                        "prompt": prompts,
                        "model_name": model_name,
                        "gemini_client": gemini_client,
                        "temperature": temperature,
                        "top_p": top_p,
                    }
                )

            with ThreadPoolExecutor() as executor:
                futures = [executor.submit(query_gemini, **param) for param in params]
                for future in concurrent.futures.as_completed(futures):
                    result = future.result()
                    fout.write(f"{json.dumps(asdict(result), ensure_ascii=False)}\n")


def prompts_to_string(
    prompts: list[dict], chat_format: str, tokenizer: PreTrainedTokenizer
) -> str:
    """
    Convert list of prompts to string.
    :param prompts: List of json based prompts.
    :param chat_format: General chat messages or format that need apply_chat_template.
    :param tokenizer: Tokenizer to use.
    :return: Prompt string.
    """
    if chat_format == chat_format_template:
        result = str(tokenizer.apply_chat_template(prompts, tokenize=False))
    elif chat_format == chat_format_general:
        lines = [x["content"] for x in prompts]
        result = "\n".join(lines) + "\n"
    else:
        raise RuntimeError
    return result


debug_max_sample = 5
max_new_token_ratio = 0.2


def get_hf_llm_result(
    *,
    output_file: str,
    record_keys: list,
    records_with_key: dict[str, Record],
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    prompt_func: Callable[[Record, bool], list[dict]],
    batch_size: int,
    temperature: float,
    top_p: float,
    chat_format: str,
    do_sample: bool,
):
    """
    Query HuggingFace compatible LLM and write results to file.
    :param output_file: Output file.
    :param record_keys: The key of records will be used to call OpenAI.
    :param records_with_key: Record ID and record object mapping.
    :param model: The model to use.
    :param tokenizer: The tokenizer to use.
    :param prompt_func: Function used to create the prompt.
    :param batch_size: Batch size when calling OpenAI.
    :param temperature: Temperature used when calling OpenAI.
    :param top_p: top_p used when calling OpenAI.
    :param chat_format: General chat messages or format that need apply_chat_template.
    :param do_sample: Do sample or not when doing LLM generation.
    """
    if len(record_keys) == 0:
        return

    gen_kwargs = {
        "temperature": temperature,
        "top_p": top_p,
        "do_sample": do_sample,
        "max_new_tokens": max_new_tokens,
        "pad_token_id": tokenizer.eos_token_id,
    }
    sample = 0
    with open(output_file, "a") as fout:
        for keys in tqdm(
            batch_iterator(record_keys, batch_size=batch_size),
            total=len(record_keys) // batch_size,
        ):
            model_inputs = []
            ids = []
            for key in keys:
                record = records_with_key[key]
                prompts = prompt_func(record, chat_format == chat_format_template)
                model_input = prompts_to_string(prompts, chat_format, tokenizer)
                model_inputs.append(model_input)
                ids.append(record.id)

            inputs = tokenizer(
                model_inputs, return_tensors="pt", padding=True, truncation=True
            ).to(model.device)
            input_lengths = (inputs.input_ids != tokenizer.pad_token_id).sum(dim=1)

            with torch.no_grad():
                outputs = model.generate(**inputs, **gen_kwargs)

            decoded = []
            for i in range(len(outputs)):
                gen_tokens = outputs[i, input_lengths[i] :]
                text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
                decoded.append(text.strip())

            if sample < debug_max_sample:
                logger.info("============parameters==========")
                logger.info(gen_kwargs)
                for i in range(len(model_inputs)):
                    logger.info("============prompt==========")
                    logger.info(model_inputs[i])
                    logger.info("============generated==========")
                    logger.info(decoded[i])
                sample += 1
            for id, model_input, output_text in zip(ids, model_inputs, decoded):
                result = ModelResponse(
                    id=id, data_set=record.data_set, response=output_text
                )
                log_obj = {
                    "id": id,
                    "data_set": record.data_set,
                    "model_name": model.name_or_path,
                    "prompt": model_input,
                    "temperature": temperature,
                    "top_p": top_p,
                    "response": output_text,
                }
                logger.info(json.dumps(log_obj))
                fout.write(f"{json.dumps(asdict(result), ensure_ascii=False)}\n")


def prepare_source_and_target(
    mode: DataMode,
    truth: dict[str, Record],
    tokenizer: PreTrainedTokenizer,
    chat_format: str,
) -> Tuple[list, list]:
    """
    Create source and target for LLM finetuning.
    :param mode: Mode of data generation.
    :param truth: Dict of truth id and truth.
    :param tokenizer: Tokenizer to use.
    :param chat_format: General chat messages or format that need apply_chat_template.
    :return: Tuple of source and target.
    """
    sources = []
    targets = []
    for _, value in truth.items():
        if value is None:
            continue
        if mode == DataMode.UtteranceToAPICall:
            prompts = get_prompt_for_strict_api(
                value, chat_format == chat_format_template
            )
            targets.append(f"{value.output}{tokenizer.eos_token}")
        elif mode == DataMode.UtteranceToSummary:
            prompts = get_prompt_for_template_summarize(
                value, chat_format == chat_format_template
            )
            template_output_str = "\n".join(value.template_output)
            targets.append(f"{template_output_str}{tokenizer.eos_token}")
        else:
            raise NotImplementedError
        source = prompts_to_string(prompts, chat_format, tokenizer)
        sources.append(source)

    return sources, targets
