import os
import json

from dataset.base import BaseDataset

class LLaVAPretrainDataset(BaseDataset):
    def __init__(self):
        super(LLaVAPretrainDataset, self).__init__()
        self.ann = json.load(open(f"./data/llava_v1_5_mix665k.json", 'r'))
        self.img_root = f"./data/"
         
    def get_data(self):
        data = [
            {
                "img_path": os.path.join(self.img_root, ins['image']),
                "question": ins['conversations'][0]['value'].replace("<image>", "").strip(),
                "label": ins['conversations'][1]['value']
            }
            for ins in self.ann if 'image' in ins.keys()
        ]
        return data, []