from typing import Any
from langchain_community.callbacks import OpenAICallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult

PRICING_PER_1K_TOKENS = {
    # Model: (input cost, output cost)
    "gemini-2.0-flash": (0.1/1000, 0.4/1000),
    "claude-3-5-haiku-20241022": (0.8/1000, 4.0/1000),
}

class OutOfBudgetError(Exception):
    """Exception raised when the LLM exceeds the budget."""
    pass

class LLMCallbackHandler(OpenAICallbackHandler):
    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Collect token usage."""
        # Check for usage_metadata (langchain-core >= 0.2.2)
        try:
            generation = response.generations[0][0]
        except IndexError:
            generation = None
        if isinstance(generation, ChatGeneration):
            try:
                message = generation.message
                if isinstance(message, AIMessage):
                    usage_metadata = message.usage_metadata
                    response_metadata = message.response_metadata
                else:
                    usage_metadata = None
                    response_metadata = None
            except AttributeError:
                usage_metadata = None
                response_metadata = None
        else:
            usage_metadata = None
            response_metadata = None
        if usage_metadata:
            token_usage = {"total_tokens": usage_metadata["total_tokens"]}
            completion_tokens = usage_metadata["output_tokens"]
            prompt_tokens = usage_metadata["input_tokens"]
        else:
            if response.llm_output is None:
                return None

            if "token_usage" not in response.llm_output:
                with self._lock:
                    self.successful_requests += 1
                return None

            # compute tokens and cost for this request
            token_usage = response.llm_output["token_usage"]
            completion_tokens = token_usage.get("completion_tokens", 0)
            prompt_tokens = token_usage.get("prompt_tokens", 0)
            # model_name = standardize_model_name(
            #     response.llm_output.get("model_name", "")
            # )
        # if model_name in MODEL_COST_PER_1K_TOKENS:
        #     completion_cost = get_openai_token_cost_for_model(
        #         model_name, completion_tokens, is_completion=True
        #     )
        #     prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
        # else:
        #     completion_cost = 0
        #     prompt_cost = 0
        prompt_cost = prompt_tokens / 1000 * PRICING_PER_1K_TOKENS[self.model_name][0]
        completion_cost = completion_tokens / 1000 * PRICING_PER_1K_TOKENS[self.model_name][1]
        # update shared state behind lock
        with self._lock:
            self.total_cost += prompt_cost + completion_cost
            self.total_tokens += token_usage.get("total_tokens", 0)
            self.prompt_tokens += prompt_tokens
            self.completion_tokens += completion_tokens
            self.successful_requests += 1