import os
import warnings
import multiprocessing
from functools import partial
from src.utils.api_lib.claude import complete_text_claude
from src.utils.api_lib.gemini import complete_text_gemini
from src.utils.api_lib.gpt import get_gpt_output
from src.utils.api_lib.huggingface import generate_text_hf


# Default parameters for retrying API calls and the sleep time between retries
MAX_OPENAI_RETRY = int(os.getenv("MAX_OPENAI_RETRY", 100))
OPENAI_SLEEP_TIME = int(os.getenv("OPENAI_SLEEP_TIME", 2))
MAX_CLAUDE_RETRY = int(os.getenv("MAX_CLAUDE_RETRY", 100))
CLAUDE_SLEEP_TIME = int(os.getenv("CLAUDE_SLEEP_TIME", 0))
MAX_GEMINI_RETRY = int(os.getenv("MAX_GEMINI_RETRY", 100))
GEMINI_SLEEP_TIME = int(os.getenv("GEMINI_SLEEP_TIME", 2))
LLM_PARALLEL_NODES = int(os.getenv("LLM_PARALLEL_NODES", 5))

# Register the available text completion LLMs
registered_text_completion_llms = {
    "gpt-4o",
    "gpt-4o-mini",
    "gpt-4-1106-preview",
    "gpt-4-0125-preview",
    "gpt-4-turbo-preview",
    "gpt-4-turbo",
    "gpt-4-turbo-2024-04-09",
    "gpt-5.2",
    "o1",
    "o1-mini",
    "o1-preview",
    "o3",
    "o3-mini",
    "o4-mini",
    "claude-2.1",
    "claude-3-opus-20240229", 
    "claude-3-sonnet-20240229", 
    "claude-3-haiku-20240307",
    "claude-3-5-sonnet-20240620",
    "claude-3-5-sonnet-20241022",
    "claude-3-5-haiku-20241022",
    "claude-sonnet-4-5-20250929",
    # Gemini (Google)
    "gemini-pro",
    "gemini-1.5-pro",
    "gemini-1.5-pro-latest",
    "gemini-2.0-pro",
    "gemini-3-pro",
}


def parallel_func(func, n_max_nodes=LLM_PARALLEL_NODES):
    """
    A general function to call a function on a list of inputs in parallel.

    Args:
        func (callable): The function to apply.
        n_max_nodes (int): Maximum number of parallel processes.

    Returns:
        callable: A wrapper function that applies `func` in parallel.
    """
    def _parallel_func(inputs: list, **kwargs):
        partial_func = partial(func, **kwargs)
        processes = min(len(inputs), n_max_nodes)
        with multiprocessing.Pool(processes=processes) as pool:
            results = pool.map(partial_func, inputs)
        return results
    return _parallel_func


def get_llm_output(message, 
                   model="gpt-4o", 
                   max_new_tokens=4096, 
                   temperature=1, 
                   json_object=False,
                   **generation_kwargs):
    """
    A general function to complete a prompt using the specified model.

    Args:
        message (str or list): The input message or a list of message dicts.
        model (str): The model to use for completion.
        max_new_tokens (int): Maximum number of tokens to generate.
        temperature (float): Sampling temperature.
        json_object (bool): Whether to output in JSON format.

    Returns:
        str: The completed text generated by the model.

    Raises:
        ValueError: If the model is not recognized.
    """
    if model not in registered_text_completion_llms:
        warnings.warn(f"Model {model} is not registered. You may still be able to use it.")
    
    kwargs = {
        'message': message, 
        'model': model, 
        'max_new_tokens': max_new_tokens, 
        'temperature': temperature, 
        'json_object': json_object
    }
    
    if 'gpt' in model or model.startswith('o1') or model.startswith('o3') or model.startswith('o4'):
        kwargs.update({'max_retry': MAX_OPENAI_RETRY, 'sleep_time': OPENAI_SLEEP_TIME})
        return get_gpt_output(**kwargs, **generation_kwargs)
    elif 'claude' in model:
        kwargs.update({'max_retry': MAX_CLAUDE_RETRY, 'sleep_time': CLAUDE_SLEEP_TIME})
        return complete_text_claude(**kwargs, **generation_kwargs)
    elif 'gemini' in model:
        kwargs.update({'max_retry': MAX_GEMINI_RETRY, 'sleep_time': GEMINI_SLEEP_TIME})
        return complete_text_gemini(**kwargs, **generation_kwargs)
    else:
        try:
            return generate_text_hf(**kwargs, **generation_kwargs)
        except Exception as e:
            raise ValueError(f"Model {model} failed: {e}")

# Parallel functions for text completion
complete_texts_claude = parallel_func(complete_text_claude)
complete_texts_hf = parallel_func(generate_text_hf)
get_gpt_outputs = parallel_func(get_gpt_output)
get_llm_outputs = parallel_func(get_llm_output)

