import gc
import hashlib
import logging
import sys
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Tuple

import torch
from transformers import (
    PreTrainedModel,
    PreTrainedTokenizer,
    AutoModelForCausalLM,
    AutoTokenizer,
)

logger = logging.getLogger(__name__)

dataset_api_bank = "api_bank"
dataset_tool_ace = "tool_ace"
dataset_when2call = "when2call"


default_openai_model_name = "gpt-4o-2024-11-20"
default_gemini_model_name = "gemini-2.5-flash"
default_batch_size = 8
default_temperature = 0.2
default_top_p = 0.1
default_prompt_func = "get_prompt_for_strict_api"
default_model_max_length = 2048

chat_format_template = "template"
chat_format_general = "general"

max_new_tokens = 1000

system_role = "system"
user_role = "user"
ai_role = "assistant"
tool_role = "tool"


class DataMode(Enum):
    UtteranceToAPICall = 1
    UtteranceToSummary = 2


@dataclass
class APICall:
    api_name: str
    params: dict
    responses: Optional[list[str]] = None


@dataclass
class ModelResponse:
    id: str
    data_set: str
    response: str


API_CAN_NOT_ANSWER = "cannot_answer"
API_REQUEST_FOR_INFO = "request_for_info"
API_TOOL_CALL = "tool_call"
API_INVALID_RESPONSE = "invalid_response"


@dataclass
class APICallStatus:
    api_call_status: str
    api_calls: Optional[list[APICall]] = None

    @classmethod
    def from_dict(cls, attributes: dict) -> "APICallStatus":
        attrs = attributes.copy()
        if "api_calls" in attributes and isinstance(attributes["api_calls"], list):
            attrs["api_calls"] = []
            for api_call in attributes["api_calls"]:
                if isinstance(api_call, dict):
                    attrs["api_calls"].append(APICall(**api_call))
        result = APICallStatus(**attrs)
        return result


@dataclass
class Record:
    id: str
    data_set: str
    pre_api: list[str]
    api_def: list[str]
    conversation: list[str]
    ending: list[str]
    output: Optional[str]
    post_api: list[str]
    api_call: APICallStatus
    template_output: Optional[list[str]] = None
    summarize_output: Optional[str] = None

    @classmethod
    def from_dict(cls, attributes: dict) -> "Record":
        attrs = attributes.copy()
        if "api_call" in attributes and isinstance(attributes["api_call"], dict):
            attrs["api_call"] = APICallStatus.from_dict(attributes["api_call"])
        record = Record(**attrs)
        return record

    def get_raw_prompt(self) -> str:
        return "\n".join(
            self.pre_api
            + self.api_def
            + self.post_api
            + self.conversation
            + self.ending
        )


@dataclass
class EvalCountResult:
    predicted_api_num: int
    predicted_param_num: int
    correct_api_num: int
    correct_param_num: int


class Status(str, Enum):
    Correct = "Correct"
    NoPrediction = "NoPrediction"
    ReasoningFailure = "ReasoningFailure"
    SchemaNotMatch = "SchemaNotMatch"
    IncorrectAPI = "IncorrectAPI"
    IncorrectParam = "IncorrectParam"
    IncorrectParamValue = "IncorrectParamValue"
    IncorrectAPIStatus = "IncorrectAPIStatus"


@dataclass
class EvalResult:
    id: str
    truth: Record
    prediction: Optional[list[APICall]]
    status: Status
    predicted_api_num: int
    correct_api_num: int
    truth_api_num: int
    predicted_param_num: int
    correct_param_num: int
    truth_param_num: int


class Section(Enum):
    Unknown = 0
    PreAPI = 1
    APIDef = 2
    PostAPI = 3
    Conversation = 4
    Ending = 5


def setup_logging(write_stdout: bool = False):
    """
    Setup logging env.
    :param write_stdout: Should log to STDOUT or not.
    """
    if write_stdout:
        logging.basicConfig(level=logging.INFO, stream=sys.stdout)
    else:
        logging.basicConfig(level=logging.INFO, handlers=[])
    logging.getLogger("httpx").setLevel(logging.WARNING)
    logging.getLogger("google_genai").setLevel(logging.WARNING)


def batch_iterator(iterable, batch_size=default_batch_size):
    """A simple batch iterator."""
    batch = []
    for sample in iterable:
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if batch:
        yield batch


def get_device():
    """Get available Torch device."""
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    return device


def prepare_model(
    model_name: str,
    model_max_length: int,
    device: Optional[str],
    torch_dtype: Optional[torch.dtype] = None,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
    """
    Prepare HF model and tokenizer.
    :param model_name: Name or path of the model.
    :param model_max_length: Max length of model.
    :param device: Device to use.
    :param torch_dtype: Dtype of model.
    :return: Tuple of the HF model and tokenizer.
    """
    logger.info(f"Loading model: {model_name}")
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype)
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        model_max_length=model_max_length,
        padding_side="left",
        use_fast=False,
    )
    if not tokenizer.pad_token:
        logger.info(f"Using eos_token {tokenizer.eos_token} as pad_token")
        tokenizer.pad_token = tokenizer.eos_token
    if not tokenizer.bos_token:
        logger.info(f"Using eos_token {tokenizer.eos_token} as bos_token")
        tokenizer.bos_token = tokenizer.eos_token
    if not device:
        device = get_device()
    logger.info(f"Using device: {device}")
    model.to(device)
    model.eval()
    return model, tokenizer


def house_clean():
    """Clean CUDA cache and GC."""
    logger.info("House cleaning...")
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    gc.collect()


def generate_id(text: str, length: int = 32) -> str:
    """
    Generate a deterministic ID based on the text.
    :param text: Input text
    :param length: Length of the ID
    :return: ID in string
    """
    hash_obj = hashlib.sha256(text.encode("utf-8"))
    return hash_obj.hexdigest()[:length]
