"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import random

from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from lavis.datasets.datasets.retrieval_datasets import (
    RetrievalDataset,
    RetrievalEvalDataset,
    VideoRetrievalDataset,
    VideoRetrievalEvalDataset,
    MllmuRetrievalDataset,
    MllmuRetrievalEvalDataset,
    MllmuMixedRetrievalDataset,
    MllmuMixedRetrievalEvalDataset,
)
from lavis.datasets.datasets.mllmu_datasets import MllmuRetrievalConsUnlearnDataset, MllmuMatchingEvalDataset
from lavis.datasets.datasets.mllmu_flickr_mixed_datasets import MllmuFlickrMixConsUnlearnDataset, MllmuFlickrMixConsUnlearnEvalDataset
import lavis.common.utils as utils

from lavis.common.registry import registry

import datasets as ds
import json
import re
import os
import warnings


@registry.register_builder("msrvtt_retrieval")
class MSRVTTRetrievalBuilder(BaseDatasetBuilder):
    train_dataset_cls = VideoRetrievalDataset
    eval_dataset_cls = VideoRetrievalEvalDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/msrvtt/defaults_ret.yaml"}


@registry.register_builder("didemo_retrieval")
class DiDeMoRetrievalBuilder(BaseDatasetBuilder):
    train_dataset_cls = VideoRetrievalDataset
    eval_dataset_cls = VideoRetrievalEvalDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/didemo/defaults_ret.yaml"}


@registry.register_builder("coco_retrieval")
class COCORetrievalBuilder(BaseDatasetBuilder):
    train_dataset_cls = RetrievalDataset
    eval_dataset_cls = RetrievalEvalDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/coco/defaults_ret.yaml"}


@registry.register_builder("flickr30k")
class Flickr30kBuilder(BaseDatasetBuilder):
    train_dataset_cls = RetrievalDataset
    eval_dataset_cls = RetrievalEvalDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/flickr30k/defaults.yaml"}

@registry.register_builder("flickr30k_virtual")
class Flickr30kVirtualBuilder(BaseDatasetBuilder):
    train_dataset_cls = RetrievalDataset
    eval_dataset_cls = RetrievalEvalDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/flickr30k_virtual/defaults.yaml"}

@registry.register_builder("mllmu")
class MllmuBuilder(BaseDatasetBuilder):

    def build_datasets(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building datasets for MLLMU-Bench from Huggingface.co ...")
        datasets = self.build()
        return datasets

    def build(self):
        """
        Create by split datasets inheriting torch.utils.data.Datasets.

        # build() can be dataset-specific. Overwrite to customize.
        """
        self.build_processors()

        build_info = self.config.build_info

        ann_info = build_info.annotations

        datasets = dict()

        print("Loading dataset splits ...")
        raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "Full_Set")
        for split in ann_info.keys():
            if split not in ["train", "val", "test"]:
                continue

            is_train = split == "train"

            # processors
            vis_processor = (
                self.vis_processors["train"]
                if is_train
                else self.vis_processors["eval"]
            )
            text_processor = (
                self.text_processors["train"]
                if is_train
                else self.text_processors["eval"]
            )

            if is_train:
                datasets[split] = MllmuRetrievalDataset(vis_processor, text_processor, raw_dataset)
            else:
                datasets[split] = MllmuRetrievalEvalDataset(vis_processor, text_processor, raw_dataset)

        return datasets

    DATASET_CONFIG_DICT = {"default": "configs/datasets/mllmu/defaults.yaml"}

@registry.register_builder("mllmu_retrieval_cons_unlearn")
class MllmuRetrievalConsUnlearnBuilder(MllmuBuilder):
    def build_datasets(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building datasets for MLLMU-Bench for constructive unlearning...")
        datasets = self.build()
        return datasets

    def build(self):
        """
        Create by split datasets inheriting torch.utils.data.Datasets.

        # build() can be dataset-specific. Overwrite to customize.
        """
        self.build_processors()

        all_splits = ["full", "test", "forget_5", "forget_10", "forget_15", "retain_95", "retain_90", "retain_85",
                      "train_5", "full+train_5"]
        forget_rate = 5
        splits = [f"train_{forget_rate}", f"full+train_{forget_rate}"]

        datasets = dict()

        # print("Loading dataset splits ...")
        df_raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "forget_5")
        df_IDs = df_raw_dataset["train"]["ID"]

        for split in splits:
            is_train = split[:5] == "train"

            # processors
            vis_processor = (
                self.vis_processors["train"]
                if is_train
                else self.vis_processors["eval"]
            )
            text_processor = (
                self.text_processors["train"]
                if is_train
                else self.text_processors["eval"]
            )

            if is_train or split in ["train"]:
                virtual_dataset = ds.load_from_disk(f"path_to_mllmu_unlearn_{forget_rate}_5captions_text_mixed")

                dr_raw_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "retain_95")

                raw_dataset = virtual_dataset
            elif split == 'df':
                raw_dataset = df_raw_dataset
            elif split[:10] == 'full+train':
                # raw_dataset = dr_raw_dataset

                # for df_in_full+virtual test
                full_dataset = ds.load_dataset("MLLMMU/MLLMU-Bench", "Full_Set")

                split_threshold = df_raw_dataset["train"].num_rows
                text_replaced = virtual_dataset["train"].select(range(split_threshold))
                image_replaced = virtual_dataset["train"].select(range(split_threshold, virtual_dataset["train"].num_rows))

                answer = []
                for ans in text_replaced["answer"]:
                    answer += re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', ans[:-1])  # -1 to omit the blank
                answer.append(answer[-1])
                assert image_replaced.num_rows == len(answer), f"Rows: {image_replaced.num_rows}, Num_answers: {len(answer)}"
                all_replaced = image_replaced.remove_columns(["answer"])
                all_replaced = all_replaced.add_column("answer", answer)
                raw_dataset = ds.DatasetDict({"train": ds.concatenate_datasets([full_dataset["train"], all_replaced])})

            else:
                print(f"Warning: Split {split} is not supported. Building for this split is skipped.")

            if is_train:
                datasets["train"] = datasets[split] = MllmuRetrievalConsUnlearnDataset(vis_processor, text_processor, raw_dataset, df_raw_dataset, dr_raw_dataset)
            else:
                datasets[split] = MllmuRetrievalEvalDataset(vis_processor, text_processor, raw_dataset, df_IDs)

        return datasets

    DATASET_CONFIG_DICT = {"default": "configs/datasets/mllmu/defaults_cons_unlearn.yaml"}

