from functools import partial
from models.gpt import gpt_completion_fn, gpt_completion_fn_cot, gpt_completion_fn_optimizer, gpt_nll_fn,gpt_completion_fn_selfcorrecting,gpt_completion_fn_selfprobing
from models.gpt import tokenize_fn as gpt_tokenize_fn
from models.llama import llama_completion_fn, llama_nll_fn
from models.llama import tokenize_fn as llama_tokenize_fn

from models.mistral import mistral_completion_fn, mistral_nll_fn
from models.mistral import tokenize_fn as mistral_tokenize_fn

from models.mistral_api import mistral_api_completion_fn, mistral_api_nll_fn
from models.mistral_api import tokenize_fn as mistral_api_tokenize_fn

from models.Gemini import grok_completion_fn, other_completion_fn,other_completion_fn_cot
from models.Gemini import tokenize_fn as other_tokenize_fn

# Required: Text completion function for each model
# -----------------------------------------------
# Each model is mapped to a function that samples text completions.
# The completion function should follow this signature:
# 
# Args:
#   - input_str (str): String representation of the input time series.
#   - steps (int): Number of steps to predict.
#   - settings (SerializerSettings): Serialization settings.
#   - num_samples (int): Number of completions to sample.
#   - temp (float): Temperature parameter for model's output randomness.
# 
# Returns:
#   - list: Sampled completion strings from the model.
completion_fns = {
    'text-davinci-003': partial(gpt_completion_fn, model='text-davinci-003'),
    'gpt-4': partial(gpt_completion_fn, model='gpt-4'),
    'gpt-4-1106-preview':partial(gpt_completion_fn, model='gpt-4-1106-preview'),
    'gpt-3.5-turbo-instruct': partial(gpt_completion_fn, model='gpt-3.5-turbo-instruct'),
    'gpt-3.5-turbo': partial(gpt_completion_fn, model='gpt-3.5-turbo'),
    'gpt-3.5-turbo-CoT': partial(gpt_completion_fn_cot, model='gpt-3.5-turbo'),
    'gpt-3.5-turbo-selfprobing': partial(gpt_completion_fn_selfprobing, model='gpt-3.5-turbo'),
    'gpt-3.5-turbo-selfcorrecting': partial(gpt_completion_fn_selfcorrecting, model='gpt-3.5-turbo'),
    'gpt-3.5-turbo-optimizer': partial(gpt_completion_fn_optimizer, model='gpt-3.5-turbo'),
    "gpt-3.5-turbo-0125": partial(gpt_completion_fn, model='gpt-3.5-turbo-0125'),
    "gpt-3.5-turbo-16k": partial(gpt_completion_fn, model='gpt-3.5-turbo-16k'),
    "gpt-4o-mini": partial(gpt_completion_fn, model='gpt-4o-mini'),
    'mistral': partial(mistral_completion_fn, model='mistral'),
    'mistral-api-tiny': partial(mistral_api_completion_fn, model='mistral-tiny'),
    'mistral-api-small': partial(mistral_api_completion_fn, model='mistral-small'),
    'mistral-api-medium': partial(mistral_api_completion_fn, model='mistral-medium'),
    'llama-7b': partial(llama_completion_fn, model='7b'),
    'llama-13b': partial(llama_completion_fn, model='13b'),
    'llama-70b': partial(llama_completion_fn, model='70b'),
    'llama-7b-chat': partial(llama_completion_fn, model='7b-chat'),
    'llama-13b-chat': partial(llama_completion_fn, model='13b-chat'),
    'llama-70b-chat': partial(llama_completion_fn, model='70b-chat'),
    'gemini-1.5-flash-latest': partial(other_completion_fn, model='gemini-1.5-flash-latest'),
    'gemini-1.5-flash-8b': partial(other_completion_fn, model='gemini-1.5-flash-8b'),
    'gemini-2.0-flash': partial(other_completion_fn, model='gemini-2.0-flash'),
    'gemini-2.0-flash-lite': partial(other_completion_fn, model='gemini-2.0-flash-lite'),
    'ge-2.5-flash': partial(other_completion_fn, model='ge-2.5-flash'),
    'claude-3-5-haiku-20241022': partial(other_completion_fn, model='claude-3-5-haiku-20241022'),
    'claude-3-5-sonnet-20240620': partial(other_completion_fn, model='claude-3-5-sonnet-20240620'),#stable
    'glm-4-air': partial(other_completion_fn, model='glm-4-air'),
    'glm-4-long': partial(other_completion_fn, model='glm-4-long'),
    'glm-4-long-CoT': partial(other_completion_fn_cot, model='glm-4-long'),
    'qwen-plus': partial(other_completion_fn, model='qwen-plus'),
    'qwen-turbo': partial(other_completion_fn, model='qwen-turbo'),
    "Qwen2.5-32B-Instruct": partial(other_completion_fn, model='Qwen2.5-32B-Instruct'),
    "qwen3-32b":partial(other_completion_fn, model='qwen3-32b'),
    'qwen3-14b':partial(other_completion_fn, model='qwen3-14b'),
    'qwen3-8b':partial(other_completion_fn, model='qwen3-8b'),
    'moonshot-v1-8k': partial(other_completion_fn, model='moonshot-v1-8k'),
    'moonshot-v1-32k': partial(other_completion_fn, model='moonshot-v1-32k'),
    'moonshot-v1-128k': partial(other_completion_fn, model='moonshot-v1-128k'),
    'deepseek-coder': partial(other_completion_fn, model='deepseek-coder'),
    "DeepSeek-R1-Distill-Qwen-1.5B": partial(other_completion_fn, model='DeepSeek-R1-Distill-Qwen-1.5B'),
    "DeepSeek-R1-Distill-Qwen-7B": partial(other_completion_fn, model='DeepSeek-R1-Distill-Qwen-7B'),
    "DeepSeek-R1-Distill-Qwen-14B": partial(other_completion_fn, model='DeepSeek-R1-Distill-Qwen-14B'),
    "DeepSeek-R1-Distill-Qwen-32B":partial(other_completion_fn,model='DeepSeek-R1-Distill-Qwen-32B'),
    "deepseek-r1":partial(other_completion_fn,model="deepseek-r1"),
    "deepseek-v3":partial(other_completion_fn,model="deepseek-v3"),
    "claude-3-opus-20240229": partial(other_completion_fn, model="claude-3-opus-20240229"),
    "doubao-lite-128k":partial(other_completion_fn, model="doubao-lite-128k"),
    "doubao-pro-128k":partial(other_completion_fn, model="doubao-pro-128k"),
    "ERNIE-4.0-8K":partial(other_completion_fn, model="ERNIE-4.0-8K"),
    "ERNIE-Lite-8K-0308":partial(other_completion_fn, model="ERNIE-Lite-8K-0308"),
    "SparkDesk-v3.1":partial(other_completion_fn, model="SparkDesk-v3.1"),
    "SparkDesk-v3.5":partial(other_completion_fn, model="SparkDesk-v3.5"),
    "yi-34b-chat-0205":partial(other_completion_fn, model="yi-34b-chat-0205"),
    "yi-34b-chat-200k":partial(other_completion_fn, model="yi-34b-chat-200k"),
    "sonar-reasoning":partial(other_completion_fn, model="sonar-reasoning"),
    "llama-3.1-sonar-huge-128k-online":partial(other_completion_fn, model="llama-3.1-sonar-huge-128k-online"),
    "grok-2":partial(other_completion_fn, model="grok-2"),
    'grok-2-1212':partial(grok_completion_fn, model="grok-2-1212"),


}

