import os
import tarfile
import requests
from scipy.io import loadmat
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

def download_and_extract_flower17_dataset(root: str = "data/flowers17"):
    """
    Downloads and extracts the 17flowers.tgz and datasplits.mat into `root/`.
    """
    os.makedirs(root, exist_ok=True)

    # 1) Download images archive
    images_url = "https://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz"
    images_tgz = os.path.join(root, "17flowers.tgz")
    if not os.path.exists(images_tgz):
        print("Downloading 17flowers.tgz...")
        r = requests.get(images_url, stream=True)
        r.raise_for_status()
        with open(images_tgz, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    else:
        print("Images archive already exists.")

    # 2) Extract images
    extract_dir = os.path.join(root, "jpg")
    if not os.path.isdir(extract_dir):
        print("Extracting images...")
        with tarfile.open(images_tgz, "r:gz") as tar:
            tar.extractall(path=root)
    else:
        print("Images already extracted.")

    # 3) Download datasplits.mat
    splits_url = "https://www.robots.ox.ac.uk/~vgg/data/flowers/17/datasplits.mat"
    splits_mat = os.path.join(root, "datasplits.mat")
    if not os.path.exists(splits_mat):
        print("Downloading datasplits.mat...")
        r = requests.get(splits_url, stream=True)
        r.raise_for_status()
        with open(splits_mat, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    else:
        print("datasplits.mat already exists.")

    print("Dataset ready in", root)




class Flower17Dataset(Dataset):
    def __init__(
        self,
        root: str,
        split: str = "train",
        split_id: int = 1,
        transform=None,
        image_size: int = 256
    ):
        """
        split: one of 'train', 'val', 'trainval', or 'test'
        split_id: 1, 2, or 3
        """
        allowed = ("train", "val", "trainval", "test")
        assert split in allowed, f"split must be one of {allowed}"
        assert split_id in (1,2,3), "split_id must be 1, 2 or 3"

        self.root = root
        self.transform = transform or transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])

        # Load the MATLAB splits
        mat = loadmat(
            os.path.join(root, "datasplits.mat"),
            squeeze_me=True,
            struct_as_record=False
        )

        # Determine which index arrays to concat
        if split == "trainval":
            trn = mat[f"trn{split_id}"]
            val = mat[f"val{split_id}"]
            idxs = list(trn) + list(val)
        else:
            key_map = {
                "train": f"trn{split_id}",
                "val":   f"val{split_id}",
                "test":  f"tst{split_id}"
            }
            key = key_map[split]
            if key not in mat:
                raise KeyError(f"Split '{key}' not found in datasplits.mat")
            idxs = mat[key]

        # Build dataset
        img_dir = os.path.join(root, "jpg")
        if not os.path.isdir(img_dir):
            raise FileNotFoundError(f"No 'jpg/' folder at {img_dir}")

        self.samples = []
        self.labels  = []
        for m_idx in idxs:
            index = int(m_idx)  # 1-based
            fname  = f"image_{index:04d}.jpg"
            path   = os.path.join(img_dir, fname)
            if not os.path.isfile(path):
                raise FileNotFoundError(f"Expected image at {path}")
            label = (index - 1) // 80  # 80 images per class → 0..16
            self.samples.append(path)
            self.labels.append(label)

        if not self.samples:
            raise RuntimeError(f"No images found for split={split}{split_id}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img = Image.open(self.samples[idx]).convert("RGB")
        img = self.transform(img)
        return img, self.labels[idx]