"""OpenAI client."""
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import requests

from manifest.clients.client import Client

logger = logging.getLogger(__name__)

AI21_ENGINES = {
    "j1-jumbo",
    "j1-grande",
    "j1-large",
}

# User param -> (client param, default value)
AI21_PARAMS = {
    "engine": ("engine", "j1-large"),
    "temperature": ("temperature", 1.0),
    "max_tokens": ("maxTokens", 10),
    "top_k_return": ("topKReturn", 0),
    "n": ("numResults", 1),
    "top_p": ("topP", 1.0),
    "stop_sequences": ("stopSequences", []),
}


class AI21Client(Client):
    """AI21Client client."""

    def connect(
        self,
        connection_str: Optional[str] = None,
        client_args: Dict[str, Any] = {},
    ) -> None:
        """
        Connect to the AI21 server.

        connection_str is passed as default AI21_API_KEY if variable not set.

        Args:
            connection_str: connection string.
            client_args: client arguments.
        """
        # Taken from https://studio.ai21.com/docs/api/
        self.host = "https://api.ai21.com/studio/v1"
        self.api_key = os.environ.get("AI21_API_KEY", connection_str)
        if self.api_key is None:
            raise ValueError(
                "AI21 API key not set. Set AI21_API_KEY environment "
                "variable or pass through `connection_str`."
            )

        for key in AI21_PARAMS:
            setattr(self, key, client_args.pop(key, AI21_PARAMS[key][1]))
        if getattr(self, "engine") not in AI21_ENGINES:
            raise ValueError(
                f"Invalid engine {getattr(self, 'engine')}. Must be {AI21_ENGINES}."
            )

    def close(self) -> None:
        """Close the client."""
        pass

    def get_model_params(self) -> Dict:
        """
        Get model params.

        By getting model params from the server, we can add to request
        and make sure cache keys are unique to model.

        Returns:
            model params.
        """
        return {"model_name": "ai21", "engine": getattr(self, "engine")}

    def get_model_inputs(self) -> List:
        """
        Get allowable model inputs.

        Returns:
            model inputs.
        """
        return list(AI21_PARAMS.keys())

    def format_response(self, response: Dict) -> Dict[str, Any]:
        """
        Format response to dict.

        Args:
            response: response

        Return:
            response as dict
        """
        return {
            "object": "text_completion",
            "model": getattr(self, "engine"),
            "choices": [
                {
                    "text": item["data"]["text"],
                    "logprobs": [
                        {
                            "token": tok["generatedToken"]["token"],
                            "logprob": tok["generatedToken"]["logprob"],
                            "start": tok["textRange"]["start"],
                            "end": tok["textRange"]["end"],
                        }
                        for tok in item["data"]["tokens"]
                    ],
                }
                for item in response["completions"]
            ],
        }

    def get_request(
        self, query: str, request_args: Dict[str, Any] = {}
    ) -> Tuple[Callable[[], Dict], Dict]:
        """
        Get request string function.

        Args:
            query: query string.

        Returns:
            request function that takes no input.
            request parameters as dict.
        """
        request_params = {"prompt": query}
        for key in AI21_PARAMS:
            request_params[AI21_PARAMS[key][0]] = request_args.pop(
                key, getattr(self, key)
            )

        def _run_completion() -> Dict:
            post_str = self.host + "/" + getattr(self, "engine") + "/complete"
            res = requests.post(
                post_str,
                headers={"Authorization": f"Bearer {self.api_key}"},
                json=request_params,
            )
            return self.format_response(res.json())

        return _run_completion, request_params

    def get_choice_logit_request(
        self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
    ) -> Tuple[Callable[[], Dict], Dict]:
        """
        Get request string function for choosing max choices.

        Args:
            query: query string.
            gold_choices: choices for model to choose from via max logits.

        Returns:
            request function that takes no input.
            request parameters as dict.
        """
        raise NotImplementedError("AI21 does not support choice logit request.")
