import logging
from pathlib import Path
import os
import hashlib
import json
import secrets
from typing import Optional

import openai
from openai.types.completion_usage import PromptTokensDetails, CompletionTokensDetails
from tenacity import (
    Retrying,
    retry_if_exception_type,
    stop_after_attempt,
    stop_after_delay,
    wait_exponential,
    wait_fixed,
    wait_chain,
    before_sleep_log,
)

from .model_utils.tokens import num_tokens_from_messages, get_token_len, MODEL_TOKEN_LIMITS
from azure.identity import DefaultAzureCredential, AzureCliCredential


OpenAITextResponses = list[str | None]
OpenAIMessages = list[dict[str, str]]


class OpenAICache:
    def __init__(
        self,
        cache_dir: Path,
    ):
        self._cache_dir = cache_dir

    def _get_from_cache(self, key: str) -> OpenAITextResponses | None:
        cache_file = self._cache_dir / f"{key}_responses.json"
        if cache_file.exists():
            cache_file_contents = cache_file.read_text()
            try:
                cached_responses = json.loads(cache_file_contents)
                return cached_responses
            
            except json.JSONDecodeError:
                logging.warning(f"Cache file at {cache_file} corrupt. Deleting")
                cache_file.unlink()
        
        return None

    def _generate_key_from_messages(
        self,
        messages: OpenAIMessages,
        model: str,
        temperature: float,
        num_samples: int,
    ) -> str:
        hash_obj = hashlib.md5()
        if temperature == 0.0 and num_samples == 1:
            model_signature = model
        else:
            model_signature = " ".join([
                model,
                str(temperature),
                str(num_samples),
            ])
        hash_obj.update(model_signature.encode())
        for m in messages:
            for k, v in m.items():
                hash_obj.update(f"{k}:{v}".encode())

        return hash_obj.hexdigest()

    def get(
        self,
        messages: OpenAIMessages,
        model: str,
        temperature: float,
        num_samples: int,

    ) -> OpenAITextResponses | None:
        key = self._generate_key_from_messages(
            messages=messages,
            model=model,
            temperature=temperature,
            num_samples=num_samples,
        )
        return self._get_from_cache(key)

    def update(
        self,
        messages: OpenAIMessages,
        model: str,
        temperature: float,
        num_samples: int,
        responses: OpenAITextResponses,
    ):
        key = self._generate_key_from_messages(
            messages=messages,
            model=model,
            temperature=temperature,
            num_samples=num_samples,
        )
        responses_str_to_cache = json.dumps(responses)
        (self._cache_dir / f"{key}_responses.json").write_text(responses_str_to_cache)
        messages_str_to_cache = json.dumps(messages)
        (self._cache_dir / f"{key}_messages.json").write_text(messages_str_to_cache)



