from tqdm import tqdm
from openai import OpenAI
import openai
import backoff, base64
import os, sys, pathlib, json, pdb
import concurrent.futures
import math


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

class ParallelGPT():
    def __init__(self, model_id):
        self.model_id = model_id
        self.client = OpenAI(api_key=API_KEY)

    def completion_with_backoff(self, **kwargs):
        while (1):
            try:
                return self.client.chat.completions.create(**kwargs)
            except:
                pass


    def generate(self, text, image=None, max_new_tokens=4096, temperature=0, num_return_sequences=1, progress_bar=True, **kwargs):
        if isinstance(text, str):
            text = [text]

        partition = [128] * (num_return_sequences // 128) + [num_return_sequences % 128]
        partition = partition[:math.ceil(num_return_sequences/128)]

        if image is not None:
            if isinstance(image, str):
                image = [image]
            assert len(text) == len(image)

            def process_text_and_image(t, i, idx):
                base64_image = encode_image(i)
                completion = self.completion_with_backoff(
                    model=self.model_id,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text", "text": t
                                },
                                {
                                    "type": "image_url",
                                    "image_url": 
                                    {
                                        "url": f"data:image/jpeg;base64,{base64_image}"
                                    },
                                },
                            ],
                        }
                    ],
                    max_tokens=max_new_tokens,
                    temperature=temperature,
                    n=num_return_sequences,
                    **kwargs
                )
                return (completion, idx)


            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = [executor.submit(process_text_and_image, t, i, idx) for idx, t, i in zip(range(len(text)), text, image)]
                completions = []
                for future in concurrent.futures.as_completed(futures):
                    completions.append(future.result())

            completions_sorted = sorted(completions, key=lambda x: x[1])
            responses = [[completion[0].choices[i].message.content for i in range(num_return_sequences)] for completion in completions_sorted]
            completions = [completion[0] for completion in completions_sorted]


            return {'responses': responses, 'completions': completions}

        else:

            def process_text(t, idx):
                responses = []
                completions = []
                for n in partition:
                    completion = self.completion_with_backoff(
                        model=self.model_id,
                        messages=[
                            {
                                "role": "user",
                                "content": t,
                            }
                        ],
                        max_tokens=max_new_tokens,
                        temperature=temperature,
                        n=n,
                        **kwargs
                    )
                    completions.append(completion)
                    responses += [completion.choices[i].message.content for i in range(n)]

                return (responses, completion, idx)


            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = [executor.submit(process_text, t, idx) for idx, t in enumerate(text)]
                outputs = []
                if progress_bar:
                    for future in tqdm(concurrent.futures.as_completed(futures)):
                        outputs.append(future.result())
                else:
                    for future in concurrent.futures.as_completed(futures):
                        outputs.append(future.result())

            outputs_sorted = sorted(outputs, key=lambda x: x[2])
            responses = [output[0] for output in outputs_sorted]
            completions = [output[1] for output in outputs_sorted]


            return {'responses': responses, 'completions': completions}
