import os
from pathlib import Path

from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
from torchvision import transforms
from PIL import Image


class CUB200(VisionDataset):
    image_url = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz"
    seg_url = "https://data.caltech.edu/records/w9d68-gec53/files/segmentations.tgz?download=1"
    image_tgz = "CUB_200_2011.tgz"
    seg_tgz = "segmentations.tgz"
    extracted_folder = "CUB_200_2011"

    def __init__(self, root, split="train", transform=None, segmentation_transform=None):
        super().__init__(root, transform=transform)
        self.segmentation_transform = segmentation_transform
        self.split = split.lower()
        self.root = Path(root)
        self.dataset_dir = self.root / self.extracted_folder
        self.segmentation_dir = self.dataset_dir / "segmentations"
        self._download()
        self._load_metadata()

    def _check_exists(self) -> bool:
        return (self.dataset_dir / "images.txt").exists() and (self.segmentation_dir / "images").exists()

    def _download(self):
        if self._check_exists():
            return

        # Download and extract image dataset
        download_and_extract_archive(
            self.image_url,
            download_root=str(self.root),
            filename=self.image_tgz
        )

        # Download and extract segmentation data
        download_and_extract_archive(
            self.seg_url,
            download_root=str(self.dataset_dir),
            filename=self.seg_tgz
        )

    def _load_metadata(self):
        with open(self.dataset_dir / "images.txt") as f:
            image_map = {int(k): v for k, v in (line.strip().split() for line in f)}

        with open(self.dataset_dir / "image_class_labels.txt") as f:
            labels = {int(k): int(v) - 1 for k, v in (line.strip().split() for line in f)}

        with open(self.dataset_dir / "train_test_split.txt") as f:
            split_info = {int(k): int(v) for k, v in (line.strip().split() for line in f)}

        self.samples = []
        for idx in image_map:
            is_train = split_info[idx] == 1
            if (self.split == "train" and is_train) or (self.split == "test" and not is_train):
                img_path = self.dataset_dir / "images" / image_map[idx]
                seg_path = self.segmentation_dir / image_map[idx].replace(".jpg", ".png")
                label = labels[idx]
                self.samples.append((img_path, seg_path, label))

    def __getitem__(self, index):
        img_path, seg_path, label = self.samples[index]

        image = Image.open(img_path).convert("RGB")

        segmentation = Image.open(seg_path).convert("L")

        if self.transform:
            image = self.transform(image)

        if self.segmentation_transform:
            segmentation = self.segmentation_transform(segmentation)

        return {
            "x": image,
            "y": label,
            "segmentation": segmentation > 0
        }

    def __len__(self):
        return len(self.samples)

