import base64
import copy
import dataclasses
import io
import json
import logging
import os
import sys
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Union

from openai import OpenAI
from openai.types.chat.chat_completion import Choice
from PIL import Image
from tenacity import (
    before_sleep_log,
    retry,
    retry_if_exception,
    retry_if_exception_type,
    stop_after_attempt,
    stop_never,
    wait_exponential,
)

from llm_mcts.file_logging import get_logging_dir
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult, Model
from llm_mcts.models.custom_json_encoder import CustomJSONEncoder

# Reference: https://openai.com/api/pricing/
PRICING = {
    "gpt-4o": {
        "prompt_tokens": 2.5 / 1e6,
        "completion_tokens": 10.0 / 1e6,
    },  # NOTE: Alias for gpt-4o-2024-08-06 (at 2024-12-21)
    "gpt-4o-2024-05-13": {
        "prompt_tokens": 5.0 / 1e6,
        "completion_tokens": 15.0 / 1e6,
    },
    "gpt-4o-2024-08-06": {
        "prompt_tokens": 2.5 / 1e6,
        "completion_tokens": 10.0 / 1e6,
    },
    "gpt-4o-2024-11-20": {
        "prompt_tokens": 2.5 / 1e6,
        "completion_tokens": 10.0 / 1e6,
    },
    "gpt-4o-mini": {
        "prompt_tokens": 0.15 / 1e6,
        "completion_tokens": 0.6 / 1e6,
    },
    "gpt-4o-mini-2024-07-18": {
        "prompt_tokens": 0.15 / 1e6,
        "completion_tokens": 0.6 / 1e6,
    },
    "o1": {
        "prompt_tokens": 15.0 / 1e6,
        "completion_tokens": 60.0 / 1e6,
    },  # NOTE: Alias for o1-2024-12-17 (at 2024-12-21)
    "o1-2024-12-17": {
        "prompt_tokens": 15.0 / 1e6,
        "completion_tokens": 60.0 / 1e6,
    },
    "o1-preview": {
        "prompt_tokens": 15.0 / 1e6,
        "completion_tokens": 60.0 / 1e6,
    },  # NOTE: Alias for o1-preview-2024-09-12 (at 2024-12-21)
    "o1-preview-2024-09-12": {
        "prompt_tokens": 15.0 / 1e6,
        "completion_tokens": 60.0 / 1e6,
    },
    "o1-mini": {
        "prompt_tokens": 3.0 / 1e6,
        "completion_tokens": 12.0 / 1e6,
    },  # NOTE: Alias for o1-mini-2024-09-12 (at 2024-12-21)
    "o1-mini-2024-09-12": {
        "prompt_tokens": 3.0 / 1e6,
        "completion_tokens": 12.0 / 1e6,
    },
    # until 2025-02-08 16:00 (UTC)
    "deepseek-chat": {
        "prompt_cache_hit_tokens": 0.014 / 1e6,  # 0.07 / 1e6
        "prompt_cache_miss_tokens": 0.14 / 1e6,  # 0.27 / 1e6
        "completion_tokens": 0.28 / 1e6,  # 1.10 / 1e6
    },
}


def safe_get(d: dict, key: str, default: int | float = 0) -> float | int:
    if not isinstance(d, dict):
        return default
    val = d.get(key, default)
    if val is None:
        val = default
    return val


def sum_usage_tokens(left: dict, right: dict) -> dict:
    # NOTE: We referred openai==1.57.4 when implementing this function
    completion_tokens = safe_get(left, "completion_tokens") + safe_get(
        right, "completion_tokens"
    )
    prompt_tokens = safe_get(left, "prompt_tokens") + safe_get(right, "prompt_tokens")
    total_tokens = safe_get(left, "total_tokens") + safe_get(right, "total_tokens")
    completion_tokens_details = {
        "accepted_prediction_tokens": 0,
        "audio_tokens": 0,
        "reasoning_tokens": 0,
        "rejected_prediction_tokens": 0,
    }
    if left.get("completion_tokens_details"):
        for key in completion_tokens_details:
            completion_tokens_details[key] += safe_get(
                left["completion_tokens_details"], key
            )
    if right.get("completion_tokens_details"):
        for key in completion_tokens_details:
            completion_tokens_details[key] += safe_get(
                right["completion_tokens_details"], key
            )
    prompt_tokens_details = {
        "audio_tokens": 0,
        "cached_tokens": 0,
    }
    if left.get("prompt_tokens_details"):
        for key in prompt_tokens_details:
            prompt_tokens_details[key] += safe_get(left["prompt_tokens_details"], key)
    if right.get("prompt_tokens_details"):
        for key in prompt_tokens_details:
            prompt_tokens_details[key] += safe_get(right["prompt_tokens_details"], key)

    prompt_cache_hit_tokens = safe_get(left, "prompt_cache_hit_tokens") + safe_get(
        right, "prompt_cache_hit_tokens"
    )
    prompt_cache_miss_tokens = safe_get(left, "prompt_cache_miss_tokens") + safe_get(
        right, "prompt_cache_miss_tokens"
    )

    total_tokens = safe_get(left, "total_tokens") + safe_get(right, "total_tokens")
    return {
        "completion_tokens": completion_tokens,
        "prompt_tokens": prompt_tokens,
        "total_tokens": total_tokens,
        "completion_tokens_details": completion_tokens_details,
        "prompt_tokens_details": prompt_tokens_details,
        "prompt_cache_hit_tokens": prompt_cache_hit_tokens,
        "prompt_cache_miss_tokens": prompt_cache_miss_tokens,
    }


logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)


def will_retry_generate_failure(e: BaseException) -> bool:
    # Error codes that are likely to be resolved by retrying
    # 429 - Rate limit reached for requests
    # 429 - You exceeded your current quota, please check your plan and billing details
    # 500 - The server had an error while processing your request
    # 503 - The engine is currently overloaded, please try again later
    # OpenAI API Reference: https://platform.openai.com/docs/guides/error-codes
    # DeepSeek API Reference: https://api-docs.deepseek.com/quick_start/error_codes
    if hasattr(e, "status_code"):
        status_code = int(e.status_code)
        if status_code in (429, 500, 503):
            print(f"Hit an API error with status code {status_code}, retrying...")
            return True
        else:
            return False
    else:
        return False


# The sleep durations before the retries are 1s, 2s, 4s, 8s, 8s, 8s...
@retry(
    wait=wait_exponential(multiplier=1, min=1, max=8),
    before_sleep=before_sleep_log(logger, logging.INFO),
    stop=(
        stop_after_attempt(int(os.environ["OPENAI_MAX_RETRY"]))
        if os.environ.get("OPENAI_MAX_RETRY") is not None
        else stop_never
    ),
    retry=retry_if_exception_type(
        json.JSONDecodeError
    )  # deepseek api sometimes raises JSONDecodeError
    | retry_if_exception(will_retry_generate_failure),
)
def try_generate(api_model: "OpenAIAPIModel", messages, request_samples):
    return api_model.client.chat.completions.create(
        messages=messages,
        model=api_model.model,
        n=request_samples,
        temperature=(
            api_model.temperature if not api_model.model.startswith("o1") else 1
        ),
    )


class OpenAIAPIModel(Model):
    def __init__(
        self,
        model: str = "gpt-4o-2024-08-06",
        temperature: float = 0.7,
        logging_dir: Optional[Path] = None,
    ) -> None:
        self.model = self.model_name = model
        if model.startswith("deepseek"):
            api_key = (
                os.environ["DEEPSEEK_API_KEY"]
                if "DEEPSEEK_API_KEY" in os.environ
                else os.environ["OPENAI_API_KEY"]
            )
            self.client = OpenAI(base_url="https://api.deepseek.com", api_key=api_key)
        else:
            self.client = OpenAI()
        self.temperature = temperature

        self.logging_dir = (
            get_logging_dir(os.getpid()) if logging_dir is None else logging_dir
        )
        if not self.logging_dir.exists():
            self.logging_dir.mkdir(parents=True, exist_ok=True)

        self.call_count = 0

    def _encode_image(self, image: Image.Image) -> str:
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        return img_str

    def _process_message(
        self, request: GenerationRequest
    ) -> List[Dict[str, Union[str, Image.Image]]]:
        request_copy = copy.deepcopy(request)
        messages = []
        for message in request_copy.messages:
            # if the message contains image tokens, the content is a list
            # image should be encoded as base64
            if isinstance(message.content, list):
                full_content = []
                for content in message.content:
                    if isinstance(content, str):
                        full_content.append(
                            {
                                "type": "text",
                                "text": content,
                            }
                        )
                    elif isinstance(content, Image.Image):
                        full_content.append(
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{self._encode_image(content.convert('RGB'))}",
                                },
                            }
                        )
                message.content = full_content
            messages.append(dataclasses.asdict(message))
        return messages

    def generate(
        self, requests: Sequence[GenerationRequest], num_samples: int = 1
    ) -> Iterable[GenerationResult]:
        results = []
        for request in requests:
            messages = self._process_message(request)
            remaining_samples = num_samples
            choices: List[Choice] = []
            usage_data = {}
            while remaining_samples > 0:
                request_samples = (
                    min(remaining_samples, 128)
                    if not self.model_name.startswith("deepseek")
                    else 1
                )

                chat_completion = try_generate(self, messages, request_samples)

                choices.extend(chat_completion.choices)
                remaining_samples -= request_samples
                if not os.getenv("DISABLE_LOG_USAGE"):
                    assert chat_completion.usage
                    data = chat_completion.usage.model_dump()
                    usage_data = sum_usage_tokens(usage_data, data)

            if not os.getenv("DISABLE_LOG_USAGE"):
                usage_data["price"] = {
                    key: int(usage_data[key]) * PRICING[self.model_name][key]
                    for key in PRICING[self.model_name]
                }
                usage_data["price"]["total"] = sum(usage_data["price"].values())
                (
                    self.logging_dir / f"{self.model_name}_{self.call_count}.txt"
                ).write_text(json.dumps(usage_data, indent=4, cls=CustomJSONEncoder))

            self.call_count += 1
            assert (
                len(choices) == num_samples
            ), f"OpenAI API returned {len(choices)} samples, expected {num_samples}"
            for i in range(num_samples):
                # TODO: Handle num_samples > 0 and more than one kind of request
                results.append(
                    GenerationResult(
                        request=request,
                        generation=choices[i].message.content,
                    )
                )
                (
                    self.logging_dir
                    / f"response_{self.model_name}_{i}_{self.call_count}.log"
                ).write_text(
                    json.dumps(
                        dataclasses.asdict(results[i]), indent=4, cls=CustomJSONEncoder
                    )
                )
        return results
