import json
from pathlib import Path

import torch
from PIL import Image
from torch.utils.data import Dataset

from data.utils import pre_caption


class FashionIQEmbsDataset(Dataset):
    def __init__(
        self, transform, image_root, ann_root, split, data="", max_words=30, vit="large"
    ):
        """
        image_root (string): Root directory of images (e.g. CIRR/images/). It should have train, dev, and test1 folders.
        ann_root (string): directory to store the annotation file
        """
        assert vit in [
            "base",
            "large",
        ], f"vit should be either base or large, not {vit}"
        annotation_path = Path(ann_root, Path(f"cap.{data}.{split}.json").name)
        assert (
            annotation_path.exists()
        ), f"Annotation file {annotation_path} does not exist"
        assert split in ["train", "val", "test"], f"split should be train, test or val"

        tar_path = Path(ann_root, Path(f"split.{data}.{split}.json").name)
        assert tar_path.exists(), f"tar file {tar_path} does not exist"
        self.target_ids = json.load(open(tar_path, "r"))

        self.annotation = json.load(open(annotation_path, "r"))
        self.transform = transform
        self.max_words = max_words
        self.split = split
        self.data = data
        self.pairid2ref = {
            id: ann["candidate"] for id, ann in enumerate(self.annotation)
        }
        if split != "test":
            self.pairid2tar = {
                id: ann["target"] for id, ann in enumerate(self.annotation)
            }
        else:
            self.pairid2tar = None

        self.image_root = Path(image_root)
        self.emb_root = Path(image_root).parent / f"blip-embs-{vit}"

        assert self.image_root.exists(), f"Image root {self.image_root} does not exist"
        assert self.emb_root.exists(), f"Emb root {self.emb_root} does not exist"
        img_pths = self.image_root.glob("*.png")
        emb_pths = self.emb_root.glob("*.pth")

        self.id2imgpth = {img_pth.stem: img_pth for img_pth in img_pths}
        self.id2embpth = {emb_pth.stem: emb_pth for emb_pth in emb_pths}
        for ann in self.annotation:
            assert (
                ann["candidate"] in self.id2imgpth
            ), f"Path to candidate {ann['candidate']} not found in {self.image_root}"
            assert (
                ann["candidate"] in self.id2embpth
            ), f"Path to candidate {ann['candidate']} not found in {self.emb_root}"
            if split != "test":
                assert (
                    ann["target"] in self.id2imgpth
                ), f"Path to target {ann['target']} not found"
                assert (
                    ann["target"] in self.id2embpth
                ), f"Path to target {ann['target']} not found"

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        ann = self.annotation[index]

        candidate_img_pth = self.id2imgpth[ann["candidate"]]
        candidate_img = Image.open(candidate_img_pth).convert("RGB")
        candidate_img = self.transform(candidate_img)

        cap1, cap2 = ann["captions"]
        caption = f"{cap1} and {cap2}"
        caption = pre_caption(caption, self.max_words)

        if self.split in ["val", "test"]:
            return candidate_img, caption, index

        target_emb_pth = self.id2embpth[ann["target"]]
        target_feat = torch.load(target_emb_pth).cpu()

        soft_inputs = index
        soft_targets = index + 100_000
        return candidate_img, target_feat, caption, index, soft_inputs, soft_targets