# Optional: NLL/D functions for each model
# -----------------------------------------------
# Each model is mapped to a function that computes the continuous Negative Log-Likelihood 
# per Dimension (NLL/D). This is used for computing likelihoods only and not needed for sampling.
# 
# The NLL function should follow this signature:
# 
# Args:
#   - input_arr (np.ndarray): Input time series (history) after data transformation.
#   - target_arr (np.ndarray): Ground truth series (future) after data transformation.
#   - settings (SerializerSettings): Serialization settings.
#   - transform (callable): Data transformation function (e.g., scaling) for determining the Jacobian factor.
#   - count_seps (bool): If True, count time step separators in NLL computation, required if allowing variable number of digits.
#   - temp (float): Temperature parameter for sampling.
# 
# Returns:
#   - float: Computed NLL per dimension for p(target_arr | input_arr).
nll_fns = {
    'text-davinci-003': partial(gpt_nll_fn, model='text-davinci-003'),
    'mistral': partial(mistral_nll_fn, model='mistral'),
    'mistral-api-tiny': partial(mistral_api_nll_fn, model='mistral-tiny'),
    'mistral-api-small': partial(mistral_api_nll_fn, model='mistral-small'),
    'mistral-api-medium': partial(mistral_api_nll_fn, model='mistral-medium'),
    'llama-7b': partial(llama_nll_fn, model='7b'),
    'llama-13b': partial(llama_nll_fn, model='13b'),
    'llama-70b': partial(llama_nll_fn, model='70b'),
    'llama-7b-chat': partial(llama_nll_fn, model='7b-chat'),
    'llama-13b-chat': partial(llama_nll_fn, model='13b-chat'),
    'llama-70b-chat': partial(llama_nll_fn, model='70b-chat'),
    
}

