"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import datasets as ds
from lavis.common.registry import registry
from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from lavis.datasets.datasets.nlvr_datasets import NLVRDataset, NLVREvalDataset
from lavis.datasets.datasets.snli_ve_datasets import SNLIVisualEntialmentDataset
from lavis.datasets.datasets.mllmu_datasets import MllmuClassificationDataset, MllmuClassificationConsUnlearnDataset, MllmuClassificationEvalMiniDataset


@registry.register_builder("nlvr")
class NLVRBuilder(BaseDatasetBuilder):
    train_dataset_cls = NLVRDataset
    eval_dataset_cls = NLVREvalDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/nlvr/defaults.yaml"}


@registry.register_builder("snli_ve")
class SNLIVisualEntailmentBuilder(BaseDatasetBuilder):
    train_dataset_cls = SNLIVisualEntialmentDataset
    eval_dataset_cls = SNLIVisualEntialmentDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/snli_ve/defaults.yaml"}

@registry.register_builder("mllmu_classification")
class MllmuClassificationBuilder(BaseDatasetBuilder):
    def build_datasets(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building datasets for MLLMU-Bench from Huggingface.co ...")
        datasets = self.build()  # dataset['train'/'val'/'test']  ds.load_dataset("MLLMMU/MLLMU-Bench", "Full_Set")   #
        return datasets

    def build(self):
        """
        Create by split datasets inheriting torch.utils.data.Datasets.

        # build() can be dataset-specific. Overwrite to customize.
        """
        self.build_processors()

        build_info = self.config.build_info

        ann_info = build_info.annotations  # todo: maube can remove info

        datasets = dict()
        for split in ann_info.keys():
            if split not in ["train", "val", "test"]:
                continue

            is_train = split == "train"

            # processors
            vis_processor = (
                self.vis_processors["train"]
                if is_train
                else self.vis_processors["eval"]
            )
            text_processor = (
                self.text_processors["train"]
                if is_train
                else self.text_processors["eval"]
            )

            if is_train:
                raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "Full_Set")
                datasets[split] = MllmuClassificationDataset(vis_processor, text_processor, raw_dataset)
            else:
                raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "Test_Set")
                datasets[split] = MllmuClassificationEvalMiniDataset(vis_processor, text_processor, raw_dataset)

        return datasets

    DATASET_CONFIG_DICT = {"default": "configs/datasets/mllmu/defaults_classification.yaml"}


@registry.register_builder("mllmu_classification_cons_unlearn")
class MllmuClassificationConsUnlearnBuilder(MllmuClassificationBuilder):
    def build(self):
        """
        Create by split datasets inheriting torch.utils.data.Datasets.

        # build() can be dataset-specific. Overwrite to customize.
        """
        self.build_processors()

        build_info = self.config.build_info

        ann_info = build_info.annotations  # todo: maube can remove info

        df_raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "forget_5")
        dr_raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "retain_95")
        test_raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "Test_Set")

        datasets = dict()
        for split in ann_info.keys():
            if split not in ["train", "val", "test"]:
                continue

            is_train = split == "train"

            # processors
            vis_processor = (
                self.vis_processors["train"]
                if is_train
                else self.vis_processors["eval"]
            )
            text_processor = (
                self.text_processors["train"]
                if is_train
                else self.text_processors["eval"]
            )

            if is_train or split in ["train"]:
                raw_dataset = ds.load_from_disk(
                    "./output/virtual_unlearn_datasets/mllmu_unlearn_5")

            elif split == 'val':
                raw_dataset = df_raw_dataset
            else:
                raw_dataset = test_raw_dataset

            if is_train:
                datasets[split] = MllmuClassificationConsUnlearnDataset(vis_processor, text_processor, raw_dataset, dr_raw_dataset)
            elif split == 'val':
                datasets[split] = MllmuClassificationDataset(vis_processor, text_processor, raw_dataset)
            else:
                datasets[split] = MllmuClassificationEvalMiniDataset(vis_processor, text_processor, raw_dataset)

        return datasets

    DATASET_CONFIG_DICT = {"default": "configs/datasets/mllmu/defaults_classification_cons_unlearn.yaml"}

