import os
import json
import base64
import requests
import time
import sys
import io
from PIL import Image


class RateLimiter:
    def __init__(self, max_calls: int, period: float):
        self.max_calls = max_calls
        self.period = period  # period in seconds (e.g., 60 seconds for a minute)
        self.call_timestamps = []

    def wait(self):
        current_time = time.time()
        # Remove timestamps older than the period window
        self.call_timestamps = [
            t for t in self.call_timestamps if current_time - t < self.period
        ]
        if len(self.call_timestamps) >= self.max_calls:
            # Calculate time to wait: difference between period and time since the oldest call
            wait_time = self.period - (current_time - self.call_timestamps[0])
            print(f"Rate limit reached. Waiting {wait_time:.2f} seconds.")
            time.sleep(wait_time)
        self.call_timestamps.append(time.time())

    def average_calls_per_minute(self) -> float:
        """Returns the average number of calls per minute based on the current window."""
        current_time = time.time()
        # Remove outdated timestamps
        self.call_timestamps = [
            t for t in self.call_timestamps if current_time - t < self.period
        ]
        return len(self.call_timestamps) * (60 / self.period)


APIBASES = {
    "OFFICIAL": "https://api.openai.com/v1/chat/completions",
}


def GPT_context_window(model):
    length_map = {
        "gpt-4": 8192,
        "gpt-4-0613": 8192,
        "gpt-4-turbo-preview": 128000,
        "gpt-4-1106-preview": 128000,
        "gpt-4-0125-preview": 128000,
        "gpt-4-vision-preview": 128000,
        "gpt-4-turbo": 128000,
        "gpt-4-turbo-2024-04-09": 128000,
        "gpt-3.5-turbo": 16385,
        "gpt-3.5-turbo-0125": 16385,
        "gpt-3.5-turbo-1106": 16385,
        "gpt-3.5-turbo-instruct": 4096,
    }
    if model in length_map:
        return length_map[model]
    else:
        return 128000


