import time
from typing import Dict, Union

import openai
import anthropic
import tiktoken
import multiprocessing, os

from .my4o import get_response_config
from .my4o_2 import client as gpt_client

# def num_tokens_from_messages(message, model="gpt-3.5-turbo-0301"):
#     """Returns the number of tokens used by a list of messages."""
#     try:
#         encoding = tiktoken.encoding_for_model(model)
#     except KeyError:
#         encoding = tiktoken.get_encoding("cl100k_base")
#     if isinstance(message, list):
#         # use last message.
#         num_tokens = len(encoding.encode(message[0]["content"]))
#     else:
#         num_tokens = len(encoding.encode(message))
#     return num_tokens

def num_tokens_from_messages(message, model="gpt-3.5-turbo-0301"):
    """Returns the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")
    if isinstance(message, list):
        # use last message.
        num_tokens = 0
        for msg in message:
            num_tokens += len(encoding.encode(msg["content"]))
    else:
        num_tokens = len(encoding.encode(message))
    return num_tokens

def create_chatgpt_config(
    message: Union[str, list],
    max_tokens: int,
    temperature: float = 1,
    batch_size: int = 1,
    system_message: str = "You are a coding assistant of our software.",
    model: str = "gpt-3.5-turbo",
) -> Dict:
    if isinstance(message, list):
        config = {
            "model": model,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "n": batch_size,
            "messages": [{"role": "system", "content": system_message}] + message,
        }
    else:
        config = {
            "model": model,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "n": batch_size,
            "messages": [
                {"role": "system", "content": system_message},
                {"role": "user", "content": message},
            ],
        }
    return config

def handler(signum, frame):
    # swallow signum and frame
    raise Exception("end of time")

def request_chatgpt_engine(config, logger, base_url=None, max_retries=10, timeout=100):
    ret = None
    retries = 0

    if "deepseek" in config["model"]:
        client = openai.OpenAI(api_key="YOUR_API_KEY", base_url="https://api.provider.com/v1")
    elif config["model"] == "llama3":
        client = openai.OpenAI(api_key="YOUR_API_KEY", base_url="http://127.0.0.1:7333/v1")
    elif os.path.exists(config["model"]):
        if base_url:
            client = openai.OpenAI(api_key="YOUR_API_KEY", base_url=base_url)
        else:
            client = openai.OpenAI(api_key="YOUR_API_KEY", base_url="http://localhost:8000/v1")
    else:
        # client = openai.OpenAI(api_key="YOUR_API_KEY", base_url="YOUR_BASE_URL")
        base_url = base_url if base_url else "https://localhost:8000/v1"
        client = openai.OpenAI(api_key="YOUR_API_KEY", base_url=base_url)

    bad_request_retry = False

    while ret is None and retries < max_retries:
        try:
            # Attempt to get the completion
            logger.info("Creating API request")

            if config["model"] == "gpt-4o":
                ret = get_response_config(config)
                # if 'error' in ret:
                #     print("Log position:", logger.handlers[0].baseFilename, "Error:", ret)
                #     continue
                # write to log (append to log file)
                log_file = logger.handlers[0].baseFilename.replace(".log", ".apilog")
                with open(log_file, "a") as f:
                    f.write(f"--Query" + "-"*100 + "\n")
                    if isinstance(config["messages"], list):
                        for msg in config["messages"]:
                            f.write(f">>>>>>>>>{msg['role']}:\n {msg['content']}\n")
                    else:
                        f.write(f">>>>>>>>>User: {config['messages']}\n")
                    f.write(f"--Response" + "-"*100 + "\n")
                    f.write(f"{ret}\n")
            elif config["model"] == "gpt-4o-2":
                arguments = dict(config)
                arguments["model"] = "gpt-4o"
                ret = gpt_client.chat.completions.create(**arguments)
            else:
                ret = client.chat.completions.create(**config)

        except openai.OpenAIError as e:
            if isinstance(e, openai.BadRequestError):
                logger.info("Request invalid")
                print(e)
                logger.info(e)
                if "Please reduce the length of the messages or completion." in str(e) and not bad_request_retry:
                    print("The request is too long, please reduce the length of the messages or completion.")
                    logger.info("The request is too long, please reduce the length of the messages or completion.")
                    # Extract max context, message, and completion lengths from error message
                    import re
                    match = re.search(
                        r"maximum context length is (\d+) tokens\. However, you requested (\d+) tokens \((\d+) in the messages, (\d+) in the completion\)",
                        str(e)
                    )
                    if match:
                        max_context = int(match.group(1))
                        total_requested = int(match.group(2))
                        messages_tokens = int(match.group(3))
                        completion_tokens = int(match.group(4))
                        logger.info(f"Extracted from error: max_context={max_context}, messages_tokens={messages_tokens}, completion_tokens={completion_tokens}")
                        # Calculate how many tokens to reduce
                        excess_tokens = total_requested - max_context
                        if excess_tokens > 0:
                            # Proportionally reduce tokens from messages and completion
                            reduction_ratio = excess_tokens / total_requested
                            last_message_string_ratio_in_dialogue = len(config["messages"][-1]["content"]) / sum(len(msg["content"]) for msg in config["messages"])
                            if reduction_ratio < last_message_string_ratio_in_dialogue:
                                # We can reduce only from the last message
                                reduction_ration_for_last_message = reduction_ratio / last_message_string_ratio_in_dialogue
                                reserve_len_at_front_and_end = len(config["messages"][-1]["content"]) * reduction_ration_for_last_message / 2 * 0.7  # 30% margin
                                reserve_len_at_front_and_end = int(reserve_len_at_front_and_end)
                                new_last_message = config["messages"][-1]["content"][:reserve_len_at_front_and_end] + "\n...\n" + config["messages"][-1]["content"][-reserve_len_at_front_and_end:]
                                config["messages"][-1]["content"] = new_last_message
                                logger.info(f"Reduced last message to fit context length.")
                                continue
                    
                        raise Exception("Invalid API Request:" + str(e))
                    else:
                        logger.info("Could not extract token info from error message.")
                        raise Exception("Invalid API Request:" + str(e))
                else:
                    raise Exception("Invalid API Request:" + str(e))
            elif isinstance(e, openai.RateLimitError):
                print("Rate limit exceeded. Waiting...")
                logger.info("Rate limit exceeded. Waiting...")
                print(e)
                logger.info(e)
                time.sleep(5)
            elif isinstance(e, openai.APIConnectionError):
                print("API connection error. Waiting...")
                logger.info("API connection error. Waiting...")
                print(e)
                logger.info(e)
                time.sleep(5)
            else:
                print("Unknown error. Waiting...")
                logger.info("Unknown error. Waiting...")
                print(e)
                logger.info(e)
                time.sleep(1)

        retries += 1

    logger.info(f"API response {ret}")
    return ret

def create_anthropic_config(
    message: str,
    max_tokens: int,
    temperature: float = 1,
    batch_size: int = 1,
    system_message: str = "You are a helpful assistant.",
    model: str = "claude-2.1",
    tools: list = None,
) -> Dict:
    if isinstance(message, list):
        config = {
            "model": model,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "messages": message,
        }
    else:
        config = {
            "model": model,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "messages": [
                {"role": "user", "content": [{"type": "text", "text": message}]},
            ],
        }

    if tools:
        config["tools"] = tools

    return config

def request_anthropic_engine(
    config, logger, max_retries=40, timeout=500, prompt_cache=False
):
    ret = None
    retries = 0

    client = anthropic.Anthropic(api_key="YOUR_API_KEY", base_url="YOUR_BASE_URL")

    while ret is None and retries < max_retries:
        try:
            start_time = time.time()
            if prompt_cache:
                # following best practice to cache mainly the reused content at the beginning
                # this includes any tools, system messages (which is already handled since we try to cache the first message)
                config["messages"][0]["content"][0]["cache_control"] = {
                    "type": "ephemeral"
                }
                ret = client.beta.prompt_caching.messages.create(**config)
            else:
                ret = client.messages.create(**config)
        except Exception as e:
            logger.error("Unknown error. Waiting...", exc_info=True)
            # Check if the timeout has been exceeded
            if time.time() - start_time >= timeout:
                logger.warning("Request timed out. Retrying...")
            else:
                logger.warning("Retrying after an unknown error...")
            time.sleep(10 * retries)
        retries += 1
    logger.info(ret)

    return ret

# def request_anthropic_engine(
#     config, logger, max_retries=40, timeout=500, prompt_cache=False
# ):
#     ret = None
#     retries = 0
#
#     client = openai.OpenAI(api_key="YOUR_API_KEY", base_url="YOUR_BASE_URL")
#
#     while ret is None and retries < max_retries:
#         try:
#             # Attempt to get the completion
#             logger.info("Creating API request")
#
#             ret = client.chat.completions.create(**config)
#
#         except openai.OpenAIError as e:
#             if isinstance(e, openai.BadRequestError):
#                 logger.info("Request invalid")
#                 print(e)
#                 logger.info(e)
#                 raise Exception("Invalid API Request")
#             elif isinstance(e, openai.RateLimitError):
#                 print("Rate limit exceeded. Waiting...")
#                 logger.info("Rate limit exceeded. Waiting...")
#                 print(e)
#                 logger.info(e)
#                 time.sleep(10 * retries)
#             elif isinstance(e, openai.APIConnectionError):
#                 print("API connection error. Waiting...")
#                 logger.info("API connection error. Waiting...")
#                 print(e)
#                 logger.info(e)
#                 time.sleep(10 * retries)
#             else:
#                 print("Unknown error. Waiting...")
#                 logger.info("Unknown error. Waiting...")
#                 print(e)
#                 logger.info(e)
#                 time.sleep(10 * retries)
#
#         retries += 1
#
#     logger.info(f"API response {ret}")
#     return ret