"""
Custom Tiny-ImageNet dataset loader with consistent class mapping
"""

import os
from PIL import Image
import torch
from torch.utils.data import Dataset


class TinyImageNet(Dataset):
    """
    Custom Tiny-ImageNet loader that ensures consistent class-to-index mapping
    between train and validation sets.
    """

    def __init__(self, root, train=True, transform=None):
        """
        Args:
            root: Path to tiny-imagenet-200 directory
            train: If True, load training set, else validation set
            transform: Transform to apply to images
        """
        self.root = root
        self.train = train
        self.transform = transform

        # Load class names from training directory (consistent ordering)
        train_dir = os.path.join(root, 'train')
        print(f"🔍 DEBUG: Looking for train dir at: {train_dir}")
        print(f"🔍 DEBUG: Root dir exists: {os.path.exists(root)}")
        print(f"🔍 DEBUG: Train dir exists: {os.path.exists(train_dir)}")

        if os.path.exists(train_dir):
            self.classes = sorted(os.listdir(train_dir))
            print(f"🔍 DEBUG: Found {len(self.classes)} classes: {self.classes[:5]}...")
        else:
            print(f"❌ DEBUG: Train directory not found at {train_dir}")
            print(f"🔍 DEBUG: Contents of root ({root}): {os.listdir(root) if os.path.exists(root) else 'Root not found'}")
            self.classes = []

        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        # Load data
        self.data = []
        self.targets = []

        if train:
            self._load_train_data()
        else:
            self._load_val_data()

        print(f"🔍 DEBUG: TinyImageNet loaded - Mode: {'train' if train else 'val'}, Samples: {len(self.data)}")

    def _load_train_data(self):
        """Load training data"""
        train_dir = os.path.join(self.root, 'train')
        total_samples = 0

        for class_name in self.classes:
            class_dir = os.path.join(train_dir, class_name)
            class_idx = self.class_to_idx[class_name]

            if os.path.exists(class_dir):
                # Debug: Check actual directory structure for first few classes
                if class_idx < 3:
                    print(f"🔍 DEBUG: Class {class_name} directory contents:")
                    try:
                        contents = os.listdir(class_dir)
                        print(f"  Contents: {contents[:5]}")
                        if contents:
                            print(f"  Sample file: {contents[0]}")
                    except Exception as e:
                        print(f"  Error listing directory: {e}")

                class_samples = 0

                # TinyImageNet structure: each class has an 'images' subdirectory
                images_dir = os.path.join(class_dir, 'images')
                if os.path.exists(images_dir):
                    # Use images subdirectory
                    for img_name in os.listdir(images_dir):
                        # TinyImageNet uses .JPEG extension (uppercase)
                        if img_name.lower().endswith(('.jpeg', '.jpg', '.png', '.bmp', '.tiff')) or img_name.endswith('.JPEG'):
                            img_path = os.path.join(images_dir, img_name)
                            self.data.append(img_path)
                            self.targets.append(class_idx)
                            class_samples += 1
                else:
                    # Fallback: look directly in class directory
                    for img_name in os.listdir(class_dir):
                        # TinyImageNet uses .JPEG extension (uppercase)
                        if img_name.lower().endswith(('.jpeg', '.jpg', '.png', '.bmp', '.tiff')) or img_name.endswith('.JPEG'):
                            img_path = os.path.join(class_dir, img_name)
                            self.data.append(img_path)
                            self.targets.append(class_idx)
                            class_samples += 1

                if class_idx < 3:
                    print(f"  Found {class_samples} images in class {class_name}")

                total_samples += class_samples

        print(f"🔍 DEBUG: Total training samples loaded: {total_samples}")

    def _load_val_data(self):
        """Load validation data - TinyImageNet val has images/ folder + val_annotations.txt"""
        val_dir = os.path.join(self.root, 'val')
        images_dir = os.path.join(val_dir, 'images')
        annotations_file = os.path.join(val_dir, 'val_annotations.txt')

        print(f"🔍 DEBUG: Val images dir: {images_dir} (exists: {os.path.exists(images_dir)})")
        print(f"🔍 DEBUG: Val annotations: {annotations_file} (exists: {os.path.exists(annotations_file)})")

        if not os.path.exists(images_dir) or not os.path.exists(annotations_file):
            print(f"❌ Validation data not found - missing images dir or annotations file")
            return

        # Read annotations file to get image -> class mapping
        img_to_class = {}
        with open(annotations_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    img_name = parts[0]
                    class_name = parts[1]
                    img_to_class[img_name] = class_name

        print(f"🔍 DEBUG: Found {len(img_to_class)} annotations")

        # Load images
        loaded_classes = set()
        for img_name in os.listdir(images_dir):
            if img_name.lower().endswith(('.jpeg', '.jpg', '.png', '.bmp', '.tiff')) or img_name.endswith('.JPEG'):
                if img_name in img_to_class:
                    class_name = img_to_class[img_name]
                    if class_name in self.class_to_idx:
                        class_idx = self.class_to_idx[class_name]
                        img_path = os.path.join(images_dir, img_name)
                        self.data.append(img_path)
                        self.targets.append(class_idx)
                        loaded_classes.add(class_idx)

        if len(self.data) > 0:
            print(f"TinyImageNet val loader: loaded {len(self.data)} images from {len(loaded_classes)} classes")
            print(f"Label range: {min(self.targets)} to {max(self.targets)}")
            if loaded_classes:
                print(f"First 10 class indices with data: {sorted(list(loaded_classes))[:10]}")
        else:
            print(f"❌ No validation images loaded!")
            print(f"Sample annotations: {list(img_to_class.items())[:3]}")
            print(f"Sample class names: {list(self.class_to_idx.keys())[:3]}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        target = self.targets[idx]

        # Load image
        with open(img_path, 'rb') as f:
            img = Image.open(f).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, target

    def __repr__(self):
        return f"TinyImageNet(root={self.root}, train={self.train}, samples={len(self.data)})"