import base64
import copy
import dataclasses
import io
import json
import logging
import os
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Union

from openai import APIConnectionError, OpenAI
from openai.types.chat.chat_completion import Choice
from PIL import Image
from tenacity import (
    before_sleep_log,
    retry,
    retry_if_exception,
    retry_if_exception_type,
    stop_after_attempt,
    stop_never,
    wait_exponential,
)

from llm_mcts.file_logging import get_logging_dir
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult, Model
from llm_mcts.models.custom_json_encoder import CustomJSONEncoder

# Reference: https://openai.com/api/pricing/
PRICING = {
    "gpt-4o": {
        "prompt_tokens": 2.5 / 1e6,
        "completion_tokens": 10.0 / 1e6,
    },  # NOTE: Alias for gpt-4o-2024-08-06 (at 2025-02-05)
    "gpt-4o-2024-05-13": {
        "prompt_tokens": 5.0 / 1e6,
        "completion_tokens": 15.0 / 1e6,
    },
    "gpt-4o-2024-08-06": {
        "prompt_tokens": 2.5 / 1e6,
        "completion_tokens": 10.0 / 1e6,
    },
    "gpt-4o-2024-11-20": {
        "prompt_tokens": 2.5 / 1e6,
        "completion_tokens": 10.0 / 1e6,
    },
    "gpt-4o-mini": {
        "prompt_tokens": 0.15 / 1e6,
        "completion_tokens": 0.6 / 1e6,
    },
    "gpt-4o-mini-2024-07-18": {
        "prompt_tokens": 0.15 / 1e6,
        "completion_tokens": 0.6 / 1e6,
    },
    "o1": {
        "prompt_tokens": 15.0 / 1e6,
        "completion_tokens": 60.0 / 1e6,
    },  # NOTE: Alias for o1-2024-12-17 (at 2025-02-05)
    "o1-2024-12-17": {
        "prompt_tokens": 15.0 / 1e6,
        "completion_tokens": 60.0 / 1e6,
    },
    "o1-preview": {
        "prompt_tokens": 15.0 / 1e6,
        "completion_tokens": 60.0 / 1e6,
    },  # NOTE: Alias for o1-preview-2024-09-12 (at 2025-02-05)
    "o1-preview-2024-09-12": {
        "prompt_tokens": 15.0 / 1e6,
        "completion_tokens": 60.0 / 1e6,
    },
    "o1-mini": {
        "prompt_tokens": 3.0 / 1e6,
        "completion_tokens": 12.0 / 1e6,
    },  # NOTE: Alias for o1-mini-2024-09-12 (at 2025-02-05)
    "o1-mini-2024-09-12": {
        "prompt_tokens": 3.0 / 1e6,
        "completion_tokens": 12.0 / 1e6,
    },
    "o3-mini": {
        "prompt_tokens": 1.1 / 1e6,
        "completion_tokens": 4.4 / 1e6,
    },  # NOTE: Alias for o3-mini-2025-01-31 (at 2025-02-05)
    "o3-mini-2025-01-31": {
        "prompt_tokens": 1.1 / 1e6,
        "completion_tokens": 4.4 / 1e6,
    },
    # until 2025-02-08 16:00 (UTC)
    "deepseek-chat": {
        "prompt_cache_hit_tokens": 0.014 / 1e6,  # 0.07 / 1e6
        "prompt_cache_miss_tokens": 0.14 / 1e6,  # 0.27 / 1e6
        "completion_tokens": 0.28 / 1e6,  # 1.10 / 1e6
    },
    # For OpenRouter models, to avoid an issue with logging dir, we replace "/" with "_" as a convention
    # Openrouter models should be prepended by openrouter_ prefix.
    # deepseek/deepseek-r1
    "openrouter_deepseek_deepseek-r1": {  # Rough pricing. The price depends on the backend picked by openrouter.
        "prompt_tokens": 0.8 / 1e6,
        "completion_tokens": 2.4 / 1e6,
    },
    # deepseek/deepseek-chat
    "openrouter_deepseek_deepseek-chat": {  # Rough pricing. The price depends on the backend picked by openrouter.
        "prompt_tokens": 0.5 / 1e6,
        "completion_tokens": 0.9 / 1e6,
    },
}
OPENAI_REASONING_MODELS = set(
    [
        model_name
        for model_name in PRICING.keys()
        if model_name.startswith("o1") or model_name.startswith("o3")
    ]
)


