import os
from os.path import join
import random
import xml.etree.ElementTree as ET

import torchvision.transforms
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms

class Imagenet(Dataset):
    """
    Loads ImageNet val by parsing one XML per image:
      <annotation>
        <filename>ILSVRC2012_val_00000001.JPEG</filename>
        <object>
          <name>n01440764</name>
          ...
        </object>
      </annotation>
    """
    def __init__(self, dir, transform=None):
        self.images_dir = join(dir, 'Data', 'CLS-LOC', 'val')
        self.annotations_file = join(dir, 'Annotations', 'val_annotations.txt')
        self.xml_dir = join(dir, 'Annotations', 'CLS-LOC', 'val')
        self.transform = transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor()
            ])

        # ensure annotations_file exists
        if not os.path.exists(self.annotations_file):
            print(f"[Info] '{self.annotations_file}' not found – parsing XMLs and creating it...")
            samples = []
            for fname in os.listdir(self.xml_dir):
                print(fname)
                if not fname.endswith('.xml'):
                    continue
                xml_path = os.path.join(self.xml_dir, fname)
                tree = ET.parse(xml_path)
                root = tree.getroot()
                filename = root.findtext('filename')
                class_id = root.find('object').findtext('name')
                samples.append((filename, class_id))
            # write out the txt file
            with open(self.annotations_file, 'w') as out:
                for fn, cid in samples:
                    out.write(f"{fn} {cid}\n")
            print(f"[Info] Wrote {len(samples)} entries to '{self.annotations_file}'.")
        else:
            print(f"[Info] Loading samples from '{self.annotations_file}'.")

        # now load samples from text file
        samples = []
        with open(self.annotations_file, 'r') as f:
            for line in f:
                fn, cid = line.strip().split()
                samples.append((fn, cid))
        self.samples = samples

        # build class->idx map
        classes = sorted({cid for _, cid in samples})
        self.class_to_idx = {cid: i for i, cid in enumerate(classes)}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fn, cid = self.samples[idx]
        path = os.path.join(self.images_dir, fn) + '.JPEG'
        img = Image.open(path).convert('RGB')
        label = self.class_to_idx[cid]
        if self.transform:
            img = self.transform(img)
        return img, label



def get_imagenet_loaders(imagenet_path, batch_size):
    # === USER PARAMETERS ===
    images_dir       = "/path/to/imagenet/val"
    annotations_file = "/path/to/val_annotations.txt"
    seed             = 42         # for reproducibility
    num_workers      = 4

    # === LOAD FULL VAL DATASET ===
    val_dataset = Imagenet(imagenet_path)

    # === GROUP SAMPLES BY CLASS ===
    class_to_indices = {}
    for idx, (_, class_id) in enumerate(val_dataset.samples):
        class_to_indices.setdefault(class_id, []).append(idx)

    # Sanity check
    num_classes = len(class_to_indices)
    print(f"Found {num_classes} classes, "
          f"each with on average {len(val_dataset) // num_classes} images.")

    # === SAMPLE ONE IMAGE PER CLASS FOR TWO DISJOINT SUBSETS ===
    random.seed(seed)
    subset1_idxs = []
    subset2_idxs = []

    for class_id, idxs in class_to_indices.items():
        if len(idxs) < 2:
            raise ValueError(f"Class {class_id} has only {len(idxs)} images (<2)")
        chosen = random.sample(idxs, 2)
        subset1_idxs.append(chosen[0])
        subset2_idxs.append(chosen[1])

    # # Optionally shuffle within each subset
    # random.shuffle(subset1_idxs)
    # random.shuffle(subset2_idxs)

    # === CREATE PyTorch Subsets & LOADERS ===
    subset1 = Subset(val_dataset, subset1_idxs)
    subset2 = Subset(val_dataset, subset2_idxs)

    loader1 = DataLoader(subset1,
                         batch_size=batch_size,
                         shuffle=False)
    loader2 = DataLoader(subset2,
                         batch_size=batch_size,
                         shuffle=False)

    return loader1, loader2

    # # === DEMO ITERATION ===
    # print("Subset1:", len(subset1), "images")
    # print("Subset2:", len(subset2), "images")
    # for name, loader in [("A", loader1), ("B", loader2)]:
    #     imgs, labels = next(iter(loader))
    #     print(f"Subset {name} — batch shapes:", imgs.shape, labels.shape)


if __name__ == "__main__":
    imagenet_path = None
    loader1, loader2 = get_imagenet_loaders(imagenet_path)