# Optional: Tokenization function for each model, only needed if you want automatic input truncation.
# The tokenization function should follow this signature:
#
# Args:
#   - str (str): A string to tokenize.
# Returns:
#   - token_ids (list): A list of token ids.
tokenization_fns = {
    'text-davinci-003': partial(gpt_tokenize_fn, model='text-davinci-003'),
    'gpt-3.5-turbo-instruct': partial(gpt_tokenize_fn, model='gpt-3.5-turbo-instruct'),
    'gpt-3.5-turbo': partial(gpt_tokenize_fn, model='gpt-3.5-turbo'),
    'gpt-3.5-turbo-CoT':partial(gpt_tokenize_fn, model='gpt-3.5-turbo'),
    'gpt-3.5-turbo-selfprobing':partial(gpt_tokenize_fn, model='gpt-3.5-turbo'),
    'gpt-3.5-turbo-selfcorrecting':partial(gpt_tokenize_fn, model='gpt-3.5-turbo'),
    'gpt-3.5-turbo-optimizer':partial(gpt_tokenize_fn, model='gpt-3.5-turbo'),
    "gpt-3.5-turbo-0125": partial(gpt_tokenize_fn, model='gpt-3.5-turbo-0125'),
    "gpt-3.5-turbo-16k": partial(gpt_tokenize_fn, model='gpt-3.5-turbo-16k'),
    "gpt-4o-mini": partial(gpt_tokenize_fn, model='gpt-4o-mini'),
    'mistral': partial(mistral_tokenize_fn, model='mistral'),
    'mistral-api-tiny': partial(mistral_api_tokenize_fn, model='mistral-tiny'),
    'mistral-api-small': partial(mistral_api_tokenize_fn, model='mistral-small'),
    'mistral-api-medium': partial(mistral_api_tokenize_fn, model='mistral-medium'),
    'llama-7b': partial(llama_tokenize_fn, model='7b'),
    'llama-13b': partial(llama_tokenize_fn, model='13b'),
    'llama-70b': partial(llama_tokenize_fn, model='70b'),
    'llama-7b-chat': partial(llama_tokenize_fn, model='7b-chat'),
    'llama-13b-chat': partial(llama_tokenize_fn, model='13b-chat'),
    'llama-70b-chat': partial(llama_tokenize_fn, model='70b-chat'),
    'gemini-1.5-flash-latest': partial(other_tokenize_fn, model='gemini-1.5-flash-latest'),
    'gemini-1.5-flash-8b': partial(other_tokenize_fn, model='gemini-1.5-flash-8b'),
    'gemini-2.0-flash': partial(other_tokenize_fn, model='gemini-2.0-flash'),
    'gemini-2.0-flash-lite': partial(other_tokenize_fn, model='gemini-2.0-flash-lite'),
    'ge-2.5-flash': partial(other_tokenize_fn, model='ge-2.5-flash'),
    'claude-3-5-haiku-20241022': partial(other_tokenize_fn, model='claude-3-5-haiku-20241022'),
    'claude-3-5-sonnet-20240620': partial(other_tokenize_fn, model='claude-3-5-sonnet-20240620'),
    'glm-4-air': partial(other_tokenize_fn, model='glm-4-air'),
    'glm-4-long': partial(other_tokenize_fn, model='glm-4-long'),
    'glm-4-long-CoT': partial(other_tokenize_fn, model='glm-4-long'),
    'qwen-plus': partial(other_tokenize_fn, model='qwen-plus'),
    'qwen-turbo': partial(other_tokenize_fn, model='qwen-turbo'),
    "Qwen2.5-32B-Instruct": partial(other_tokenize_fn, model='Qwen2.5-32B-Instruct'),
    "qwen3-32b":partial(other_tokenize_fn,model='qwen3-32b'),
    'qwen3-14b':partial(other_tokenize_fn, model='qwen3-14b'),
    'qwen3-8b':partial(other_tokenize_fn, model='qwen3-8b'),
    'moonshot-v1-8k': partial(other_tokenize_fn, model='moonshot-v1-8k'),
    'moonshot-v1-32k': partial(other_tokenize_fn, model='moonshot-v1-32k'),
    'moonshot-v1-128k': partial(other_tokenize_fn, model='moonshot-v1-128k'),
    'deepseek-coder': partial(other_tokenize_fn, model='deepseek-coder'),
    "DeepSeek-R1-Distill-Qwen-1.5B": partial(other_tokenize_fn, model='DeepSeek-R1-Distill-Qwen-1.5B'),
    "DeepSeek-R1-Distill-Qwen-7B": partial(other_tokenize_fn, model='DeepSeek-R1-Distill-Qwen-7B'),
    "DeepSeek-R1-Distill-Qwen-14B": partial(other_tokenize_fn, model='DeepSeek-R1-Distill-Qwen-14B'),
    "DeepSeek-R1-Distill-Qwen-32B":partial(other_tokenize_fn,model='DeepSeek-R1-Distill-Qwen-32B'),
    "deepseek-r1":partial(other_tokenize_fn,model="deepseek-r1"),
    "deepseek-v3":partial(other_tokenize_fn,model="deepseek-v3"),
    "claude-3-opus-20240229": partial(other_tokenize_fn, model="claude-3-opus-20240229"),
    "doubao-lite-128k":partial(other_tokenize_fn, model="doubao-lite-128k"),
    "doubao-pro-128k":partial(other_tokenize_fn, model="doubao-pro-128k"),
    "ERNIE-4.0-8K":partial(other_tokenize_fn, model="ERNIE-4.0-8K"),
    "ERNIE-Lite-8K-0308":partial(other_tokenize_fn, model="ERNIE-Lite-8K-0308"),
    "SparkDesk-v3.1":partial(other_tokenize_fn, model="SparkDesk-v3.1"),
    "SparkDesk-v3.5":partial(other_tokenize_fn, model="SparkDesk-v3.5"),
    "yi-34b-chat-0205":partial(other_tokenize_fn, model="yi-34b-chat-0205"),
    "yi-34b-chat-200k":partial(other_tokenize_fn, model="yi-34b-chat-200k"),
    "sonar-reasoning":partial(other_tokenize_fn, model="sonar-reasoning"),
    "llama-3.1-sonar-huge-128k-online":partial(other_tokenize_fn, model="llama-3.1-sonar-huge-128k-online"),
    "grok-2":partial(other_tokenize_fn, model="grok-2"),
    'grok-2-1212':partial(other_tokenize_fn, model="grok-2-1212"),
}

