"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
from collections import OrderedDict
import re
import random

import torch

from lavis.datasets.datasets.base_dataset import BaseDataset
from PIL import Image


class __DisplMixin:
    def displ_item(self, index):
        sample, ann = self.__getitem__(index), self.annotation[index]
        visual_key = "image" if "image" in ann else "video"

        return OrderedDict(
            {
                "file": ann[visual_key],
                "caption": ann["caption"],
                visual_key: sample[visual_key],
            }
        )


class RetrievalDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

        self.img_ids = {}
        n = 0
        for ann in self.annotation:
            img_id = ann["image_id"]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1

    def __getitem__(self, index):

        ann = self.annotation[index]

        image_path = os.path.join(self.vis_root, ann["image"])
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        caption = self.text_processor(ann["caption"])

        return {
            "image": image,
            "text_input": caption,
            "image_id": self.img_ids[ann["image_id"]],
            "instance_id": ann["instance_id"],
        }


class RetrievalEvalDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        """

        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}
        self.ind2ID = {}  # added to record index to image filename mapping

        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann["image"])
            self.img2txt[img_id] = []
            self.ind2ID[img_id] = ann["image"][:-4]
            for i, caption in enumerate(ann["caption"]):
                self.text.append(self.text_processor(caption))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1

    def __getitem__(self, index):

        image_path = os.path.join(self.vis_root, self.annotation[index]["image"])
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)

        return {"image": image, "index": index}


class VideoRetrievalDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
        """
        vis_root (string): Root directory of videos.
        ann_root (string): directory to store the annotation file
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

        self.img_ids = {}
        n = 0
        for ann in self.annotation:
            img_id = ann["video"]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1

    def __getitem__(self, index):

        ann = self.annotation[index]

        vpath = os.path.join(self.vis_root, ann["video"])

        video = self.vis_processor(vpath)
        caption = self.text_processor(ann["caption"])

        # return image, caption, self.img_ids[ann['image_id']]
        return {
            "video": video,
            "text_input": caption,
            "image_id": self.img_ids[ann["video"]],
        }


class VideoRetrievalEvalDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
        """
        vis_root (string): Root directory of videos.
        ann_root (string): directory to store the annotation file
        split (string): val or test
        """

        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann["video"])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann["caption"]):
                self.text.append(self.text_processor(caption))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1

    def __getitem__(self, index):
        ann = self.annotation[index]

        vpath = os.path.join(self.vis_root, ann["video"])
        video = self.vis_processor(vpath)

        return {"video": video, "index": index}


class MllmuRetrievalDataset(BaseDataset):
    def __init__(self, vis_processor, text_processor, raw_dataset):
        """
        """
        # super().__init__(vis_processor, text_processor, vis_root, ann_paths)
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.img_ids = {}  # only used for get image_nums
        self.annotation = []
        instance_cnt = 0
        for i, (image, image_id, answer) in enumerate(zip(raw_dataset["train"]["image"], raw_dataset["train"]["ID"], raw_dataset["train"]["answer"])):
            # for bio in bios:
            #     bio_dict = json.loads(bio)
            #     assert isinstance(bio_dict, dict), "Biography should loaded as dict."
            self.img_ids[image_id] = i
            if answer:
                captions = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', answer)
                for caption in captions:
                    self.annotation.append({"image": image, "caption": caption, "image_id": image_id, "instance_id": instance_cnt})
                    instance_cnt += 1
                    # break  # todo: temp code for text_gene
            else:
                print(f"Omitted missing description for image {image_id}")
        print(f"Decomposed MLLMU-Bench to {instance_cnt} instances.")
        # self.img_ids = {}
        # n = 0
        # for ann in self.annotation:
        #     img_id = ann["image_id"]
        #     if img_id not in self.img_ids.keys():
        #         self.img_ids[img_id] = n
        #         n += 1
    def __getitem__(self, index):
        ann = self.annotation[index]
        image = ann["image"]
        image = self.vis_processor(image)
        caption = self.text_processor(ann["caption"])
        return {
            "image": image,
            "text_input": caption,
            "image_id": int(ann["image_id"]),
            "instance_id": ann["instance_id"],
        }
    def __len__(self):
        return len(self.annotation)