def safe_get(d: dict, key: str, default: int | float = 0) -> float | int:
    if not isinstance(d, dict):
        return default
    val = d.get(key, default)
    if val is None:
        val = default
    return val


def sum_usage_tokens(left: dict, right: dict) -> dict:
    # NOTE: We referred openai==1.57.4 when implementing this function
    completion_tokens = safe_get(left, "completion_tokens") + safe_get(
        right, "completion_tokens"
    )
    prompt_tokens = safe_get(left, "prompt_tokens") + safe_get(right, "prompt_tokens")
    total_tokens = safe_get(left, "total_tokens") + safe_get(right, "total_tokens")
    completion_tokens_details = {
        "accepted_prediction_tokens": 0,
        "audio_tokens": 0,
        "reasoning_tokens": 0,
        "rejected_prediction_tokens": 0,
    }
    if left.get("completion_tokens_details"):
        for key in completion_tokens_details:
            completion_tokens_details[key] += safe_get(
                left["completion_tokens_details"], key
            )
    if right.get("completion_tokens_details"):
        for key in completion_tokens_details:
            completion_tokens_details[key] += safe_get(
                right["completion_tokens_details"], key
            )
    prompt_tokens_details = {
        "audio_tokens": 0,
        "cached_tokens": 0,
    }
    if left.get("prompt_tokens_details"):
        for key in prompt_tokens_details:
            prompt_tokens_details[key] += safe_get(left["prompt_tokens_details"], key)
    if right.get("prompt_tokens_details"):
        for key in prompt_tokens_details:
            prompt_tokens_details[key] += safe_get(right["prompt_tokens_details"], key)

    prompt_cache_hit_tokens = safe_get(left, "prompt_cache_hit_tokens") + safe_get(
        right, "prompt_cache_hit_tokens"
    )
    prompt_cache_miss_tokens = safe_get(left, "prompt_cache_miss_tokens") + safe_get(
        right, "prompt_cache_miss_tokens"
    )

    total_tokens = safe_get(left, "total_tokens") + safe_get(right, "total_tokens")
    return {
        "completion_tokens": completion_tokens,
        "prompt_tokens": prompt_tokens,
        "total_tokens": total_tokens,
        "completion_tokens_details": completion_tokens_details,
        "prompt_tokens_details": prompt_tokens_details,
        "prompt_cache_hit_tokens": prompt_cache_hit_tokens,
        "prompt_cache_miss_tokens": prompt_cache_miss_tokens,
    }


logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)


def will_retry_generate_failure(e: BaseException) -> bool:
    # Error codes that are likely to be resolved by retrying
    # 429 - Rate limit reached for requests
    # 429 - You exceeded your current quota, please check your plan and billing details
    # 500 - The server had an error while processing your request
    # 503 - The engine is currently overloaded, please try again later
    # OpenAI API Reference: https://platform.openai.com/docs/guides/error-codes
    # DeepSeek API Reference: https://api-docs.deepseek.com/quick_start/error_codes
    if hasattr(e, "status_code"):
        status_code = int(e.status_code)
        if status_code in (429, 500, 503):
            print(f"Hit an API error with status code {status_code}, retrying...")
            return True
        else:
            return False
    else:
        return False


def to_openai_client_model_name(model_name: str) -> str:
    if not model_name.startswith("openrouter_"):
        return model_name
    else:
        return model_name[len("openrouter_") :].replace("_", "/")


# The sleep durations before the retries are 1s, 2s, 4s, 8s, 8s, 8s...
@retry(
    wait=wait_exponential(multiplier=1, min=1, max=8),
    before_sleep=before_sleep_log(logger, logging.INFO),
    stop=(
        stop_after_attempt(int(os.environ["OPENAI_MAX_RETRY"]))
        if os.environ.get("OPENAI_MAX_RETRY") is not None
        else stop_never
    ),
    retry=retry_if_exception_type(
        json.JSONDecodeError
    )  # deepseek api sometimes raises JSONDecodeError
    | retry_if_exception_type(APIConnectionError)  # Server-side error
    | retry_if_exception(will_retry_generate_failure),
)
def try_generate(api_model: "OpenAIAPIModel", messages, request_samples):
    if api_model.model in OPENAI_REASONING_MODELS:
        return api_model.client.chat.completions.create(
            messages=messages,
            model=to_openai_client_model_name(api_model.model),
            n=request_samples,
            reasoning_effort=os.environ.get("OPENAI_REASONING_EFFORT", "medium"),
        )
    else:
        return api_model.client.chat.completions.create(
            messages=messages,
            model=to_openai_client_model_name(api_model.model),
            n=request_samples,
            temperature=api_model.temperature,
        )


