import os
import zipfile
import json

import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


def download_and_extract_butterfly_dataset():

    
    if (os.path.exists("data/butterfly/leedsbutterfly")):
        print("Dataset already downloaded and unzipped.")
        return
    out_dir = "data/butterfly"
        
    if (os.path.exists("data/butterfly/butterfly-dataset.zip")):
        print("Dataset already downloaded.")
    else:
        
        if not os.path.exists(".kaggle/kaggle.json"):
            raise FileNotFoundError("Please place your kaggle.json in the .kaggle folder in your home directory.")
        with open(".kaggle/kaggle.json") as f:
            cred = json.load(f)

        # 2. Export env vars *before* import/auth
        os.environ["KAGGLE_USERNAME"] = cred["username"]
        os.environ["KAGGLE_KEY"]      = cred["key"]

        # 3. Authenticate
        from kaggle.api.kaggle_api_extended import KaggleApi
        api = KaggleApi()
        api.authenticate()

        # 2. Define paths
        dataset = "veeralakrishna/butterfly-dataset"

        os.makedirs(out_dir, exist_ok=True)

        # 3. Download
        print("Downloading dataset...")
        api.dataset_download_files(dataset, path=out_dir, unzip=False, quiet=False)
    
    print('unzipping files...')
    # 4. Unzip
    zip_path = os.path.join(out_dir, "butterfly-dataset.zip")
    print("Extracting files from {}...".format(zip_path))
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(out_dir)
    print('Removing zip file...')
    os.remove(zip_path)
    print("Done. Files are in", out_dir)
    

class ButterflyDataset(Dataset):
    """
    Leeds Butterfly Dataset (v1.0)

    Expects:
        root/
            images/
                001xxx.png, 001yyy.png, 002xxx.png, ...
            # segmentations/  (not used here)
    Splits each class's images by the first 85% (train) and last 15% (test).
    """

    def __init__(self, root: str, split: str = "train", transform=None, image_size=256):
        """
        Args:
            root: path to 'leedsbutterfly' folder
            split: "train" or "test"
            transform: torchvision transforms to apply to the PIL image
        """
        assert split in ("train", "test"), "split must be 'train' or 'test'"
        self.root = root
        self.split = split
        self.transform = transform or transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])

        img_dir = os.path.join(root, "images")
        if not os.path.isdir(img_dir):
            raise FileNotFoundError(f"Expected images folder at {img_dir}")

        # 1) gather all filenames, group by class_id
        class_to_files = {}
        for fname in sorted(os.listdir(img_dir)):
            if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
                continue
            class_id = fname[:3]  # "001", "002", …, "010"
            class_to_files.setdefault(class_id, []).append(os.path.join(img_dir, fname))

        # 2) build a sorted list of class_ids and a mapping to integer labels
        self.class_ids = sorted(class_to_files.keys())
        self.class_to_idx = {cid: idx for idx, cid in enumerate(self.class_ids)}

        # 3) split each class’s file list into train / test
        self.samples = []
        self.labels = []
        for cid, files in class_to_files.items():
            files = sorted(files)
            n_total = len(files)
            n_train = int(n_total * 0.85)
            if split == "train":
                chosen = files[:n_train]
            else:  # test
                chosen = files[n_train:]
            label = self.class_to_idx[cid]
            # self.samples += [(path, label) for path in chosen]
            self.samples += [self.transform(Image.open(path).convert("RGB")) for path in chosen]
            self.labels += [label] * len(chosen)

        if not self.samples:
            raise RuntimeError(f"No samples found for split={split} in {img_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        
        return self.samples[idx], self.labels[idx]
        
        # path, label = self.samples[idx]
        # img = Image.open(path).convert("RGB")
        # img = self.transform(img)
        # return img, label