from collections import defaultdict
import logging
import os
import pickle
import time
import numpy as np
import torch
from sklearn.metrics import pairwise_distances
import torchvision.models as models
import torch.nn as nn

logger = logging.getLogger(__name__)
# Base path for precomputed indices (.pkl)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PRECOMPUTED_DIR = os.path.join(BASE_DIR, 'precomputed_indices')

class UNetLike(nn.Module):
    """
    U-Net like architecture for image-to-image transformation.
    Features encoder-decoder structure with skip connections.
    """
    def __init__(self, in_channels=3, out_channels=3, base_channels=64):
        """Initialize the UNet-like module.

        Args:
            in_channels: Number of input channels.
            out_channels: Number of output channels.
            base_channels: Base number of channels used in the first stage.
        """
        super(UNetLike, self).__init__()
        
        # Encoder (downsampling path)
        self.enc1 = self._conv_block(in_channels, base_channels)
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = self._conv_block(base_channels, base_channels * 2)
        self.pool2 = nn.MaxPool2d(2)
        
        self.enc3 = self._conv_block(base_channels * 2, base_channels * 4)
        self.pool3 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = self._conv_block(base_channels * 4, base_channels * 8)
        
        # Decoder (upsampling path)
        self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, 2, stride=2)
        self.dec3 = self._conv_block(base_channels * 8, base_channels * 4)
        
        self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 2, stride=2)
        self.dec2 = self._conv_block(base_channels * 4, base_channels * 2)
        
        self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, 2, stride=2)
        self.dec1 = self._conv_block(base_channels * 2, base_channels)
        
        # Output layer
        self.out_conv = nn.Conv2d(base_channels, out_channels, 1)
        
    def _conv_block(self, in_channels, out_channels):
        """Double convolution block used in U-Net"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        """Forward pass.

        Args:
            x: Input tensor of shape (B, C, H, W).

        Returns:
            torch.Tensor: Output tensor with shape (B, out_channels, H, W).
        """
        # Encoder path
        enc1_out = self.enc1(x)
        enc1_pool = self.pool1(enc1_out)
        
        enc2_out = self.enc2(enc1_pool)
        enc2_pool = self.pool2(enc2_out)
        
        enc3_out = self.enc3(enc2_pool)
        enc3_pool = self.pool3(enc3_out)
        
        # Bottleneck
        bottleneck_out = self.bottleneck(enc3_pool)
        
        # Decoder path with skip connections
        dec3_up = self.upconv3(bottleneck_out)
        # Concatenate skip connection from encoder
        dec3_concat = torch.cat([dec3_up, enc3_out], dim=1)
        dec3_out = self.dec3(dec3_concat)
        
        dec2_up = self.upconv2(dec3_out)
        dec2_concat = torch.cat([dec2_up, enc2_out], dim=1)
        dec2_out = self.dec2(dec2_concat)
        
        dec1_up = self.upconv1(dec2_out)
        dec1_concat = torch.cat([dec1_up, enc1_out], dim=1)
        dec1_out = self.dec1(dec1_concat)
        
        # Output
        out = self.out_conv(dec1_out)
        
        return out
    
class SyntheticImage(torch.nn.Module):
    """
    Module to create and manage synthetic images for dataset distillation.
    Supports multiple generation modes (direct, conv, residual variants, unet),
    optional batching, and different initialization strategies (random, mean,
    sample, coreset, herding, k-medoid).
    """
    
    def __init__(self, dataset, num_synthetic=100, device='cuda', initialization='sample', 
                 synth_mode='direct', feature_extractor=None, synth_batch_size=None):
        """
        Initialize the SyntheticImage module.
        
        Args:
            dataset: The dataset to sample from
            num_synthetic: Number of synthetic images to generate
            device: Device to store tensors on
            initialization: How to initialize synthetic images ('random', 'mean', 'sample', or 'coreset')
            synth_mode: Mode of synthetic image generation
                'direct': Directly optimize the synthetic images
                'conv': Use a convolutional network to generate synthetic images
                'residual': Use original images + convolutional network for residual
                'residual_y': Keep original images fixed, only train residual (previously residual_fixed)
                'residual_pure': Only train the conv weights, keep synthetic images fixed
            feature_extractor: Pre-trained model for extracting features (used for coreset selection)
                If None and initialization='coreset', a ResNet18 will be used by default
        """
        super(SyntheticImage, self).__init__()
        self.dataset = dataset
        self.num_synthetic = num_synthetic
        self.device = device
        self.initialization = initialization
        
        # Parse the synth_mode parameter
        self.synth_mode = synth_mode
        self.use_conv = synth_mode in ['conv', 'residual', 'residual_y', 'residual_pure']
        self.use_residual = synth_mode in ['residual', 'residual_y', 'residual_pure']
        self.fixed_base = synth_mode in ['residual_y', 'residual_pure']
        self.train_synthetic = synth_mode != 'residual_pure'

        # Save synth_batch_size if specified
        self.synth_batch_size = synth_batch_size
        
        # Initialize feature extractor for coreset selection
        if initialization == 'coreset' or initialization == 'herding' or initialization == 'k-medoid':
            if feature_extractor is None:
                # Use ResNet18 as default feature extractor
                self.feature_extractor = models.resnet18(pretrained=True)
                self.feature_extractor.fc = nn.Identity()  # Remove final classification layer
                self.feature_extractor.to(device)
                self.feature_extractor.eval()
            else:
                self.feature_extractor = feature_extractor
        
        self.synthetic_images = None
        self.synthetic_labels = None
        self.synthetic_indices = None
        
        real_image, real_label = self.dataset[0]
        self.num_classes = len(self.dataset.class_names)
        # self.dataset_images = self.dataset.data
        self.dataset_labels = self.dataset.targets
        
        # get all indices of images divided per class
        self.class_indices = defaultdict(list)
        for idx, target in enumerate(self.dataset_labels):
            self.class_indices[target].append(idx)
        
        # Calculate images per class for even distribution
        images_per_class = self.num_synthetic // self.num_classes
        remainder = self.num_synthetic % self.num_classes
        
        # Initialize synthetic images
        if self.initialization == 'random':
            (real_image, real_label) = self.dataset[0]
            synthetic_images = []
            synthetic_labels = []
            
            for class_id in range(self.num_classes):
                # Calculate number of images for this class
                n_images = images_per_class + (1 if class_id < remainder else 0)
                
                if n_images > 0:
                    # Generate random images for this class
                    class_images = torch.randn(
                        n_images,
                        real_image.shape[0],
                        real_image.shape[1],
                        real_image.shape[2],
                        device=self.device
                    )
                    synthetic_images.append(class_images)
                    synthetic_labels.extend([class_id] * n_images)
            
            self.synthetic_images = torch.cat(synthetic_images, dim=0).to(self.device)
            self.synthetic_images.requires_grad = self.train_synthetic
            self.synthetic_labels = torch.tensor(synthetic_labels, device=self.device)
            
        elif self.initialization == 'mean':
            # Initialize synthetic images with the mean image of each class
            synthetic_images = []
            synthetic_labels = []
            
            # Compute mean images per class using class_indices
            for class_id in range(self.num_classes):
                if class_id in self.class_indices and len(self.class_indices[class_id]) > 0:
                    # Calculate number of images for this class
                    n_images = images_per_class + (1 if class_id < remainder else 0)
                    
                    if n_images > 0:
                        # Compute mean image for this class - use more samples
                        # and make sure to properly handle channel dimensions
                        indices = self.class_indices[class_id][:min(len(self.class_indices[class_id]), 100)]  # Use up to 100 images
                        
                        # Stack all images and compute mean properly
                        class_images = torch.stack([self.dataset[idx][0] for idx in indices])
                        mean_image = class_images.mean(dim=0)  # Compute mean along batch dimension
                        
                        # Create n_images copies of the mean image
                        repeated_means = mean_image.unsqueeze(0).repeat(n_images, 1, 1, 1)
                        synthetic_images.append(repeated_means)
                        synthetic_labels.extend([class_id] * n_images)
            
            self.synthetic_images = torch.cat(synthetic_images, dim=0).to(self.device)
            self.synthetic_images.requires_grad = self.train_synthetic
            self.synthetic_labels = torch.tensor(synthetic_labels, device=self.device)
           
        elif self.initialization == 'sample':
            synthetic_images = []
            synthetic_labels = []
            synthetic_indices = {}
            
            # Sample images for each class
            for class_id in range(self.num_classes):
                # Calculate number of images for this class
                n_images = images_per_class + (1 if class_id < remainder else 0)
                
                if n_images > 0 and class_id in self.class_indices:
                    # Use the get_n_images_per_class method to sample images
                    indices, class_images = self.get_n_images_per_class(n_images, class_id)
                    synthetic_images.append(class_images)
                    synthetic_labels.extend([class_id] * n_images)
                    synthetic_indices[class_id] = indices

            self.synthetic_images = torch.cat(synthetic_images, dim=0).to(self.device)
            self.synthetic_images.requires_grad = self.train_synthetic
            self.synthetic_labels = torch.tensor(synthetic_labels, device=self.device)
            self.synthetic_indices = synthetic_indices
            
        elif self.initialization == 'coreset':
            synthetic_images = []
            synthetic_labels = []
            synthetic_indices = {}
            
            # Try to load existing coreset indices from pickle file
            coreset_filename = f'{dataset.name}_coreset_indices_{images_per_class}.pkl'
            coreset_file = os.path.join(PRECOMPUTED_DIR, coreset_filename)
            if os.path.exists(coreset_file):
                logger.info(f"Loading existing coreset indices from {coreset_file}")
                with open(coreset_file, 'rb') as f:
                    synthetic_indices = pickle.load(f)
                
                # Use loaded indices to get images
                for class_id in range(self.num_classes):
                    n_images = images_per_class + (1 if class_id < remainder else 0)
                    
                    if n_images > 0 and class_id in synthetic_indices:
                        indices = synthetic_indices[class_id][:n_images]  # Take only needed images
                        class_images = torch.stack([self.dataset[i][0] for i in indices])
                        synthetic_images.append(class_images)
                        synthetic_labels.extend([class_id] * len(indices))
            else:
                logger.info(f"No existing coreset found. Computing new coreset and saving to {coreset_file}")
                # Use coreset selection for each class
                for class_id in range(self.num_classes):
                    # Calculate number of images for this class
                    n_images = images_per_class + (1 if class_id < remainder else 0)
                    
                    if n_images > 0 and class_id in self.class_indices:
                        # Use the coreset selection method
                        indices, class_images = self.get_coreset_images_per_class(n_images, class_id)
                        synthetic_images.append(class_images)
                        synthetic_labels.extend([class_id] * n_images)
                        synthetic_indices[class_id] = indices
                
                # Ensure directory exists then save the computed indices for future use
                os.makedirs(PRECOMPUTED_DIR, exist_ok=True)
                with open(coreset_file, 'wb') as f:
                    pickle.dump(synthetic_indices, f)
                logger.info(f"Coreset indices saved to {coreset_file}")

            self.synthetic_images = torch.cat(synthetic_images, dim=0).to(self.device)
            self.synthetic_images.requires_grad = self.train_synthetic
            self.synthetic_labels = torch.tensor(synthetic_labels, device=self.device)
            self.synthetic_indices = synthetic_indices
            
        elif self.initialization == 'herding':
            synthetic_images = []
            synthetic_labels = []
            synthetic_indices = {}

            herding_filename = f'{dataset.name}_herding_indices_{images_per_class}.pkl'
            herding_file = os.path.join(PRECOMPUTED_DIR, herding_filename)
            if os.path.exists(herding_file):
                logger.info(f"Loading existing herding indices from {herding_file}")
                with open(herding_file, 'rb') as f:
                    synthetic_indices = pickle.load(f)
            
                for class_id in range(self.num_classes):
                    n_images = images_per_class + (1 if class_id < remainder else 0)
                    
                    if n_images > 0 and class_id in synthetic_indices:
                        indices = synthetic_indices[class_id][:n_images]
                        class_images = torch.stack([self.dataset[i][0] for i in indices])
                        synthetic_images.append(class_images)
                        synthetic_labels.extend([class_id] * len(indices))
            else:
                logger.info(f"No existing herding found. Computing new herding and saving to {herding_file}")
                for class_id in range(self.num_classes):
                    n_images = images_per_class + (1 if class_id < remainder else 0)
                    
                    if n_images > 0 and class_id in self.class_indices:
                        indices, class_images = self.get_herding_images_per_class(n_images, class_id)
                        synthetic_images.append(class_images)
                        synthetic_labels.extend([class_id] * n_images)
                        synthetic_indices[class_id] = indices

                os.makedirs(PRECOMPUTED_DIR, exist_ok=True)
                with open(herding_file, 'wb') as f:
                    pickle.dump(synthetic_indices, f)
                logger.info(f"Herding indices saved to {herding_file}")

            self.synthetic_images = torch.cat(synthetic_images, dim=0).to(self.device)
            self.synthetic_images.requires_grad = self.train_synthetic
            self.synthetic_labels = torch.tensor(synthetic_labels, device=self.device)
            self.synthetic_indices = synthetic_indices
        
        elif self.initialization == 'k-medoid':
            synthetic_images = []
            synthetic_labels = []
            synthetic_indices = {}

            medoids_filename = f'{dataset.name}_medoids_indices_{images_per_class}.pkl'
            medoids_file = os.path.join(PRECOMPUTED_DIR, medoids_filename)
            if os.path.exists(medoids_file):
                logger.info(f"Loading existing medoids indices from {medoids_file}")
                with open(medoids_file, 'rb') as f:
                    synthetic_indices = pickle.load(f)
            
                for class_id in range(self.num_classes):
                    n_images = images_per_class + (1 if class_id < remainder else 0)
                    
                    if n_images > 0 and class_id in synthetic_indices:
                        indices = synthetic_indices[class_id][:n_images]
                        class_images = torch.stack([self.dataset[i][0] for i in indices])
                        synthetic_images.append(class_images)
                        synthetic_labels.extend([class_id] * len(indices))
            else:
                logger.info(f"No existing medoids found. Computing new medoids and saving to {medoids_file}")
                for class_id in range(self.num_classes):
                    n_images = images_per_class + (1 if class_id < remainder else 0)
                    
                    if n_images > 0 and class_id in self.class_indices:
                        indices, class_images = self.get_medoid_images_per_class(n_images, class_id)
                        synthetic_images.append(class_images)
                        synthetic_labels.extend([class_id] * n_images)
                        synthetic_indices[class_id] = indices

                os.makedirs(PRECOMPUTED_DIR, exist_ok=True)
                with open(medoids_file, 'wb') as f:
                    pickle.dump(synthetic_indices, f)
                logger.info(f"K-medoids indices saved to {medoids_file}")

            self.synthetic_images = torch.cat(synthetic_images, dim=0).to(self.device)
            self.synthetic_images.requires_grad = self.train_synthetic
            self.synthetic_labels = torch.tensor(synthetic_labels, device=self.device)
            self.synthetic_indices = synthetic_indices
                

        
        # For residual modes with fixed base, create a separate tensor for residual computations
        if self.use_residual and self.fixed_base:
            self.synthetic_images.requires_grad = False
            self.y_images = self.synthetic_images.clone().detach()
            self.y_images.requires_grad = True

        # Create small convolutional network for synthetic image generation, if needed
        if self.use_conv:
            self.small_net = torch.nn.Sequential(
                nn.Conv2d(3, 256, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.Conv2d(256, 3, kernel_size=3, stride=1, padding=1),
            ).to(device)
        if self.synth_mode == 'unet':
            self.small_net = UNetLike(in_channels=3, out_channels=3).to(self.device)

    def get_n_images_per_class(self, n, c, deterministic=False):
        """Random sampling of n images from class c"""
        if deterministic:
            indices = np.random.permutation(self.class_indices[c])[:n]
            images = torch.stack([self.dataset[i][0] for i in indices])
            return indices.tolist(), images
        else:
            rng = np.random.default_rng(int(time.time() * 1e6) % (2**32 - 1))
            indices = rng.permutation(self.class_indices[c])[:n]
            images = torch.stack([self.dataset[i][0] for i in indices])
            return indices.tolist(), images
    
    def get_coreset_images_per_class(self, n, c):
        """
        Coreset selection: Find n images from class c that are closest to medoids
        in the feature space extracted by the feature extractor.
        """
        class_indices = self.class_indices[c]
        
        # If we need more images than available, just return all
        if n >= len(class_indices):
            indices = class_indices
            images = torch.stack([self.dataset[i][0] for i in indices])
            return indices, images
        
        # Extract features for all images in this class
        features = []
        
        with torch.no_grad():
            # Process images in batches to avoid memory issues
            batch_size = 64
            for i in range(0, len(class_indices), batch_size):
                batch_indices = class_indices[i:i+batch_size]
                batch_images = torch.stack([self.dataset[idx][0] for idx in batch_indices])
                batch_images = batch_images.to(self.device)
                
                # Extract features
                batch_features = self.feature_extractor(batch_images)
                batch_features /= batch_features.norm(dim=-1, keepdim=True)
                features.append(batch_features.cpu().numpy())
        
        # Concatenate all features
        features = np.concatenate(features, axis=0)
        
        # Compute pairwise distances between all features
        distances = pairwise_distances(features, metric='euclidean')
        
        # Find medoids using a greedy approach
        selected_indices = []
        remaining_indices = list(range(len(class_indices)))
        
        # First, select the point that minimizes the sum of distances to all other points
        sum_distances = distances.sum(axis=1)
        first_medoid = np.argmin(sum_distances)
        selected_indices.append(first_medoid)
        remaining_indices.remove(first_medoid)
        
        # Iteratively select the next medoid that minimizes the maximum distance
        # to the closest already selected medoid
        for _ in range(n - 1):
            if not remaining_indices:
                break
                
            best_candidate = None
            best_score = float('inf')
            
            for candidate in remaining_indices:
                # Find minimum distance from candidate to any selected medoid
                min_dist_to_selected = min([distances[candidate][selected] for selected in selected_indices])
                
                # We want to minimize the maximum distance, so we prefer candidates
                # that are far from already selected medoids
                if min_dist_to_selected < best_score:
                    best_score = min_dist_to_selected
                    best_candidate = candidate
            
            if best_candidate is not None:
                selected_indices.append(best_candidate)
                remaining_indices.remove(best_candidate)
        
        # Convert back to original dataset indices
        selected_dataset_indices = [class_indices[i] for i in selected_indices]
        
        # Load the selected images
        images = torch.stack([self.dataset[i][0] for i in selected_dataset_indices])
        
        return selected_dataset_indices, images

    def get_medoid_images_per_class(self, n, c, fast=True, threshold=2000):
        """
        K-medoids per class.
        - fast=True: approximated with MiniBatchKMeans + selection of the sample closest to the center
          (recommended for large classes, n_samples >> n).
        - fast=False or num_samples <= threshold: uses exact version (slow).
        """
        class_indices = self.class_indices[c]

        if n >= len(class_indices):
            indices = class_indices
            images = torch.stack([self.dataset[i][0] for i in indices])
            return indices, images

        features = []
        with torch.no_grad():
            batch_size = 64
            for i in range(0, len(class_indices), batch_size):
                batch_indices = class_indices[i:i+batch_size]
                batch_images = torch.stack([self.dataset[idx][0] for idx in batch_indices]).to(self.device)
                batch_features = self.feature_extractor(batch_images)
                batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
                features.append(batch_features.cpu().numpy())
        features = np.concatenate(features, axis=0).astype(np.float32)  # shape: (num_samples, d)
        num_samples = features.shape[0]

        use_fast = fast or (num_samples > threshold)

        if use_fast:
            try:
                from sklearn.cluster import MiniBatchKMeans
                rng = 42
                kmeans = MiniBatchKMeans(
                    n_clusters=n,
                    batch_size=2048,
                    max_iter=100,
                    n_init="auto",
                    random_state=rng
                )
                kmeans.fit(features)
                centers = kmeans.cluster_centers_.astype(np.float32)  # (n, d)
                centers /= (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-12)
                sims = centers @ features.T

                selected = []
                used = set()
                for row in sims:
                    order = np.argsort(-row)  # descending by similarity
                    pick = None
                    for idx in order:
                        idx = int(idx)
                        if idx not in used:
                            used.add(idx)
                            pick = idx
                            break
                    if pick is not None:
                        selected.append(pick)

                if len(selected) < n:
                    flat_order = np.argsort(-sims.ravel())
                    for flat_idx in flat_order:
                        cand = int(flat_idx % num_samples)
                        if cand not in used:
                            used.add(cand)
                            selected.append(cand)
                        if len(selected) == n:
                            break

                medoid_local_idx = np.array(selected[:n], dtype=int)
                selected_dataset_indices = [class_indices[i] for i in medoid_local_idx]
                images = torch.stack([self.dataset[i][0] for i in selected_dataset_indices])
                return selected_dataset_indices, images
            except Exception as e:
                logger.warning(f"Fast k-medoids fallback to exact due to: {e}")

        np.random.seed(42)
        medoid_indices = np.random.choice(num_samples, size=n, replace=False)

        max_iterations = 100
        tolerance = 1e-6
        best_cost = None

        for _ in range(max_iterations):
            medoid_feats = features[medoid_indices]  # (n, d)
            distances_to_medoids = np.linalg.norm(
                features[:, None, :] - medoid_feats[None, :, :], axis=2
            )
            assignments = np.argmin(distances_to_medoids, axis=1)
            current_cost = np.sum(distances_to_medoids[np.arange(num_samples), assignments])

            new_medoid_indices = medoid_indices.copy()
            improved = False

            for cluster_id in range(n):
                cluster_points = np.where(assignments == cluster_id)[0]
                if len(cluster_points) == 0:
                    continue
                sub = features[cluster_points]
                dists = np.linalg.norm(sub[:, None, :] - sub[None, :, :], axis=2)
                costs = dists.sum(axis=1)
                best_idx_local = np.argmin(costs)
                best_medoid_local = cluster_points[best_idx_local]
                if best_medoid_local != medoid_indices[cluster_id]:
                    new_medoid_indices[cluster_id] = best_medoid_local
                    improved = True

            if not improved or (best_cost is not None and abs(current_cost - best_cost) < tolerance):
                break
            medoid_indices = new_medoid_indices
            best_cost = current_cost

        selected_dataset_indices = [class_indices[i] for i in medoid_indices]
        images = torch.stack([self.dataset[i][0] for i in selected_dataset_indices])
        return selected_dataset_indices, images


    def get_herding_images_per_class(self, n, c):
        """
        Herding selection: Find n images from class c that are closest to the mean
        in the feature space extracted by the feature extractor.
        """
        class_indices = self.class_indices[c]
        
        # If we need more images than available, just return all
        if n >= len(class_indices):
            indices = class_indices
            images = torch.stack([self.dataset[i][0] for i in indices])
            return indices, images
        
        # Extract features for all images in this class
        features = []
        
        with torch.no_grad():
            # Process images in batches to avoid memory issues
            batch_size = 64
            for i in range(0, len(class_indices), batch_size):
                batch_indices = class_indices[i:i+batch_size]
                batch_images = torch.stack([self.dataset[idx][0] for idx in batch_indices])
                batch_images = batch_images.to(self.device)
                
                # Extract features
                batch_features = self.feature_extractor(batch_images)
                batch_features /= batch_features.norm(dim=-1, keepdim=True)
                features.append(batch_features.cpu().numpy())
        
        features = np.concatenate(features, axis=0)
        mean_feature = features.mean(axis=0)
        mean_feature /= np.linalg.norm(mean_feature)
        
        selected_indices = []
        remaining_indices = list(range(len(class_indices)))

        selected_sum = np.zeros_like(mean_feature)

        for i in range(n):
            if not remaining_indices:
                break
            
            target_mean = mean_feature
            best_candidate = None
            best_distances = float('inf')

            for candidate_idx in remaining_indices:
                candidate_feature = features[candidate_idx]

                new_selected_sum = selected_sum + candidate_feature
                new_selected_mean = new_selected_sum / (i + 1)
                distance = np.linalg.norm(new_selected_mean - target_mean)
                if distance < best_distances:
                    best_distances = distance
                    best_candidate = candidate_idx

            if best_candidate is not None:
                selected_indices.append(best_candidate)
                remaining_indices.remove(best_candidate)
                selected_sum += features[best_candidate]

            selected_dataset_indices = [class_indices[i] for i in selected_indices]
            images = torch.stack([self.dataset[i][0] for i in selected_dataset_indices])
        
        return selected_dataset_indices, images
    
    def forward(self, class_id=None, idx=None):   
        """Return synthetic images according to filters/mode.

        One of class_id or idx can be specified to filter the returned batch.
        If idx is provided, it can be an int or a list of indices. The output
        depends on the selected synth_mode (direct, conv, residual, residual_y,
        residual_pure, unet).

        Args:
            class_id: Optional class id to filter images of a single class.
            idx: Optional index or list of indices to retrieve specific items.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: (images, labels) tensors. Images
            shape depends on selection; labels align with images.
        """
        
        # Access by single index or batch of indices
        if idx is not None:
            if isinstance(idx, int):
                idx = [idx]

            synthetic_images = self.synthetic_images[idx]
            synthetic_labels = self.synthetic_labels[idx]

            if self.synth_mode == 'direct':
                return synthetic_images, synthetic_labels
            elif self.synth_mode == 'conv':
                return self.small_net(synthetic_images), synthetic_labels
            elif self.synth_mode == 'residual':
                return synthetic_images + self.small_net(synthetic_images), synthetic_labels
            elif self.synth_mode == 'residual_y':
                class_y_images = self.y_images[idx]
                return synthetic_images + self.small_net(class_y_images), synthetic_labels
            elif self.synth_mode == 'residual_pure':
                return synthetic_images + self.small_net(synthetic_images), synthetic_labels
            elif self.synth_mode == 'unet':
                return self.small_net(synthetic_images), synthetic_labels
            else:
                raise ValueError(f"Unknown synth_mode: {self.synth_mode}")   
            
        if class_id is not None:
            logger.debug(f"Filtering synthetic images for class {class_id}.")
            # If class_id is specified, filter synthetic images and labels
            synthetic_images = self.synthetic_images[self.synthetic_labels == class_id]
            synthetic_labels = self.synthetic_labels[self.synthetic_labels == class_id]
            if self.synth_mode == 'direct':
                # Direct optimization of synthetic images
                return synthetic_images, synthetic_labels
            elif self.synth_mode == 'conv':
                # Convolutional network generates synthetic images
                return self.small_net(synthetic_images), synthetic_labels
            elif self.synth_mode == 'residual':
                # Original images + residual from convolutional network
                return synthetic_images + self.small_net(synthetic_images), synthetic_labels
            elif self.synth_mode == 'residual_y':
                # Fixed original images + learned residual (previously residual_fixed)
                class_y_images = self.y_images[self.synthetic_labels == class_id]
                return synthetic_images + self.small_net(class_y_images), synthetic_labels
            elif self.synth_mode == 'residual_pure':
                # Only train convolutional weights, base images remain fixed
                return synthetic_images + self.small_net(synthetic_images), synthetic_labels
            elif self.synth_mode == 'unet':
                return self.small_net(synthetic_images), synthetic_labels
            else:
                raise ValueError(f"Unknown synth_mode: {self.synth_mode}")
        else:
            logger.debug("No class_id specified, returning all synthetic images and labels.")    
        
        synth_batch_size = self.synth_batch_size
        # Always batch if synth_batch_size is set and < num images
        def batch_apply(images, func):
            if synth_batch_size is not None and synth_batch_size < len(images):
                outputs = []
                for i in range(0, len(images), synth_batch_size):
                    logger.debug(f"Processing batch {i // synth_batch_size + 1} of {len(images) // synth_batch_size + 1}")
                    batch = images[i:i+synth_batch_size]
                    outputs.append(func(batch))
                return torch.cat(outputs, dim=0)
            else:
                return func(images)

        if self.synth_mode == 'direct':
            # Direct optimization of synthetic images
            return self.synthetic_images, self.synthetic_labels
        elif self.synth_mode == 'conv':
            # Convolutional network generates synthetic images
            out = batch_apply(self.synthetic_images, self.small_net)
            return out, self.synthetic_labels
        elif self.synth_mode == 'residual':
            # Original images + residual from convolutional network
            out = batch_apply(self.synthetic_images, self.small_net)
            return self.synthetic_images + out, self.synthetic_labels
        elif self.synth_mode == 'residual_y':
            # Fixed original images + learned residual
            out = batch_apply(self.y_images, self.small_net)
            return self.synthetic_images + out, self.synthetic_labels
        elif self.synth_mode == 'residual_pure':
            # Only train convolutional weights, base images remain fixed
            out = batch_apply(self.synthetic_images, self.small_net)
            return self.synthetic_images + out, self.synthetic_labels
        elif self.synth_mode == 'unet':
            out = batch_apply(self.synthetic_images, self.small_net)
            return out, self.synthetic_labels
        else:
            raise ValueError(f"Unknown synth_mode: {self.synth_mode}")

from torch.utils.data import Dataset

class SyntheticImageDataset(Dataset):
    """Efficient dataset wrapper for synthetic images.

    Args:
        synth_module (SyntheticImage): instance of SyntheticImage (nn.Module)
    """
    def __init__(self, synth_module):
        self.synth = synth_module

    def __len__(self):
        """Return the number of synthetic samples available."""
        return self.synth.synthetic_images.shape[0]

    def __getitem__(self, idx):
        """Get a single synthetic sample.

        Args:
            idx: Integer index of the sample to retrieve.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The (image, label) pair for the
            requested index.
        """
        image, label = self.synth(idx=idx)
        return image[0], label[0]
