import json
from pathlib import Path

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url

from data.utils import pre_caption


class cirr_split(Dataset):
    def __init__(self, transform, image_root, ann_root, split, data="", max_words=30):
        """
        image_root (string): Root directory of images (e.g. CIRR/images/). It should have train, dev, and test1 folders.
        ann_root (string): directory to store the annotation file
        """
        urls = {
            "train": "cap.rc2.train.json",
            "val": "cap.rc2.val.json",
            "test": "cap.rc2.test1.json",
        }
        if data:
            urls["train"] = f"cap.rc2.train.{data}.json"

        self.annotation = json.load(open(Path(ann_root, Path(urls[split]).name), "r"))
        self.transform = transform
        self.max_words = max_words
        self.split = split
        self.pairid2ref = {ann["pairid"]: ann["reference"] for ann in self.annotation}
        self.pairid2members = {
            ann["pairid"]: ann["img_set"]["members"] for ann in self.annotation
        }
        if split != "test":
            self.pairid2tar = {
                ann["pairid"]: ann["target_hard"] for ann in self.annotation
            }
        else:
            self.pairid2tar = None

        # Create a dictionary of image paths
        split_dict = {
            "train": "train",
            "val": "dev",
            "test": "test1",
        }
        self.image_root = Path(image_root) / split_dict[split]
        assert self.image_root.exists(), f"Image root {self.image_root} does not exist"
        if split == "train":
            img_pths = self.image_root.glob("*/*.png")
        else:
            img_pths = self.image_root.glob("*.png")
        self.id2pth = {img_pth.stem: img_pth for img_pth in img_pths}
        for ann in self.annotation:
            assert (
                ann["reference"] in self.id2pth
            ), f"Path to reference {ann['reference']} not found"
            if split != "test":
                assert (
                    ann["target_hard"] in self.id2pth
                ), f"Path to target {ann['target_hard']} not found"

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        ann = self.annotation[index]

        reference_img_pth = self.id2pth[ann["reference"]]
        reference_img = Image.open(reference_img_pth).convert("RGB")
        reference_img = self.transform(reference_img)

        caption = pre_caption(ann["caption"], self.max_words)

        if self.split == "test":
            return reference_img, caption, ann["pairid"]

        target_img_pth = self.id2pth[ann["target_hard"]]
        target_img = Image.open(target_img_pth).convert("RGB")
        target_img = self.transform(target_img)

        return reference_img, target_img, caption, ann["pairid"]


class CIRREmbsDataset(Dataset):
    def __init__(
        self, transform, image_root, ann_root, split, data="", max_words=30, vit="large"
    ):
        """
        image_root (string): Root directory of images (e.g. CIRR/images/). It should have train, dev, and test1 folders.
        ann_root (string): directory to store the annotation file
        """
        assert vit in [
            "base",
            "large",
        ], f"vit should be either base or large, not {vit}"
        urls = {
            "train": "cap.rc2.train.json",
            "val": "cap.rc2.val.json",
            "test": "cap.rc2.test1.json",
        }
        if data:
            urls["train"] = f"cap.rc2.train.{data}.json"

        self.annotation = json.load(open(Path(ann_root, Path(urls[split]).name), "r"))
        self.transform = transform
        self.max_words = max_words
        self.split = split
        self.pairid2ref = {ann["pairid"]: ann["reference"] for ann in self.annotation}
        self.pairid2members = {
            ann["pairid"]: ann["img_set"]["members"] for ann in self.annotation
        }
        if split != "test":
            self.pairid2tar = {
                ann["pairid"]: ann["target_hard"] for ann in self.annotation
            }
        else:
            self.pairid2tar = None

        # Create a dictionary of image paths
        split_dict = {
            "train": "train",
            "val": "dev",
            "test": "test1",
        }
        self.image_root = Path(image_root) / split_dict[split]
        self.emb_root = Path(image_root).parent / f"blip-embs-{vit}" / split_dict[split]

        assert self.image_root.exists(), f"Image root {self.image_root} does not exist"
        assert self.emb_root.exists(), f"Emb root {self.emb_root} does not exist"
        if split == "train":
            img_pths = self.image_root.glob("*/*.png")
            emb_pths = self.emb_root.glob("*/*.pth")
        else:
            img_pths = self.image_root.glob("*.png")
            emb_pths = self.emb_root.glob("*.pth")
        self.id2imgpth = {img_pth.stem: img_pth for img_pth in img_pths}
        self.id2embpth = {emb_pth.stem: emb_pth for emb_pth in emb_pths}
        for ann in self.annotation:
            assert (
                ann["reference"] in self.id2imgpth
            ), f"Path to reference {ann['reference']} not found in {self.image_root}"
            assert (
                ann["reference"] in self.id2embpth
            ), f"Path to reference {ann['reference']} not found in {self.emb_root}"
            if split != "test":
                assert (
                    ann["target_hard"] in self.id2imgpth
                ), f"Path to target {ann['target_hard']} not found"
                assert (
                    ann["target_hard"] in self.id2embpth
                ), f"Path to target {ann['target_hard']} not found"

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        ann = self.annotation[index]

        reference_img_pth = self.id2imgpth[ann["reference"]]
        reference_img = Image.open(reference_img_pth).convert("RGB")
        reference_img = self.transform(reference_img)

        caption = pre_caption(ann["caption"], self.max_words)

        if self.split == "test":
            reference_feat = torch.load(self.id2embpth[ann["reference"]])
            return reference_img, reference_feat, caption, ann["pairid"]

        target_emb_pth = self.id2embpth[ann["target_hard"]]
        target_feat = torch.load(target_emb_pth).cpu()

        soft_input = ann["img_set"]["id"]
        soft_target = ann["pairid"]

        return (
            reference_img,
            target_feat,
            caption,
            ann["pairid"],
            soft_input,
            soft_target,
        )
