"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
import random
import re
import copy

from collections import OrderedDict

from lavis.datasets.datasets.base_dataset import BaseDataset, default_collate
from lavis.datasets.datasets.retrieval_datasets import MllmuRetrievalDataset, MllmuMixedRetrievalDataset, MllmuMixedRetrievalEvalDataset

class MllmuRetrievalConsUnlearnDataset(MllmuRetrievalDataset):
    def __init__(self, vis_processor, text_processor, raw_dataset, df_raw_dataset, dr_raw_dataset=None):
        """
        """
        # super().__init__(vis_processor, text_processor, vis_root, ann_paths)
        self.vis_processor = vis_processor
        self.text_processor = text_processor

        self.annotation = []
        instance_cnt = 0
        for image, image_id, answer in zip(raw_dataset["train"]["image"], raw_dataset["train"]["ID"], raw_dataset["train"]["answer"]):
            if answer:
                captions = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', answer)
                for caption in captions:
                    # todo: fast implement to avoid empty strings
                    if caption != ' ':
                        self.annotation.append({"image": image, "caption": caption, "image_id": image_id, "instance_id": instance_cnt})
                        instance_cnt += 1
            else:
                print(f"Omitted missing description for image {image_id}")

        self.ID2Df = {}
        # text_replaced
        for row in df_raw_dataset["train"]:
            self.ID2Df[row["ID"]] = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', row["answer"])
        # image_replaced
        ID2img = {raw["ID"]: raw["image"] for raw in df_raw_dataset["train"]}
        ID2text = copy.deepcopy(self.ID2Df)
        for ann in self.annotation:
            if int(ann["image_id"]) >= 1000:
                for ori_ID, cpts in ID2text.items():
                    uncased_cpts = [text_processor(t) for t in cpts]
                    if ann["caption"] in cpts or ann["caption"] in uncased_cpts:
                        self.ID2Df[ann["image_id"]] = ID2img[ori_ID]

        print(f"Decomposed MLLMU-Bench_df to {instance_cnt} instances.")

        instance_cnt = 0
        if dr_raw_dataset is not None:
            # todo: make this configable
            selected_dr = dr_raw_dataset["train"].shuffle(seed=42).select(range(df_raw_dataset["train"].num_rows))
            for image, image_id, answer in zip(selected_dr["image"], selected_dr["ID"],
                                               selected_dr["answer"]):
                if answer:
                    captions = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', answer)
                    for caption in captions:
                        self.annotation.append(
                            {"image": image, "caption": caption, "image_id": image_id, "instance_id": instance_cnt})
                        instance_cnt += 1
                else:
                    print(f"Omitted missing description for image {image_id}")

        print(f"Decomposed MLLMU-Bench_dr to {instance_cnt} instances.")

    def __getitem__(self, index):
        ann = self.annotation[index]
        image = ann["image"]
        image = self.vis_processor(image)
        caption = self.text_processor(ann["caption"])

        df_text = df_image = None
        if int(ann["image_id"]) < 1000 and ann["image_id"] in self.ID2Df.keys():  # text_replaced
            df_text = self.text_processor(random.choice(self.ID2Df[ann["image_id"]]))
        elif int(ann["image_id"]) >= 1000:                            # image_replaced
            df_image = self.vis_processor(self.ID2Df[ann["image_id"]])

        return {
            "image": image,
            "text_input": caption,
            "image_id": int(ann["image_id"]),
            "instance_id": ann["instance_id"],
            "df_text": df_text,
            "df_image": df_image,
        }

    def __len__(self):
        return len(self.annotation)

    def collater(self, samples):
        """Costomize collater to collate enforced neg data"""
        df_text = [s["df_text"] for s in samples]
        df_image = [s["df_image"] for s in samples]
        for s in samples:
            del s["df_text"]
            del s["df_image"]

        collated_samples = default_collate(samples)
        collated_samples["df_text"] = df_text
        collated_samples["df_image"] = df_image
        return collated_samples


