"""
Dataset registry and utilities

This module provides a centralized registry for all datasets, along with
utilities to create and manipulate dataset instances.
"""

import os
import sys
import inspect
import torch
import copy
import itertools

from torch.utils.data.dataset import Subset, random_split

# Import dataset classes
from src.datasets.cars import Cars
from src.datasets.dtd import DTD
from src.datasets.eurosat import EuroSAT, EuroSATVal
from src.datasets.gtsrb import GTSRB
from src.datasets.mnist import MNIST
from src.datasets.resisc45 import RESISC45
from src.datasets.svhn import SVHN
from src.datasets.sun397 import SUN397

# Build registry from imported classes
registry = {
    name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)
}


class GenericDataset(object):
    """Generic dataset container with standard interface"""
    def __init__(self):
        self.train_dataset = None
        self.train_loader = None
        self.test_dataset = None
        self.test_loader = None
        self.classnames = None


def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers,
                              val_fraction, max_val_samples=None, seed=0):
    """Create a new dataset with train split into train and validation

    Args:
        dataset: Source dataset to split
        new_dataset_class_name: Name for the new dataset class
        batch_size: Batch size for data loaders
        num_workers: Number of worker threads for data loaders
        val_fraction: Fraction of training data to use for validation
        max_val_samples: Maximum number of validation samples
        seed: Random seed for reproducibility

    Returns:
        New dataset with train/val split
    """
    # Validate inputs
    assert val_fraction > 0. and val_fraction < 1., "Validation fraction must be between 0 and 1"

    # Calculate sizes
    total_size = len(dataset.train_dataset)
    val_size = int(total_size * val_fraction)
    if max_val_samples is not None:
        val_size = min(val_size, max_val_samples)
    train_size = total_size - val_size

    assert val_size > 0, "Validation set is empty"
    assert train_size > 0, "Training set is empty"

    # Create random split
    lengths = [train_size, val_size]
    trainset, valset = random_split(
        dataset.train_dataset,
        lengths,
        generator=torch.Generator().manual_seed(seed)
    )

    # Create new dataset class
    new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {})
    new_dataset = new_dataset_class()

    # Setup training set
    new_dataset.train_dataset = trainset
    new_dataset.train_loader = torch.utils.data.DataLoader(
        new_dataset.train_dataset,
        shuffle=True,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True
    )

    # Setup validation set
    new_dataset.test_dataset = valset
    new_dataset.test_loader = torch.utils.data.DataLoader(
        new_dataset.test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True
    )

    # Copy classnames
    new_dataset.classnames = copy.copy(dataset.classnames)

    return new_dataset


def extract_class_data(dataset, cls_idx, batch_size, num_workers):
    """Create a subset dataset containing only specified classes

    Args:
        dataset: Source dataset
        cls_idx: List of class indices to extract
        batch_size: Batch size for data loaders
        num_workers: Number of worker threads for data loaders

    Returns:
        Subset dataset with only the specified classes
    """
    # Return original dataset if no filtering needed
    if cls_idx is None or len(cls_idx) == 0:
        return dataset

    # Validate class indices
    if max(cls_idx) >= len(dataset.classnames):
        raise ValueError(f"Class index {max(cls_idx)} exceeds the number of classes ({len(dataset.classnames)})")

    # Get selected classnames
    classnames = [dataset.classnames[i] for i in cls_idx]

    # Create subset container
    subset = GenericDataset()
    subset.classnames = classnames

    # Get indices for selected classes
    if hasattr(dataset, 'class_splits'):
        # Use pre-computed class splits if available
        train_split = dataset.class_splits['train']
        train_subset = list(itertools.chain.from_iterable(
            [train_split[str(i)] for i in cls_idx]
        ))

        test_split = dataset.class_splits['test']
        test_subset = list(itertools.chain.from_iterable(
            [test_split[str(i)] for i in cls_idx]
        ))
    else:
        # Otherwise warn about slow implementation
        print("Warning: Using slow class extraction method. Consider implementing class_splits.")
        train_subset = []
        test_subset = []

        # Find indices manually (slow)
        for i, (_, target) in enumerate(dataset.train_dataset):
            if target in cls_idx:
                train_subset.append(i)

        for i, (_, target) in enumerate(dataset.test_dataset):
            if target in cls_idx:
                test_subset.append(i)

    # Create subset datasets and loaders
    subset.train_dataset = Subset(dataset.train_dataset, train_subset)
    subset.train_loader = torch.utils.data.DataLoader(
        subset.train_dataset,
        shuffle=True,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True
    )

    subset.test_dataset = Subset(dataset.test_dataset, test_subset)
    subset.test_loader = torch.utils.data.DataLoader(
        subset.test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True
    )

    return subset


