# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]


import abc
import json
import os
import re
import sys
import time
import traceback
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List, Tuple, Union

import requests
from tenacity import retry, stop_after_attempt, wait_random_exponential


class Client(abc.ABC):
    def __init__(
        self,
        server_host,
        server_port="5000",
        ssh_server=None,
        ssh_key_path=None,
        **generation_kwargs,
    ):
        self.server_host = server_host
        self.server_port = server_port
        self.ssh_server = os.getenv("SSH_SERVER", ssh_server)
        self.ssh_key_path = os.getenv("SSH_KEY_PATH", ssh_key_path)
        self.generation_kwargs = generation_kwargs

    @abc.abstractmethod
    def _single_call(
        self,
        prompts,
    ):
        pass

    def __call__(self, prompt: str, **kwargs):
        request = self.generation_kwargs
        # prompts are added later
        request["prompts"] = [f"{prompt}"]
        if "others" in kwargs:
            requeset["others"] = kwargs["others"]

        outputs = self._single_call(**request)
        response = {"text": outputs}
        return response

    @retry(wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(3))
    def _send_request(self, request, route="generate"):
        if self.ssh_server and self.ssh_key_path:
            import sshtunnel_requests

            sshtunnel_request = sshtunnel_requests.from_url(
                f"ssh://{self.ssh_server}:22", self.ssh_key_path
            )
            outputs = sshtunnel_request.put(
                url="http://{}:{}/{}".format(self.server_host, self.server_port, route),
                data=json.dumps(request),
                headers={"Content-Type": "application/json"},
            ).json()
        else:
            outputs = requests.put(
                url="http://{}:{}/{}".format(self.server_host, self.server_port, route),
                data=json.dumps(request),
                headers={"Content-Type": "application/json"},
            ).json()
        return outputs


class TRTLLMClient(Client):
    def _single_call(
        self,
        prompts,
        tokens_to_generate,
        temperature,
        top_p,
        top_k,
        random_seed,
        stop: List[str],
        max_attention_window_size=None,
    ):
        request = {
            "prompts": prompts,
            "tokens_to_generate": tokens_to_generate,
            "temperature": temperature,
            "top_k": top_k,
            "top_p": top_p,
            "random_seed": random_seed,
            "stop_words_list": ",".join(stop),
        }
        if max_attention_window_size:
            request["max_attention_window_size"] = max_attention_window_size

        outputs = self._send_request(request)
        return outputs


class VLLMClient(Client):
    def _single_call(
        self,
        prompts,
        tokens_to_generate,
        temperature,
        top_p,
        top_k,
        random_seed,
        stop: List[str],
    ):
        request = {
            "prompt": prompts[0],
            "max_tokens": tokens_to_generate,
            "temperature": temperature,
            "top_k": top_k,
            "top_p": top_p,
            "stop": stop,
        }
        # TODO: random seed is not supported?
        outputs = self._send_request(request)
        outputs = outputs["text"]
        return outputs