class OpenAIHandler:
    """
    OpenAIHandler is a class that manages interactions with OpenAI's GPT models.
    It supports retry functionality in case of rate limit errors or other
    transient issues that can be resolved through retries.

    Arguments:
        model (str): Identifier for the GPT model being used.
        temperature (float): Sampling temperature to control the randomness of the output.
        api_key (str): API key for authentication with the OpenAI or Azure service.
        system_message (str): Initial system message to start the conversation.
        conversational (bool): If True, the handler operates in conversational mode.
        n (int): Number of completions to generate for each prompt.
        retry_interval (int): Base delay interval between retries.
        max_retry_attempts (int): Maximum number of retry attempts.
        retry_timeout (int): Maximum time to keep retrying before giving up.
        retry_interval_sequence (list[int]): Custom sequence of intervals between retries.
        max_retry_interval (int): Maximum delay interval between retries.
        exponential_backoff (bool): Whether to use exponential backoff for retries.
        api_type (str): Type of API to use ('open_ai' or 'azure').
        api_base (str): Base URL for the API service.
        api_version (str): API version to use.
        cache_dir (Path): Path to cache.
        pin_first_n (int): Number of messages to to pin in conversational setting.
    """

    _SUPPORTED_MODELS_MAX_TOKENS = MODEL_TOKEN_LIMITS

    _SUPPORTED_API_TYPES = {
        "open_ai",
        "azure"
    }

    def __init__(
        self,
        model: str,
        temperature: float,
        api_key: str | list[str] | None,
        system_message: str = "",
        conversational: bool = False,
        stop: str | list[str] | None = None,
        retry_interval: int = 8,
        max_retry_attempts: int = 50,
        retry_timeout: int = 2048,
        retry_interval_sequence: list[int] | None = None,
        max_retry_interval: int = 128,
        exponential_backoff: bool = True,
        api_type: str = "open_ai",
        api_base: str = "https://api.openai.com/v1",
        api_version: str | None = None,
        api_token: str | None = None,
        cache_dir: Path | None = None,
        use_api_token: bool = False,
        pin_first_n: int = 0,
    ):
        if model not in self._SUPPORTED_MODELS_MAX_TOKENS:
            raise ValueError(f"Given model: {model} is not supported")

        if api_type not in self._SUPPORTED_API_TYPES:
            raise ValueError(f"Given API type: {api_type} is not supported")

        if isinstance(api_key, list) and len(api_key) == 0:
            raise ValueError("api_key parameter should not be an empty list.")

        self._model = model
        self._temperature = temperature
        self._stop = stop

        self._max_retry_attempts = max_retry_attempts
        self._retry_timeout = retry_timeout
        self._retry_interval_sequence = retry_interval_sequence
        self._retry_interval = retry_interval
        self._max_retry_interval = max_retry_interval

        self._api_type = api_type
        self._api_base = api_base
        self._api_version = api_version

        self._conversational = conversational
        self._pin_first_n = pin_first_n
        self.use_api_token = use_api_token

        if self.use_api_token:
            self._token = api_token if api_token else DefaultAzureCredential().get_token('https://cognitiveservices.azure.com/.default').token
            self._api_key = None
        else:
            if api_key is None:
                raise ValueError("api_key parameter should not be None.")
            self._api_key = api_key

        self._cache = OpenAICache(cache_dir) if cache_dir is not None else None

        if self._retry_interval_sequence is None:
            if exponential_backoff:
                self._wait_config = wait_exponential(
                    multiplier=2,
                    min=self._retry_interval,
                    max=self._max_retry_interval,
                )
            else:
                self._wait_config = wait_fixed(self._retry_interval)
        else:
            self._wait_config = wait_chain(
                *[wait_fixed(interval) for interval in self._retry_interval_sequence]
            )

        if "o1" in model:
            self._system_message = {
                "role": "developer",
                "content": system_message,
            }
        else:
            self._system_message = {
                "role": "system",
                "content": system_message,
            }
        self._memory: list[dict[str, str]] = []

        self._num_prompt_tokens = 0
        self._num_completion_tokens = 0


    def _prepare_messages(self) -> OpenAIMessages:
        """Prepares list of messages from memory keeping in mind context limit."""
        if self._conversational:

            pinned_mem = [self._system_message]
            for i, m in enumerate(self._memory[:self._pin_first_n]):
                t_pinned_mem = pinned_mem + [m]
                tc = num_tokens_from_messages(t_pinned_mem, self._model)
                # print(tc)
                # print(self._SUPPORTED_MODELS_MAX_TOKENS[self._model])
                if tc > self._SUPPORTED_MODELS_MAX_TOKENS[self._model]:
                    raise ValueError(
                        f"Could not fit pinned message number {i + 1} in context window with tokens: {tc} > {self._SUPPORTED_MODELS_MAX_TOKENS[self._model]}"
                    )
                pinned_mem.append(m)

            new_mem = []
            num_m = 0
            for m in self._memory[self._pin_first_n:][::-1]:
                t_new_mem = pinned_mem + [m] + new_mem[::-1]
                tc = num_tokens_from_messages(t_new_mem, self._model)
                if num_m == 0 and tc > self._SUPPORTED_MODELS_MAX_TOKENS[self._model]:
                    raise ValueError(
                        f"Could not fit message in context window with tokens: {tc} > {self._SUPPORTED_MODELS_MAX_TOKENS[self._model]}"
                    )

                if tc > self._SUPPORTED_MODELS_MAX_TOKENS[self._model]:
                    break

                num_m += 1
                new_mem.append(m)

            return pinned_mem + new_mem[::-1]
        else:

            messages = [self._system_message, self._memory[-1]]
            tc = num_tokens_from_messages(messages, self._model)
            if tc > self._SUPPORTED_MODELS_MAX_TOKENS[self._model]:
                raise ValueError(
                    f"Could not fit message in context window with tokens: {tc} > {self._SUPPORTED_MODELS_MAX_TOKENS[self._model]}"
                )

            return messages

    def _get_from_cache(self, messages: OpenAIMessages, n: int) -> OpenAITextResponses | None:
        if self._cache:
            return self._cache.get(
                messages=messages,
                model=self._model,
                temperature=self._temperature,
                num_samples=n,
            )
        return None

    def _update_cache(self, 
            messages: OpenAIMessages,
            model: str,
            temperature: float,
            num_samples: int,
            responses: OpenAITextResponses
        ):
        if self._cache:
            self._cache.update(
                messages=messages,
                model=model,
                temperature=temperature,
                num_samples=num_samples,
                responses=responses,
            )

    def _select_key(self) -> str:
        if isinstance(self._api_key, list):
            return secrets.choice(self._api_key)
        elif self._api_key:
            return self._api_key
        else:
            return AzureCliCredential().get_token('https://cognitiveservices.azure.com/.default').token

    def _create_client(self) -> openai.OpenAI | openai.AzureOpenAI :
        # select key to use for client
        selected_key = self._select_key()
        
        # instantiate client object depending on API Type
        if self._api_type == "open_ai":
            client = openai.OpenAI(api_key=selected_key)
        elif self._api_type == "azure":
            if self.use_api_token:
                client = openai.AzureOpenAI(
                    api_key=selected_key,
                    azure_deployment=self._model.split("_")[0],
                    azure_endpoint=self._api_base,
                    api_version=self._api_version,
                )
            else:
                client = openai.AzureOpenAI(
                    # api_key=selected_key,
                    azure_ad_token_provider = selected_key,
                    azure_endpoint=self._api_base,
                    api_version=self._api_version
                )
        else:
            raise ValueError(f"Unsupported value for API type: {self._api_type}")
        
        return client

    @property
    def _num_total_tokens(self) -> int:
        return self._num_completion_tokens + self._num_prompt_tokens

    def get_num_tokens(self) -> tuple[int, int, int]:
        return self._num_prompt_tokens, self._num_completion_tokens, self._num_total_tokens

    def set_num_tokens(self, num_prompt_tokens: int, _num_completion_tokens: int):
        self._num_prompt_tokens = num_prompt_tokens
        self._num_completion_tokens = _num_completion_tokens

    def get_responses(self, prompt: str, n: int = 1, return_token_count: bool = False) -> OpenAITextResponses | tuple[OpenAITextResponses, dict[str, int | Optional[PromptTokensDetails] | Optional[CompletionTokensDetails]]]:
        """
        Generates response from OpenAI API for given prompt.
        Outputs a list of text respones.
        Each element in output list may be a string or None
        depending on the success of the query.

        Arguments:
            prompt (str)
            n (int): Number of completions to generate for each prompt.
        """
        if self._conversational and n != 1:
            raise ValueError(
                f"Number of completions (n) must be 1 for conversational mode"
            )


        # add user query to memory
        self._memory.append({"role": "user", "content": prompt})
        input_messages = self._prepare_messages()

        # get cached response
        if (cached_responses := self._get_from_cache(input_messages, n=n)) is None:
            # call API with retrying logic
            api_response = None
            for attempt in Retrying(
                retry=retry_if_exception_type((openai.InternalServerError, openai.RateLimitError, openai.APIStatusError)),
                stop=stop_after_delay(self._retry_timeout)
                | stop_after_attempt(self._max_retry_attempts),
                wait=self._wait_config,
                reraise=True,
                before_sleep=before_sleep_log(
                    logging.getLogger(__name__), logging.WARNING
                ),
            ):
                with attempt:
                    logging.info(f"Initiating request to {self._model}")

                    while True:
                        try:
                            client = self._create_client()
                            if "o1" in self._model:
                                # Reasoning models don't support temperature
                                api_response = client.chat.completions.create(
                                    messages=input_messages, # type: ignore
                                    model=self._model,
                                    n=n,
                                    stop=self._stop,
                                )
                            else:
                                api_response = client.chat.completions.create(
                                    messages=input_messages, # type: ignore
                                    model=self._model,
                                    n=n,
                                    temperature=self._temperature,
                                    stop=self._stop,
                                )
                        except openai.AuthenticationError as e:
                            if isinstance(self._api_key, list):
                                self._api_key.remove(client.api_key)
                                if len(self._api_key) == 0:
                                    raise RuntimeError(f"All provided API keys found invalid.")
                                else:
                                    logging.warning(f"API key {client.api_key} is invalid. Removing from key set.")
                            else:
                                print(f"API key {client.api_key} is invalid.")
                                if self.use_api_token:
                                    self._token = DefaultAzureCredential().get_token('https://cognitiveservices.azure.com/.default').token
                                    print("Reset the token!")
                                raise e
                        else:
                            break

            # unwrap API response into list of strings
            if api_response is not None:
                print(f"Usage: {api_response.usage}")
                num_prompt_tokens = api_response.usage.prompt_tokens
                num_completion_tokens = api_response.usage.completion_tokens
                prompt_tokens_details = api_response.usage.prompt_tokens_details
                completion_tokens_details = api_response.usage.completion_tokens_details

                responses: OpenAITextResponses = [
                    c.message.content for c in api_response.choices
                ]

                self._update_cache(
                    messages=input_messages,
                    model=self._model,
                    temperature=self._temperature,
                    num_samples=n,
                    responses=responses,
                )

            else:
                responses: OpenAITextResponses = [None]
        else:
            num_prompt_tokens = num_tokens_from_messages(input_messages, model=self._model)
            num_completion_tokens = sum(get_token_len(r) for r in cached_responses)
            prompt_tokens_details = None
            completion_tokens_details = None

            responses = cached_responses

        r_for_mem = responses[0] if responses[0] is not None else ""
        self._memory.append({"role": "assistant", "content": r_for_mem})

        self._num_prompt_tokens += num_prompt_tokens
        self._num_completion_tokens += num_completion_tokens

        if return_token_count:
            return responses, {"prompt_tokens": num_prompt_tokens, "completion_tokens": num_completion_tokens, "prompt_tokens_details": prompt_tokens_details, "completion_tokens_details": completion_tokens_details}
        else:
            return responses

    def reset_memory(self):
        """Resets the conversation memory for the handler to empty"""
        self._memory = []

    def get_memory(self) -> list[dict[str, str]]:
        return self._memory.copy()

    def set_memory(self, memory: list[dict[str, str]]):
        self._memory = memory.copy()
