import transformers
import torch
import os
import yaml
import time
from vllm import SamplingParams

API_MAX_RETRY = 16
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"

def set_seeds(seed):
    transformers.set_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)

def read_config(config_file: str) -> dict:
    config_kwargs = {}
    with open(config_file, "r") as f:
        config_kwargs = yaml.load(f, Loader=yaml.SafeLoader)

    return config_kwargs

def openai_chat_completion(messages, config):
    import openai
    if config.get("base_url") is not None:
        client = openai.OpenAI(
            base_url=config["base_url"],
            api_key=os.environ["OPENROUTER_API_KEY"])
    else:
        client = openai.OpenAI()

    output = API_ERROR_OUTPUT
    for _ in range(API_MAX_RETRY):
        try:
            if "o3" in config["model_name"]:
                completion = client.chat.completions.create(
                    model=config["model_name"],
                    messages=messages,
                    reasoning_effort=config.get("reasoning_effort", None))
            elif "o1-mini" in config["model_name"]:
                completion = client.chat.completions.create(
                    model=config["model_name"],
                    messages=messages)
            else:
                completion = client.chat.completions.create(
                    model=config["model_name"],
                    messages=messages,
                    temperature=config.get("temperature", None),
                    max_completion_tokens=config.get("max_tokens", None))
            output = completion.choices[0]
            break
        except openai.RateLimitError as e:
            print(type(e), e)
            time.sleep(API_RETRY_SLEEP)
        except openai.BadRequestError as e:
            print(messages)
            print(type(e), e)
        except KeyError:
            print(type(e), e)
        except TypeError as e:
            print(type(e), e)
            print(completion.error["metadata"]["raw"])
            break

    return output.message.content if output is not API_ERROR_OUTPUT else output


def vllm_chat_completion_batch(llm, messages, config):
    tokenizer = llm.get_tokenizer()

    messages = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages]
    
    sampling_params = SamplingParams(
        temperature=config["temperature"],
        max_tokens=config["max_tokens"],
        logprobs=config.get("logprobs"))
    
    outputs = llm.generate(messages, sampling_params=sampling_params)
    outputs = sorted(outputs, key=lambda x: int(x.request_id))

    if config.get("logprobs") is not None:
        return [(o.outputs[0].text, o.outputs[0].logprobs[0]) for o in outputs]
    else:
        return [o.outputs[0].text for o in outputs]
