import random
import re
import os
from PIL import Image
from torch.utils.data._utils.collate import default_collate


class MllmuFlickrMixConsUnlearnDataset:
    def __init__(self, virtual_ds, dr_ds, mllmu_split, flickr_split,
                 offset=30000):
        self.virtual = virtual_ds
        self.dr = dr_ds
        self.mllmu = mllmu_split
        self.flickr = flickr_split
        self.offset = offset

        # build mapping from ID to original MLLMU answer strings
        self.ID2answer = {}
        for ex in self.mllmu.annotation:
            # if ex["image_id"] in self.ID2answer.keys():
            #     self.ID2answer[ex["image_id"]] += " " + ex["caption"]
            # else:
            # todo: fast implement, only select the first text for each img
            if ex["image_id"] not in self.ID2answer.keys():
                self.ID2answer[ex["image_id"]] = ex["caption"]

        # prepare combined annotations
        self.annotation = []  # list of dicts: image, caption, df_image, df_text, image_id
        inst_id = 0
        for img, img_id, ans in zip(self.virtual["image"], self.virtual["ID"], self.virtual["answer"]):
            # split current answer into captions
            captions = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', ans)

            # determine replaced modality by ID offset
            if int(img_id) >= offset:
                # image replaced, store original image in df_image
                assert len(captions) == 1, f"Assert only 1 caption, but found {len(captions)}."
                orig_ans = captions[0]
                df_image = self._find_original_image(orig_ans)
                df_text = None
            else:
                # text replaced, store original text in df_text
                assert img_id in self.ID2answer.keys()
                df_image = None
                df_text = self.ID2answer.get(img_id)

            for cap in captions:
                self.annotation.append({
                    "image": img,
                    "caption": cap,
                    "image_id": img_id,
                    "df_image": df_image,
                    "df_text": df_text,
                    "instance_id": inst_id
                })
                ann_t = {
                    "image": img,
                    "caption": cap,
                    "image_id": img_id,
                    "df_image": df_image,
                    "df_text": df_text,
                    "instance_id": inst_id
                }
                inst_id += 1

        # append dr samples (Flickr) as additional
        for ann in self.dr:
            # Subset yields tuples (img, annotations, ...); adjust indexing if needed
            self.annotation.append({
                "image": ann["image"],
                "caption": ann["text_input"],
                "image_id": ann["image_id"],
                "df_image": None,
                "df_text": None,
                "instance_id": inst_id
            })
            inst_id += 1

    def _find_original_image(self, orig_ans):
        # lookup in flickr split by ID field
        for ex in self.mllmu.annotation:
            if orig_ans in ex["caption"]:
                return ex["image"]
        print(f"Warning: Fail to found original image for caption '{orig_ans}' in mllmu")
        return None

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, idx):
        ann = self.annotation[idx]
        is_dr = ann["df_image"] is None and ann["df_text"] is None
        if is_dr:
            img = ann["image"]
        else:
            img = self.mllmu.vis_processor(ann["image"])
        txt = self.mllmu.text_processor(ann["caption"])
        df_i = None
        df_t = None
        if ann["df_image"] is not None:
            df_i = self.mllmu.vis_processor(ann["df_image"])
        if ann["df_text"] is not None:
            df_t = self.mllmu.text_processor(ann["df_text"])
        return {
            "image": img,
            "text_input": txt,
            "image_id": int(ann["image_id"]),
            "instance_id": ann["instance_id"],
            "df_image": df_i,
            "df_text": df_t
        }

    def collater(self, samples):
        df_images = [s.pop("df_image") for s in samples]
        df_texts = [s.pop("df_text") for s in samples]
        batch = default_collate(samples)
        batch["df_image"] = df_images
        batch["df_text"] = df_texts
        return batch


class MllmuFlickrMixConsUnlearnEvalDataset:
    """
    Custom eval dataset combining:
      1) Flickr30k test splits
      2) Virtual train samples
      3) Replaced-modality pairs
    Stores images and texts in lists and builds img2txt/txt2img mappings.
    """
    def __init__(self, vis_processor, text_processor, vis_root,
                 flickr_test, virtual_train, replaced_pairs):
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.vis_root = vis_root
        self.ind2ID = {}

        # 1) Flickr test: image is path; caption list from text_input
        self.image = flickr_test.image # list of image paths or PIL images
        self.text = flickr_test.text  # list of processed text inputs
        self.img2txt = flickr_test.img2txt  # map image_idx -> list of txt_idx
        self.txt2img = flickr_test.txt2img  # map txt_idx -> image_idx
        assert hasattr(flickr_test, "ind2ID")
        self.ind2ID = flickr_test.ind2ID # map image_idx -> image_id in annotation

        # Helper to add samples
        txt_id = len(self.text)
        def add_sample(img_obj, image_id, captions):
            nonlocal txt_id
            idx = len(self.image)
            self.image.append(img_obj)
            self.img2txt[idx] = []
            self.ind2ID[idx] = image_id
            for cap in captions:
                proc_txt = self.text_processor(cap)
                self.text.append(proc_txt)
                self.img2txt[idx].append(txt_id)
                self.txt2img[txt_id] = idx
                txt_id += 1

        # 2) Df pairs: each dict has 'image' (PIL or path) and 'caption'
        self.df_img_inds = []
        for pair in replaced_pairs:
            self.df_img_inds.append(len(self.image))
            add_sample(pair["image"], pair["image_id"], pair["caption"])

        # 3) Virtual train: images are PIL, answers may have multiple sentences but treat full answer
        off_set = 30000
        for img, ID, ans in zip(virtual_train["image"], virtual_train["ID"], virtual_train["answer"]):
            # no img2txt mapping because virtual sample does not exist in both df and dr
            if int(ID) >= off_set:  # img_replaced
                self.image.append(img)
            else:
                captions = re.findall(r'\s*(.+?(?:\.|$))(?=\s|$)', ans)
                for cpt in captions:
                    self.text.append(cpt)

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        img_obj = self.image[index]
        # Load if path
        if isinstance(img_obj, str):
            full = os.path.join(self.vis_root, img_obj)
            img_pil = Image.open(full).convert("RGB")
        else:
            img_pil = img_obj
        img = self.vis_processor(img_pil)
        return {"image": img, "index": index}
