import os
import sys

import json
import argparse
from tqdm import tqdm
from transformers import SamModel, SamProcessor
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GPT2TokenizerFast
import torch

parser = argparse.ArgumentParser()

parser.add_argument(
    "--dataset",
    type=str,
    default='nsd',
)

parser.add_argument(
    "--caption_per_image",
    type=int,
    default=8,
)

parser.add_argument(
    "--gpu",
    type=int,
    default=0,
    help="gpu"
)

args = parser.parse_args()


# Load the model and processor

device = "cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu"


if __name__ == '__main__':
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b",
        load_in_8bit=False,
        device_map={"": args.gpu},
        torch_dtype=torch.bfloat16
    )
    model.to(device)

    root_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}'
    total = len(os.listdir(f'{root_dir}/images'))
    batch_size = 16

    captions = []

    for i in tqdm(range(0, total, batch_size)):
        image_ids = list(range(i, min(i + batch_size, total)))
        images = []
        for image_id in image_ids:
            image = Image.open(f'{root_dir}/images/{args.dataset}_image_{image_id:06}.png').convert('RGB')
            images.append(image)

        inputs = processor(images, return_tensors="pt").to(device, torch.float16)
        generated_ids = model.generate(
            **inputs,
            num_beams=2 * args.caption_per_image,
            temperature=0.7,
            num_return_sequences=args.caption_per_image,
            num_beam_groups=args.caption_per_image,
            diversity_penalty=0.5,
            max_new_tokens=77,
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

        for index in range(len(images)):
            captions.append({
                'image_id': image_ids[index],
                'captions': generated_text[index * args.caption_per_image: (index + 1) * args.caption_per_image]
            })

        # print(len(generated_text))
        # break

    output_file = f'{root_dir}/{args.dataset}_captions.json'
    with open(output_file, 'w') as f:
        json.dump(captions, f, indent=4)