class MllmuRetrievalEvalDataset(BaseDataset):
    def __init__(self, vis_processor, text_processor, raw_dataset, df_IDs=None):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        """
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.annotation = []
        instance_cnt = 0
        if "images" in raw_dataset["train"].features.keys():
            for images, image_id, answer in zip(raw_dataset["train"]["images"], raw_dataset["train"]["ID"],
                                               raw_dataset["train"]["answer"]):
                if answer:
                    captions = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', answer)
                    for i, image in enumerate(images):
                        num_cpts = len(captions)
                        self.annotation.append({"image": image, "captions": captions, "image_id": image_id, "instance_id": instance_cnt+i})
                        instance_cnt += num_cpts
                        break  #  TODO: fast implement to prevent repeated text in eval
                else:
                    print(f"Omitted missing description for image {image_id}")
        else:
            for image, image_id, answer in zip(raw_dataset["train"]["image"], raw_dataset["train"]["ID"],
                                               raw_dataset["train"]["answer"]):
                if answer:
                    captions = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', answer)
                    num_cpts = len(captions)
                    self.annotation.append({"image": image, "captions": captions, "image_id": image_id, "instance_id": instance_cnt})
                    instance_cnt += num_cpts
                else:
                    print(f"Omitted missing description for image {image_id}")
        print(f"Decomposed MLLMU-Bench to {len(self.annotation)} test samples.")

        all_image_IDs = [ann["image_id"] for ann in self.annotation]
        if df_IDs is not None:
            self.df_img_inds = []
            for image_ID in df_IDs:
                img_df_ind = all_image_IDs.index(image_ID)
                self.df_img_inds.append(img_df_ind)

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann["image"])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann["captions"]):
                self.text.append(self.text_processor(caption))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1

    def __getitem__(self, index):
        image = self.vis_processor(self.annotation[index]["image"])
        return {"image": image, "index": index, "image_id": int(self.annotation[index]["image_id"])}

class MllmuMixedRetrievalDataset(BaseDataset):
    def __init__(self, retrieval_dataset, mllmu_dataset, retrieval_samples, mllmu_samples, mllmu_img_indices=None):
        """
        Args:
            retrieval_dataset (RetrievalDataset):
            mllmu_dataset (MllmuRetrievalDataset):
            retrieval_samples (int):
            mllmu_samples (int):
        """
        self.retrieval_dataset = retrieval_dataset
        self.mllmu_dataset = mllmu_dataset

        # instance-level select
        # self.retrieval_indices = random.sample(range(len(retrieval_dataset)), retrieval_samples)
        # self.mllmu_indices = random.sample(range(len(mllmu_dataset)), mllmu_samples)

        # image-level select
        self.retrieval_indices = []
        self.retrieval_img_indices = random.sample(list(retrieval_dataset.img_ids.keys()), retrieval_samples)
        for i, ann in enumerate(self.retrieval_dataset.annotation):
            if ann["image_id"] in self.retrieval_img_indices:
                self.retrieval_indices.append(i)
        self.mllmu_indices = []
        self.mllmu_img_indices = random.sample(list(mllmu_dataset.img_ids.keys()), mllmu_samples) \
            if mllmu_img_indices is None else mllmu_img_indices
        print(self.mllmu_img_indices)
        for i, ann in enumerate(self.mllmu_dataset.annotation):
            if ann["image_id"] in self.mllmu_img_indices:
                self.mllmu_indices.append(i)

        self._calculate_id_offsets()

        self.samples = []
        self.samples.extend([('retrieval', idx) for idx in self.retrieval_indices])
        self.samples.extend([('mllmu', idx) for idx in self.mllmu_indices])
        random.shuffle(self.samples)

    def _calculate_id_offsets(self):
        if self.retrieval_dataset.img_ids:
            self.mllmu_image_offset = max(self.retrieval_dataset.img_ids.values()) + 1
        else:
            self.mllmu_image_offset = 0

        self.mllmu_instance_offset = len(self.retrieval_dataset)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        dataset_type, original_idx = self.samples[index]

        if dataset_type == 'retrieval':
            return self.retrieval_dataset[original_idx]
        else:
            data = self.mllmu_dataset[original_idx]

            data['image_id'] += self.mllmu_image_offset
            data['instance_id'] = str(int(data['instance_id']) + self.mllmu_instance_offset)

            return data

    def set_processors(self, vis_processor, text_processor):
        self.retrieval_dataset.set_processors(vis_processor, text_processor)
        self.mllmu_dataset.set_processors(vis_processor, text_processor)


class MllmuMixedRetrievalEvalDataset(BaseDataset):
    def __init__(self, retrieval_eval_ds, mllmu_eval_ds, train_ds):
        """
        Args:
            retrieval_eval_ds (RetrievalEvalDataset): Initialized retrieval eval dataset
            mllmu_eval_ds (MllmuRetrievalEvalDataset): Full MLLMU eval dataset
            train_ds (MllmuMixedRetrievalDataset): Training dataset used in mixed training
                      (to identify which samples were actually included)
        """
        # Store original datasets
        self.retrieval_ds = retrieval_eval_ds
        self.mllmu_eval_ds = mllmu_eval_ds
        # Calculate ID offsets based on training data
        self._calculate_id_offsets(train_ds)
        # Build combined dataset structures

        self._build_combined_dataset(train_ds)

    def _calculate_id_offsets(self, train_ds):
        """Calculate ID offsets based on training dataset configuration"""
        # Get max image ID from retrieval training data
        self.image_id_offset = max(train_ds.mllmu_image_offset,
                                   max(self.retrieval_ds.img2txt.keys()) + 1) if self.retrieval_ds.img2txt else 0
        # Get max text ID from retrieval eval data
        max_retrieval_txt = max(self.retrieval_ds.txt2img.keys()) if self.retrieval_ds.txt2img else -1
        self.text_id_offset = max_retrieval_txt + 1

    def _build_combined_dataset(self, train_ds):
        """Combine datasets while maintaining evaluation structures"""
        # Get unique image IDs used in training
        train_mllmu_anns = []
        for dataset_type, original_idx in train_ds.samples:
            if dataset_type == 'mllmu':
                train_mllmu_anns.append(train_ds.mllmu_dataset.annotation[original_idx])
        # Initialize combined structures
        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        self._add_retrieval_data()

        self._add_mllmu_data(train_mllmu_anns)

    def _add_retrieval_data(self):
        """Incorporate retrieval dataset as-is"""
        # Directly extend base structures
        self.text.extend(self.retrieval_ds.text)
        self.image.extend(self.retrieval_ds.image)
        self.txt2img.update(self.retrieval_ds.txt2img)
        self.img2txt.update(self.retrieval_ds.img2txt)

    def _add_mllmu_data(self, train_anns):
        """Add filtered MLLMU data with ID adjustments"""
        # Track mapping between original and new IDs
        self.added_image_id = dict()  # image_id in ann: idx in sim matrix
        for ann in train_anns:
            new_txt_id = len(self.text)
            self.text.append(ann["caption"])
            # Create new image ID with offset
            if ann["image_id"] not in self.added_image_id.keys():
                new_img_id = len(self.image)
                self.image.append(ann["image"])
                self.added_image_id[ann["image_id"]] = new_img_id

                self.img2txt[new_img_id] = [new_txt_id]
                self.txt2img[new_txt_id] = new_img_id
            else:
                new_img_id = self.added_image_id[ann["image_id"]]
                self.img2txt[new_img_id].append(new_txt_id)
                self.txt2img[new_txt_id] = new_img_id

        self.index2ID = {ind: ID for ID, ind in self.added_image_id.items()}

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        if index < 1000:  # todo: for simplicity
            image_path = os.path.join(self.retrieval_ds.vis_root, self.image[index])
            image = Image.open(image_path).convert("RGB")
            image = self.retrieval_ds.vis_processor(image)
        else:
            image = self.mllmu_eval_ds.vis_processor(self.image[index])
        return {"image": image, "index": index}

