
import json
import os
import sys
import time
import argparse

from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import torch

from transformers import DetrImageProcessor, DetrForObjectDetection, GPT2TokenizerFast

from llava.datasets.fmri_vit3d_datasets import fMRIViT3dDataset
from llava.train import DataArguments

parser = argparse.ArgumentParser()
parser.add_argument(
    "--dataset",
    type=str,
    default='nsd',
)

args = parser.parse_args()


if __name__ == '__main__':
    source = json.load(open(f'/mnt/NSD_dataset/datasets/{args.dataset}/{args.dataset}_captions.json', 'r'))

    tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

    result = []

    for i, item in enumerate(tqdm(source)):
        caption = item['captions']
        tokens = []
        for capt in caption:
            token = tokenizer(capt, return_tensors='pt', padding='max_length', max_length=77, truncation=True)
            tokens.append({
                'input_ids': token['input_ids'].tolist(),
                'attention_mask': token['attention_mask'].tolist(),
            })
        result.append({
            'image_id': item['image_id'],
            'tokens': tokens,
        })
        # break

    with open(f'/mnt/NSD_dataset/datasets/{args.dataset}/{args.dataset}_gpt2_tokens.json', 'w') as f:
        json.dump(result, f, indent=4)