def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16,
               val_fraction=0.1, max_val_samples=5000):
    """Get dataset by name with appropriate configuration

    Args:
        dataset_name: Name of the dataset to load
        preprocess: Preprocessing function for images
        location: Root directory for datasets
        batch_size: Batch size for data loaders
        num_workers: Number of worker threads for data loading
        val_fraction: Fraction of training data to use for validation (for Val datasets)
        max_val_samples: Maximum number of validation samples

    Returns:
        Configured dataset instance
    """
    # Special handling for dataset variations
    cls_idx = None

    # Handle precomputed features
    if dataset_name.startswith('precomputed_'):
        return _get_precomputed_dataset(dataset_name, location, batch_size, num_workers)

    # Extract class indices if specified in name (format: "1,2,3_DatasetName")
    if '_' in dataset_name:
        cls_part, dataset_part = dataset_name.split('_', 1)
        try:
            cls_idx = [int(i) for i in cls_part.split(',')]
            dataset_name = dataset_part
        except ValueError:
            # Not a class specification, restore original name
            pass

    # Handle validation datasets (with Val suffix)
    if dataset_name.endswith('Val'):
        if dataset_name in registry:
            # Use predefined validation dataset if available
            dataset_class = registry[dataset_name]
            dataset = dataset_class(
                preprocess, location=location, batch_size=batch_size, num_workers=num_workers
            )
        else:
            # Create validation split from base dataset
            base_dataset_name = dataset_name.split('Val')[0]
            base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers)

            # Apply class filtering if needed
            if cls_idx:
                base_dataset = extract_class_data(base_dataset, cls_idx, batch_size, num_workers)
                cls_idx = None  # Clear cls_idx to avoid double filtering

            # Create train/val split
            dataset = split_train_into_train_val(
                base_dataset, dataset_name, batch_size, num_workers,
                val_fraction, max_val_samples
            )
            return dataset
    else:
        # Standard dataset
        if dataset_name not in registry:
            raise ValueError(f"Unknown dataset: {dataset_name}. Available datasets: {list(registry.keys())}")

        dataset_class = registry[dataset_name]
        dataset = dataset_class(
            preprocess, location=location, batch_size=batch_size, num_workers=num_workers
        )

    # Apply class filtering if needed
    if cls_idx:
        dataset = extract_class_data(dataset, cls_idx, batch_size, num_workers)

    return dataset


def _get_precomputed_dataset(dataset_name, location, batch_size=128, num_workers=8):
    """Get a dataset with precomputed features

    Args:
        dataset_name: Name of the dataset (with 'precomputed_' prefix)
        location: Root data directory
        batch_size: Batch size for dataloaders
        num_workers: Number of worker threads

    Returns:
        Dataset with precomputed features
    """
    try:
        # Import at runtime to avoid circular imports
        from src.datasets.precomputed_features import PrecomputedFeatures

        # Extract the actual dataset name
        base_dataset_name = dataset_name[len("precomputed_"):]

        # Build standard feature directory path
        feature_dir = os.path.join(location, "precomputed_features", base_dataset_name)

        # Check if directory exists
        if not os.path.exists(feature_dir):
            # Try common alternatives
            alternatives = [
                os.path.join(location, base_dataset_name),
                os.path.join(location, "features", base_dataset_name),
                os.path.join(location, "precomputed", base_dataset_name)
            ]

            for alt_dir in alternatives:
                if os.path.exists(alt_dir):
                    feature_dir = alt_dir
                    break
            else:
                raise FileNotFoundError(f"Could not find precomputed features for {base_dataset_name}")

        # Create and return dataset
        return PrecomputedFeatures(
            feature_dir=feature_dir,
            batch_size=batch_size,
            num_workers=num_workers
        )
    except ImportError as e:
        raise ImportError(f"Could not import PrecomputedFeatures class: {e}")