# Copyright 2023 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time as time

import anthropic
import openai
from utils.conversation import get_conv_template
from openai import OpenAI


# API setting constants
API_MAX_RETRY = 25
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"

# also uses ArenaHard code
# noqa https://github.com/lm-sys/arena-hard/blob/51c04e5a6449e920c01d4159f56a051216af6bd9/utils.py#L166
def chat_completion_anthropic(model, message_list, temperature, max_tokens, api_key):
    sys_msg = ""
    if message_list[0]["role"] == "system":
        sys_msg = message_list[0]["content"]
        messages = message_list[1:]

    output = API_ERROR_OUTPUT
    for _ in range(API_MAX_RETRY):
        try:
            c = anthropic.Anthropic(api_key=api_key)
            response = c.messages.create(
                model=model,
                messages=messages,
                stop_sequences=[anthropic.HUMAN_PROMPT],
                max_tokens=max_tokens,
                temperature=temperature,
                system=sys_msg,
            )
            output = response.content[0].text
            break
        except anthropic.APIError as e:
            print(type(e), e)
            time.sleep(API_RETRY_SLEEP)
    return output.strip()


def chat_completion_openai(model, message_list, temperature, max_tokens, api_key):
    client = OpenAI(api_key=api_key)
    output = API_ERROR_OUTPUT
    for _ in range(API_MAX_RETRY):
        try:
            response = client.chat.completions.create(
                model=model, messages=message_list, n=1, temperature=temperature, max_tokens=max_tokens
            )
            output = response.choices[0].message.content
            break
        except openai.APIError as e:
            # Handle API error here, e.g. retry or log
            print(f"OpenAI API returned an API Error: {e}")
            time.sleep(API_RETRY_SLEEP)

        except openai.APIConnectionError as e:
            # Handle connection error here
            print(f"Failed to connect to OpenAI API: {e}")
            time.sleep(API_RETRY_SLEEP)

        except openai.RateLimitError as e:
            # Handle rate limit error (we recommend using exponential backoff)
            print(f"OpenAI API request exceeded rate limit: {e}")
            time.sleep(API_RETRY_SLEEP)

    return output
