
import os

from pathlib import Path

from dotenv import load_dotenv

from tenacity import (

    retry,

    stop_after_attempt,

    wait_random_exponential,

)  

from openai import OpenAI

from transformers import AutoTokenizer

import logging 

import openai 



logging.basicConfig(level=logging.INFO)



dotenv_path = Path('')

load_dotenv(dotenv_path=dotenv_path)

with open("", "") as f:

    api_key = f.read().strip()

with open("", "") as f:

    organization = f.read().strip()

    

client = OpenAI(

    api_key=api_key,

    organization=organization

)





def chat_gpt(prompt):

    response = client.chat.completions.create(

        model="",

        messages=[{"": "", "": prompt}]

    )

    return response.choices[0].message.content.strip()





@retry(wait=wait_random_exponential(min=2, max=60), stop=stop_after_attempt(15), retry_error_callback=(lambda e: isinstance(e, openai.InternalServerError)))

def chatcompletions_with_backoff(model, messages, n, **kwargs):

    if model in ["", "", ""]:

        return client.completions.create(

            model=model, 

            prompt=messages[0][''],

            n=n,

            **kwargs)

    else: 

        return client.chat.completions.create(

            model=model, 

            messages=messages,

            n=n,

            **kwargs)



class GPTModel:



    def __init__(self, model_name="",**kwargs):

        self.model_name = model_name

        self.tokenizer = AutoTokenizer.from_pretrained("")



    def generate(self, prompt: str, 

                 num_samples=1, 

                 max_new_tokens=50, 

                 top_k=None, 

                 top_p=1.0, 

                 do_sample=True, 

                 return_full_text=False,  

                 temperature=1.0, 

                 return_dict_in_generate=False, 

                 batch_size=-1):

        

        if return_dict_in_generate or return_full_text or top_k:

            raise NotImplementedError("")

        

        if not isinstance(prompt, str):

            raise ValueError("")

        

        if batch_size == -1:

            batch_size = num_samples



        message = {"": "",

                    "": prompt} 

        completions = []



        while len(completions) < num_samples:

            this_batch_size = min(batch_size, num_samples - len(completions))

            

            try: 

                response = chatcompletions_with_backoff(

                    model=self.model_name,

                    messages=[message],

                    n=this_batch_size,

                    max_tokens=max_new_tokens,

                    top_p=top_p,

                    temperature=temperature if do_sample else 0,

                    

                )

                completions.extend(response.choices)

            except openai.InternalServerError as e:

                logging.error(f"Internal Server Error: {e}")

                continue

            

            

        if self.model_name in ["", "", ""]:

            gen_text = [choice.text.strip() for choice in completions]

        else: 

            gen_text = [choice.message.content.strip() for choice in completions]

            

        

        

        

        

        

        

        

        

        

        

        

        

            

        

        

        

        

        

        return gen_text
