"""
Dataset, dataloader and sampler utilities

Core functionality for dataset handling, including subset sampling,
image folders with paths, and feature extraction utilities.
"""

import os
import torch
import collections
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, Sampler, random_split
import torchvision.datasets as datasets

class SubsetSampler(Sampler):
    """Samples elements from a specified list of indices

    Args:
        indices: Indices to sample from
    """
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (i for i in self.indices)

    def __len__(self):
        return len(self.indices)


class ImageFolderWithPaths(datasets.ImageFolder):
    """Extension of ImageFolder that returns image paths

    Args:
        path: Root directory path
        transform: Image transform function
        flip_label_prob: Probability of random label flipping (for robust training)
    """
    def __init__(self, path, transform, flip_label_prob=0.0):
        super().__init__(path, transform)
        self.flip_label_prob = flip_label_prob

        # Apply label flipping if requested
        if self.flip_label_prob > 0:
            print(f'Using label flipping with probability {self.flip_label_prob}')
            self._apply_label_flipping()

    def _apply_label_flipping(self):
        """Apply random label flipping for robust training"""
        import random
        num_classes = len(self.classes)
        for i in range(len(self.samples)):
            if random.random() < self.flip_label_prob:
                new_label = random.randint(0, num_classes-1)
                self.samples[i] = (
                    self.samples[i][0],
                    new_label
                )

    def __getitem__(self, index):
        """Get item with image path included"""
        image, label = super(ImageFolderWithPaths, self).__getitem__(index)
        return {
            'images': image,
            'labels': label,
            'image_paths': self.samples[index][0]
        }


def maybe_dictionarize(batch):
    """Convert various batch formats to a standard dictionary

    Args:
        batch: Input batch that may be a tuple or dict

    Returns:
        Dictionary with standardized keys
    """
    if isinstance(batch, dict):
        return batch

    if len(batch) == 2:
        batch = {'images': batch[0], 'labels': batch[1]}
    elif len(batch) == 3:
        if isinstance(batch[0], (list, tuple)) and len(batch[0]) == 2:  # Two transforms
            batch = {'images': batch[0][0], 'images_': batch[0][1],
                     'labels': batch[1], 'index': batch[2]}
        else:
            batch = {'images': batch[0], 'labels': batch[1], 'index': batch[2]}
    else:
        raise ValueError(f'Unexpected batch format with {len(batch)} elements')

    return batch


def get_features_helper(image_encoder, dataloader, device):
    """Extract features from images using an encoder

    Args:
        image_encoder: Model for encoding images to features
        dataloader: Data loader with images
        device: Device to use for computation

    Returns:
        Dictionary of extracted features and metadata
    """
    all_data = collections.defaultdict(list)

    # Setup encoder
    image_encoder = image_encoder.to(device)
    if torch.cuda.device_count() > 1:
        image_encoder = torch.nn.DataParallel(
            image_encoder,
            device_ids=list(range(torch.cuda.device_count()))
        )
    image_encoder.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting features"):
            batch = maybe_dictionarize(batch)
            features = image_encoder(batch['images'].to(device))

            all_data['features'].append(features.cpu())

            # Collect all metadata
            for key, val in batch.items():
                if key == 'images':
                    continue
                if hasattr(val, 'cpu'):
                    val = val.cpu()
                    all_data[key].append(val)
                else:
                    all_data[key].extend(val)

    # Concatenate tensors
    for key, val in all_data.items():
        if len(val) > 0 and torch.is_tensor(val[0]):
            all_data[key] = torch.cat(val).numpy()

    return all_data


def get_features(is_train, image_encoder, dataset, device):
    """Get features for a dataset, using cache if available

    Args:
        is_train: Whether to use training or validation set
        image_encoder: Model for encoding images
        dataset: Dataset to extract features from
        device: Device to use for computation

    Returns:
        Dictionary of features and metadata
    """
    split = 'train' if is_train else 'val'
    dname = type(dataset).__name__

    # Use cache if available
    if hasattr(image_encoder, 'cache_dir') and image_encoder.cache_dir is not None:
        cache_dir = os.path.join(image_encoder.cache_dir, dname, split)
        if os.path.exists(cache_dir):
            cached_files = [f for f in os.listdir(cache_dir) if f.endswith('.pt')]
            if cached_files:
                print(f'Loading features from cache: {cache_dir}')
                data = {}
                for cached_file in cached_files:
                    name = os.path.splitext(cached_file)[0]
                    data[name] = torch.load(os.path.join(cache_dir, cached_file))
                return data

    # Extract features if not cached
    print(f'Extracting {split} features for {dname}')
    loader = dataset.train_loader if is_train else dataset.test_loader
    data = get_features_helper(image_encoder, loader, device)

    # Cache features if requested
    if hasattr(image_encoder, 'cache_dir') and image_encoder.cache_dir is not None:
        cache_dir = os.path.join(image_encoder.cache_dir, dname, split)
        os.makedirs(cache_dir, exist_ok=True)
        print(f'Caching features to {cache_dir}')
        for name, val in data.items():
            torch.save(val, os.path.join(cache_dir, f'{name}.pt'))

    return data


class FeatureDataset(Dataset):
    """Dataset for pre-computed features

    Args:
        is_train: Whether to use training or validation data
        image_encoder: Model used for feature extraction
        dataset: Original dataset to extract features from
        device: Device to use for computation
    """
    def __init__(self, is_train, image_encoder, dataset, device):
        self.data = get_features(is_train, image_encoder, dataset, device)

    def __len__(self):
        return len(self.data['features'])

    def __getitem__(self, idx):
        data = {k: v[idx] for k, v in self.data.items()}
        if 'features' in data:
            data['features'] = torch.from_numpy(data['features']).float()
        return data


def get_dataloader(dataset, is_train, args, image_encoder=None):
    """Get appropriate dataloader based on settings

    Args:
        dataset: Dataset to create loader for
        is_train: Whether to use training or validation data
        args: Arguments with dataloader settings
        image_encoder: Optional encoder for feature extraction

    Returns:
        Data loader for the requested dataset configuration
    """
    # Use feature dataset if encoder provided
    if image_encoder is not None:
        feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device)
        dataloader = DataLoader(
            feature_dataset,
            batch_size=args.batch_size,
            shuffle=is_train,
            num_workers=args.num_workers,
            pin_memory=True
        )
    else:
        # Use existing loaders
        dataloader = dataset.train_loader if is_train else dataset.test_loader

    # Subsample training data if requested
    if hasattr(args, 'subsample') and isinstance(args.subsample, float) and is_train:
        src = dataloader.dataset
        subsample_size = int(len(src) * args.subsample)

        if subsample_size < len(src):
            lengths = [subsample_size, len(src) - subsample_size]
            print(f"Subsampling from {len(src)} to {subsample_size} samples")

            # Create new dataset and loader
            new_dataset, _ = random_split(
                src,
                lengths,
                generator=torch.Generator().manual_seed(args.seed or 42)
            )

            dataloader = DataLoader(
                new_dataset,
                batch_size=dataloader.batch_size,
                shuffle=True,
                num_workers=args.num_workers,
                pin_memory=True
            )

    return dataloader