class MllmuClassificationDataset:
    def __init__(self, vis_processor, text_processor, raw_dataset):
        """
        """
        self.class_labels = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'None of the above': 4}
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.annotation = []
        image_key = "image" if "image" in raw_dataset["train"].features.keys() else "images"
        for image, image_id, QAs in zip(raw_dataset["train"][image_key], raw_dataset["train"]["ID"], raw_dataset["train"]["Classification_Task"]):

            if image_key == "image":
                images = [image]
            else:
                images = image

            for image in images:
                for QA in QAs["Image_Textual_Questions"]:
                    sentence = QA["Question"]
                    for opt, opt_text in QA["Options"].items():
                        sentence += f" {opt}: {opt_text}"
                    self.annotation.append(
                        {"image": image, "sentence": sentence, "label": QA["Correct_Answer"], "image_id": image_id, "instance_id": len(self.annotation)})

        print(f"Decomposed MLLMU-Bench to {len(self.annotation)} classification instances.")

    def __getitem__(self, index):
        ann = self.annotation[index]

        image = ann["image"]
        image = self.vis_processor(image)

        sentence = self.text_processor(ann["sentence"])
        label = self.class_labels[ann["label"]]

        return {
                "image": image,
                "text_input": sentence,
                "label": label,
                "image_id": ann["image_id"],
                "instance_id": ann["instance_id"],
                }

    def __len__(self):
        return len(self.annotation)

    def collater(self, samples):
        return default_collate(samples)


class MllmuClassificationConsUnlearnDataset(MllmuClassificationDataset):
    def __init__(self, vis_processor, text_processor, raw_dataset, dr_raw_dataset=None):
        """
        """
        super().__init__(vis_processor, text_processor, raw_dataset)
        if dr_raw_dataset is not None:
            # todo: make this configable
            image_key = "image"
            selected_dr = dr_raw_dataset["train"].shuffle(seed=42).select(range(raw_dataset["train"].num_rows))
            for image, image_id, QAs in zip(selected_dr[image_key], selected_dr["ID"],
                                            selected_dr["Classification_Task"]):
                if image_key == "image":
                    images = [image]
                else:
                    images = image

                for image in images:
                    for QA in QAs["Image_Textual_Questions"]:
                        sentence = QA["Question"]
                        for opt, opt_text in QA["Options"].items():
                            sentence += f" {opt}: {opt_text}"
                        self.annotation.append(
                            {"image": image, "sentence": sentence, "label": QA["Correct_Answer"], "image_id": image_id,
                             "instance_id": len(self.annotation)})
            print(f"Expanded MLLMU-Bench_with_dr to {len(self.annotation)} classification instances.")

