import os
from src.llm_clients.base_llm_client import BaseLLMClient
import google.generativeai as genai
from google.api_core.exceptions import ResourceExhausted
from google.generativeai.types import GenerationConfig
from typing import Tuple, Dict
import time

class GeminiClient(BaseLLMClient): 
    def __init__(self, model: str = "gemini-2.5-pro"):
        API_KEY = os.getenv("GEMINI_API_KEY")  
        genai.configure(api_key=API_KEY)
        self.model = genai.GenerativeModel(model)

    def count_tokens(self, text: str) -> int:
        """count the number of tokens in a given text using Gemini's tokenizer."""
        return self.model.count_tokens(text).total_tokens

    def call(self, prompt: str) -> Tuple[str, Dict[str, int]]:
        """Calls the Gemini API and returns the response along with token counts."""
        max_retries = 10
        backoff = 1
        for attempt in range(max_retries):
            try:
                input_tokens = self.count_tokens(prompt)  # Count input tokens
                response = self.model.generate_content(
                    prompt,
                    generation_config=GenerationConfig(temperature=0.0)
                )

                returned_msg = response.text  # Generated text

                output_tokens = self.count_tokens(returned_msg)  # Count output tokens
                total_tokens = input_tokens + output_tokens

                return returned_msg, {
                    "input_tokens": input_tokens,
                    "output_tokens": output_tokens,
                    "total_tokens": total_tokens
                }
            except ResourceExhausted as e:
                if attempt < max_retries - 1:
                    print(f"Resource exhausted, Attemp {attempt}, retrying in {backoff} seconds...")
                    time.sleep(backoff)
                    backoff *= 2
                else:
                    raise
            

if __name__ == "__main__":
    client = GeminiClient()
    response, token_info = client.call("What is the capital of France?")
    print(response)
    print(token_info)
