import os
import random
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms
from PIL import Image


class RescaledTinyImageNet100:
    def __init__(self,
                 preprocess,
                 location=os.path.expanduser('./tiny-imagenet-200'),
                 batch_size=128,
                 num_workers=16,
                 random_seed=42):
        self.random_seed = random_seed

        self.location = './data/tiny-imagenet-100-rescaled/'

        self.transform = preprocess

        # Load the words.txt file
        self.classname_mapping = self.load_classname_mapping()

        self.train_dataset, self.train_classnames = self.create_dataset('train')
        self.test_dataset, self.test_classnames = self.create_val_dataset()

        # Create data loaders
        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True,
                                       num_workers=num_workers, pin_memory=True)
        self.test_loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False,
                                      num_workers=num_workers, pin_memory=True)
        self.test_loader_shuffle = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=True,
                                              num_workers=num_workers)

        # Create fast validation dataset and loader (10% of validation data)
        self.fast_test_dataset, self.fast_test_loader = self.get_fast_val_loader(batch_size, num_workers)

        # Get class names for easy access
        self.classnames = self.test_classnames

    def load_classname_mapping(self):
        """Load the mapping from class IDs to human-readable names from words.txt."""
        words_file = os.path.join(self.location, 'words.txt')
        classname_mapping = {}
        with open(words_file, 'r') as f:
            for line in f:
                class_id, class_name = line.strip().split('\t')
                classname_mapping[class_id] = class_name
        return classname_mapping

    def create_dataset(self, split):
        """Create a dataset of the first 100 classes for 'train'."""
        dataset_dir = os.path.join(self.location, split)
        dataset = datasets.ImageFolder(
            root=dataset_dir,
            transform=self.transform,
            is_valid_file=lambda x: x.endswith(".jpeg")
        )

        class_to_idx = dataset.class_to_idx
        first_100_classes = sorted(class_to_idx.keys())[:100]  # Sort to ensure consistency
        first_100_classes_idx = set(class_to_idx[cls] for cls in first_100_classes)

        # Filter indices for the first 100 classes
        indices = [i for i, (_, label) in enumerate(dataset.samples) if label in first_100_classes_idx]

        # Map folder names to class names using words.txt
        filtered_classnames = [self.classname_mapping[cls] for cls in first_100_classes]

        # Return filtered subset and class names
        return Subset(dataset, indices), filtered_classnames

    def create_val_dataset(self):
        """Create a validation dataset using val_annotations.txt and first 100 classes."""
        val_dir = os.path.join(self.location, 'val')
        annotations_path = os.path.join(val_dir, 'val_annotations.txt')

        # Read the annotations file
        val_annotations = np.loadtxt(annotations_path, dtype=str)
        images_folder = os.path.join(val_dir, 'images')

        # Load the class IDs from the annotations
        class_ids = sorted(set(line[1] for line in val_annotations))
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_ids)}
        first_100_classes = class_ids[:100]  # Only take first 100 classes

        # Prepare the validation dataset for the first 100 classes
        filtered_annotations = [line for line in val_annotations if line[1] in first_100_classes]

        # Limit the validation dataset to 5K samples
        filtered_annotations = filtered_annotations[:5000]

        # Create a custom Dataset class for the validation set
        val_dataset = CustomValDataset(filtered_annotations, images_folder, class_to_idx, self.transform)

        # Map class IDs to class names using words.txt
        filtered_classnames = [self.classname_mapping[cls] for cls in first_100_classes]
        return val_dataset, filtered_classnames

    def get_fast_val_loader(self, batch_size, num_workers):
        """Create a DataLoader for fast evaluation with 10% randomly selected validation samples."""
        random.seed(self.random_seed)
        val_size = len(self.test_dataset)
        fast_val_size = int(val_size * 0.1)

        # Randomly select indices for the subset
        indices = random.sample(range(val_size), fast_val_size)

        # Create a Subset for the fast validation dataset
        fast_val_dataset = Subset(self.test_dataset, indices)

        # Return a DataLoader for the fast validation dataset
        return fast_val_dataset, DataLoader(fast_val_dataset, batch_size=batch_size, shuffle=True,
                                            num_workers=num_workers)

class CustomValDataset(Dataset):
    """Custom Dataset class to load validation data based on val_annotations.txt."""

    def __init__(self, annotations, images_folder, class_to_idx, transform):
        self.annotations = annotations
        self.images_folder = images_folder
        self.class_to_idx = class_to_idx
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        image_name, class_code = self.annotations[idx]
        image_path = os.path.join(self.images_folder, image_name)
        image = Image.open(image_path).convert('RGB')
        label = self.class_to_idx[class_code]

        if self.transform:
            image = self.transform(image)

        return image, label