class CC12MCLIP5KSIGLIPValRealImageDataset(RealImageDataset):
    """CC12M validation return RealImageDatapoints with tensor"""

    def __init__(self, img_size=(256, 256), complexity=1):
        self.complexity = complexity

        self.IMG_ROOT_PATH, self.ANNOT_VAL_PATH = self.get_dataset_path()

        with open(self.ANNOT_VAL_PATH, "rb") as f:
            self.labels = pickle.load(f)

        self.labels = self.labels[self.complexity]
        self.image_ids = list(self.labels.keys())

        if isinstance(img_size, tuple):
            self.img_size = img_size
        else:
            self.img_size = (img_size, img_size)
        assert len(self.img_size) == 2

        self.to_tensor = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Lambda(lambda t: (t * 2) - 1)]
        )

        self.targeted_folder = f"/tmp/job{os.environ['SLURM_JOB_ID']}"

    def get_dataset_path(self) -> tuple[str, str]:
        """Returns IMG_ROOT_VAL, ANNOT_VAL"""
        config_path = os.path.join(os.path.dirname(__file__), "paths.yaml")

        with open(config_path) as f:
            data = yaml.safe_load(f)

        paths = data["CC12M_CLIP5K_SIGLIP"]
        return paths["IMG_ROOT_VAL"], paths["ANNOT_VAL"]

    def __getitem__(self, idx) -> RealImageDatapoint:
        image_id = self.image_ids[idx]
        item = self.labels[image_id]
        # img = pil_loader(item["image_path"])
        try:
            img_path = f"{self.targeted_folder}/eval/{image_id}.jpg"
            img = pil_loader_v2(Path(img_path), max_size=self.img_size)
        except FileNotFoundError:
            print("DATA COPY failed")
            exit()
        except Exception as e:
            print(f"An unexpected error occurred: {e}")

        img_cropped = img.resize(self.img_size).convert("RGB")
        img_tensor = self.to_tensor(img_cropped)

        datapoint = RealImageDatapoint(image=img_tensor,
                                       class_label=item["caps"])
        return datapoint

    def __len__(self):
        return len(self.image_ids)