# __all__ = ["MonitoredModel"]

from smolagents.models import ApiModel, ChatMessage
from smolagents.monitoring import TokenUsage
from typing import Generator

# 这是一个我们创建的辅助类，用于包装 LiteLLMModel
class MonitoredModel:
    def __init__(self, model: ApiModel):
        self._model = model
        self.total_token_usage = TokenUsage(input_tokens=0, output_tokens=0)

    def generate(self, *args, **kwargs) -> ChatMessage:
        """
        调用原始模型的 generate 方法，并将token用量累加到总数中。
        """
        # 调用真实模型以获取响应
        response_message = self._model.generate(*args, **kwargs)
        
        # 如果响应包含token用量信息，就将其累加到我们的总数中
        if response_message.token_usage:
            self.total_token_usage.input_tokens += response_message.token_usage.input_tokens
            self.total_token_usage.output_tokens += response_message.token_usage.output_tokens
            
        return response_message

    def generate_stream(self, *args, **kwargs) -> Generator:
        """
        包装原始模型的流式生成方法，以在末尾捕获token用量。
        """
        # 来自原始模型的流
        stream = self._model.generate_stream(*args, **kwargs)
        
        for delta in stream:
            # 如果数据块包含最终的token用量，就将其累加到我们的总数中
            if delta.token_usage:
                self.total_token_usage.input_tokens += delta.token_usage.input_tokens
                self.total_token_usage.output_tokens += delta.token_usage.output_tokens
            yield delta

    # 将任何其他属性的访问都委托给原始模型
    def __getattr__(self, name):
        return getattr(self._model, name)