import os
import PIL.Image
import google.generativeai as genai


class GeminiWrapper:
    def __init__(
        self,
        model: str = "gemini-2.0-flash-exp",
        key: str = None,
        retry: int = 5,
        wait: int = 5,
        verbose: bool = False,
        system_prompt: str = None,
        temperature: float = 0.9,
        timeout: int = 60,
        max_tokens: int = 1024,
        **kwargs,
    ):
        """Initialize the GeminiWrapper with model and API settings."""
        self.model_name = model
        # Use provided key or fall back to environment variable
        env_key = os.environ.get("GEMINI_API_KEY", "")
        self.key = key if key else env_key
        if not self.key:
            raise ValueError("API key is required for Gemini API.")

        # Configure the Gemini API with the key
        genai.configure(api_key=self.key)
        self.model = genai.GenerativeModel(self.model_name)

        self.temperature = temperature
        self.max_tokens = max_tokens
        self.timeout = timeout
        self.verbose = verbose
        self.fail_msg = "Failed to obtain answer via API."
        # Note: system_prompt is included for compatibility but not used yet
        self.system_prompt = system_prompt

    def get_prediction(
        self, image_path=None, prompt=None, passages=None, passage_prompt=None
    ) -> str:
        """Generate a prediction based on image and text inputs."""
        contents = []

        if not isinstance(image_path, list):
            image_path = [image_path]

        # Process passages (image-text pairs or text-only)
        if passages:
            for passage in passages:  # Reverse order to match OpenAIWrapper
                if isinstance(passage, dict):
                    if "image_path" in passage:
                        img_str = passage["image_path"]
                        pil_image = PIL.Image.open(img_str)
                        contents.append(pil_image)
                    if "caption" in passage:
                        contents.append(f"Caption: {passage['caption']}")
                if isinstance(passage, tuple) and len(passage) == 2:
                    passage_image_path, passage_text = passage
                    try:
                        pil_image = PIL.Image.open(passage_image_path)
                        contents.append(pil_image)
                        contents.append(f"Passage: {passage_text}")
                    except FileNotFoundError:
                        if self.verbose:
                            print(f"Image file not found: {passage_image_path}")
                        return self.fail_msg
                elif isinstance(passage, str):
                    contents.append(f"Passage: {passage}")

        # Add passage prompt if provided
        if passage_prompt:
            contents.append(passage_prompt)

        for img_path in image_path:
            try:
                pil_image = PIL.Image.open(img_path)
                contents.append(pil_image)
            except FileNotFoundError:
                if self.verbose:
                    print(f"Image file not found: {img_path}")
                return self.fail_msg

        if prompt:
            contents.append(prompt)

        response = self.generate_inner(contents)
        ret_code, answer, _ = response
        if ret_code != 0:
            return self.fail_msg
        return answer

    def generate_inner(self, input_contents):
        """Internal method to call the Gemini API and handle the response."""
        try:
            # Configure generation parameters
            generation_config = {
                "temperature": self.temperature,
                "max_output_tokens": self.max_tokens,
            }
            # Call the Gemini API
            response = self.model.generate_content(
                contents=input_contents,
                generation_config=generation_config,
            )
            answer = response.text.strip()
            ret_code = 0
        except Exception as e:
            if self.verbose:
                print(f"Error during API call: {e}")
            answer = self.fail_msg
            ret_code = 1
            response = None

        return ret_code, answer, response
