import os
import sys
import time
from functools import partial
import argparse
import numpy as np

import openai
from openai import OpenAI

from multiprocessing import Pool
import json

from tqdm import tqdm


openai.api_key = "YOUR_OPENAI_API_KEY"


def generate(system_prompt, instance, blip_caption, coco_caption, output_dir):

    retry_count = 1000
    retry_interval = 1

    text = format_prompt(instance, blip_caption, coco_caption)

    client = OpenAI(
        api_key="YOUR_OPENAI_API_KEY",
    )

    for _ in range(retry_count):
        try:
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo-0125",
                messages=[
                    {
                        "role": "system",
                        "content": system_prompt
                        },
                    {
                        "role": "user",
                        "content": text
                        }
                    ],
                stream=False,
                temperature=0
            )
            answer = completion.choices[0].message.content.strip()
            with open(output_dir, 'w') as f:
                f.write(answer)
            return

        except TimeoutError:
            print("Time Out", output_dir)
            print('Retrying....')
            retry_count += 1
            retry_interval += 1
            time.sleep(retry_interval)

        except Exception as e:
            print("Error：", e)
            print('Retrying....')
            retry_count += 1
            retry_interval += 1
            time.sleep(retry_interval)

    return


def format_prompt(instance, blip_caption, coco_caption):
    # print(instance, blip_caption, coco_caption)
    assert instance['image_id'] == blip_caption['image_id'], "The image id is not equal."
    if coco_caption is not None:
        assert instance['image_id'] == coco_caption['image_id'], "The image id is not equal."

    text = ''
    if coco_caption is not None:
        for cap in coco_caption['coco_caption']:
            cap = cap.strip().strip('\n').strip('\r')
            text += f'{cap}\n'

    for cap in blip_caption['captions']:
        cap = cap.strip().strip('\n').strip('\r')
        text += f'{cap}\n'

    text += '\n'
    for inst in instance['instances']:
        category = inst['label_str']
        bbox = (np.round(inst['bbox'], 2))
        bbox = np.maximum(bbox, 0).tolist()
        text += f'{category}: {bbox}\n'

    return text


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--dataset",
        type=str,
        default='nsd',
    )

    parser.add_argument(
        "--prompt_type",
        type=str,
        default='conversation',
    )

    args = parser.parse_args()

    root_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}'

    system_prompt = open(f'../../playground/data/prompts/{args.prompt_type}/system_message.txt', 'r').read()

    instances = json.load(open(f'{root_dir}/{args.dataset}_instances.json', 'r'))
    blip_captions = json.load(open(f'{root_dir}/{args.dataset}_captions.json', 'r'))
    if os.path.isfile(f'{root_dir}/{args.dataset}_coco_captions.json'):
        coco_caption = json.load(open(f'{root_dir}/{args.dataset}_coco_captions.json', 'r'))
        assert len(coco_caption) == len(blip_captions), "The number of coco and blip captions are not equal."
    else:
        coco_caption = None

    assert len(instances) == len(blip_captions), "The number of instances and captions are not equal."

    output_dir = f'{root_dir}/nsd_gpt_conversation/{args.dataset}_{args.prompt_type}'
    os.makedirs(output_dir, exist_ok=True)

    length = len(instances)

    pool = Pool(32)
    pbar = tqdm(total=len(instances))
    for index in range(length):

        output_fname = f'{output_dir}/{args.dataset}_{args.prompt_type}_{index:06}.txt'

        if os.path.isfile(output_fname) and os.stat(output_fname).st_size != 0:
            pbar.update()
            continue
        pool.apply_async(
            func=generate,
            args=(
                system_prompt,
                instances[index],
                blip_captions[index],
                coco_caption[index] if coco_caption is not None else None,
                output_fname
            ),
            callback=lambda *args: pbar.update()
        )


    pool.close()
    pool.join()


if __name__ == '__main__':
    main()
