#import os
from google import genai
from google.genai.types import GenerateContentConfig
from types import SimpleNamespace


class GeminiClient:
    """
    A client for interacting with the Google Gemini API on Vertex AI, with support for log probabilities.
    To run the code on Vertext AI, you should first:
    1. brew install google-cloud-sdk
    2. gcloud auth application-default login
    3. gcloud config set project <your-project-id>
    reference page: https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/logprobs/intro_logprobs.ipynb
    """
    def __init__(self, api_key=None, project_id=None, location="global", model="gemini-2.0-flash"):
        self.model = model
        self.location = location
        self.chat_history = []

        if project_id:  # Vertex AI mode
            self.client = genai.Client(vertexai=True, project=project_id, location=location)
        elif api_key:  # API key mode
            self.client = genai.Client(api_key=api_key)
        else:
            raise ValueError("Either project_id (for Vertex AI) or api_key must be provided.")

    def add_message(self, role, content):
        self.chat_history.append({"role": role, "content": content})

    def ask(
        self,
        user_input: str,
        temperature: float = 0,
        remember_history: bool = False,
        max_tokens: int = None,
        logprobs: bool = False,
        top_logprobs: int = None,
        **extra_kwargs
    ):
        if remember_history:
            self.add_message("user", user_input)
        else:
            self.chat_history = [{"role": "user", "content": user_input}]

        generation_config = GenerateContentConfig(
            response_mime_type="text/plain",
            temperature=temperature,
            max_output_tokens=max_tokens,
            response_logprobs=logprobs,
            logprobs=top_logprobs if logprobs else None,
            **extra_kwargs
        )

        response = self.client.models.generate_content(
            model=self.model,
            contents=user_input,
            config=generation_config,
        )

        gemini_candidate = response.candidates[0]
        content = response.text

        transformed_logprobs_content = []
        if logprobs and gemini_candidate.logprobs_result:
            for chosen_candidate in gemini_candidate.logprobs_result.chosen_candidates:
                token_obj = SimpleNamespace(token=chosen_candidate.token, logprob=chosen_candidate.log_probability)
                transformed_logprobs_content.append(token_obj)

        message = SimpleNamespace(content=content)
        logprobs_obj = SimpleNamespace(content=transformed_logprobs_content)
        choice = SimpleNamespace(message=message, logprobs=logprobs_obj)

        return choice

    def count_tokens(self, text_input: str) -> int:
        try:
            response = self.client.models.count_tokens(
                model=self.model,
                contents=text_input
            )
            return response.total_tokens
        except Exception as e:
            print(f"Error counting tokens: {e}")
            return 0
    
    def reset(self):
        self.chat_history = []

if __name__ == "__main__":
    project_id = ""
    try:
        client = GeminiClient(project_id=project_id, model="gemini-2.0-flash")

        # example 1
        response = client.ask("What is the capital of France?")
        print(response.message.content)

        # example 2 with logprobs
        response = client.ask("The sky is", logprobs=True, top_logprobs=1)
        print("Response Text:", response.message.content)
        
        if response.logprobs and response.logprobs.content:
            total_logprob = 0
            for token_logprob in response.logprobs.content:
                print(f"  - Token '{token_logprob.token}' has log-probability: {token_logprob.logprob:.4f}")
                total_logprob += token_logprob.logprob
            print(f"Total log-probability for the response: {total_logprob:.4f}")
        else:
            print("No logprobs returned.")

    except Exception as e:
        print(f"An error occurred: {e}")