class OpenAIAPIModel(Model):
    def __init__(
        self,
        model: str = "gpt-4o-2024-08-06",
        temperature: float = 0.7,
        logging_dir: Optional[Path] = None,
    ) -> None:
        if model.startswith("deepseek"):
            api_key = (
                os.environ["DEEPSEEK_API_KEY"]
                if "DEEPSEEK_API_KEY" in os.environ
                else os.environ["OPENAI_API_KEY"]
            )
            self.client = OpenAI(base_url="https://api.deepseek.com", api_key=api_key)
        elif model.startswith("openrouter_"):
            api_key = (
                os.environ["OPENROUTER_API_KEY"]
                if "OPENROUTER_API_KEY" in os.environ
                else os.environ["OPENAI_API_KEY"]
            )
            self.client = OpenAI(
                base_url="https://openrouter.ai/api/v1", api_key=api_key
            )
        else:
            self.client = OpenAI()

        self.model = self.model_name = model
        self.temperature = temperature

        self.logging_dir = (
            get_logging_dir(os.getpid()) if logging_dir is None else logging_dir
        )
        if not self.logging_dir.exists():
            self.logging_dir.mkdir(parents=True, exist_ok=True)

        self.call_count = 0

    def _encode_image(self, image: Image.Image) -> str:
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        return img_str

    def _process_message(
        self, request: GenerationRequest
    ) -> List[Dict[str, Union[str, Image.Image]]]:
        request_copy = copy.deepcopy(request)
        messages = []
        for message in request_copy.messages:
            # if the message contains image tokens, the content is a list
            # image should be encoded as base64
            if isinstance(message.content, list):
                full_content = []
                for content in message.content:
                    if isinstance(content, str):
                        full_content.append(
                            {
                                "type": "text",
                                "text": content,
                            }
                        )
                    elif isinstance(content, Image.Image):
                        full_content.append(
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{self._encode_image(content.convert('RGB'))}",
                                },
                            }
                        )
                message.content = full_content
            messages.append(dataclasses.asdict(message))
        return messages

    def generate(
        self, requests: Sequence[GenerationRequest], num_samples: int = 1
    ) -> Iterable[GenerationResult]:
        results = []
        for request in requests:
            messages = self._process_message(request)
            remaining_samples = num_samples
            choices: List[Choice] = []
            usage_data = {}
            while remaining_samples > 0:
                request_samples = (
                    min(remaining_samples, 128)
                    if not (
                        self.model_name.startswith("deepseek")
                        or self.model_name.startswith("openrouter")
                    )
                    else 1
                )

                chat_completion = try_generate(self, messages, request_samples)

                choices.extend(chat_completion.choices)
                remaining_samples -= request_samples
                if not os.getenv("DISABLE_LOG_USAGE"):
                    assert chat_completion.usage
                    data = chat_completion.usage.model_dump()
                    usage_data = sum_usage_tokens(usage_data, data)

            if not os.getenv("DISABLE_LOG_USAGE"):
                usage_data["price"] = {
                    key: int(usage_data[key]) * PRICING[self.model_name][key]
                    for key in PRICING[self.model_name]
                }
                usage_data["price"]["total"] = sum(usage_data["price"].values())
                (
                    self.logging_dir
                    / f"{self.model_name.replace('/', '_')}_{self.call_count}.txt"
                ).write_text(json.dumps(usage_data, indent=4, cls=CustomJSONEncoder))

            self.call_count += 1
            assert (
                len(choices) == num_samples
            ), f"OpenAI API returned {len(choices)} samples, expected {num_samples}"
            for i in range(num_samples):
                # TODO: Handle num_samples > 0 and more than one kind of request
                results.append(
                    GenerationResult(
                        request=request,
                        generation=choices[i].message.content,
                    )
                )
                (
                    self.logging_dir
                    / f"response_{self.model_name.replace('/', '_')}_{i}_{self.call_count}.log"
                ).write_text(
                    json.dumps(
                        dataclasses.asdict(results[i]), indent=4, cls=CustomJSONEncoder
                    )
                )
        return results