class OpenAIWrapper:
    def __init__(
        self,
        model: str = "gpt-4-vision-preview",  # Using a model that supports vision capabilities
        api_base="https://api.openai.com/v1/chat/completions",
        retry: int = 5,
        wait: int = 5,
        key=os.environ.get("OPENAI_API_KEY", ""),
        verbose: bool = False,
        system_prompt: str = None,
        temperature: float = 0.9,
        timeout: int = 10,
        max_tokens: int = 32768,
        img_size: int = 512,
        img_detail: str = "low",
        use_azure: bool = False,
        **kwargs,
    ):
        self.model = model
        self.cur_idx = 0
        self.fail_msg = "Failed to obtain answer via API. "
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.use_azure = use_azure
        self.key = key
        assert img_size > 0 or img_size == -1
        self.img_size = img_size
        assert img_detail in ["high", "low"]
        self.img_detail = img_detail
        self.timeout = timeout
        self.api_base = api_base
        self.rate_limiter = None
        if "free" in model:  # OpenRouter supports 20 calls per 60 seconds
            self.rate_limiter = RateLimiter(max_calls=15, period=60)
        else:
            self.rate_limiter = RateLimiter(max_calls=300, period=60)

    def _determine_output_format_and_mime(
        self, original_pil_format: str | None
    ) -> tuple[str, str]:
        """
        Determines the PIL format string to save as and its corresponding MIME type.
        Prefers common original web formats; otherwise, defaults to PNG.
        """
        pil_save_format = "PNG"  # Default PIL format string for saving
        output_mime_type = "image/png"  # Default MIME type

        if original_pil_format:
            fmt_upper = original_pil_format.upper()
            if fmt_upper == "JPEG":
                return "JPEG", "image/jpeg"
            elif fmt_upper == "PNG":
                return "PNG", "image/png"
            elif fmt_upper == "GIF":
                return "GIF", "image/gif"
            elif fmt_upper == "WEBP":
                return "WEBP", "image/webp"

        return pil_save_format, output_mime_type

    def encode_image_to_base64(self, img_path: str, target_size: int) -> str:
        """
        Encode image as base64 string for API input.
        """
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image file not found: {img_path}")

        # with Image.open(img_path) as img:
        #     original_width, original_height = img.size
        #     original_pil_format = img.format  # E.g., "JPEG", "PNG", None
        #     pil_save_format, output_mime_type = self._determine_output_format_and_mime(
        #         original_pil_format
        #     )

        #     needs_resize = False
        #     if original_width > target_size or original_height > target_size:
        #         needs_resize = True
        #         if original_width >= original_height:
        #             new_width = target_size
        #             new_height = int(original_height * (target_size / original_width))
        #         else:
        #             new_height = target_size
        #             new_width = int(original_width * (target_size / original_height))
        #         new_width = max(1, new_width)
        #         new_height = max(1, new_height)

        #     img_byte_arr = io.BytesIO()
        #     img_format = img.format if img.format else "PNG"
        #     if needs_resize:
        #         # Pillow versions 9.1.0+ use Image.Resampling.LANCZOS
        #         # Older versions use Image.LANCZOS
        #         try:
        #             resampling_filter = Image.Resampling.LANCZOS
        #         except AttributeError:
        #             resampling_filter = Image.LANCZOS
        #         resized_img = img.resize((new_width, new_height), resampling_filter)
        #         resized_img.save(img_byte_arr, format=img_format)
        #     else:
        #         img.save(img_byte_arr, format=img_format)

        #     img_byte_arr = img_byte_arr.getvalue()

        # encoded_image = base64.b64encode(img_byte_arr).decode("utf-8")

        with open(img_path, "rb") as image_file:
            encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
        return encoded_image

    def get_prediction(
        self, image_path=None, prompt=None, passages=None, passage_prompt=None
    ) -> str:
        """
        Get prediction for a question/image pair.
        """
        if self.rate_limiter:
            self.rate_limiter.wait()

        # Initialize the list for message content
        content_list = []

        # Add passages (if any) at the beginning of the content list
        if passages:
            for passage in passages[::-1]:
                if isinstance(passage, dict):
                    if "image_path" in passage:
                        passage_image_path = passage["image_path"]
                        b64_passage = self.encode_image_to_base64(
                            passage_image_path, target_size=self.img_size
                        )
                        img_struct = dict(
                            url=f"data:image/jpeg;base64,{b64_passage}",
                            detail=self.img_detail,
                        )
                        content_list.append(
                            dict(type="image_url", image_url=img_struct)
                        )
                    if "caption" in passage:
                        passage_text = passage["caption"]
                        content_list.append(dict(type="text", text=f"{passage_text}"))
                if isinstance(passage, tuple) and len(passage) == 2:  # Image-text pair
                    passage_image_path, passage_text = passage
                    b64_passage = self.encode_image_to_base64(
                        passage_image_path, target_size=self.img_size
                    )
                    img_struct = dict(
                        url=f"data:image/jpeg;base64,{b64_passage}",
                        detail=self.img_detail,
                    )
                    content_list.append(dict(type="image_url", image_url=img_struct))
                    content_list.append(
                        dict(type="text", text=f"Passage: {passage_text}")
                    )
                elif isinstance(passage, str):  # Text-only passage
                    content_list.append(dict(type="text", text=f"Passage: {passage}"))
        if passage_prompt:
            content_list.append(dict(type="text", text=passage_prompt))

        # Encode the query image
        if not isinstance(image_path, list):
            image_path = [image_path]
        for img in image_path:
            b64 = self.encode_image_to_base64(img, target_size=self.img_size)
            img_struct = dict(
                url=f"data:image/jpeg;base64,{b64}", detail=self.img_detail
            )
            content_list.append(dict(type="image_url", image_url=img_struct))
        if prompt:
            content_list.append(dict(type="text", text=prompt))

        # Prepare the input messages for the API request
        input_msgs = [{"role": "user", "content": content_list}]

        # Call the model to get the prediction
        response = self.generate_inner(input_msgs)

        # Return the model's response
        ret_code, answer, _ = response
        if ret_code != 0:
            return self.fail_msg
        return answer

    def generate_inner(self, input_msgs) -> str:
        context_window = GPT_context_window(self.model)

        # Create the payload for the OpenAI API request
        payload = dict(
            model=self.model,
            messages=input_msgs,
            max_tokens=self.max_tokens,
            n=1,
            temperature=self.temperature,
        )
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.key}",
        }
        # if "free" in self.model:
        #    payload["provider"] = {'order': ['Chutes']}

        # Sending the request to the API
        response = requests.post(
            self.api_base,
            headers=headers,
            data=json.dumps(payload),
            timeout=self.timeout * 1.1,
        )

        ret_code = response.status_code
        ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
        answer = self.fail_msg
        try:
            resp_struct = json.loads(response.text)
            if "error" in resp_struct:  # For OpenRouter, exit on error
                print(f"Error: {resp_struct['error']}")
                sys.exit()
            answer = resp_struct["choices"][0]["message"]["content"].strip()
        except Exception as err:
            print(f"{type(err)}: {err}")
            print(response.text if hasattr(response, "text") else response)
        return ret_code, answer, response