# Optional: Context lengths for each model, only needed if you want automatic input truncation.
context_lengths = {
    'text-davinci-003': 4097,
    'gpt-3.5-turbo-instruct': 4097,
    'gpt-3.5-turbo': 4097,
    'gpt-3.5-turbo-CoT':4097,
    'gpt-3.5-turbo-selfprobing':4097,
    'gpt-3.5-turbo-selfcorrecting':4097,
    'gpt-3.5-turbo-optimizer':4097,
    "gpt-3.5-turbo-0125": 4097,
    "gpt-3.5-turbo-16k": 4097,
    "gpt-4o-mini": 4097,
    'mistral-api-tiny': 4097,
    'mistral-api-small': 4097,
    'mistral-api-medium': 4097,
    'mistral': 4096,
    'llama-7b': 4096,
    'llama-13b': 4096,
    'llama-70b': 4096,
    'llama-7b-chat': 4096,
    'llama-13b-chat': 4096,
    'llama-70b-chat': 4096,
    'gemini-1.5-flash-latest': 4097,  
    'gemini-1.5-flash-8b': 4097,
    'gemini-2.0-flash': 4097,
    'gemini-2.0-flash-lite':4097,
    'ge-2.5-flash':4097,
    'claude-3-5-haiku-20241022': 4097,
    'claude-3-5-sonnet-20240620': 4097,
    'glm-4-air': 4097,
    'glm-4-long': 4097,
    'glm-4-long-CoT': 4097,
    'qwen-plus': 4097,
    'qwen-turbo': 4097,
    "Qwen2.5-32B-Instruct": 4097,
    "qwen3-32b":4097,
    'qwen3-14b':4097,
    'qwen3-8b':4097,
    'moonshot-v1-8k': 4097,
    'moonshot-v1-32k': 4097,
    'moonshot-v1-128k': 4097,
    'deepseek-coder': 4097,
    "DeepSeek-R1-Distill-Qwen-1.5B": 4097,
    "DeepSeek-R1-Distill-Qwen-7B": 4097,
    "DeepSeek-R1-Distill-Qwen-14B": 4097,
    "DeepSeek-R1-Distill-Qwen-32B":4097,
    "deepseek-r1":4097,
    "deepseek-v3":4097,
    "claude-3-opus-20240229": 4097,
    "doubao-lite-128k":4097, 
    "doubao-pro-128k":4097,
    "ERNIE-4.0-8K":4097,
    "ERNIE-Lite-8K-0308":4097,
    "SparkDesk-v3.1":4097,
    "SparkDesk-v3.5":4097,
    "yi-34b-chat-0205":4097,
    "yi-34b-chat-200k":4097,
    "sonar-reasoning":4097,
    "llama-3.1-sonar-huge-128k-online":4097,
    "grok-2":4097,
    'grok-2-1212':4097,
}