# forward_forward/models/layers/class_grouping.py

"""Class Grouping utilities for Forward-Forward models with multiple modes"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Union, Optional, Type, Callable
from enum import Enum


class ClassGroupingMode(Enum):
    """Enumeration of different class grouping modes."""
    DIMENSION_REDUCTION = "dimension_reduction"  # Original mode - reduces label dimensions
    GROUP_AWARE_NEGATIVE = "group_aware_negative"  # New mode - avoids same-group negatives


class ClassGroupingManager:
    """Manages class grouping for Forward-Forward layers.
    
    This class handles the mapping of original classes to grouped classes
    and provides utilities for label transformation and inverse mapping.
    
    Supports two modes:
    1. DIMENSION_REDUCTION: Groups classes together, reducing effective number of classes
    2. GROUP_AWARE_NEGATIVE: Uses grouping only to avoid same-group negative sampling
    """
    
    def __init__(
        self, 
        class_groups: Dict[int, List[int]], 
        num_original_classes: int,
        mode: ClassGroupingMode = ClassGroupingMode.DIMENSION_REDUCTION
    ):
        """
        Initialize class grouping manager.
        
        Args:
            class_groups: Dictionary mapping group_id to list of original class indices
                         e.g., {0: [3, 5]} means classes 3 and 5 are grouped together as group 0
            num_original_classes: Total number of original classes
            mode: Grouping mode to use
            
        Example:
            # Group classes 3 and 5 together, and 1 and 8 together
            groups = {0: [3, 5], 1: [1, 8]}
            
            # Dimension reduction mode (original behavior)
            manager = ClassGroupingManager(groups, num_original_classes=10, 
                                         mode=ClassGroupingMode.DIMENSION_REDUCTION)
            
            # Group-aware negative sampling mode (new behavior)
            manager = ClassGroupingManager(groups, num_original_classes=10,
                                         mode=ClassGroupingMode.GROUP_AWARE_NEGATIVE)
        """
        self.class_groups = class_groups
        self.num_original_classes = num_original_classes
        self.mode = mode
        
        # Create mapping from original class to group (or itself if not grouped)
        self.class_to_group = {}
        self.grouped_classes = set()
        
        # Track which classes are grouped
        for group_id, classes in class_groups.items():
            for cls in classes:
                if cls in self.grouped_classes:
                    raise ValueError(f"Class {cls} appears in multiple groups")
                self.class_to_group[cls] = group_id
                self.grouped_classes.add(cls)
        
        if self.mode == ClassGroupingMode.DIMENSION_REDUCTION:
            # Map ungrouped classes to themselves with offset
            next_group_id = max(class_groups.keys()) + 1 if class_groups else 0
            for cls in range(num_original_classes):
                if cls not in self.grouped_classes:
                    self.class_to_group[cls] = next_group_id
                    next_group_id += 1
            
            # Calculate number of effective classes after grouping
            self.num_grouped_classes = len(set(self.class_to_group.values()))
        else:
            # In group-aware negative mode, we keep original number of classes
            for cls in range(num_original_classes):
                if cls not in self.grouped_classes:
                    self.class_to_group[cls] = cls  # Map to itself
            
            self.num_grouped_classes = num_original_classes
        
        # Create reverse mapping for interpretation
        self.group_to_classes = {}
        for cls, group in self.class_to_group.items():
            if group not in self.group_to_classes:
                self.group_to_classes[group] = []
            self.group_to_classes[group].append(cls)
        
        print(f"Class grouping initialized ({self.mode.value}):")
        print(f"  Original classes: {num_original_classes}")
        if self.mode == ClassGroupingMode.DIMENSION_REDUCTION:
            print(f"  Grouped classes: {self.num_grouped_classes}")
        print(f"  Groups: {self.group_to_classes}")
    
    def transform_labels(self, labels: torch.Tensor) -> torch.Tensor:
        """
        Transform original labels to grouped labels.
        
        Args:
            labels: Original labels tensor of shape (B,)
            
        Returns:
            torch.Tensor: Grouped labels tensor of shape (B,)
                         - In DIMENSION_REDUCTION mode: transformed to grouped space
                         - In GROUP_AWARE_NEGATIVE mode: unchanged (identity mapping)
        """
        if self.mode == ClassGroupingMode.GROUP_AWARE_NEGATIVE:
            # In group-aware negative mode, labels remain unchanged
            return labels
        
        device = labels.device
        dtype = labels.dtype
        
        # Create mapping tensor for efficient transformation
        mapping = torch.zeros(self.num_original_classes, dtype=dtype, device=device)
        for original_cls, group_id in self.class_to_group.items():
            mapping[original_cls] = group_id
        
        return mapping[labels]
    
    def get_same_group_classes(self, class_idx: int) -> List[int]:
        """
        Get all classes that belong to the same group as the given class.
        
        Args:
            class_idx: Original class index
            
        Returns:
            List[int]: All classes in the same group (including the input class)
        """
        if class_idx not in self.class_to_group:
            return [class_idx]
        
        group_id = self.class_to_group[class_idx]
        return self.group_to_classes.get(group_id, [class_idx])
    
    def get_valid_negative_classes(self, positive_class: int) -> List[int]:
        """
        Get all classes that are valid as negative samples for the given positive class.
        In group-aware mode, this excludes classes from the same group.
        
        Args:
            positive_class: The positive class index
            
        Returns:
            List[int]: List of valid negative classes
        """
        if self.mode == ClassGroupingMode.DIMENSION_REDUCTION:
            # In dimension reduction mode, use transformed space
            return list(range(self.num_grouped_classes))
        
        # In group-aware negative mode, exclude same-group classes
        same_group_classes = set(self.get_same_group_classes(positive_class))
        valid_classes = [cls for cls in range(self.num_original_classes) 
                        if cls not in same_group_classes]
        return valid_classes
    
    def create_group_aware_negative_mask(self, labels: torch.Tensor) -> torch.Tensor:
        """
        Create a mask for valid negative classes for each sample in the batch.
        
        Args:
            labels: Batch of labels (B,)
            
        Returns:
            torch.Tensor: Boolean mask of shape (B, num_classes) where True indicates
                         valid negative classes for each sample
        """
        batch_size = labels.shape[0]
        mask = torch.ones(batch_size, self.num_original_classes, 
                         dtype=torch.bool, device=labels.device)
        
        for i, label in enumerate(labels):
            same_group_classes = self.get_same_group_classes(label.item())
            for cls in same_group_classes:
                mask[i, cls] = False
        
        return mask
    
    def get_num_classes(self) -> int:
        """Get number of classes after grouping."""
        return self.num_grouped_classes
    
    def get_group_info(self) -> Dict[int, List[int]]:
        """Get information about groups."""
        return self.group_to_classes.copy()
    
    def is_dimension_reduction_mode(self) -> bool:
        """Check if using dimension reduction mode."""
        return self.mode == ClassGroupingMode.DIMENSION_REDUCTION
    
    def is_group_aware_negative_mode(self) -> bool:
        """Check if using group-aware negative sampling mode."""
        return self.mode == ClassGroupingMode.GROUP_AWARE_NEGATIVE


def create_class_grouping_manager(
    layer_name: str, 
    layer_config: Dict, 
    num_original_classes: int
) -> Optional[ClassGroupingManager]:
    """
    Create a class grouping manager from layer configuration.
    
    Args:
        layer_name: Name of the layer
        layer_config: Layer configuration dictionary
        num_original_classes: Total number of original classes
        
    Returns:
        ClassGroupingManager if class grouping is specified, None otherwise
    """
    class_groups = layer_config.get("class_groups", None)
    if class_groups is None:
        return None
    
    # Get grouping mode from config
    grouping_mode_str = layer_config.get("class_grouping_mode", "dimension_reduction")
    try:
        grouping_mode = ClassGroupingMode(grouping_mode_str)
    except ValueError:
        print(f"Warning: Invalid class_grouping_mode '{grouping_mode_str}', "
              f"using default 'dimension_reduction'")
        grouping_mode = ClassGroupingMode.DIMENSION_REDUCTION
    
    # Convert string keys to integers if needed
    processed_groups = {}
    for group_id, classes in class_groups.items():
        if isinstance(group_id, str):
            group_id = int(group_id)
        
        # Validate class indices
        for cls in classes:
            if not isinstance(cls, int) or cls < 0 or cls >= num_original_classes:
                raise ValueError(f"Invalid class index {cls} for group {group_id}")
        
        processed_groups[group_id] = classes
    
    return ClassGroupingManager(processed_groups, num_original_classes, grouping_mode)


def _calculate_total_partitions(spatial_dims) -> int:
    """Calculate total number of partitions based on n_partitions setting."""
    if spatial_dims is None:
        return 1
    
    if isinstance(spatial_dims, int):
        return spatial_dims
    
    if isinstance(spatial_dims, (tuple, list)):
        total = 1
        for n in spatial_dims:
            total *= n
        return total
    
    raise ValueError(f"Unsupported n_partitions type: {type(spatial_dims)}")


# Enhanced FleaConvWithEncodingBase to support both grouping modes
class FleaConvWithEncodingBase(nn.Module):
    """Enhanced base class with support for both class grouping modes."""
    
    def __init__(
        self,
        conv_cls: Type[nn.Module],
        *,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Union[int, Tuple[int, ...]] = 1,
        dilation: Union[int, Tuple[int, ...]] = 1,
        groups: int = 1,
        bias: bool = False,
        padding_mode: str = "zeros",
        label_dimension: Optional[int] = None,
        dropout: float = 0.0,
        activation_function: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        g_power: int = 2,
        alpha: int = 1,
        input_shape: int = 32,
        n_partitions: int = 1,
        # Enhanced parameter for class grouping
        class_grouping_manager: Optional[ClassGroupingManager] = None,
        **kwargs
    ):
        super().__init__()
        
        # Store class grouping manager
        self.class_grouping_manager = class_grouping_manager
        
        # Determine effective number of classes based on grouping mode
        if class_grouping_manager is not None:
            effective_num_classes = class_grouping_manager.get_num_classes()
            mode_str = class_grouping_manager.mode.value
            print(f"Layer using class grouping ({mode_str}): {effective_num_classes} effective classes")
        else:
            effective_num_classes = label_dimension or 10
        
        # Configuration parameters
        self.label_dimension = effective_num_classes  # Use effective classes
        self.original_label_dimension = label_dimension or 10  # Keep original for reference
        self.activation_function = activation_function
        self.num_classes = effective_num_classes
        self.accepts_label = True
        self.accepts_mode = True
        self.g_power = g_power
        self.alpha = alpha
        
        # Initialize convolution layer
        self.conv = conv_cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode
        )
        self.out_channels = self.conv.out_channels
        if isinstance(n_partitions, int):
            self.n_partitions = n_partitions
        else:
            self.n_partitions = list(n_partitions)
        
        # Initialize dropout if needed
        self.dropout = nn.Dropout2d(p=dropout)
        
        # Label-dependent scaling layer (uses effective number of classes)
        embedding_dim = out_channels * _calculate_total_partitions(self.n_partitions)
        self.label_encoder = nn.Linear(self.num_classes, embedding_dim, bias=True)
        # self.label_encoder = nn.utils.weight_norm(self.label_encoder)
        
        # Initialize normalization layers
        ndims = self.__class__._NDIMS
        if ndims == 1:
            normalized_shape = [in_channels, input_shape]
        elif ndims == 2:
            normalized_shape = [in_channels, input_shape, input_shape]
        else:
            raise NotImplementedError(f"_NDIMS={ndims} not supported")
        
        self.normalizer = nn.LayerNorm(normalized_shape, elementwise_affine=False)
        
        output_shape = (input_shape - kernel_size + padding * 2) + 1
        if output_shape > 1:
            self.batch_norm = nn.BatchNorm2d(out_channels, affine=False)
        else:
            self.batch_norm = nn.GroupNorm(1, out_channels)
        
        # Learnable feature threshold
        self.feature_threshold = nn.Parameter(torch.tensor(1.0))
        nn.init.constant_(self.feature_threshold, 1.0)

    def _transform_labels_if_needed(self, labels: torch.Tensor) -> torch.Tensor:
        """Transform labels using class grouping if available."""
        if self.class_grouping_manager is not None:
            return self.class_grouping_manager.transform_labels(labels)
        return labels
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Standard forward pass through convolutional layers."""
        # Apply normalization
        x = self.normalizer(x)
        
        # Convolution operation
        x = self.conv(x)
        
        # Apply batch normalization
        x = self.batch_norm(x)
        
        # Apply activation function
        x = self.activation_function(x)
    
        # Apply optional dropout
        x = self.dropout(x)
        
        return x
    
    def embedding_alignment(
        self, 
        feature_embedding: torch.Tensor, 
        labels: torch.Tensor, 
        mode: str,
    ) -> torch.Tensor:
        """Compute label-modulated goodness scores with enhanced class grouping support."""
        # Transform labels if class grouping is enabled
        transformed_labels = self._transform_labels_if_needed(labels.long())
        
        # Encode labels (one-hot for positive/negative classes)
        if mode == "positive":
            label_vec = self.generate_true_labels(transformed_labels, self.num_classes)
        elif mode == "negative":
            class_scores = self.gpredict(feature_embedding)
            label_vec = self.generate_hard_labels_with_grouping(
                class_scores, labels, transformed_labels, self.num_classes, self.alpha
            )
        
        # Get scaled directions
        label_embedding = self.label_encoder(label_vec)
        
        # Original scoring behavior (dot product -> summing)
        alignment = torch.einsum('bc,bc->b', feature_embedding, label_embedding)
        
        if mode == "positive":
            return alignment - self.feature_threshold
        elif mode == "negative":
            return alignment - self.feature_threshold, (torch.argmax(class_scores, dim=1) == transformed_labels).float().mean().item()
    
    def generate_hard_labels_with_grouping(
        self,
        class_scores: torch.Tensor,
        original_labels: torch.Tensor,
        transformed_labels: torch.Tensor,
        num_classes: int,
        alpha: float = 1.0,
    ) -> torch.Tensor:
        """Generate hard labels for training with group-aware negative sampling."""
        B = class_scores.shape[0]
        
        if (self.class_grouping_manager is not None and 
            self.class_grouping_manager.is_group_aware_negative_mode()):
            # Group-aware negative sampling mode
            return self._generate_group_aware_hard_labels(
                class_scores, original_labels, num_classes, alpha
            )
        else:
            # Standard hard label generation (dimension reduction mode or no grouping)
            return self.generate_hard_labels(
                class_scores, transformed_labels, num_classes, alpha
            )
    
    def _generate_group_aware_hard_labels(
        self,
        class_scores: torch.Tensor,
        original_labels: torch.Tensor,
        num_classes: int,
        alpha: float = 1.0,
    ) -> torch.Tensor:
        """Generate hard labels avoiding same-group classes."""
        B = class_scores.shape[0]
        device = class_scores.device
        
        # Get mask for valid negative classes
        valid_mask = self.class_grouping_manager.create_group_aware_negative_mask(original_labels)
        
        # Create masked predictions (set invalid classes to -inf)
        masked_preds = class_scores.clone()
        masked_preds[~valid_mask] = -torch.finfo(class_scores.dtype).max
        
        # Log-space operations for numerical stability
        log_probs = F.log_softmax(masked_preds, dim=1)
        scaled_log_probs = alpha * log_probs
        probs = scaled_log_probs.exp().clamp(min=1e-10, max=1e10)
        
        # Ensure no same-group classes are selected
        probs[~valid_mask] = 0
        
        # Normalize probabilities
        probs_sum = probs.sum(dim=1, keepdim=True)
        probs = torch.where(probs_sum > 0, probs / probs_sum, 
                           valid_mask.float() / valid_mask.sum(dim=1, keepdim=True))
        
        # Sample wrong labels from valid classes
        wrong_labels = torch.multinomial(probs, num_samples=1).squeeze(1)
        
        return F.one_hot(wrong_labels, num_classes=num_classes).float()
    
    @torch.no_grad()
    def gpredict(self, feature_embedding: torch.Tensor) -> torch.Tensor:
        """Predict class scores from goodness vectors."""
        # Compute base scores
        class_scores = feature_embedding @ self.label_encoder.weight  # (B, num_classes)
        # Add bias contribution if present
        if self.label_encoder.bias is not None:
            class_scores += feature_embedding @ self.label_encoder.bias.unsqueeze(1)
        
        # TODO 17 Sep
        return class_scores - self.feature_threshold
        # return class_scores
    
    @torch.no_grad()
    def predict(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Full prediction pipeline with class grouping considerations."""
        features = self.forward(inputs)
        feature_embedding = self.image_encoder(features)
        
        # TODO 17 Sep
        class_scores = self.gpredict(feature_embedding)
        # class_scores = self.gpredict(feature_embedding) - self.feature_threshold
        predictions = class_scores.argmax(dim=1)
        
        return class_scores, predictions, features
    
    def get_grouped_prediction_mapping(self) -> Optional[Dict[int, List[int]]]:
        """Get mapping from grouped predictions back to original classes."""
        if (self.class_grouping_manager is not None and 
            self.class_grouping_manager.is_dimension_reduction_mode()):
            return self.class_grouping_manager.get_group_info()
        return None
    
    def get_grouping_mode(self) -> Optional[str]:
        """Get the current class grouping mode."""
        if self.class_grouping_manager is not None:
            return self.class_grouping_manager.mode.value
        return None
    
    # ... (keep all other methods from the original class unchanged)
    def image_encoder(self, features: torch.Tensor) -> torch.Tensor:
        """Encode image per sample from layer features with flexible partitioning.
        
        Supports:
        - n_partitions = int: Vertical partitioning (horizontal rectangular regions)
        - n_partitions = tuple: Spatial partitioning (e.g., (2,2) for 2x2 grid)
        - n_partitions = 1 or None: No partitioning (global average)
        """
        B, C = features.shape[:2]
        spatial_shape = features.shape[2:]
        n_spatial_dims = len(spatial_shape)
        
        # Handle different n_partitions formats
        if not hasattr(self, 'n_partitions') or self.n_partitions is None:
            # No partitioning
            partition_mode = 'none'
        elif isinstance(self.n_partitions, int):
            # Vertical partitioning
            partition_mode = 'vertical'
            n_partitions = [self.n_partitions]  # Single element list for consistency
        elif isinstance(self.n_partitions, (tuple, list)):
            # Spatial partitioning
            partition_mode = 'spatial'
            n_partitions = list(self.n_partitions)
            if len(n_partitions) != n_spatial_dims:
                raise ValueError(f"n_partitions tuple length {len(n_partitions)} must match "
                               f"spatial dimensions {n_spatial_dims}")
        else:
            raise ValueError(f"Unsupported n_partitions type: {type(self.n_partitions)}")
        
        # If no partitioning, use original behavior
        if partition_mode == 'none' or (partition_mode == 'vertical' and n_partitions[0] <= 1):
            spatial_dims = tuple(range(2, features.dim()))
            if self.g_power == 1:
                return features.abs().mean(dim=spatial_dims)
            elif self.g_power == 2:
                return features.square().mean(dim=spatial_dims)
            else:
                return features.abs().clamp_min(1e-6).pow(self.g_power).mean(dim=spatial_dims)
        
        # Calculate partition sizes
        partition_sizes = []
        for dim_size, n_part in zip(spatial_shape, n_partitions):
            partition_size = dim_size // n_part
            if partition_size == 0:
                raise ValueError(f"Feature map dimension {dim_size} is too small for {n_part} partitions")
            partition_sizes.append(partition_size)
        
        # Create list to store goodness values for each partition
        partition_goodness = []
        
        if partition_mode == 'vertical':
            # Vertical partitioning - only split along height dimension
            H = spatial_shape[0]
            partition_h = partition_sizes[0]
            
            if n_spatial_dims == 2:  # 2D case (H, W)
                for i in range(n_partitions[0]):
                    h_start = i * partition_h
                    h_end = min((i + 1) * partition_h, H)
                    
                    # Extract vertical partition (full width)
                    partition = features[:, :, h_start:h_end, :]
                    
                    # Calculate goodness for this partition
                    if self.g_power == 1:
                        goodness_val = partition.abs().mean(dim=(2, 3))
                    elif self.g_power == 2:
                        goodness_val = partition.square().mean(dim=(2, 3))
                    else:
                        goodness_val = partition.abs().clamp_min(1e-6).pow(self.g_power).mean(dim=(2, 3))
                    
                    partition_goodness.append(goodness_val)
            
            elif n_spatial_dims == 3:  # 3D case (H, W, D)
                for i in range(n_partitions[0]):
                    h_start = i * partition_h
                    h_end = min((i + 1) * partition_h, H)
                    
                    # Extract vertical partition (full width and depth)
                    partition = features[:, :, h_start:h_end, :, :]
                    
                    # Calculate goodness for this partition
                    if self.g_power == 1:
                        goodness_val = partition.abs().mean(dim=(2, 3, 4))
                    elif self.g_power == 2:
                        goodness_val = partition.square().mean(dim=(2, 3, 4))
                    else:
                        goodness_val = partition.abs().clamp_min(1e-6).pow(self.g_power).mean(dim=(2, 3, 4))
                    
                    partition_goodness.append(goodness_val)
        
        elif partition_mode == 'spatial':
            # Spatial partitioning - split along all dimensions
            if n_spatial_dims == 2:  # 2D case (H, W)
                H, W = spatial_shape
                partition_h, partition_w = partition_sizes
                
                for i in range(n_partitions[0]):
                    for j in range(n_partitions[1]):
                        h_start = i * partition_h
                        h_end = min((i + 1) * partition_h, H)
                        w_start = j * partition_w
                        w_end = min((j + 1) * partition_w, W)
                        
                        # Extract partition
                        partition = features[:, :, h_start:h_end, w_start:w_end]
                        
                        # Calculate goodness for this partition
                        if self.g_power == 1:
                            goodness_val = partition.abs().mean(dim=(2, 3))
                        elif self.g_power == 2:
                            goodness_val = partition.square().mean(dim=(2, 3))
                        else:
                            goodness_val = partition.abs().clamp_min(1e-6).pow(self.g_power).mean(dim=(2, 3))
                        
                        partition_goodness.append(goodness_val)
            
            elif n_spatial_dims == 3:  # 3D case (H, W, D)
                H, W, D = spatial_shape
                partition_h, partition_w, partition_d = partition_sizes
                
                for i in range(n_partitions[0]):
                    for j in range(n_partitions[1]):
                        for k in range(n_partitions[2]):
                            h_start = i * partition_h
                            h_end = min((i + 1) * partition_h, H)
                            w_start = j * partition_w
                            w_end = min((j + 1) * partition_w, W)
                            d_start = k * partition_d
                            d_end = min((k + 1) * partition_d, D)
                            
                            # Extract partition
                            partition = features[:, :, h_start:h_end, w_start:w_end, d_start:d_end]
                            
                            # Calculate goodness for this partition
                            if self.g_power == 1:
                                goodness_val = partition.abs().mean(dim=(2, 3, 4))
                            elif self.g_power == 2:
                                goodness_val = partition.square().mean(dim=(2, 3, 4))
                            else:
                                goodness_val = partition.abs().clamp_min(1e-6).pow(self.g_power).mean(dim=(2, 3, 4))
                            
                            partition_goodness.append(goodness_val)
        
        else:
            raise ValueError(f"Unsupported partition mode: {partition_mode}")
        
        # Concatenate all partition goodness values
        result = torch.cat(partition_goodness, dim=1)
        
        return result
    
    def generate_true_labels(self, labels: torch.Tensor, num_classes: int) -> torch.Tensor:
        """Generate one-hot encoded vectors for true labels."""
        return F.one_hot(labels, num_classes).float()
    
    def generate_hard_labels(
        self,
        class_scores: torch.Tensor,
        labels: torch.Tensor,
        num_classes: int,
        alpha: float = 1.0,
    ) -> torch.Tensor:
        """Generate hard labels for training by sampling from the wrong class probabilities."""
        B = class_scores.shape[0]
        
        # 1. Stable masking
        masked_preds = class_scores.clone()
        masked_preds.scatter_(1, labels.unsqueeze(1), -torch.finfo(class_scores.dtype).max)
        
        # 2. Log-space operations
        log_probs = F.log_softmax(masked_preds, dim=1)
        scaled_log_probs = alpha * log_probs
        probs = scaled_log_probs.exp().clamp(min=1e-10, max=1e10)
        
        # 3. Explicitly mask true label and normalize
        probs[torch.arange(B), labels] = 0
        probs_sum = probs.sum(dim=1, keepdim=True)
        probs = torch.where(probs_sum > 0, probs / probs_sum, 1.0 / (num_classes - 1))
        
        # 4. Sample wrong labels
        wrong_labels = torch.multinomial(probs, num_samples=1).squeeze(1)
        
        return F.one_hot(wrong_labels, num_classes=num_classes).float()
    
    def extra_repr(self):
        """String representation of layer configuration."""
        repr_str = (
            f"in_channels={self.conv.in_channels}, "
            f"out_channels={self.out_channels}, "
            f"kernel_size={self.conv.kernel_size}, "
            f"stride={self.conv.stride}, "
            f"padding={self.conv.padding}, "
            f"dilation={self.conv.dilation}, "
            f"groups={self.conv.groups}, "
            f"label_dimension={self.label_dimension}, "
            f"g_power={self.g_power}, "
            f"alpha={self.alpha}"
        )
        
        if self.class_grouping_manager is not None:
            mode_str = self.class_grouping_manager.mode.value
            repr_str += f", class_groups={self.class_grouping_manager.get_group_info()}"
            repr_str += f", grouping_mode={mode_str}"
        
        return repr_str


# -----------------------------------------------------------------------------#
#                           Dimension-Specific Implementations                #
# -----------------------------------------------------------------------------#
class FleaConv1dWithEncoding(FleaConvWithEncodingBase):
    """1D Flea Convolution with Label Encoding"""
    _NDIMS = 1

    def __init__(self, **kwargs):
        super().__init__(conv_cls=nn.Conv1d, **kwargs)


class FleaBlockWithEncoding(FleaConvWithEncodingBase):
    """2D Flea Convolution with Label Encoding"""
    _NDIMS = 2

    def __init__(self, **kwargs):
        super().__init__(conv_cls=nn.Conv2d, **kwargs)