# @registry.register_builder("mllmu_classification_f5")
# class MllmuClassificationForget5Builder(MllmuClassificationBuilder):
#
#     def build(self):
#         """
#         Create by split datasets inheriting torch.utils.data.Datasets.
#
#         # build() can be dataset-specific. Overwrite to customize.
#         """
#         self.build_processors()
#
#         build_info = self.config.build_info
#
#         ann_info = build_info.annotations  # todo: maube can remove info
#
#         datasets = dict()
#         for split in ann_info.keys():
#             if split not in ["train", "val", "test"]:
#                 continue
#
#             is_train = split == "train"
#
#             # processors
#             vis_processor = (
#                 self.vis_processors["train"]
#                 if is_train
#                 else self.vis_processors["eval"]
#             )
#             text_processor = (
#                 self.text_processors["train"]
#                 if is_train
#                 else self.text_processors["eval"]
#             )
#
#             if is_train:
#                 raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "forget_5")
#                 datasets[split] = MllmuClassificationDataset(vis_processor, text_processor, raw_dataset)
#             else:
#                 raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "Test_Set")
#                 datasets[split] = MllmuClassificationEvalMiniDataset(vis_processor, text_processor, raw_dataset)
#
#         return datasets
#
#     DATASET_CONFIG_DICT = {"default": "configs/datasets/mllmu/defaults_classification.yaml"}
#
#
# @registry.register_builder("mllmu_classification_f10")
# class MllmuClassificationForget10Builder(MllmuClassificationBuilder):
#
#     def build(self):
#         """
#         Create by split datasets inheriting torch.utils.data.Datasets.
#
#         # build() can be dataset-specific. Overwrite to customize.
#         """
#         self.build_processors()
#
#         build_info = self.config.build_info
#
#         ann_info = build_info.annotations  # todo: maube can remove info
#
#         datasets = dict()
#         for split in ann_info.keys():
#             if split not in ["train", "val", "test"]:
#                 continue
#
#             is_train = split == "train"
#
#             # processors
#             vis_processor = (
#                 self.vis_processors["train"]
#                 if is_train
#                 else self.vis_processors["eval"]
#             )
#             text_processor = (
#                 self.text_processors["train"]
#                 if is_train
#                 else self.text_processors["eval"]
#             )
#
#             if is_train:
#                 raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "forget_10")
#                 datasets[split] = MllmuClassificationDataset(vis_processor, text_processor, raw_dataset)
#             else:
#                 raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "Test_Set")
#                 datasets[split] = MllmuClassificationEvalMiniDataset(vis_processor, text_processor, raw_dataset)
#
#         return datasets
#
#     DATASET_CONFIG_DICT = {"default": "configs/datasets/mllmu/defaults_classification.yaml"}
#
#
# @registry.register_builder("mllmu_classification_f15")
# class MllmuClassificationForget15Builder(MllmuClassificationBuilder):
#
#     def build(self):
#         """
#         Create by split datasets inheriting torch.utils.data.Datasets.
#
#         # build() can be dataset-specific. Overwrite to customize.
#         """
#         self.build_processors()
#
#         build_info = self.config.build_info
#
#         ann_info = build_info.annotations  # todo: maube can remove info
#
#         datasets = dict()
#         for split in ann_info.keys():
#             if split not in ["train", "val", "test"]:
#                 continue
#
#             is_train = split == "train"
#
#             # processors
#             vis_processor = (
#                 self.vis_processors["train"]
#                 if is_train
#                 else self.vis_processors["eval"]
#             )
#             text_processor = (
#                 self.text_processors["train"]
#                 if is_train
#                 else self.text_processors["eval"]
#             )
#
#             if is_train:
#                 raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "forget_15")
#                 datasets[split] = MllmuClassificationDataset(vis_processor, text_processor, raw_dataset)
#             else:
#                 raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "Test_Set")
#                 datasets[split] = MllmuClassificationEvalMiniDataset(vis_processor, text_processor, raw_dataset)
#
#         return datasets
#
#     DATASET_CONFIG_DICT = {"default": "configs/datasets/mllmu/defaults_classification.yaml"}