from utils import *
from configs import *
import google.generativeai as palm
import openai
from tqdm import tqdm
import warnings
import concurrent.futures
from requests.exceptions import ReadTimeout
import random
import time
warnings.simplefilter('always')  # Always display all warnings

class LmOutput:
    def __init__(self, text, usage):
        self.text = text
        self.usage = usage
    
    def __str__(self):
        return self.text

class TimeoutException(Exception):
    pass

def run_with_timeout(func, timeout, *args, **kwargs):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        future = executor.submit(func, *args, **kwargs)
        try:
            return future.result(timeout=timeout)
        except concurrent.futures.TimeoutError:
            raise TimeoutException("Function exceeded the timeout limit")

class LM:
    def __init__(self, args):
        self.model = args.model
        self.max_length = args.lm_output_max_length
        # self.api_key_idx = 0 if not args.model.startswith("palm") else str(args.model.split("-")[-1])

    def process(self, prompts, desc="LM Completion"):
        api_usage = 0
        output_list = []
        for prompt in tqdm(prompts, desc=desc, leave=True):
            output = self.completion(input_text=prompt)
            output_text = output.text.split("Question:")[0].strip()
            output_list.append(output_text)
            api_usage += output.usage
        return {"output_list": output_list, "api_usage": api_usage}

    def completion(self, input_text):
        temperature = 0.0
        if self.model.startswith("gpt"):
            output = self.openai_completion(input_text, temperature=temperature)
        elif self.model.startswith("palm"):
            output = self.palm_completion(input_text, temperature=temperature)
        else:
            raise NotImplementedError
        return output

    def openai_completion(self, input_text, temperature=0.0):
        openai.api_key = os.getenv("OPENAI_API_KEY")
        engine = "text-davinci-003"
        attempt = 0
        while True:
            try:
                request_kwargs = {
                    'model': engine,
                    'prompt': input_text,
                    'max_tokens': self.max_length,
                    'temperature': temperature,
                    'stop': None,
                }
                response = run_with_timeout(openai.Completion.create, 60, **request_kwargs)
                output_text = response["choices"][0]["text"]
                usage = response["usage"]["total_tokens"]
                break
            except Exception as e:
                attempt += 1
                warnings.warn(f"Error type: {type(e)}, Error message: {e}, Attempt: {attempt}")
                time.sleep((1.2 ** attempt) % 30)
        return LmOutput(output_text, usage)

    def palm_completion(self, input_text, temperature=0.0):
        # api_key = os.getenv(f"PALM_API_KEY_{self.api_key_idx}")
        api_key = os.getenv(f"PALM_API_KEY")
        palm.configure(api_key=api_key)
        engine = "models/text-bison-001"
        attempt = 0
        while True:
            try:
                # print(f"Input text: {input_text}")
                response = palm.generate_text(
                    model=engine,
                    prompt=input_text,
                    temperature=temperature,
                    max_output_tokens=self.max_length,
                    safety_settings=PALM_SAFETY_SETTINGS,
                )
                output_text = response.result
                if len(response.filters) > 0 and response.result is None:
                    warnings.warn("This sample is considered harmful content.")
                    output_text = "This sample is considered harmful content."
                break
            except Exception as e:
                attempt += 1
                warnings.warn(f"Error type: {type(e)}, Error message: {e}, Attempt: {attempt}")
                time.sleep((1.5 ** attempt) % 30)
                if attempt > 3000:
                    output_text = ""
                    break
        return LmOutput(output_text, len(input_text)+len(output_text))