import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.larrp_unimodal import AdaptiveRankReducedLinear

class AdaptiveRankReducedAE_NInFEA(torch.nn.Module):
    """Convolutional multimodal autoencoder for NInFEA dataset with adaptive conv layers per modality"""
    
    def __init__(self, input_shapes, latent_dims, depth=2, hidden_dim=512, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=10, conv_depth=3):
        super(AdaptiveRankReducedAE_NInFEA, self).__init__()
        
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_shapes))])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_shapes))])
        self.adaptive_layers = nn.ModuleList()  # Track adaptive rank layers for rank reduction
        
        self.input_shapes = input_shapes  # List of tuples: (seq_length, n_channels) or (total_dim,)
        self.hidden_dim = hidden_dim
        self.conv_depth = conv_depth
        self.conv_configs = []  # Store conv configuration for each modality

        print(f"Creating AdaptiveRankReducedAE_NInFEA for {len(input_shapes)} modalities with:")
        print(f"   input_shapes={input_shapes}, latent_dims={latent_dims}")
        print(f"   depth={depth}, hidden_dim={hidden_dim}, adaptive conv layers (target final_dim≤10000), dropout={dropout}")
        print(f"   initial_rank_ratio: {initial_rank_ratio}, min_rank: {min_rank}")

        # Build convolutional encoders/decoders for each modality
        # Support both 2D: (seq_length, n_channels) and 3D: (height, width, n_channels)
        for m in range(len(input_shapes)):
            input_shape = input_shapes[m]
            
            # Determine if this is 2D or 3D data
            if len(input_shape) == 2:
                # 2D data: (seq_length, n_channels)
                seq_length, n_channels = input_shape
                print(f"Modality {m}: 2D input shape {seq_length} x {n_channels} channels")
                is_2d_conv = False
                
                conv_config = {
                    'original_seq_length': seq_length,
                    'n_channels': n_channels,
                    'conv_layers': [],
                    'is_2d': is_2d_conv,
                    'input_shape': input_shape
                }
            elif len(input_shape) == 3:
                # 3D data: (height, width, n_channels) 
                height, width, n_channels = input_shape
                print(f"Modality {m}: 3D input shape {height} x {width} x {n_channels} channels")
                is_2d_conv = True
                
                conv_config = {
                    'original_height': height,
                    'original_width': width,
                    'n_channels': n_channels,
                    'conv_layers': [],
                    'is_2d': is_2d_conv,
                    'input_shape': input_shape
                }
            else:
                raise ValueError(f"Unsupported input shape {input_shape} for modality {m}")
            
            # Build convolutional encoder - adaptive number of layers based on input size
            current_channels = n_channels
            conv_layer_count = 0
            
            if not is_2d_conv:
                # 1D convolutions for 2D data (seq_length, n_channels)
                current_seq_length = input_shape[0]  # seq_length
                
                # Continue adding conv layers until final dimension is manageable (≤ 10000)
                while True:
                    # Calculate what the dimension would be after this layer
                    out_channels = 1  # Keep all intermediate channels at 1 to minimize data size
                    kernel_size = 3
                    stride = 2
                    padding = 1
                    
                    next_seq_length = (current_seq_length + 2 * padding - kernel_size) // stride + 1
                    next_conv_dim = next_seq_length * out_channels
                    
                    # Stop if we've reached our target size or if we can't reduce further
                    if next_conv_dim <= 10000 or next_seq_length <= 3:
                        break
                    
                    # Add the convolutional layer
                    self.encoders[m].append(nn.Conv1d(current_channels, out_channels, kernel_size, stride, padding))
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                    
                    # Update dimensions
                    current_seq_length = next_seq_length
                    current_channels = out_channels
                    conv_layer_count += 1
                    
                    # Store layer config for decoder reconstruction
                    conv_config['conv_layers'].append({
                        'out_channels': out_channels,
                        'kernel_size': kernel_size,
                        'stride': stride,
                        'padding': padding,
                        'seq_length': current_seq_length
                    })
                
                # Flatten and add final encoder layers
                final_conv_dim = current_seq_length * current_channels
                
            else:
                # 2D convolutions for 3D data (height, width, n_channels)
                current_height = input_shape[0]  # height
                current_width = input_shape[1]   # width
                
                # Continue adding conv layers until final dimension is manageable (≤ 10000)
                while True:
                    # Calculate what the dimension would be after this layer
                    out_channels = 1  # Keep all intermediate channels at 1 to minimize data size
                    kernel_size = 3
                    stride = 2
                    padding = 1
                    
                    next_height = (current_height + 2 * padding - kernel_size) // stride + 1
                    next_width = (current_width + 2 * padding - kernel_size) // stride + 1
                    next_conv_dim = next_height * next_width * out_channels
                    
                    # Stop if we've reached our target size or if we can't reduce further
                    if next_conv_dim <= 10000 or next_height <= 3 or next_width <= 3:
                        break
                    
                    # Add the convolutional layer
                    self.encoders[m].append(nn.Conv2d(current_channels, out_channels, kernel_size, stride, padding))
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                    
                    # Update dimensions
                    current_height = next_height
                    current_width = next_width
                    current_channels = out_channels
                    conv_layer_count += 1
                    
                    # Store layer config for decoder reconstruction
                    conv_config['conv_layers'].append({
                        'out_channels': out_channels,
                        'kernel_size': kernel_size,
                        'stride': stride,
                        'padding': padding,
                        'height': current_height,
                        'width': current_width
                    })
                
                # Flatten and add final encoder layers
                final_conv_dim = current_height * current_width * current_channels
            
            conv_config['actual_conv_depth'] = conv_layer_count
            self.encoders[m].append(nn.Flatten())
            conv_config['final_conv_dim'] = final_conv_dim
            
            if not is_2d_conv:
                print(f"  After {conv_layer_count} 1D conv layers: {current_seq_length} x {current_channels} = {final_conv_dim}")
            else:
                print(f"  After {conv_layer_count} 2D conv layers: {current_height} x {current_width} x {current_channels} = {final_conv_dim}")
            
            self.conv_configs.append(conv_config)
            
            # Add fully connected layers after convolution
            for i in range(depth):
                if i == 0:
                    # First FC layer - conv output to hidden
                    encoder_layer = nn.Linear(final_conv_dim, self.hidden_dim)
                    self.encoders[m].append(encoder_layer)
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                        
                    # First decoder layer - latent to hidden
                    decoder_layer = nn.Linear(latent_dims[m] + latent_dims[-1], self.hidden_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.decoders[m].append(nn.Dropout(dropout))
                        
                elif i == depth - 1:
                    # Final encoder layer - hidden to latent
                    encoder_layer = nn.Linear(self.hidden_dim, latent_dims[m])
                    self.encoders[m].append(encoder_layer)
                    
                    # Final decoder layer - hidden to conv input
                    decoder_layer = nn.Linear(self.hidden_dim, final_conv_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.Sigmoid())
                else:
                    # Middle layers
                    encoder_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
                    self.encoders[m].append(encoder_layer)
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                        
                    decoder_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.decoders[m].append(nn.Dropout(dropout))
            
            # Add convolutional decoder (reverse of encoder)
            # Start by reshaping back to conv feature map
            if not is_2d_conv:
                # 1D case: reshape to (batch, channels, seq_length)
                self.decoders[m].append(nn.Unflatten(1, (current_channels, current_seq_length)))
            else:
                # 2D case: reshape to (batch, channels, height, width)  
                self.decoders[m].append(nn.Unflatten(1, (current_channels, current_height, current_width)))
            
            # Add transposed convolutions in reverse order
            if not is_2d_conv:
                # 1D transposed convolutions
                expected_output_sizes = []
                
                # Calculate what sizes we expect at each layer (going backwards)
                if conv_config['actual_conv_depth'] > 0:
                    # Start from the original size and work backwards to calculate expected sizes
                    temp_size = conv_config['original_seq_length']
                    expected_output_sizes.append(temp_size)
                    
                    for i in range(conv_config['actual_conv_depth'] - 1):
                        # Calculate what the size was before this conv layer
                        kernel_size = conv_config['conv_layers'][i]['kernel_size']
                        stride = conv_config['conv_layers'][i]['stride']
                        padding = conv_config['conv_layers'][i]['padding']
                        temp_size = (temp_size + 2 * padding - kernel_size) // stride + 1
                        expected_output_sizes.append(temp_size)
                    
                    expected_output_sizes.reverse()  # Reverse to match deconv order
            
                for i in reversed(range(conv_config['actual_conv_depth'])):
                    layer_config = conv_config['conv_layers'][i]
                    in_channels = layer_config['out_channels']
                    
                    if i == 0:
                        out_channels = conv_config['n_channels']  # Back to original channels
                        expected_output_size = conv_config['original_seq_length']
                    else:
                        out_channels = conv_config['conv_layers'][i-1]['out_channels']
                        expected_output_size = expected_output_sizes[conv_config['actual_conv_depth'] - 1 - i]
                    
                    kernel_size = layer_config['kernel_size']
                    stride = layer_config['stride']
                    padding = layer_config['padding']
                    
                    # Calculate current input size (from the layer config)
                    current_input_size = layer_config['seq_length']
                    
                    # Calculate output_padding to get exact target size
                    # Formula: output_size = (input_size - 1) * stride - 2 * padding + kernel_size + output_padding
                    # Solving for output_padding: output_padding = expected_output_size - ((input_size - 1) * stride - 2 * padding + kernel_size)
                    calculated_output_size = (current_input_size - 1) * stride - 2 * padding + kernel_size
                    output_padding = max(0, expected_output_size - calculated_output_size)
                    
                    self.decoders[m].append(nn.ConvTranspose1d(
                        in_channels, out_channels, kernel_size, stride, padding, output_padding=output_padding
                    ))
                    if i > 0:  # No activation after final layer
                        self.decoders[m].append(nn.ReLU())
                        if dropout > 0:
                            self.decoders[m].append(nn.Dropout(dropout))
            else:
                # 2D transposed convolutions
                expected_output_sizes = []
                
                # Calculate what sizes we expect at each layer (going backwards)
                if conv_config['actual_conv_depth'] > 0:
                    # Start from the original size and work backwards
                    temp_height = conv_config['original_height']
                    temp_width = conv_config['original_width']
                    expected_output_sizes.append((temp_height, temp_width))
                    
                    for i in range(conv_config['actual_conv_depth'] - 1):
                        # Calculate what the sizes were before this conv layer
                        kernel_size = conv_config['conv_layers'][i]['kernel_size']
                        stride = conv_config['conv_layers'][i]['stride']
                        padding = conv_config['conv_layers'][i]['padding']
                        temp_height = (temp_height + 2 * padding - kernel_size) // stride + 1
                        temp_width = (temp_width + 2 * padding - kernel_size) // stride + 1
                        expected_output_sizes.append((temp_height, temp_width))
                    
                    expected_output_sizes.reverse()  # Reverse to match deconv order
                
                for i in reversed(range(conv_config['actual_conv_depth'])):
                    layer_config = conv_config['conv_layers'][i]
                    in_channels = layer_config['out_channels']
                    
                    if i == 0:
                        out_channels = conv_config['n_channels']  # Back to original channels
                        expected_height = conv_config['original_height']
                        expected_width = conv_config['original_width']
                    else:
                        out_channels = conv_config['conv_layers'][i-1]['out_channels']
                        expected_height, expected_width = expected_output_sizes[conv_config['actual_conv_depth'] - 1 - i]
                    
                    kernel_size = layer_config['kernel_size']
                    stride = layer_config['stride']
                    padding = layer_config['padding']
                    
                    # Calculate current input sizes (from the layer config)
                    current_height = layer_config['height']
                    current_width = layer_config['width']
                    
                    # Calculate output_padding for height and width
                    calc_height = (current_height - 1) * stride - 2 * padding + kernel_size
                    calc_width = (current_width - 1) * stride - 2 * padding + kernel_size
                    output_padding_h = max(0, expected_height - calc_height)
                    output_padding_w = max(0, expected_width - calc_width)
                    output_padding = (output_padding_h, output_padding_w)
                    
                    self.decoders[m].append(nn.ConvTranspose2d(
                        in_channels, out_channels, kernel_size, stride, padding, output_padding=output_padding
                    ))
                    if i > 0:  # No activation after final layer
                        self.decoders[m].append(nn.ReLU())
                        if dropout > 0:
                            self.decoders[m].append(nn.Dropout(dropout))
            
            # Flatten final output appropriately
            if not is_2d_conv:
                # 1D case: flatten to (batch, seq_length * n_channels)
                self.decoders[m].append(nn.Flatten())
            else:
                # 3D case: keep as (batch, height, width, n_channels) - transpose from conv format
                pass  # We'll handle this in the forward pass
        
        # Hierarchical shared and specific latent spaces with adaptive rank reduction
        self.adaptive_layer_map = {}  # Maps layer name to index in self.adaptive_layers
        self._build_hierarchical_adaptive_layers(input_shapes, latent_dims, initial_rank_ratio, min_rank)
    
        # Initialize modality weights for balanced training
        self.modality_weights = nn.Parameter(torch.ones(len(input_shapes)), requires_grad=True)
    
    def _build_hierarchical_adaptive_layers(self, input_shapes, latent_dims, initial_rank_ratio, min_rank):
        """Build hierarchical adaptive layers for multi-modal architecture"""
        import itertools
        
        num_modalities = len(input_shapes)
        layer_index = 0
        
        print(f"Building hierarchical adaptive layers for {num_modalities} modalities:")
        
        # Build layers from global shared down to modality-specific
        subspace_level = num_modalities
        while subspace_level > 0:
            if subspace_level == num_modalities:
                # Global shared layer - combines all modalities
                print(f"  Level {subspace_level}: Creating global shared layer")
                input_dim = sum(latent_dims[:num_modalities])  # Sum of all modality latent dims
                output_dim = latent_dims[-1]  # Global shared dimension
                
                shared_layer = AdaptiveRankReducedLinear(
                    input_dim, output_dim,
                    initial_rank_ratio=initial_rank_ratio,
                    min_rank=min_rank
                )
                self.adaptive_layers.append(shared_layer)
                self.adaptive_layer_map['global_shared'] = layer_index
                layer_index += 1
                
            elif subspace_level > 1:
                # Shared subspaces - combinations of modalities at this level
                print(f"  Level {subspace_level}: Creating shared subspaces")
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    print(f"    Creating subspace: {combo_name}")
                    
                    # Input dimension is sum of latent dims for modalities in this combination
                    input_dim = sum(latent_dims[i] for i in combo)
                    # Output dimension is the minimum of the constituent latent dims
                    output_dim = min(latent_dims[i] for i in combo)
                    
                    subspace_layer = AdaptiveRankReducedLinear(
                        input_dim, output_dim,
                        initial_rank_ratio=initial_rank_ratio,
                        min_rank=min_rank
                    )
                    self.adaptive_layers.append(subspace_layer)
                    self.adaptive_layer_map[combo_name] = layer_index
                    layer_index += 1
            
            else:  # subspace_level == 1
                # Modality-specific layers
                print(f"  Level {subspace_level}: Creating modality-specific layers")
                for i in range(num_modalities):
                    layer_name = f"specific_{i}"
                    print(f"    Creating layer: {layer_name}")
                    
                    specific_layer = AdaptiveRankReducedLinear(
                        latent_dims[i], latent_dims[i],
                        initial_rank_ratio=initial_rank_ratio,
                        min_rank=min_rank
                    )
                    self.adaptive_layers.append(specific_layer)
                    self.adaptive_layer_map[layer_name] = layer_index
                    layer_index += 1
            
            subspace_level -= 1
        
        print(f"Created {len(self.adaptive_layers)} adaptive layers:")
        for name, idx in self.adaptive_layer_map.items():
            layer = self.adaptive_layers[idx]
            print(f"  {name}: {layer.in_features} -> {layer.out_features} (index {idx})")
            
        return layer_index
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)

            # Find the rank that preserves specified energy threshold
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers by specified increment"""
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def encode(self, x, compute_jacobian=False):
        h_concat = []
        for m, x_m in enumerate(x):
            conv_config = self.conv_configs[m]
            
            if not conv_config['is_2d']:
                # 2D input: (batch, seq_length, channels) -> (batch, channels, seq_length)
                x_m = x_m.transpose(1, 2)
            else:
                # 3D input: (batch, height, width, channels) -> (batch, channels, height, width)
                x_m = x_m.permute(0, 3, 1, 2)
            
            for layer in self.encoders[m]:
                x_m = layer(x_m)
            h_concat.append(x_m)
        
        # Apply hierarchical adaptive layers
        h_encoded = {}
        
        # Global shared layer
        h_all = torch.cat(h_concat, dim=1)
        global_idx = self.adaptive_layer_map['global_shared']
        h_encoded['global_shared'] = self.adaptive_layers[global_idx](h_all)
        
        # Shared subspace layers (combinations of modalities)
        import itertools
        num_modalities = len(h_concat)
        
        for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
            for combo in itertools.combinations(range(num_modalities), subspace_level):
                combo_name = f"shared_{'_'.join(map(str, combo))}"
                if combo_name in self.adaptive_layer_map:
                    # Concatenate the specific modalities for this combination
                    h_combo = torch.cat([h_concat[i] for i in combo], dim=1)
                    combo_idx = self.adaptive_layer_map[combo_name]
                    h_encoded[combo_name] = self.adaptive_layers[combo_idx](h_combo)
        
        # Modality-specific layers
        h_specific = []
        for i in range(num_modalities):
            specific_name = f'specific_{i}'
            if specific_name in self.adaptive_layer_map:
                specific_idx = self.adaptive_layer_map[specific_name]
                h_specific.append(self.adaptive_layers[specific_idx](h_concat[i]))
            else:
                h_specific.append(h_concat[i])  # Fallback if layer doesn't exist
        
        h_encoded['specific'] = h_specific
        
        return h_encoded

    def decode(self, h_encoded):
        """Decode from hierarchical encoded representations"""
        h_global = h_encoded['global_shared']
        h_specific = h_encoded['specific']
        
        x_hat = []
        for m, h_m in enumerate(h_specific):
            # Collect all relevant hierarchical representations for this modality
            h_components = [h_global, h_m]  # Start with global shared and modality-specific
            
            # Add shared subspace representations that include this modality
            import itertools
            num_modalities = len(h_specific)
            
            for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    if combo_name in h_encoded and m in combo:
                        h_components.append(h_encoded[combo_name])
            
            # Separate global and modality-specific components for proper averaging
            modality_components = [h_m]  # Start with modality-specific
            
            # Add shared subspace representations that include this modality (excluding global)
            for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    if combo_name in h_encoded and m in combo:
                        modality_components.append(h_encoded[combo_name])
            
            # Average pool modality-specific and shared subspace representations
            if len(modality_components) > 1:
                # Stack and average modality-specific components
                h_stacked = torch.stack(modality_components, dim=0)  # (num_components, batch_size, latent_dim)
                h_modality_combined = torch.mean(h_stacked, dim=0)   # (batch_size, latent_dim)
            else:
                h_modality_combined = modality_components[0]
            
            # Combine global shared and averaged modality representations for decoder input
            # This maintains the expected input dimensions: latent_dims[m] + latent_dims[-1]
            h_concat = torch.cat([h_global, h_modality_combined], dim=1)
            
            for layer in self.decoders[m]:
                h_concat = layer(h_concat)
            
            # Reshape back to original format based on modality type
            conv_config = self.conv_configs[m]
            
            if not conv_config['is_2d']:
                # 2D modality: handle 1D conv output
                if len(h_concat.shape) == 2:  # Flattened output: (batch, total_features)
                    original_seq_length = conv_config['original_seq_length']
                    n_channels = conv_config['n_channels']
                    
                    # Calculate the actual output size after conv/deconv operations
                    total_features = h_concat.shape[1]  # Total flattened features
                    expected_features = original_seq_length * n_channels
                    
                    if total_features == expected_features:
                        # Perfect match - reshape to original dimensions
                        h_concat = h_concat.view(h_concat.shape[0], original_seq_length, n_channels)
                    else:
                        # Size mismatch due to conv operations - try to infer correct dimensions
                        actual_seq_length = total_features // n_channels
                        if total_features % n_channels == 0:
                            h_concat = h_concat.view(h_concat.shape[0], actual_seq_length, n_channels)
                        else:
                            # If not divisible by n_channels, use original seq_length and adjust channels
                            actual_channels = total_features // original_seq_length
                            if total_features % original_seq_length == 0:
                                h_concat = h_concat.view(h_concat.shape[0], original_seq_length, actual_channels)
                            else:
                                # Last resort: use original dimensions and truncate/pad as needed
                                h_concat = h_concat.view(h_concat.shape[0], -1, n_channels)[:, :original_seq_length, :]
                                
                elif len(h_concat.shape) == 3:  # Conv output: (batch, channels, seq_length)
                    # Transpose from (batch, channels, seq_length) to (batch, seq_length, channels)
                    h_concat = h_concat.transpose(1, 2)
            else:
                # 3D modality: handle 2D conv output
                if len(h_concat.shape) == 4:  # Conv output: (batch, channels, height, width)
                    # Transpose from (batch, channels, height, width) to (batch, height, width, channels)
                    h_concat = h_concat.permute(0, 2, 3, 1)
                elif len(h_concat.shape) == 2:  # Flattened output (shouldn't happen for 3D but handle it)
                    original_height = conv_config['original_height']
                    original_width = conv_config['original_width']
                    n_channels = conv_config['n_channels']
                    
                    total_features = h_concat.shape[1]
                    expected_features = original_height * original_width * n_channels
                    
                    if total_features == expected_features:
                        h_concat = h_concat.view(h_concat.shape[0], original_height, original_width, n_channels)
                    else:
                        # Try to infer dimensions
                        import math
                        sqrt_spatial = int(math.sqrt(total_features // n_channels))
                        if sqrt_spatial * sqrt_spatial * n_channels == total_features:
                            h_concat = h_concat.view(h_concat.shape[0], sqrt_spatial, sqrt_spatial, n_channels)
                        else:
                            # Fallback to original shape and truncate/pad as needed
                            h_concat = h_concat.view(h_concat.shape[0], -1)
                            padding_size = max(0, expected_features - total_features)
                            if padding_size > 0:
                                padding = torch.zeros(h_concat.shape[0], padding_size, device=h_concat.device)
                                h_concat = torch.cat([h_concat, padding], dim=1)
                            h_concat = h_concat[:, :expected_features].view(h_concat.shape[0], original_height, original_width, n_channels)
                
            x_hat.append(h_concat)
        return x_hat

    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h
    
    def encode_modalities(self, x):
        """Encode each modality with its combined representations"""
        h_encoded = self.encode(x)
        h_global = h_encoded['global_shared']
        h_specific = h_encoded['specific']
        
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_global, h_m], dim=1))
        return h_combined

class AdaptiveRankReducedAE_NInFEA_2Mods(torch.nn.Module):
    """Convolutional multimodal autoencoder for NInFEA dataset with adaptive conv layers per modality"""
    
    def __init__(self, input_shapes, latent_dims, depth=2, hidden_dim=512, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=10, conv_depth=3):
        super(AdaptiveRankReducedAE_NInFEA_2Mods, self).__init__()
        
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_shapes))])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_shapes))])
        self.adaptive_layers = nn.ModuleList()  # Track adaptive rank layers for rank reduction
        if isinstance(latent_dims, int):
            self.latent_dim = latent_dims
        elif isinstance(latent_dims, (list, tuple)):
            self.latent_dim = latent_dims[-1]  # Global shared latent dim
        else:
            raise ValueError("latent_dims must be int or list/tuple of ints")
        
        self.input_shapes = input_shapes  # List of tuples: (seq_length, n_channels) or (total_dim,)
        self.hidden_dim = hidden_dim
        self.conv_depth = conv_depth
        self.conv_configs = []  # Store conv configuration for each modality

        print(f"Creating AdaptiveRankReducedAE_NInFEA for {len(input_shapes)} modalities with:")
        print(f"   input_shapes={input_shapes}, latent_dims={latent_dims}")
        print(f"   depth={depth}, hidden_dim={hidden_dim}, adaptive conv layers (target final_dim≤10000), dropout={dropout}")
        print(f"   initial_rank_ratio: {initial_rank_ratio}, min_rank: {min_rank}")

        # Build convolutional encoders/decoders for each modality
        # Support both 2D: (seq_length, n_channels) and 3D: (height, width, n_channels)
        for m in range(len(input_shapes)):
            input_shape = input_shapes[m]
            
            # Determine if this is 2D or 3D data
            if len(input_shape) == 2:
                # 2D data: (seq_length, n_channels)
                seq_length, n_channels = input_shape
                print(f"Modality {m}: 2D input shape {seq_length} x {n_channels} channels")
                is_2d_conv = False
                
                conv_config = {
                    'original_seq_length': seq_length,
                    'n_channels': n_channels,
                    'conv_layers': [],
                    'is_2d': is_2d_conv,
                    'input_shape': input_shape
                }
            elif len(input_shape) == 3:
                # 3D data: (height, width, n_channels) 
                height, width, n_channels = input_shape
                print(f"Modality {m}: 3D input shape {height} x {width} x {n_channels} channels")
                is_2d_conv = True
                
                conv_config = {
                    'original_height': height,
                    'original_width': width,
                    'n_channels': n_channels,
                    'conv_layers': [],
                    'is_2d': is_2d_conv,
                    'input_shape': input_shape
                }
            else:
                raise ValueError(f"Unsupported input shape {input_shape} for modality {m}")
            
            # Build convolutional encoder - adaptive number of layers based on input size
            current_channels = n_channels
            conv_layer_count = 0
            
            if not is_2d_conv:
                # 1D convolutions for 2D data (seq_length, n_channels)
                current_seq_length = input_shape[0]  # seq_length
                
                # Continue adding conv layers until final dimension is manageable (≤ 10000)
                while True:
                    # Calculate what the dimension would be after this layer
                    out_channels = 1  # Keep all intermediate channels at 1 to minimize data size
                    kernel_size = 3
                    stride = 2
                    padding = 1
                    
                    next_seq_length = (current_seq_length + 2 * padding - kernel_size) // stride + 1
                    next_conv_dim = next_seq_length * out_channels

                    # Stop if we've reached our target size or if we can't reduce further
                    if next_conv_dim <= 10000 or next_seq_length <= 3:
                        break
                    
                    # Add the convolutional layer
                    self.encoders[m].append(nn.Conv1d(current_channels, out_channels, kernel_size, stride, padding))
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                    
                    # Update dimensions
                    current_seq_length = next_seq_length
                    current_channels = out_channels
                    conv_layer_count += 1
                    
                    # Store layer config for decoder reconstruction
                    conv_config['conv_layers'].append({
                        'out_channels': out_channels,
                        'kernel_size': kernel_size,
                        'stride': stride,
                        'padding': padding,
                        'seq_length': current_seq_length
                    })
                
                # Flatten and add final encoder layers
                final_conv_dim = current_seq_length * current_channels
                
            else:
                # 2D convolutions for 3D data (height, width, n_channels)
                current_height = input_shape[0]  # height
                current_width = input_shape[1]   # width
                
                # Continue adding conv layers until final dimension is manageable (≤ 10000)
                while True:
                    # Calculate what the dimension would be after this layer
                    out_channels = 1  # Keep all intermediate channels at 1 to minimize data size
                    kernel_size = 3
                    stride = 2
                    padding = 1
                    
                    next_height = (current_height + 2 * padding - kernel_size) // stride + 1
                    next_width = (current_width + 2 * padding - kernel_size) // stride + 1
                    next_conv_dim = next_height * next_width * out_channels

                    # Stop if we've reached our target size or if we can't reduce further
                    if next_conv_dim <= 100000 or next_height <= 3 or next_width <= 3:
                        break
                    
                    # Add the convolutional layer
                    self.encoders[m].append(nn.Conv2d(current_channels, out_channels, kernel_size, stride, padding))
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                    
                    # Update dimensions
                    current_height = next_height
                    current_width = next_width
                    current_channels = out_channels
                    conv_layer_count += 1
                    
                    # Store layer config for decoder reconstruction
                    conv_config['conv_layers'].append({
                        'out_channels': out_channels,
                        'kernel_size': kernel_size,
                        'stride': stride,
                        'padding': padding,
                        'height': current_height,
                        'width': current_width
                    })
                
                # Flatten and add final encoder layers
                final_conv_dim = current_height * current_width * current_channels
            
            conv_config['actual_conv_depth'] = conv_layer_count
            self.encoders[m].append(nn.Flatten())
            conv_config['final_conv_dim'] = final_conv_dim
            
            if not is_2d_conv:
                print(f"  After {conv_layer_count} 1D conv layers: {current_seq_length} x {current_channels} = {final_conv_dim}")
            else:
                print(f"  After {conv_layer_count} 2D conv layers: {current_height} x {current_width} x {current_channels} = {final_conv_dim}")
            
            self.conv_configs.append(conv_config)
            
            # Add fully connected layers after convolution
            for i in range(depth):
                if i == 0:
                    # First FC layer - conv output to hidden
                    encoder_layer = nn.Linear(final_conv_dim, self.hidden_dim)
                    self.encoders[m].append(encoder_layer)
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                        
                    # First decoder layer - latent to hidden
                    decoder_layer = nn.Linear(2*self.latent_dim, self.hidden_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.decoders[m].append(nn.Dropout(dropout))
                        
                elif i == depth - 1:
                    # Final encoder layer - hidden to latent
                    encoder_layer = nn.Linear(self.hidden_dim, latent_dims[m])
                    self.encoders[m].append(encoder_layer)
                    
                    # Final decoder layer - hidden to conv input
                    decoder_layer = nn.Linear(self.hidden_dim, final_conv_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.Sigmoid())
                else:
                    # Middle layers
                    encoder_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
                    self.encoders[m].append(encoder_layer)
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                        
                    decoder_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.decoders[m].append(nn.Dropout(dropout))
            
            # Add convolutional decoder (reverse of encoder)
            # Start by reshaping back to conv feature map
            if not is_2d_conv:
                # 1D case: reshape to (batch, channels, seq_length)
                self.decoders[m].append(nn.Unflatten(1, (current_channels, current_seq_length)))
            else:
                # 2D case: reshape to (batch, channels, height, width)  
                self.decoders[m].append(nn.Unflatten(1, (current_channels, current_height, current_width)))
            
            # Add transposed convolutions in reverse order
            if not is_2d_conv:
                # 1D transposed convolutions
                expected_output_sizes = []
                
                # Calculate what sizes we expect at each layer (going backwards)
                if conv_config['actual_conv_depth'] > 0:
                    # Start from the original size and work backwards to calculate expected sizes
                    temp_size = conv_config['original_seq_length']
                    expected_output_sizes.append(temp_size)
                    
                    for i in range(conv_config['actual_conv_depth'] - 1):
                        # Calculate what the size was before this conv layer
                        kernel_size = conv_config['conv_layers'][i]['kernel_size']
                        stride = conv_config['conv_layers'][i]['stride']
                        padding = conv_config['conv_layers'][i]['padding']
                        temp_size = (temp_size + 2 * padding - kernel_size) // stride + 1
                        expected_output_sizes.append(temp_size)
                    
                    expected_output_sizes.reverse()  # Reverse to match deconv order
            
                for i in reversed(range(conv_config['actual_conv_depth'])):
                    layer_config = conv_config['conv_layers'][i]
                    in_channels = layer_config['out_channels']
                    
                    if i == 0:
                        out_channels = conv_config['n_channels']  # Back to original channels
                        expected_output_size = conv_config['original_seq_length']
                    else:
                        out_channels = conv_config['conv_layers'][i-1]['out_channels']
                        expected_output_size = expected_output_sizes[conv_config['actual_conv_depth'] - 1 - i]
                    
                    kernel_size = layer_config['kernel_size']
                    stride = layer_config['stride']
                    padding = layer_config['padding']
                    
                    # Calculate current input size (from the layer config)
                    current_input_size = layer_config['seq_length']
                    
                    # Calculate output_padding to get exact target size
                    # Formula: output_size = (input_size - 1) * stride - 2 * padding + kernel_size + output_padding
                    # Solving for output_padding: output_padding = expected_output_size - ((input_size - 1) * stride - 2 * padding + kernel_size)
                    calculated_output_size = (current_input_size - 1) * stride - 2 * padding + kernel_size
                    output_padding = max(0, expected_output_size - calculated_output_size)
                    
                    self.decoders[m].append(nn.ConvTranspose1d(
                        in_channels, out_channels, kernel_size, stride, padding, output_padding=output_padding
                    ))
                    if i > 0:  # No activation after final layer
                        self.decoders[m].append(nn.ReLU())
                        if dropout > 0:
                            self.decoders[m].append(nn.Dropout(dropout))
            else:
                # 2D transposed convolutions
                expected_output_sizes = []
                
                # Calculate what sizes we expect at each layer (going backwards)
                if conv_config['actual_conv_depth'] > 0:
                    # Start from the original size and work backwards
                    temp_height = conv_config['original_height']
                    temp_width = conv_config['original_width']
                    expected_output_sizes.append((temp_height, temp_width))
                    
                    for i in range(conv_config['actual_conv_depth'] - 1):
                        # Calculate what the sizes were before this conv layer
                        kernel_size = conv_config['conv_layers'][i]['kernel_size']
                        stride = conv_config['conv_layers'][i]['stride']
                        padding = conv_config['conv_layers'][i]['padding']
                        temp_height = (temp_height + 2 * padding - kernel_size) // stride + 1
                        temp_width = (temp_width + 2 * padding - kernel_size) // stride + 1
                        expected_output_sizes.append((temp_height, temp_width))
                    
                    expected_output_sizes.reverse()  # Reverse to match deconv order
                
                for i in reversed(range(conv_config['actual_conv_depth'])):
                    layer_config = conv_config['conv_layers'][i]
                    in_channels = layer_config['out_channels']
                    
                    if i == 0:
                        out_channels = conv_config['n_channels']  # Back to original channels
                        expected_height = conv_config['original_height']
                        expected_width = conv_config['original_width']
                    else:
                        out_channels = conv_config['conv_layers'][i-1]['out_channels']
                        expected_height, expected_width = expected_output_sizes[conv_config['actual_conv_depth'] - 1 - i]
                    
                    kernel_size = layer_config['kernel_size']
                    stride = layer_config['stride']
                    padding = layer_config['padding']
                    
                    # Calculate current input sizes (from the layer config)
                    current_height = layer_config['height']
                    current_width = layer_config['width']
                    
                    # Calculate output_padding for height and width
                    calc_height = (current_height - 1) * stride - 2 * padding + kernel_size
                    calc_width = (current_width - 1) * stride - 2 * padding + kernel_size
                    output_padding_h = max(0, expected_height - calc_height)
                    output_padding_w = max(0, expected_width - calc_width)
                    output_padding = (output_padding_h, output_padding_w)
                    
                    self.decoders[m].append(nn.ConvTranspose2d(
                        in_channels, out_channels, kernel_size, stride, padding, output_padding=output_padding
                    ))
                    if i > 0:  # No activation after final layer
                        self.decoders[m].append(nn.ReLU())
                        if dropout > 0:
                            self.decoders[m].append(nn.Dropout(dropout))
            
            # Flatten final output appropriately
            if not is_2d_conv:
                # 1D case: flatten to (batch, seq_length * n_channels)
                self.decoders[m].append(nn.Flatten())
            else:
                # 3D case: keep as (batch, height, width, n_channels) - transpose from conv format
                pass  # We'll handle this in the forward pass
        
        # Hierarchical shared and specific latent spaces with adaptive rank reduction
        self.adaptive_layer_map = {}  # Maps layer name to index in self.adaptive_layers
        self._build_hierarchical_adaptive_layers(input_shapes, latent_dims, initial_rank_ratio, min_rank)
    
        # Initialize modality weights for balanced training
        self.modality_weights = nn.Parameter(torch.ones(len(input_shapes)), requires_grad=True)
    
    def _build_hierarchical_adaptive_layers(self, input_shapes, latent_dims, initial_rank_ratio, min_rank):
        """Build hierarchical adaptive layers for multi-modal architecture"""
        import itertools
        
        num_modalities = len(input_shapes)
        layer_index = 0
        
        print(f"Building hierarchical adaptive layers for {num_modalities} modalities:")
        
        # Build layers from global shared down to modality-specific
        subspace_level = num_modalities
        while subspace_level > 0:
            if subspace_level == num_modalities:
                # Global shared layer - combines all modalities
                print(f"  Level {subspace_level}: Creating global shared layer")
                input_dim = sum(latent_dims[:num_modalities])  # Sum of all modality latent dims
                
                shared_layer = AdaptiveRankReducedLinear(
                    input_dim, self.latent_dim,
                    initial_rank_ratio=initial_rank_ratio,
                    min_rank=min_rank
                )
                self.adaptive_layers.append(shared_layer)
                self.adaptive_layer_map['global_shared'] = layer_index
                layer_index += 1
                
            elif subspace_level > 1:
                # Shared subspaces - combinations of modalities at this level
                print(f"  Level {subspace_level}: Creating shared subspaces")
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    print(f"    Creating subspace: {combo_name}")
                    
                    subspace_layer = AdaptiveRankReducedLinear(
                        input_dim, self.latent_dim,
                        initial_rank_ratio=initial_rank_ratio,
                        min_rank=min_rank
                    )
                    self.adaptive_layers.append(subspace_layer)
                    self.adaptive_layer_map[combo_name] = layer_index
                    layer_index += 1
            
            else:  # subspace_level == 1
                # Modality-specific layers
                print(f"  Level {subspace_level}: Creating modality-specific layers")
                for i in range(num_modalities):
                    layer_name = f"specific_{i}"
                    print(f"    Creating layer: {layer_name}")
                    
                    specific_layer = AdaptiveRankReducedLinear(
                        latent_dims[i], self.latent_dim,
                        initial_rank_ratio=initial_rank_ratio,
                        min_rank=min_rank
                    )
                    self.adaptive_layers.append(specific_layer)
                    self.adaptive_layer_map[layer_name] = layer_index
                    layer_index += 1
            
            subspace_level -= 1
        
        print(f"Created {len(self.adaptive_layers)} adaptive layers:")
        for name, idx in self.adaptive_layer_map.items():
            layer = self.adaptive_layers[idx]
            print(f"  {name}: {layer.in_features} -> {layer.out_features} (index {idx})")
            
        return layer_index
    
    def reduce_rank(self, reduction_ratio=0.2, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)

            # Find the rank that preserves specified energy threshold
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            new_rank = max(target_rank, ratio_rank)
            #new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers by specified increment"""
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def encode(self, x, compute_jacobian=False):
        h_concat = []
        for m, x_m in enumerate(x):
            conv_config = self.conv_configs[m]
            
            if not conv_config['is_2d']:
                # 2D input: (batch, seq_length, channels) -> (batch, channels, seq_length)
                x_m = x_m.transpose(1, 2)
            else:
                # 3D input: (batch, height, width, channels) -> (batch, channels, height, width)
                x_m = x_m.permute(0, 3, 1, 2)
            
            for layer in self.encoders[m]:
                x_m = layer(x_m)
            h_concat.append(x_m)
        
        # Apply hierarchical adaptive layers
        h_encoded = {}
        
        # Global shared layer
        h_all = torch.cat(h_concat, dim=1)
        global_idx = self.adaptive_layer_map['global_shared']
        h_encoded['global_shared'] = self.adaptive_layers[global_idx](h_all)
        
        # Shared subspace layers (combinations of modalities)
        import itertools
        num_modalities = len(h_concat)
        
        for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
            for combo in itertools.combinations(range(num_modalities), subspace_level):
                combo_name = f"shared_{'_'.join(map(str, combo))}"
                if combo_name in self.adaptive_layer_map:
                    # Concatenate the specific modalities for this combination
                    h_combo = torch.cat([h_concat[i] for i in combo], dim=1)
                    combo_idx = self.adaptive_layer_map[combo_name]
                    h_encoded[combo_name] = self.adaptive_layers[combo_idx](h_combo)
        
        # Modality-specific layers
        h_specific = []
        for i in range(num_modalities):
            specific_name = f'specific_{i}'
            if specific_name in self.adaptive_layer_map:
                specific_idx = self.adaptive_layer_map[specific_name]
                h_specific.append(self.adaptive_layers[specific_idx](h_concat[i]))
            else:
                h_specific.append(h_concat[i])  # Fallback if layer doesn't exist
        
        h_encoded['specific'] = h_specific
        
        return h_encoded

    def decode(self, h_encoded):
        """Decode from hierarchical encoded representations"""
        h_global = h_encoded['global_shared']
        h_specific = h_encoded['specific']
        
        x_hat = []
        for m, h_m in enumerate(h_specific):
            # Collect all relevant hierarchical representations for this modality
            h_components = [h_global, h_m]  # Start with global shared and modality-specific
            
            # Add shared subspace representations that include this modality
            import itertools
            num_modalities = len(h_specific)
            
            for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    if combo_name in h_encoded and m in combo:
                        h_components.append(h_encoded[combo_name])
            
            # Separate global and modality-specific components for proper averaging
            modality_components = [h_m]  # Start with modality-specific
            
            # Add shared subspace representations that include this modality (excluding global)
            for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    if combo_name in h_encoded and m in combo:
                        modality_components.append(h_encoded[combo_name])
            
            # Average pool modality-specific and shared subspace representations
            if len(modality_components) > 1:
                # Stack and average modality-specific components
                h_modality_combined = torch.stack(modality_components, dim=1)  # (num_components, batch_size, latent_dim)
                #h_modality_combined = torch.mean(h_stacked, dim=0)   # (batch_size, latent_dim)
            else:
                h_modality_combined = modality_components[0]
            
            # Combine global shared and averaged modality representations for decoder input
            # This maintains the expected input dimensions: latent_dims[m] + latent_dims[-1]
            h_concat = torch.cat([h_global, h_modality_combined], dim=1)
            
            for layer in self.decoders[m]:
                h_concat = layer(h_concat)
            
            # Reshape back to original format based on modality type
            conv_config = self.conv_configs[m]
            
            if not conv_config['is_2d']:
                # 2D modality: handle 1D conv output
                if len(h_concat.shape) == 2:  # Flattened output: (batch, total_features)
                    original_seq_length = conv_config['original_seq_length']
                    n_channels = conv_config['n_channels']
                    
                    # Calculate the actual output size after conv/deconv operations
                    total_features = h_concat.shape[1]  # Total flattened features
                    expected_features = original_seq_length * n_channels
                    
                    if total_features == expected_features:
                        # Perfect match - reshape to original dimensions
                        h_concat = h_concat.view(h_concat.shape[0], original_seq_length, n_channels)
                    else:
                        # Size mismatch due to conv operations - try to infer correct dimensions
                        actual_seq_length = total_features // n_channels
                        if total_features % n_channels == 0:
                            h_concat = h_concat.view(h_concat.shape[0], actual_seq_length, n_channels)
                        else:
                            # If not divisible by n_channels, use original seq_length and adjust channels
                            actual_channels = total_features // original_seq_length
                            if total_features % original_seq_length == 0:
                                h_concat = h_concat.view(h_concat.shape[0], original_seq_length, actual_channels)
                            else:
                                # Last resort: use original dimensions and truncate/pad as needed
                                h_concat = h_concat.view(h_concat.shape[0], -1, n_channels)[:, :original_seq_length, :]
                                
                elif len(h_concat.shape) == 3:  # Conv output: (batch, channels, seq_length)
                    # Transpose from (batch, channels, seq_length) to (batch, seq_length, channels)
                    h_concat = h_concat.transpose(1, 2)
            else:
                # 3D modality: handle 2D conv output
                if len(h_concat.shape) == 4:  # Conv output: (batch, channels, height, width)
                    # Transpose from (batch, channels, height, width) to (batch, height, width, channels)
                    h_concat = h_concat.permute(0, 2, 3, 1)
                elif len(h_concat.shape) == 2:  # Flattened output (shouldn't happen for 3D but handle it)
                    original_height = conv_config['original_height']
                    original_width = conv_config['original_width']
                    n_channels = conv_config['n_channels']
                    
                    total_features = h_concat.shape[1]
                    expected_features = original_height * original_width * n_channels
                    
                    if total_features == expected_features:
                        h_concat = h_concat.view(h_concat.shape[0], original_height, original_width, n_channels)
                    else:
                        # Try to infer dimensions
                        import math
                        sqrt_spatial = int(math.sqrt(total_features // n_channels))
                        if sqrt_spatial * sqrt_spatial * n_channels == total_features:
                            h_concat = h_concat.view(h_concat.shape[0], sqrt_spatial, sqrt_spatial, n_channels)
                        else:
                            # Fallback to original shape and truncate/pad as needed
                            h_concat = h_concat.view(h_concat.shape[0], -1)
                            padding_size = max(0, expected_features - total_features)
                            if padding_size > 0:
                                padding = torch.zeros(h_concat.shape[0], padding_size, device=h_concat.device)
                                h_concat = torch.cat([h_concat, padding], dim=1)
                            h_concat = h_concat[:, :expected_features].view(h_concat.shape[0], original_height, original_width, n_channels)
                
            x_hat.append(h_concat)
        return x_hat

    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h
    
    def encode_modalities(self, x):
        """Encode each modality with its combined representations"""
        h_encoded = self.encode(x)
        h_global = h_encoded['global_shared']
        h_specific = h_encoded['specific']
        
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_global, h_m], dim=1))
        return h_combined

