from openai import OpenAI


class ApiCompletion:

    def __init__(self, model_name, base_url, api_key, max_input_length=None):
        self.model_name = model_name
        self.max_input_length = max_input_length
        self.client = OpenAI(
            base_url=base_url,
            api_key=api_key,
        )

    def to(self, device):
        pass

    def complete(self, chats, max_retry=30, **kwargs):
        res = []
        if "gen_length" in kwargs.keys():
            kwargs["max_tokens"] = kwargs["gen_length"]
        parameters = {
            "max_tokens": None,
            "temperature": 0,
        }
        for k, v in kwargs.items():
            if k in parameters.keys():
                parameters[k] = kwargs[k]

        for chat in chats:
            current_try = 0
            while current_try < max_retry:
                try:
                    response = self.client.chat.completions.create(
                        model=self.model_name,
                        messages=chat,
                        **parameters,
                    )
                    res.append(response.choices[0].message.content)
                    break
                except Exception as e:
                    print(f"Retrying {current_try} ...")
                    print(e)
                current_try += 1

            if current_try == max_retry:
                # raise Exception("Reach max retry times.")
                res.append("")

        return res, [None for i in res]
