import random
import asyncio
import openai
import random
from apikey import api_configs

model_name = None

MAX_RETRIES = 3
LOG = True
total_prompt_tokens, total_completion_tokens, call_count, cost = 0, 0, 0, 0
current_prompt_tokens, current_completion_tokens = 0, 0

def set_model(model):
    global model_name
    model_name = model
    
def set_log(log):
    global LOG
    LOG = log
    
def need_extracted():
    model = get_model()
    return not ("gpt" in model or "o3" in model)

model2base = {
    "qwen3-32b": "aigcbest",
    "deepseek/deepseek-chat-v3-0324": "aigcbest",
    "o3-mini": "deepwisdom",
    "deepseek-ai/DeepSeek-R1": "aigcbest",
}


async def gen(msg, model=None, temperature=1.0, response_format="json_object", log_token=True):
    global call_count, cost, current_prompt_tokens, current_completion_tokens, model_name
    if not model:
        model = model_name
    if model in model2base:
        base = model2base[model]
    else:
        base =  "deepwisdom"
    client = openai.AsyncClient(base_url=api_configs[base]["url"], api_key=random.choice(api_configs[base]["api_key"]))
    errors = []
    if LOG:
        call_count += 1
        
    if isinstance(msg, list):
        messages = msg
    else:
        messages = [{"role": "user", "content": msg}]

    DEFAULT_RETRY_AFTER = random.uniform(0.1, 2)
    for retry in range(MAX_RETRIES):
        try:
            async with asyncio.timeout(120 * 2):
                if model == "o3-mini":
                    response = await client.chat.completions.create(
                        model=model,
                        messages=messages,
                        temperature=temperature,
                        stop=None,
                        response_format={"type": response_format}
                    )
                    content = response.choices[0].message.content
                else:
                    response = await client.chat.completions.create(
                        model=model,
                        messages=messages,
                        temperature=temperature,
                        stop=None,
                        max_tokens=8192*2,
                        response_format={"type": response_format}
                    )
                content = response.choices[0].message.content
                
                usage = response.usage
                if LOG and log_token:
                    current_prompt_tokens = usage.prompt_tokens
                    current_completion_tokens = usage.completion_tokens
                    update_token()
                
                return content
        except asyncio.TimeoutError:
            errors.append("Request timeout")
        except openai.RateLimitError:
            errors.append("Rate limit error")
        except openai.APIError as e:
            errors.append(f"API error: {str(e)}")
        except Exception as e:
            errors.append(f"Error: {type(e).__name__}, {str(e)}")
        
        await asyncio.sleep(DEFAULT_RETRY_AFTER * (2 ** retry))

    print(f"Error log: {errors}")


async def get_embedding(text, model="text-embedding-3-small"):
    global call_count, current_prompt_tokens, current_completion_tokens
    
    base = "deepwisdom"
    client = openai.AsyncClient(base_url=api_configs[base]["url"], api_key=random.choice(api_configs[base]["api_key"]))

    errors = []
    DEFAULT_RETRY_AFTER = random.uniform(0.1, 2)
    
    for retry in range(MAX_RETRIES):
        try:
            async with asyncio.timeout(60):
                response = await client.embeddings.create(
                    model=model,
                    input=text,
                    encoding_format="float"
                )
                
                embedding = response.data[0].embedding
                return embedding
                
        except asyncio.TimeoutError:
            errors.append("Request timeout")
        except openai.RateLimitError:
            errors.append("Rate limit error")
        except openai.APIError as e:
            errors.append(f"API error: {str(e)}")
        except Exception as e:
            errors.append(f"Error: {type(e).__name__}, {str(e)}")
        
        await asyncio.sleep(DEFAULT_RETRY_AFTER * (2 ** retry))
    
    print(f"Error log: {errors}")
    return [0] * 1536

def get_cost():
    return cost

def get_model():
    return model_name

def update_token():
    global total_prompt_tokens, total_completion_tokens, current_completion_tokens, current_prompt_tokens
    total_prompt_tokens += current_prompt_tokens
    total_completion_tokens += current_completion_tokens

def reset_token():
    global total_prompt_tokens, total_completion_tokens, call_count
    total_prompt_tokens = 0
    total_completion_tokens = 0
    call_count = 0

def get_token():
    return total_prompt_tokens, total_completion_tokens

def get_call_count():
    return call_count