@registry.register_builder("mllmu_flickr30k_mix")
class MllmuFlickrMixBuilder(Flickr30kBuilder, MllmuBuilder):
    train_dataset_cls = RetrievalDataset
    eval_dataset_cls = RetrievalEvalDataset
    mix_train_dataset_cls = MllmuMixedRetrievalDataset
    mix_eval_dataset_cls = MllmuMixedRetrievalEvalDataset
    def build_datasets(self):
        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building flickr30K dataset...")
        flickr = Flickr30kBuilder.build(self)

        logging.info("Building MLLMMU dataset...")
        mllmu = MllmuBuilder.build(self)

        build_info = self.config.build_info
        n_flickr = build_info.flickr_samples
        n_mllmu = build_info.mllmu_samples

        ann_info = build_info.annotations

        datasets = dict()
        for split in ann_info.keys():
            if split not in ["train", "val", "test"]:
                continue

            is_train = split == "train"

            dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls

            df_indices = ['001', '011', '030', '036', '041', '044', '055', '058', '079', '083', '092', '102', '123', '124', '135', '153', '158', '166', '186', '192', '249', '250', '257', '260', '272', '277', '291', '307', '312', '321', '328', '369', '370', '375', '391', '399', '414', '418', '426', '431', '446', '449', '450', '468', '471', '473', '476', '486', '489', '498']
            # create datasets
            if is_train:
                datasets[split] = self.mix_train_dataset_cls(flickr[split], mllmu[split], n_flickr, n_mllmu, mllmu_img_indices=df_indices)
            else:
                assert "train" in datasets.keys()
                datasets[split] = self.mix_eval_dataset_cls(flickr[split], mllmu[split], datasets["train"])

        return datasets

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/mllmu_flickr_mix/defaults.yaml"
    }

from torch.utils.data import Subset

@registry.register_builder("mllmu_flickr30k_mix_cons_unlearn")
class MllmuFlickrMixConsUnlearnBuilder(Flickr30kBuilder, MllmuBuilder):
    """
    Build a mixed Flickr30k + MLLMU dataset with unlearning/forgetting support.
    """
    train_dataset_cls = RetrievalDataset
    eval_dataset_cls = RetrievalEvalDataset
    unlearn_mix_cls = None    # placeholder

    def build_datasets(self):
        logging.info("Building Flickr30k + MLLMU-Unlearn mixed dataset...")
        # load original splits
        flickr = Flickr30kBuilder.build(self)
        mllmu = MllmuBuilder.build(self)

        # load external virtual-unlearned data
        virtual_ds = ds.load_from_disk(self.config.virtual_path)
        # offset for replaced modality detection
        offset = 30000

        # randomly select dr samples from Flickr30k equal to len(virtual)
        n_virtual = len(virtual_ds["train"])
        flickr_train = flickr["train"]
        dr_indices = random.Random(self.config.seed).sample(range(len(flickr_train)), n_virtual)
        dr_ds = Subset(flickr_train, dr_indices)

        # ----- Custom test split -----
        ID2cpts = {}
        for ann in mllmu["train"].annotation:
            if ann["image_id"] in ID2cpts:
                ID2cpts[ann["image_id"]].append(ann["caption"])
            else:
                ID2cpts[ann["image_id"]] = [ann["caption"]]

        flickr_test = flickr["test"]
        virtual_train = virtual_ds["train"]
        replaced_pairs = []
        for img, img_ID, ans in zip(virtual_train["image"], virtual_train["ID"], virtual_train["answer"]):
            if int(img_ID) < offset:
                replaced_pairs.append({"image": img, "image_id": img_ID, "caption": ID2cpts[img_ID]})

        test_ds = MllmuFlickrMixConsUnlearnEvalDataset(
            vis_processor=mllmu["train"].vis_processor,
            text_processor=mllmu["train"].text_processor,
            vis_root=flickr_test.vis_root,
            flickr_test=flickr_test,
            virtual_train=virtual_train,
            replaced_pairs=replaced_pairs
        )

        return {
            "train": MllmuFlickrMixConsUnlearnDataset(
                virtual_ds["train"], dr_ds,
                mllmu_split=mllmu["train"],
                flickr_split=flickr["train"],
                offset=offset,
            ),
            "val": flickr["val"],
            "test": test_ds
        }

    def _find_in_mllmu_image(self, caption, annotations):
        for ex in annotations:
            if ex["caption"] == caption:
                return ex["image"]
        raise ValueError(f"Image not found for caption '{caption}'")

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/mllmu_flickr_mix/cons_unlearn.yaml"
    }