import os
from PIL import Image
from torch.utils.data import Dataset


class Tiny_ImageNetDataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.root = root
        self.split = split
        self.transform = transform
        self.class_to_idx, self.idx_to_class = self._load_classes()

        if split == 'train':
            self.data, self.targets = self._load_train()
        elif split == 'val':
            self.data, self.targets = self._load_val()
        else:
            raise ValueError(f"Unknown split")

        self.classes = self.get_class_names()

    def _load_classes(self):
        wnid_file = os.path.join(self.root, 'wnids.txt')
        with open(wnid_file, 'r') as f:
            wnids = [x.strip() for x in f.readlines()]
        class_to_idx = {wnid: idx for idx, wnid in enumerate(wnids)}
        idx_to_class = {v: k for k, v in class_to_idx.items()}
        return class_to_idx, idx_to_class

    def _load_train(self):
        data, targets = [], []
        train_dir = os.path.join(self.root, 'train')
        for wnid in os.listdir(train_dir):
            class_dir = os.path.join(train_dir, wnid, 'images')
            if not os.path.isdir(class_dir):
                continue
            label = self.class_to_idx[wnid]
            for fname in os.listdir(class_dir):
                if fname.endswith('.JPEG'):
                    data.append(os.path.join(class_dir, fname))
                    targets.append(label)
        return data, targets

    def _load_val(self):
        val_dir = os.path.join(self.root, 'val')
        img_dir = os.path.join(val_dir, 'images')
        annot_file = os.path.join(val_dir, 'val_annotations.txt')
        data, targets = [], []
        with open(annot_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    fname, wnid = parts[0], parts[1]
                    if wnid in self.class_to_idx:
                        data.append(os.path.join(img_dir, fname))
                        targets.append(self.class_to_idx[wnid])
        return data, targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = self.data[index]
        label = self.targets[index]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

    def get_class_names(self):
        word_file = os.path.join(self.root, 'words.txt')
        wnid_to_name = {}
        with open(word_file, 'r') as f:
            for line in f:
                wnid, desc = line.strip().split('\t')
                wnid_to_name[wnid] = desc
        class_names = [wnid_to_name[self.idx_to_class[i]] for i in range(len(self.idx_to_class))]
        return class_names
