from anthropic import Anthropic
from google import genai
from google.genai import types
from openai import OpenAI
import time


def get_client(messages, cfg):
    if 'gpt' in cfg.model.family_name or cfg.model.family_name == 'o':
        client = OpenAI(api_key=cfg.model.api_key)
    elif 'claude' in cfg.model.family_name:
        client = Anthropic(api_key=cfg.model.api_key)
    elif 'deepseek' in cfg.model.family_name:
        client = OpenAI(api_key=cfg.model.api_key, base_url=cfg.model.base_url)
    elif 'gemini' in cfg.model.family_name:
        client = genai.Client(api_key=cfg.model.api_key)
    elif cfg.model.family_name == 'qwen':
        client = OpenAI(api_key=cfg.model.api_key, base_url=cfg.model.base_url)
    else:
        raise ValueError(f'Model {cfg.model.family_name} not recognized')
    return client


def generate_response(messages, cfg):
    client = get_client(messages, cfg)
    model_name = cfg.model.name
    if 'o1' in model_name or 'o3' in model_name or 'o4' in model_name:
        # Need to follow the restrictions
        if 'o1' in model_name and len(messages)>0 and messages[0]['role'] == 'system':
            system_prompt = messages[0]['content']
            messages = messages[1:]
            messages[0]['content'] = system_prompt + messages[0]['content']
        # TODO: add these to the hydra config
        num_tokens = 16384
        temperature = 1.0
        start_time = time.time()
        response = client.chat.completions.create(
            model=model_name,
            messages=messages,
            max_completion_tokens=num_tokens,
            temperature=temperature)
        end_time = time.time()
        print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
        return response
    
    if 'claude' in model_name:
        num_tokens = 8192  # Give claude more tokens
        temperature = 0.7

        if len(messages)>0 and messages[0]['role'] == 'system':
            system_prompt = messages[0]['content']
            messages = messages[1:]

        start_time = time.time() 
        if cfg.model.thinking:
            num_thinking_tokens = 12288
            response = client.messages.create(
                model=model_name,
                max_tokens=num_tokens+num_thinking_tokens,
                thinking= {"type": "enabled", "budget_tokens": num_thinking_tokens},
                system=system_prompt,
                messages=messages,
                # temperature has to be set to 1 for thinking
                temperature=1.0,
            )
        else:
            response = client.messages.create(
                model=model_name,
                max_tokens=num_tokens,
                system=system_prompt,
                messages=messages,
                temperature=temperature,
            )
        end_time = time.time()
        print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
        return response

    if 'gemini' in model_name:
        start_time = time.time() 
        if len(messages)>0 and messages[0]['role'] == 'system':
            # If the first message is a system message, we need to prepend it to the user message
            system_prompt = messages[0]['content']
            messages = messages[1:]
            messages[0]['content'] = system_prompt + messages[0]['content']

        for message in messages:
            if message['role'] == 'assistant':
                message['role'] = 'model'
    
        chat = client.chats.create(
            model=model_name,
            history=[
                types.Content(role=message['role'], parts=[types.Part(text=message['content'])])
                for message in messages[:-1]
            ],
        )
        response = chat.send_message(message=messages[-1]['content'])
        end_time = time.time()
        print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
        return response

    num_tokens = 4096
    temperature = 0.7

    start_time = time.time()
    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        max_tokens=num_tokens,
        temperature=temperature,
        stream=('qwq' in model_name),
    )

    if 'qwq' in model_name:
        answer_content = ""
        for chunk in response:
            if chunk.choices:
                delta = chunk.choices[0].delta
                if hasattr(delta, 'reasoning_content') and delta.reasoning_content != None:
                    # We don't need to print the reasoning content
                    pass
                else:
                    answer_content += delta.content
        response = answer_content

    end_time = time.time()
    print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
    return response

