import base64
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from threading import Lock
from typing import List, Tuple, Union

import numpy as np
from accelerate import Accelerator, DistributedType
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from tqdm import tqdm

try:
    from decord import VideoReader, cpu
except ImportError:
    pass

from dotenv import load_dotenv
from loguru import logger as eval_logger
from openai import AzureOpenAI, OpenAI
from PIL import Image

load_dotenv(verbose=True)


@register_model("parallel_openai_compatible")
class ParallelOpenAICompatible(lmms):
    """
    A class for parallel processing with OpenAI-compatible API endpoints.
    Inherits from lmms (likely a base class for language model management systems).

    This class enables concurrent API requests to multiple OpenAI-compatible endpoints,
    allowing for improved throughput and efficiency in language model operations.

    Key features:
    - Handles parallel API requests by multithreading
    - Maintains compatibility with OpenAI API specifications
    - Inherits core functionality from lmms parent class

    Note: This is a class stub - implementation details would be in the methods.
    """

    def __init__(
        self,
        model_version: str = "gpt-5.2",
        timeout: int = 10,
        max_retries: int = 3,
        max_size_in_mb: int = 10,
        continual_mode: bool = False,
        response_persistent_folder: str = None,
        azure_openai: bool = False,
        max_num_frames: int = 10,
        max_workers: int = 8,
        **kwargs,
    ) -> None:
        super().__init__()
        self.model_version = model_version
        self.timeout = timeout
        self.max_retries = max_retries
        self.max_size_in_mb = (
            max_size_in_mb  # some models have a limit on the size of the image
        )
        self.continual_mode = continual_mode
        self.max_num_frames = max_num_frames
        if self.continual_mode:
            if response_persistent_folder is None:
                raise ValueError(
                    "Continual mode requires a persistent path for the response. Please provide a valid path."
                )

            os.makedirs(response_persistent_folder, exist_ok=True)
            self.response_persistent_folder = response_persistent_folder
            self.response_persistent_file = os.path.join(
                self.response_persistent_folder, f"{self.model_version}_response.json"
            )

            if os.path.exists(self.response_persistent_file):
                with open(self.response_persistent_file, "r") as f:
                    self.response_cache = json.load(f)
                self.cache_mode = "resume"
            else:
                self.response_cache = {}
                self.cache_mode = "start"

        self.client = (
            OpenAI(
                api_key=os.getenv("OPENAI_API_KEY"),
                base_url=os.getenv("OPENAI_API_BASE"),
                default_headers=kwargs.get("default_headers"),
            )
            if not azure_openai
            else AzureOpenAI(
                api_key=os.getenv("AZURE_OPENAI_API_KEY"),
                azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"),
                api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
            )
        )

        accelerator = Accelerator()
        # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
        if accelerator.num_processes > 1:
            assert accelerator.distributed_type in [
                DistributedType.FSDP,
                DistributedType.MULTI_GPU,
                DistributedType.DEEPSPEED,
            ], "Unsupported distributed type provided. Only DDP and FSDP are supported."
            self.accelerator = accelerator
            if self.accelerator.is_local_main_process:
                eval_logger.info(
                    f"Using {accelerator.num_processes} devices with data parallelism"
                )
            self._rank = self.accelerator.local_process_index
            self._world_size = self.accelerator.num_processes
        else:
            self.accelerator = accelerator
            self._rank = self.accelerator.local_process_index
            self._world_size = self.accelerator.num_processes

        self.device = self.accelerator.device
        self.max_workers = max_workers
        self.cache_lock = Lock() if self.continual_mode else None

    # Function to encode the image
    def encode_image(self, image: Union[Image.Image, str]):
        max_size = self.max_size_in_mb * 1024 * 1024  # 20MB in bytes
        if isinstance(image, str):
            img = Image.open(image).convert("RGB")
        else:
            img = image.copy()

        output_buffer = BytesIO()
        img.save(output_buffer, format="PNG")
        byte_data = output_buffer.getvalue()

        # If image is too large, resize it while maintaining aspect ratio
        while len(byte_data) > max_size and img.size[0] > 100 and img.size[1] > 100:
            new_size = (int(img.size[0] * 0.75), int(img.size[1] * 0.75))
            img = img.resize(new_size, Image.Resampling.LANCZOS)

            output_buffer = BytesIO()
            img.save(output_buffer, format="PNG")
            byte_data = output_buffer.getvalue()

        base64_str = base64.b64encode(byte_data).decode("utf-8")
        return base64_str

    # Function to encode the video
    def encode_video(self, video_path, for_get_frames_num):
        vr = VideoReader(video_path, ctx=cpu(0), width=448, height=448)
        total_frame_num = len(vr)
        uniform_sampled_frames = np.linspace(
            0, total_frame_num - 1, for_get_frames_num, dtype=int
        )

        # Ensure the last frame is included
        if total_frame_num - 1 not in uniform_sampled_frames:
            uniform_sampled_frames = np.append(
                uniform_sampled_frames, total_frame_num - 1
            )

        frame_idx = uniform_sampled_frames.tolist()
        frames = vr.get_batch(frame_idx).asnumpy()

        base64_frames = []
        for frame in frames:
            img = Image.fromarray(frame)
            output_buffer = BytesIO()
            img.save(output_buffer, format="PNG")
            byte_data = output_buffer.getvalue()
            base64_str = base64.b64encode(byte_data).decode("utf-8")
            base64_frames.append(base64_str)

        return base64_frames

    def video_to_base64(self, video_path):
        with open(video_path, "rb") as video_file:
            encoded_str = base64.b64encode(video_file.read()).decode("utf-8")
        return encoded_str

    def flatten(self, input):
        new_list = []
        for i in input:
            for j in i:
                new_list.append(j)
        return new_list

    def _process_single_request(self, request_data):
        """Process a single request in parallel"""
        contexts, gen_kwargs, doc_to_visual, doc_id, task, split = request_data

        # Check cache first if continual mode
        if self.continual_mode is True and self.cache_mode == "resume":
            doc_uuid = f"{task}___{split}___{doc_id}"
            with self.cache_lock:
                if doc_uuid in self.response_cache:
                    response_text = self.response_cache[doc_uuid]
                    if response_text:
                        return response_text, doc_uuid, response_text

        # Process visual data
        visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
        if None in visuals:
            visuals = []
            imgs = []
            videos = []
        else:
            visuals = self.flatten(visuals)
            imgs = []  # multiple images or frames for video
            videos = []
            for visual in visuals:
                if isinstance(visual, str):
                    if ".mp4" in visual:
                        # encode video file to base64Data
                        videos.append(self.video_to_base64(visual))
                    elif (
                        ".avi" in visual
                        or ".mov" in visual
                        or ".flv" in visual
                        or ".wmv" in visual
                    ):
                        frames = self.encode_video(visual, self.max_num_frames)
                        imgs.extend(frames)
                elif isinstance(visual, str) and (
                    ".jpg" in visual
                    or ".jpeg" in visual
                    or ".png" in visual
                    or ".gif" in visual
                    or ".bmp" in visual
                    or ".tiff" in visual
                    or ".webp" in visual
                ):
                    img = self.encode_image(visual)
                    imgs.append(img)
                elif isinstance(visual, Image.Image):
                    img = self.encode_image(visual)
                    imgs.append(img)

        # Build payload
        payload = {"messages": []}
        payload["model"] = self.model_version

        # When there is no image token in the context, append the image to the text
        payload["messages"].append({"role": "user", "content": []})
        payload["messages"][0]["content"].append({"type": "text", "text": contexts})
        for img in imgs:
            payload["messages"][0]["content"].append(
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/png;base64,{img}"},
                }
            )
        for video in videos:
            payload["messages"][0]["content"].append(
                {
                    "type": "video_url",
                    "video_url": {"url": f"data:video/mp4;base64,${video}"},
                },
            )

        # import debugpy
        # debugpy.listen(("0.0.0.0", 5678))
        # debugpy.wait_for_client()

        # Set generation parameters
        if "max_new_tokens" not in gen_kwargs:
            gen_kwargs["max_new_tokens"] = 1024
        if gen_kwargs["max_new_tokens"] > 4096:
            gen_kwargs["max_new_tokens"] = 4096
        if "temperature" not in gen_kwargs:
            gen_kwargs["temperature"] = 0
        if "top_p" not in gen_kwargs:
            gen_kwargs["top_p"] = None
        if "num_beams" not in gen_kwargs:
            gen_kwargs["num_beams"] = 1

        payload["max_tokens"] = gen_kwargs["max_new_tokens"]
        payload["temperature"] = gen_kwargs["temperature"]

        if "o1" in self.model_version or "o3" in self.model_version:
            del payload["temperature"]
            payload["reasoning_effort"] = "medium"
            payload["response_format"] = {"type": "text"}
            payload.pop("max_tokens")
            payload["max_completion_tokens"] = gen_kwargs["max_tokens"]

        # Make API call with retries
        for attempt in range(self.max_retries):
            try:
                response = self.client.chat.completions.create(**payload)
                response_text = response.choices[0].message.content
                break
            except Exception as e:
                error_msg = str(e)
                eval_logger.info(
                    f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}"
                )

                if attempt == self.max_retries - 1:
                    eval_logger.error(
                        f"All {self.max_retries} attempts failed. Last error: {error_msg}"
                    )
                    response_text = ""
                else:
                    time.sleep(self.timeout)

        doc_uuid = f"{task}___{split}___{doc_id}" if self.continual_mode else None
        return response_text, doc_uuid, response_text

    def generate_until(self, requests) -> List[str]:
        pbar = tqdm(
            total=len(requests), disable=(self.rank != 0), desc="Model Responding"
        )

        # Prepare request data
        request_data_list = [reg.args for reg in requests]

        # Process requests in parallel using ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all tasks
            future_to_index = {
                executor.submit(self._process_single_request, req_data): idx
                for idx, req_data in enumerate(request_data_list)
            }

            # Create a results list with the same size as input
            results = [None] * len(request_data_list)

            # Collect results as they complete
            for future in as_completed(future_to_index):
                index = future_to_index[future]
                try:
                    response_text, doc_uuid, original_response = future.result()
                    results[index] = original_response

                    # Cache the response if in continual mode
                    if self.continual_mode is True and doc_uuid:
                        with self.cache_lock:
                            self.response_cache[doc_uuid] = original_response
                        # Write cache to file (this could be optimized to batch writes)
                        with self.cache_lock:
                            with open(self.response_persistent_file, "w") as f:
                                json.dump(self.response_cache, f)

                    pbar.update(1)
                except Exception as exc:
                    eval_logger.error(f"Request {index} generated an exception: {exc}")
                    results[index] = ""
                    pbar.update(1)

        pbar.close()
        return results

    def generate_until_multi_round(self, requests) -> List[str]:
        raise NotImplementedError(
            "TODO: Implement multi-round generation for OpenAI compatible models"
        )

    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
        raise NotImplementedError(
            "TODO: Implement loglikelihood for OpenAI compatible models"
        )