class MllmuClassificationEvalMiniDataset:
    def __init__(self, vis_processor, text_processor, raw_dataset):
        """
        """
        self.class_labels = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'None of the above': 4}
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.annotation = []
        raw_dataset = raw_dataset["train"][:len(raw_dataset["train"])//5]
        for images, image_id, QAs in zip(raw_dataset["images"], raw_dataset["ID"], raw_dataset["Classification_Task"]):
            for image in images:
                for QA in QAs["Image_Textual_Questions"]:
                    sentence = QA["Question"]
                    for opt, opt_text in QA["Options"].items():
                        sentence += f" {opt}: {opt_text}"
                    self.annotation.append(
                        {"image": image, "sentence": sentence, "label": QA["Correct_Answer"], "image_id": image_id, "instance_id": len(self.annotation)})
        print(f"Decomposed MLLMU-Bench to {len(self.annotation)} classification instances.")

    def __getitem__(self, index):
        ann = self.annotation[index]

        image = ann["image"]
        image = self.vis_processor(image)

        sentence = self.text_processor(ann["sentence"])
        label = self.class_labels[ann["label"]]

        return {
                "image": image,
                "text_input": sentence,
                "label": label,
                "image_id": ann["image_id"],
                "instance_id": ann["instance_id"],
                }

    def __len__(self):
        return len(self.annotation)

    def collater(self, samples):
        return default_collate(samples)

class MllmuMatchingDataset(BaseDataset):
    def __init__(self, vis_processor, text_processor, raw_dataset):
        """
        """
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.annotation = []
        instance_cnt = 0
        for image, image_id, cls_ann in zip(raw_dataset["train"]["image"], raw_dataset["train"]["ID"], raw_dataset["train"]["Classification_Task"]):
            if cls_ann:
                for item in cls_ann["Image_Textual_Questions"]:
                    for opt in item["Options"]:
                        self.annotation.append({"image": image, "caption": opt, "image_id": image_id, "instance_id": instance_cnt})
                        instance_cnt += 1
            else:
                print(f"Omitted missing description for image {image_id}")
        print(f"Decomposed MLLMU-Bench to {instance_cnt} instances.")

    def __getitem__(self, index):
        ann = self.annotation[index]
        image = ann["image"]
        image = self.vis_processor(image)
        caption = self.text_processor(ann["caption"])
        return {
            "image": image,
            "text_input": caption,
            "image_id": int(ann["image_id"]),
            "instance_id": ann["instance_id"],
        }
    def __len__(self):
        return len(self.annotation)


class MllmuMatchingEvalDataset(BaseDataset):
    def __init__(self, vis_processor, text_processor, raw_dataset, df_IDs=None):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        """
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.annotation = []
        instance_cnt = 0
        opt2ind = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'None of the above': 4}
        # if "images" in raw_dataset["train"].features.keys():
        #     for images, image_id, answer in zip(raw_dataset["train"]["images"], raw_dataset["train"]["ID"],
        #                                        raw_dataset["train"]["answer"]):
        #         if answer:
        #             captions = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', answer)
        #             for i, image in enumerate(images):
        #                 num_cpts = len(captions)
        #                 self.annotation.append({"image": image, "captions": captions, "image_id": image_id, "instance_id": instance_cnt+i})
        #                 instance_cnt += num_cpts
        #                 break  #  TODO: fast implement to prevent repeated text in eval
        #         else:
        #             print(f"Omitted missing description for image {image_id}")
        # else:
        for image, image_id, cls_ann in zip(raw_dataset["train"]["image"], raw_dataset["train"]["ID"],
                                            raw_dataset["train"]["Classification_Task"]):
            if cls_ann["Image_Textual_Questions"] != []:  # annotation of image_id '159' in mllmu is empty
                correct_inds = []
                options = []
                for item in cls_ann["Image_Textual_Questions"]:
                    options += item["Options"].values()
                    correct_inds.append(opt2ind[item["Correct_Answer"]])
                self.annotation.append(
                    {"image": image, "caption": options, "correct_inds": correct_inds, "image_id": image_id})
            else:
                print(f"Omitted missing annotation for image {image_id}")
        print(f"Decomposed MLLMU-Bench to {len(self.annotation)} test samples.")

        all_image_IDs = [ann["image_id"] for ann in self.annotation]
        if df_IDs is not None:
            self.df_img_inds = []
            for image_ID in df_IDs:
                img_df_ind = all_image_IDs.index(image_ID)
                self.df_img_inds.append(img_df_ind)

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann["image"])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann["caption"]):
                self.text.append(self.text_processor(caption))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1
    def __getitem__(self, index):
        image = self.vis_processor(self.annotation[index]["image"])
        return {"image": image, "index": index, "image_id": int(self.annotation[index]["image_id"])}

class MllmuVqaDataset(BaseDataset):
    def __init__(self, vis_processor, text_processor, raw_dataset):
        """
        """
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.annotation = []
        for image, QAs in zip(raw_dataset["train"]["image"], raw_dataset["train"]["Classification_Task"]):
            for QA in QAs["Image_Textual_Questions"]:
                self.annotation.append(
                    {"image": image, "question": QA["Question"], "answer": QA["Options"]})
        print(f"Decomposed MLLMU-Bench to {len(self.annotation)} IT-QA pairs.")

    def __getitem__(self, index):
        ann = self.annotation[index]
        image = ann["image"]

        image = self.vis_processor(image)
        question = self.text_processor(ann["question"])

        answer_weight = {}
        for answer in ann["answer"]:
            if answer in answer_weight.keys():
                answer_weight[answer] += 1 / len(ann["answer"])
            else:
                answer_weight[answer] = 1 / len(ann["answer"])

        answers = list(answer_weight.keys())
        weights = list(answer_weight.values())

        return {
            "image": image,
            "text_input": question,
            "answers": answers,
            "weights": weights,
        }

    def __len__(self):
        return len(self.annotation)


class MixedMllmuDataset(MllmuMixedRetrievalDataset):
    """only returns mllmu samples in mixed dataset"""
    # todo: flickr30k is not need to load
    def __len__(self):
        return len(self.mllmu_indices)

    def __getitem__(self, index):
        original_idx = self.mllmu_indices[index]

        data = self.mllmu_dataset[original_idx]
        data['image_id'] += self.mllmu_image_offset
        data['instance_id'] = str(int(data['instance_id']) + self.mllmu_instance_offset)
        return data

class MixedMllmuEvalDataset(MllmuMixedRetrievalEvalDataset):
    """only returns mllmu samples in mixed dataset"""
    # todo: flickr30k is not need to load
    def __len__(self):
        return len(self.image)-1000 # todo: for simplicity

    def __getitem__(self, index):
        image = self.mllmu_eval_ds.vis_processor(self.image[index+1000]) # todo: for simplicity
        return {"image": image, "index": index+1000, "image_id": self.index2ID[index+1000]}