class AdaptiveRankReducedAE_MM(torch.nn.Module):
    """Convolutional multimodal autoencoder for NInFEA dataset with adaptive conv layers per modality"""
    
    def __init__(self, input_shapes, latent_dims, depth=2, hidden_dim=512, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=10, conv_depth=3):
        super(AdaptiveRankReducedAE_MM, self).__init__()
        
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_shapes))])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_shapes))])
        self.adaptive_layers = nn.ModuleList()  # Track adaptive rank layers for rank reduction
        
        self.input_shapes = input_shapes  # List of tuples: (seq_length, n_channels) or (total_dim,)
        self.hidden_dim = hidden_dim
        self.latent_dim = max(latent_dims)  # Common latent dimension across modalities
        self.conv_depth = conv_depth
        self.conv_configs = []  # Store conv configuration for each modality

        print(f"Creating AdaptiveRankReducedAE_NInFEA for {len(input_shapes)} modalities with:")
        print(f"   input_shapes={input_shapes}, latent_dims={latent_dims}")
        print(f"   depth={depth}, hidden_dim={hidden_dim}, adaptive conv layers (target final_dim≤10000), dropout={dropout}")
        print(f"   initial_rank_ratio: {initial_rank_ratio}, min_rank: {min_rank}")

        n_latents_per_mod = 1
        n_mods = len(input_shapes)
        while n_mods > 1:
            n_latents_per_mod += (n_mods - 1)
            n_mods -= 1
        self.n_latents_per_mod = n_latents_per_mod
        print(f"   Total adaptive layers: {n_latents_per_mod} per modality")

        # Build convolutional encoders/decoders for each modality
        # Support both 2D: (seq_length, n_channels) and 3D: (height, width, n_channels)
        for m in range(len(input_shapes)):
            input_shape = input_shapes[m]
            
            # Determine if this is 2D or 3D data
            if len(input_shape) == 2:
                # 2D data: (seq_length, n_channels)
                seq_length, n_channels = input_shape
                print(f"Modality {m}: 2D input shape {seq_length} x {n_channels} channels")
                is_2d_conv = False
                is_conv = True
                
                conv_config = {
                    'original_seq_length': seq_length,
                    'n_channels': n_channels,
                    'conv_layers': [],
                    'is_2d': is_2d_conv,
                    'input_shape': input_shape
                }
            elif len(input_shape) == 3:
                # 3D data: (height, width, n_channels) 
                height, width, n_channels = input_shape
                print(f"Modality {m}: 3D input shape {height} x {width} x {n_channels} channels")
                is_2d_conv = True
                is_conv = True
                
                conv_config = {
                    'original_height': height,
                    'original_width': width,
                    'n_channels': n_channels,
                    'conv_layers': [],
                    'is_2d': is_2d_conv,
                    'input_shape': input_shape
                }
            else:
                print(f"normal 1D input shape {input_shape[0]} features")
                conv_config = None
                is_conv = False
                self.conv_configs = None
            
            if is_conv:
                # Build convolutional encoder - adaptive number of layers based on input size
                current_channels = n_channels
                conv_layer_count = 0
                if not is_2d_conv:
                    # 1D convolutions for 2D data (seq_length, n_channels)
                    current_seq_length = input_shape[0]  # seq_length
                    
                    # Continue adding conv layers until final dimension is manageable (≤ 10000)
                    while True:
                        # Calculate what the dimension would be after this layer
                        out_channels = 1  # Keep all intermediate channels at 1 to minimize data size
                        kernel_size = 3
                        stride = 2
                        padding = 1
                        
                        next_seq_length = (current_seq_length + 2 * padding - kernel_size) // stride + 1
                        next_conv_dim = next_seq_length * out_channels
                        
                        # Stop if we've reached our target size or if we can't reduce further
                        if next_conv_dim <= 10000 or next_seq_length <= 3:
                            break
                        
                        # Add the convolutional layer
                        self.encoders[m].append(nn.Conv1d(current_channels, out_channels, kernel_size, stride, padding))
                        self.encoders[m].append(nn.ReLU())
                        if dropout > 0:
                            self.encoders[m].append(nn.Dropout(dropout))
                        
                        # Update dimensions
                        current_seq_length = next_seq_length
                        current_channels = out_channels
                        conv_layer_count += 1
                        
                        # Store layer config for decoder reconstruction
                        conv_config['conv_layers'].append({
                            'out_channels': out_channels,
                            'kernel_size': kernel_size,
                            'stride': stride,
                            'padding': padding,
                            'seq_length': current_seq_length
                        })
                    
                    # Flatten and add final encoder layers
                    final_conv_dim = current_seq_length * current_channels
                    
                else:
                    # 2D convolutions for 3D data (height, width, n_channels)
                    current_height = input_shape[0]  # height
                    current_width = input_shape[1]   # width
                    
                    # Continue adding conv layers until final dimension is manageable (≤ 10000)
                    while True:
                        # Calculate what the dimension would be after this layer
                        #out_channels = 1  # Keep all intermediate channels at 1 to minimize data size
                        out_channels = 4  # Keep all intermediate channels at 4 to minimize data size
                        kernel_size = 3
                        stride = 2
                        padding = 1
                        
                        next_height = (current_height + 2 * padding - kernel_size) // stride + 1
                        next_width = (current_width + 2 * padding - kernel_size) // stride + 1
                        next_conv_dim = next_height * next_width * out_channels
                        
                        # Stop if we've reached our target size or if we can't reduce further
                        if next_conv_dim <= 100000 or next_height <= 3 or next_width <= 3:
                            break
                        
                        # Add the convolutional layer
                        self.encoders[m].append(nn.Conv2d(current_channels, out_channels, kernel_size, stride, padding))
                        self.encoders[m].append(nn.ReLU())
                        if dropout > 0:
                            self.encoders[m].append(nn.Dropout(dropout))
                        
                        # Update dimensions
                        current_height = next_height
                        current_width = next_width
                        current_channels = out_channels
                        conv_layer_count += 1
                        
                        # Store layer config for decoder reconstruction
                        conv_config['conv_layers'].append({
                            'out_channels': out_channels,
                            'kernel_size': kernel_size,
                            'stride': stride,
                            'padding': padding,
                            'height': current_height,
                            'width': current_width
                        })
                    
                    # Flatten and add final encoder layers
                    final_conv_dim = current_height * current_width * current_channels
                conv_config['actual_conv_depth'] = conv_layer_count
                self.encoders[m].append(nn.Flatten())
                conv_config['final_conv_dim'] = final_conv_dim
                if not is_2d_conv:
                    print(f"  After {conv_layer_count} 1D conv layers: {current_seq_length} x {current_channels} = {final_conv_dim}")
                else:
                    print(f"  After {conv_layer_count} 2D conv layers: {current_height} x {current_width} x {current_channels} = {final_conv_dim}")
                self.conv_configs.append(conv_config)
            else:
                # No convolutional layers for standard 1D input
                final_conv_dim = input_shape[0]
            
            # Add fully connected layers after convolution
            for i in range(depth):
                if i == 0:
                    # First FC layer - conv output to hidden
                    encoder_layer = nn.Linear(final_conv_dim, self.hidden_dim)
                    self.encoders[m].append(encoder_layer)
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                        
                    # First decoder layer - latent to hidden
                    #decoder_layer = nn.Linear(latent_dims[m] + latent_dims[-1], self.hidden_dim)
                    decoder_layer = nn.Linear(self.latent_dim*n_latents_per_mod, self.hidden_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.decoders[m].append(nn.Dropout(dropout))
                        
                elif i == depth - 1:
                    # Final encoder layer - hidden to latent
                    #encoder_layer = nn.Linear(self.hidden_dim, latent_dims[m])
                    encoder_layer = nn.Linear(self.hidden_dim, self.latent_dim)#*len(input_shapes))
                    self.encoders[m].append(encoder_layer)
                    
                    # Final decoder layer - hidden to conv input
                    decoder_layer = nn.Linear(self.hidden_dim, final_conv_dim)
                    self.decoders[m].append(decoder_layer)
                    if self.conv_configs is not None:
                        self.decoders[m].append(nn.Sigmoid())
                else:
                    # Middle layers
                    encoder_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
                    self.encoders[m].append(encoder_layer)
                    self.encoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.encoders[m].append(nn.Dropout(dropout))
                        
                    decoder_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.ReLU())
                    if dropout > 0:
                        self.decoders[m].append(nn.Dropout(dropout))
            
            # Add convolutional decoder (reverse of encoder)
            # Start by reshaping back to conv feature map
            if is_conv:
                if not is_2d_conv:
                    # 1D case: reshape to (batch, channels, seq_length)
                    self.decoders[m].append(nn.Unflatten(1, (current_channels, current_seq_length)))
                else:
                    # 2D case: reshape to (batch, channels, height, width)  
                    self.decoders[m].append(nn.Unflatten(1, (current_channels, current_height, current_width)))
            
                # Add transposed convolutions in reverse order
                if not is_2d_conv:
                    # 1D transposed convolutions
                    expected_output_sizes = []
                    
                    # Calculate what sizes we expect at each layer (going backwards)
                    if conv_config['actual_conv_depth'] > 0:
                        # Start from the original size and work backwards to calculate expected sizes
                        temp_size = conv_config['original_seq_length']
                        expected_output_sizes.append(temp_size)
                        
                        for i in range(conv_config['actual_conv_depth'] - 1):
                            # Calculate what the size was before this conv layer
                            kernel_size = conv_config['conv_layers'][i]['kernel_size']
                            stride = conv_config['conv_layers'][i]['stride']
                            padding = conv_config['conv_layers'][i]['padding']
                            temp_size = (temp_size + 2 * padding - kernel_size) // stride + 1
                            expected_output_sizes.append(temp_size)
                        
                        expected_output_sizes.reverse()  # Reverse to match deconv order
                
                    for i in reversed(range(conv_config['actual_conv_depth'])):
                        layer_config = conv_config['conv_layers'][i]
                        in_channels = layer_config['out_channels']
                        
                        if i == 0:
                            out_channels = conv_config['n_channels']  # Back to original channels
                            expected_output_size = conv_config['original_seq_length']
                        else:
                            out_channels = conv_config['conv_layers'][i-1]['out_channels']
                            expected_output_size = expected_output_sizes[conv_config['actual_conv_depth'] - 1 - i]
                        
                        kernel_size = layer_config['kernel_size']
                        stride = layer_config['stride']
                        padding = layer_config['padding']
                        
                        # Calculate current input size (from the layer config)
                        current_input_size = layer_config['seq_length']
                        
                        # Calculate output_padding to get exact target size
                        # Formula: output_size = (input_size - 1) * stride - 2 * padding + kernel_size + output_padding
                        # Solving for output_padding: output_padding = expected_output_size - ((input_size - 1) * stride - 2 * padding + kernel_size)
                        calculated_output_size = (current_input_size - 1) * stride - 2 * padding + kernel_size
                        output_padding = max(0, expected_output_size - calculated_output_size)
                        
                        self.decoders[m].append(nn.ConvTranspose1d(
                            in_channels, out_channels, kernel_size, stride, padding, output_padding=output_padding
                        ))
                        if i > 0:  # No activation after final layer
                            self.decoders[m].append(nn.ReLU())
                            if dropout > 0:
                                self.decoders[m].append(nn.Dropout(dropout))
                else:
                    # 2D transposed convolutions
                    expected_output_sizes = []
                    
                    # Calculate what sizes we expect at each layer (going backwards)
                    if conv_config['actual_conv_depth'] > 0:
                        # Start from the original size and work backwards
                        temp_height = conv_config['original_height']
                        temp_width = conv_config['original_width']
                        expected_output_sizes.append((temp_height, temp_width))
                        
                        for i in range(conv_config['actual_conv_depth'] - 1):
                            # Calculate what the sizes were before this conv layer
                            kernel_size = conv_config['conv_layers'][i]['kernel_size']
                            stride = conv_config['conv_layers'][i]['stride']
                            padding = conv_config['conv_layers'][i]['padding']
                            temp_height = (temp_height + 2 * padding - kernel_size) // stride + 1
                            temp_width = (temp_width + 2 * padding - kernel_size) // stride + 1
                            expected_output_sizes.append((temp_height, temp_width))
                        
                        expected_output_sizes.reverse()  # Reverse to match deconv order
                    
                    for i in reversed(range(conv_config['actual_conv_depth'])):
                        layer_config = conv_config['conv_layers'][i]
                        in_channels = layer_config['out_channels']
                        
                        if i == 0:
                            out_channels = conv_config['n_channels']  # Back to original channels
                            expected_height = conv_config['original_height']
                            expected_width = conv_config['original_width']
                        else:
                            out_channels = conv_config['conv_layers'][i-1]['out_channels']
                            expected_height, expected_width = expected_output_sizes[conv_config['actual_conv_depth'] - 1 - i]
                        
                        kernel_size = layer_config['kernel_size']
                        stride = layer_config['stride']
                        padding = layer_config['padding']
                        
                        # Calculate current input sizes (from the layer config)
                        current_height = layer_config['height']
                        current_width = layer_config['width']
                        
                        # Calculate output_padding for height and width
                        calc_height = (current_height - 1) * stride - 2 * padding + kernel_size
                        calc_width = (current_width - 1) * stride - 2 * padding + kernel_size
                        output_padding_h = max(0, expected_height - calc_height)
                        output_padding_w = max(0, expected_width - calc_width)
                        output_padding = (output_padding_h, output_padding_w)
                        
                        self.decoders[m].append(nn.ConvTranspose2d(
                            in_channels, out_channels, kernel_size, stride, padding, output_padding=output_padding
                        ))
                        if i > 0:  # No activation after final layer
                            self.decoders[m].append(nn.ReLU())
                            if dropout > 0:
                                self.decoders[m].append(nn.Dropout(dropout))
                
                # Flatten final output appropriately
                if not is_2d_conv:
                    # 1D case: flatten to (batch, seq_length * n_channels)
                    self.decoders[m].append(nn.Flatten())
                else:
                    # 3D case: keep as (batch, height, width, n_channels) - transpose from conv format
                    pass  # We'll handle this in the forward pass
        
        # Hierarchical shared and specific latent spaces with adaptive rank reduction
        self.adaptive_layer_map = {}  # Maps layer name to index in self.adaptive_layers
        self._build_hierarchical_adaptive_layers(input_shapes, latent_dims, initial_rank_ratio, min_rank)
    
        # Initialize modality weights for balanced training
        self.modality_weights = nn.Parameter(torch.ones(len(input_shapes)), requires_grad=True)
    
    def _build_hierarchical_adaptive_layers(self, input_shapes, latent_dims, initial_rank_ratio, min_rank):
        """Build hierarchical adaptive layers for multi-modal architecture"""
        import itertools
        
        num_modalities = len(input_shapes)
        layer_index = 0
        
        print(f"Building hierarchical adaptive layers for {num_modalities} modalities:")
        
        # Build layers from global shared down to modality-specific
        subspace_level = num_modalities
        while subspace_level > 0:
            if subspace_level == num_modalities:
                # Global shared layer - combines all modalities
                print(f"  Level {subspace_level}: Creating global shared layer")
                #input_dim = sum(latent_dims[:num_modalities])  # Sum of all modality latent dims
                #output_dim = latent_dims[-1]  # Global shared dimension
                input_dim = self.latent_dim * num_modalities
                output_dim = self.latent_dim
                
                shared_layer = AdaptiveRankReducedLinear(
                    input_dim, output_dim,
                    initial_rank_ratio=initial_rank_ratio,
                    min_rank=min_rank
                )
                self.adaptive_layers.append(shared_layer)
                self.adaptive_layer_map['global_shared'] = layer_index
                layer_index += 1
                
            elif subspace_level > 1:
                # Shared subspaces - combinations of modalities at this level
                print(f"  Level {subspace_level}: Creating shared subspaces")
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    print(f"    Creating subspace: {combo_name}")
                    
                    # Input dimension is sum of latent dims for modalities in this combination
                    #input_dim = sum(latent_dims[i] for i in combo)
                    # Output dimension is the minimum of the constituent latent dims
                    #output_dim = min(latent_dims[i] for i in combo)
                    input_dim = self.latent_dim * len(combo)
                    output_dim = self.latent_dim
                    
                    subspace_layer = AdaptiveRankReducedLinear(
                        input_dim, output_dim,
                        initial_rank_ratio=initial_rank_ratio,
                        min_rank=min_rank
                    )
                    self.adaptive_layers.append(subspace_layer)
                    self.adaptive_layer_map[combo_name] = layer_index
                    layer_index += 1
            
            else:  # subspace_level == 1
                # Modality-specific layers
                print(f"  Level {subspace_level}: Creating modality-specific layers")
                for i in range(num_modalities):
                    layer_name = f"specific_{i}"
                    print(f"    Creating layer: {layer_name}")

                    input_dim = self.latent_dim# * num_modalities
                    output_dim = self.latent_dim
                    
                    specific_layer = AdaptiveRankReducedLinear(
                        #latent_dims[i], latent_dims[i],
                        input_dim, output_dim,
                        initial_rank_ratio=initial_rank_ratio,
                        min_rank=min_rank
                    )
                    self.adaptive_layers.append(specific_layer)
                    self.adaptive_layer_map[layer_name] = layer_index
                    layer_index += 1
            
            subspace_level -= 1
        
        print(f"Created {len(self.adaptive_layers)} adaptive layers:")
        for name, idx in self.adaptive_layer_map.items():
            layer = self.adaptive_layers[idx]
            print(f"  {name}: {layer.in_features} -> {layer.out_features} (index {idx})")
            
        return layer_index
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)

            # Find the rank that preserves specified energy threshold
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers by specified increment"""
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def encode(self, x, compute_jacobian=False):
        h_concat = []
        for m, x_m in enumerate(x):
            if self.conv_configs is not None:
                conv_config = self.conv_configs[m]
                
                if not conv_config['is_2d']:
                    # 2D input: (batch, seq_length, channels) -> (batch, channels, seq_length)
                    x_m = x_m.transpose(1, 2)
                else:
                    # 3D input: (batch, height, width, channels) -> (batch, channels, height, width)
                    x_m = x_m.permute(0, 3, 1, 2)
            
            for layer in self.encoders[m]:
                x_m = layer(x_m)
            h_concat.append(x_m)
        
        # Apply hierarchical adaptive layers
        h_encoded = {}
        
        # Global shared layer
        h_all = torch.cat(h_concat, dim=1)
        global_idx = self.adaptive_layer_map['global_shared']
        h_encoded['global_shared'] = self.adaptive_layers[global_idx](h_all)
        
        # Shared subspace layers (combinations of modalities)
        import itertools
        num_modalities = len(h_concat)
        
        for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
            for combo in itertools.combinations(range(num_modalities), subspace_level):
                combo_name = f"shared_{'_'.join(map(str, combo))}"
                if combo_name in self.adaptive_layer_map:
                    # Concatenate the specific modalities for this combination
                    h_combo = torch.cat([h_concat[i] for i in combo], dim=1)
                    combo_idx = self.adaptive_layer_map[combo_name]
                    h_encoded[combo_name] = self.adaptive_layers[combo_idx](h_combo)
        
        # Modality-specific layers
        h_specific = []
        for i in range(num_modalities):
            specific_name = f'specific_{i}'
            if specific_name in self.adaptive_layer_map:
                specific_idx = self.adaptive_layer_map[specific_name]
                h_specific.append(self.adaptive_layers[specific_idx](h_concat[i]))
            else:
                h_specific.append(h_concat[i])  # Fallback if layer doesn't exist
        
        h_encoded['specific'] = h_specific
        
        return h_encoded

    def decode(self, h_encoded):
        """Decode from hierarchical encoded representations"""
        h_global = h_encoded['global_shared']
        h_specific = h_encoded['specific']
        
        x_hat = []
        for m, h_m in enumerate(h_specific):
            # Collect all relevant hierarchical representations for this modality
            h_components = [h_global, h_m]  # Start with global shared and modality-specific
            
            # Add shared subspace representations that include this modality
            import itertools
            num_modalities = len(h_specific)
            
            for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    if combo_name in h_encoded and m in combo:
                        h_components.append(h_encoded[combo_name])
            
            # Separate global and modality-specific components for proper averaging
            modality_components = [h_m]  # Start with modality-specific
            
            # Add shared subspace representations that include this modality (excluding global)
            for subspace_level in range(num_modalities - 1, 1, -1):  # From n-1 down to 2
                for combo in itertools.combinations(range(num_modalities), subspace_level):
                    combo_name = f"shared_{'_'.join(map(str, combo))}"
                    if combo_name in h_encoded and m in combo:
                        modality_components.append(h_encoded[combo_name])
            
            # Average pool modality-specific and shared subspace representations
            if len(modality_components) > 1:
                # Stack and average modality-specific components
                #h_stacked = torch.stack(modality_components, dim=0)  # (num_components, batch_size, latent_dim)
                #h_modality_combined = torch.mean(h_stacked, dim=0)   # (batch_size, latent_dim)
                h_modality_combined = torch.cat(modality_components, dim=-1)
            else:
                h_modality_combined = modality_components[0]
            
            # Combine global shared and averaged modality representations for decoder input
            # This maintains the expected input dimensions: latent_dims[m] + latent_dims[-1]
            h_concat = torch.cat([h_global, h_modality_combined], dim=1)
            
            for layer in self.decoders[m]:
                h_concat = layer(h_concat)
            
            if self.conv_configs is not None:
                # Reshape back to original format based on modality type
                conv_config = self.conv_configs[m]
                
                if not conv_config['is_2d']:
                    # 2D modality: handle 1D conv output
                    if len(h_concat.shape) == 2:  # Flattened output: (batch, total_features)
                        original_seq_length = conv_config['original_seq_length']
                        n_channels = conv_config['n_channels']
                        
                        # Calculate the actual output size after conv/deconv operations
                        total_features = h_concat.shape[1]  # Total flattened features
                        expected_features = original_seq_length * n_channels
                        
                        if total_features == expected_features:
                            # Perfect match - reshape to original dimensions
                            h_concat = h_concat.view(h_concat.shape[0], original_seq_length, n_channels)
                        else:
                            # Size mismatch due to conv operations - try to infer correct dimensions
                            actual_seq_length = total_features // n_channels
                            if total_features % n_channels == 0:
                                h_concat = h_concat.view(h_concat.shape[0], actual_seq_length, n_channels)
                            else:
                                # If not divisible by n_channels, use original seq_length and adjust channels
                                actual_channels = total_features // original_seq_length
                                if total_features % original_seq_length == 0:
                                    h_concat = h_concat.view(h_concat.shape[0], original_seq_length, actual_channels)
                                else:
                                    # Last resort: use original dimensions and truncate/pad as needed
                                    h_concat = h_concat.view(h_concat.shape[0], -1, n_channels)[:, :original_seq_length, :]
                                    
                    elif len(h_concat.shape) == 3:  # Conv output: (batch, channels, seq_length)
                        # Transpose from (batch, channels, seq_length) to (batch, seq_length, channels)
                        h_concat = h_concat.transpose(1, 2)
                else:
                    # 3D modality: handle 2D conv output
                    if len(h_concat.shape) == 4:  # Conv output: (batch, channels, height, width)
                        # Transpose from (batch, channels, height, width) to (batch, height, width, channels)
                        h_concat = h_concat.permute(0, 2, 3, 1)
                    elif len(h_concat.shape) == 2:  # Flattened output (shouldn't happen for 3D but handle it)
                        original_height = conv_config['original_height']
                        original_width = conv_config['original_width']
                        n_channels = conv_config['n_channels']
                        
                        total_features = h_concat.shape[1]
                        expected_features = original_height * original_width * n_channels
                        
                        if total_features == expected_features:
                            h_concat = h_concat.view(h_concat.shape[0], original_height, original_width, n_channels)
                        else:
                            # Try to infer dimensions
                            import math
                            sqrt_spatial = int(math.sqrt(total_features // n_channels))
                            if sqrt_spatial * sqrt_spatial * n_channels == total_features:
                                h_concat = h_concat.view(h_concat.shape[0], sqrt_spatial, sqrt_spatial, n_channels)
                            else:
                                # Fallback to original shape and truncate/pad as needed
                                h_concat = h_concat.view(h_concat.shape[0], -1)
                                padding_size = max(0, expected_features - total_features)
                                if padding_size > 0:
                                    padding = torch.zeros(h_concat.shape[0], padding_size, device=h_concat.device)
                                    h_concat = torch.cat([h_concat, padding], dim=1)
                                h_concat = h_concat[:, :expected_features].view(h_concat.shape[0], original_height, original_width, n_channels)

            x_hat.append(h_concat)
        return x_hat

    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h
    
    def encode_modalities(self, x):
        """Encode each modality with its combined representations"""
        h_encoded = self.encode(x)
        h_global = h_encoded['global_shared']
        h_specific = h_encoded['specific']
        
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_global, h_m], dim=1))
        return h_combined