class OpenAIClient:
    def __init__(self, model_name, **generation_kwargs):
        model2length = {
            # OpenAI
            "gpt-4": 8192,
            "gpt-4-0613": 8192,
            "gpt-4-1106-preview": 128000,
            "gpt-4-0125-preview": 128000,
            "gpt-4-turbo-preview": 128000,
            "gpt-3.5-turbo-0125": 16385,
            "gpt-3.5-turbo-1106": 16385,
            "gpt-3.5-turbo-0613": 4096,
            "gpt-3.5-turbo": 16385,
            "gpt-3.5-turbo-16k": 16385,
            "gpt-3.5-turbo-16k-0613": 16385,
            # Azure
            "gpt-4-32k": 32768,
            "gpt-4": 128000,
            "gpt-35-turbo-16k": 16384,
        }
        self.openai_api_key = os.environ["OPENAI_API_KEY"]
        self.azure_api_id = os.environ["AZURE_API_ID"]
        self.azure_api_secret = os.environ["AZURE_API_SECRET"]
        self.azure_api_endpoint = os.environ["AZURE_API_ENDPOINT"]
        self.model_name = model_name

        # Azure
        if self.azure_api_id and self.azure_api_secret:
            if "gpt-3.5" in model_name:
                self.model_name = "gpt-35-turbo-16k"
            if "gpt-4" in model_name:
                self.model_name = "gpt-4"

        import tiktoken

        self.encoding = tiktoken.get_encoding("cl100k_base")
        self.max_length = model2length[self.model_name]
        self.generation_kwargs = generation_kwargs
        self._create_client()

    def _create_client(
        self,
    ):
        from openai import AzureOpenAI, OpenAI

        # OpenAI
        if self.openai_api_key:
            self.client = OpenAI(api_key=self.openai_api_key)

        # Azure
        elif self.azure_api_id and self.azure_api_secret:
            self.client = AzureOpenAI(
                api_key=self.get_azure_api_key(
                    self.azure_api_id,
                    self.azure_api_secret,
                    self.azure_api_endpoint,
                ),
                api_version="2024-02-15-preview",
                azure_endpoint=os.path.join(self.azure_api_endpoint, "llm/v1/azure"),
            )

    def _count_tokens(self, messages):
        tokens_per_message = 3
        tokens_per_name = 1
        num_tokens = 0
        for message in messages:
            num_tokens += tokens_per_message
            for key, value in message.items():
                num_tokens += len(self.encoding.encode(value))
                if key == "name":
                    num_tokens += tokens_per_name
        num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
        return num_tokens

    @retry(wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(3))
    def _send_request(self, request):
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=request["msgs"],
                max_tokens=request["tokens_to_generate"],
                temperature=request["temperature"],
                seed=request["random_seed"],
                top_p=request["top_p"],
                stop=request["stop"],
            )
        except Exception as e:
            print(f"Error occurred while calling OpenAI: {e}")
            if self.azure_api_id and self.azure_api_secret and e.status_code == 401:
                # token expired
                self._create_client()

        return response

    def __call__(
        self,
        prompt: str,
    ):
        # system_msg = [{"role": "system", "content": ""}]
        system_msg = []
        user_assistant_msgs = [{"role": "user", "content": prompt}]
        msgs = system_msg + user_assistant_msgs
        openai_length = self._count_tokens(msgs)
        request = self.generation_kwargs

        tokens_to_generate_new = self.max_length - openai_length
        if tokens_to_generate_new < request["tokens_to_generate"]:
            print(
                f"Reduce generate tokens from {request['tokens_to_generate']} to {tokens_to_generate_new}"
            )
            request["tokens_to_generate"] = tokens_to_generate_new

        request["msgs"] = msgs
        outputs = self._send_request(request)
        response = {"text": [outputs.choices[0].message.content]}
        return response

    def get_azure_api_key(
        self,
        p_client_id,
        p_client_secret,
        p_token_url,
        p_scope="azureopenai-readwrite",
        cache_file="azure_openai_key.json",
    ):
        base_path = Path(__file__).parent
        file_path = Path.joinpath(base_path, cache_file)

        # Check if the token is cached
        renew = True
        if os.path.exists(file_path):
            with open(file_path, "r") as f:
                token = json.load(f)
                renew = True if time.time() > token["expires_in"] else False

        if renew:
            # Get a new token from the OAuth server
            response = requests.post(
                os.path.join(p_token_url, "oauth/api/v1/ssa/default/token"),
                data={
                    "grant_type": "client_credentials",
                    "client_id": p_client_id,
                    "client_secret": p_client_secret,
                    "scope": p_scope,
                },
            )
            response.raise_for_status()
            token = response.json()
            token["expires_in"] += time.time()
            with open(file_path, "w") as f:
                json.dump(token, f)

        authToken = token["access_token"]
        return authToken


class GeminiClient:
    def __init__(self, model_name, **generation_kwargs):
        model2length = {
            "gemini-1.0-pro-latest": (30720, 2048),
            "gemini-1.5-pro-latest": (1048576, 8192),
        }

        self.model_name = model_name
        self.model = self._initialize_model()
        self.max_input_length = model2length[model_name][0]
        self.max_output_length = model2length[model_name][1]
        assert generation_kwargs["tokens_to_generate"] < self.max_output_length, print(
            f"tokens_to_generate exceeds {self.max_output_length}"
        )

        import google.generativeai as genai

        self.config = genai.GenerationConfig(
            candidate_count=1,
            stop_sequences=generation_kwargs["stop"],
            max_output_tokens=generation_kwargs["tokens_to_generate"],
            temperature=generation_kwargs["temperature"],
            top_p=generation_kwargs["top_p"],
            top_k=generation_kwargs["top_k"],
        )

        from google.generativeai.types import HarmBlockThreshold, HarmCategory

        self.safety_settings = {
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
        }

    @retry(wait=wait_random_exponential(min=60, max=60), stop=stop_after_attempt(3))
    def _send_request(self, request):
        try:
            response = self.model.generate_content(
                request["prompt"],
                generation_config=request["config"],
                safety_settings=self.safety_settings,
            )
        except Exception as e:
            traceback.print_exc()
            return None
        return response

    def __call__(
        self,
        prompt: str,
    ):
        assert (
            self.model.count_tokens(prompt).total_tokens < self.max_input_length
        ), print(f"input length exceeds {self.max_input_length}")

        request = {
            "prompt": prompt,
            "config": self.config,
        }

        outputs = self._send_request(request)

        try:
            response = {"text": [outputs.candidates[0].content.parts[0].text]}
        except Exception as e:
            response = {"text": []}
            print(outputs)
            traceback.print_exc()

        return response

    def _initialize_model(self):
        import google.generativeai as genai

        genai.configure(api_key=os.environ["GEMINI_API_KEY"])
        return genai.GenerativeModel(self.model_name)
