"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
from collections import OrderedDict
import json
from torch.utils.data import Dataset
from lavis.datasets.datasets.base_dataset import BaseDataset
from PIL import Image
from torch.utils.data.dataloader import default_collate
# from lavis.datasets.datasets.base_prompt import *
from lavis.common.registry import registry

class __DisplMixin:
    def displ_item(self, index):
        sample, ann = self.__getitem__(index), self.annotation[index]

        return OrderedDict(
            {
                "file": ann["image"],
                "caption": ann["caption"],
                "image": sample["image"],
            }
        )
 
class RefineDataset(Dataset, __DisplMixin):
    def __init__(
        self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
    ):
        self.annotation = []
        for ann_path in ann_paths:
            f = json.load(open(ann_path, "r"))
            for ann in f:
                instance_ann = {}
                instance_ann['image_path'] = ann['image']
                instance_ann['instruction'] = ann['text_input']
                instance_ann['caption'] = ann['text_output']
                self.annotation.append(instance_ann)
        
        self.vis_processor = vis_processor
        self.text_processor = text_processor

        self._add_instance_ids()

    def __len__(self):
        return len(self.annotation)

    def collater(self, samples):
        return default_collate(samples)

    def set_processors(self, vis_processor, text_processor):
        self.vis_processor = vis_processor
        self.text_processor = text_processor

    def _add_instance_ids(self, key="instance_id"):
        for idx, ann in enumerate(self.annotation):
            ann[key] = str(idx)

    def __getitem__(self, index):

        # TODO this assumes image input, not general enough
        ann = self.annotation[index]

        image_path = ann['image_path'] 
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)

        instruction_text = ann['instruction']
        if '<image>\n' in instruction_text:
            instruction_text = instruction_text.strip('<image>\n')
        if '\n<image>' in instruction_text:
            instruction_text = instruction_text.strip('\n<image>')

        caption_text = ann['caption']
        
        instruction = self.text_processor(instruction_text)
        caption = self.text_processor(caption_text)

        return {
            "image": image,
            "text_input": instruction,
            "text_output": caption,
            "image_path": image_path,
            "instance_id": ann["instance_id"],
        }

