"""
CNN-based multimodal autoencoder for MultiBench datasets with adaptive rank reduction.
Designed for temporal sequences from text (GloVe), vision (FACET/OpenFace), and audio (COVAREP) modalities.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.larrp_unimodal import AdaptiveRankReducedLinear
#import numpy as np

# --- Helper Modules for Tensor Permutation ---

class _TransposeForRNN(nn.Module):
    """Permutes (N, Channels, Time) -> (N, Time, Channels) for RNN layers."""
    def forward(self, x):
        return x.transpose(1, 2) # (N, C, T) -> (N, T, C)

class _TransposeForConv1d(nn.Module):
    """Permutes (N, Time, Channels) -> (N, Channels, Time) for Conv1D layers."""
    def forward(self, x):
        return x.transpose(1, 2) # (N, T, C) -> (N, C, T)

class _SelectGRUSequence(nn.Module):
    """Selects the output sequence from a GRU, discarding the hidden state."""
    def forward(self, x):
        output, _ = x # GRU returns (output_sequence, hidden_state)
        return output

class _SelectLastHiddenState(nn.Module):
    """Selects the last hidden state from a GRU/RNN, discarding the output sequence."""
    def forward(self, x):
        _ , hidden = x # GRU returns (output_sequence, hidden_state)
        # hidden is (num_layers, N, H), we want the last layer's state
        return hidden[-1]

class _RepeatVector(nn.Module):
    """Repeats a vector (N, H) to a sequence (N, T, H) for the decoder."""
    def __init__(self, seq_len):
        super().__init__()
        self.seq_len = seq_len
    def forward(self, x):
        # Input x is (N, H)
        return x.unsqueeze(1).repeat(1, self.seq_len, 1) # (N, T, H)

# --- Main Autoencoder Class ---

class AdaptiveRankReducedAE_MultiBench(nn.Module):
    """
    Multimodal autoencoder for MultiBench datasets with selectable encoder/decoder architectures.
    
    Supports 3 modalities with temporal structure:
    - Text: (batch, seq_len, 300) - GloVe embeddings
    - Vision: (batch, seq_len, 35/709) - FACET or OpenFace features
    - Audio: (batch, seq_len, 74) - COVAREP features
    
    Architecture:
    - Modality-specific encoders (Conv1D, GRU, or Transformer)
    - Adaptive low-rank layers for dimension reduction
    - Modality-specific decoders (symmetric to encoders)
    """
    
    def __init__(self, 
                 input_shapes, 
                 latent_dims, 
                 model_type: str = 'conv1d',
                 # Conv1D parameters
                 conv_channels: list = [64, 128, 256], 
                 kernel_sizes: list = [3, 3, 3], 
                 # GRU parameters
                 gru_hidden_dim: int = 128,
                 gru_num_layers: int = 2,
                 # Attention (Transformer) parameters
                 attn_num_heads: int = 4,
                 attn_num_layers: int = 2,
                 # Common parameters
                 dropout=0.1, 
                 initial_rank_ratio=1.0, 
                 min_rank=10,
                 activation=None):
        """
        Args:
            input_shapes: List of tuples [(seq_len, n_features), ...] for each modality
            latent_dims: Latent dimension (int or list of ints per modality)
            model_type: Type of sequence model: 'conv1d', 'gru', or 'attn'.
            conv_channels: List of channel sizes for conv layers (default [64, 128, 256])
            kernel_sizes: List of kernel sizes for conv layers (default [3, 3, 3])
            gru_hidden_dim: Hidden dimension for GRU layers (default 128)
            gru_num_layers: Number of layers for GRU (default 2)
            attn_num_heads: Number of heads for Transformer (default 4)
            attn_num_layers: Number of layers for Transformer (default 2)
            dropout: Dropout rate (default 0.1)
            initial_rank_ratio: Initial rank ratio for adaptive layers (default 1.0)
            min_rank: Minimum rank for adaptive layers (default 10)
            activation: Optional output activation ('tanh', 'sigmoid', 'softmax', or None)
        """
        super(AdaptiveRankReducedAE_MultiBench, self).__init__()
        
        self.n_modalities = len(input_shapes)
        self.input_shapes = input_shapes
        self.model_type = model_type
        self.activation = activation
        
        # Store architecture-specific params
        self.conv_channels = conv_channels
        self.kernel_sizes = kernel_sizes
        self.n_conv_layers = len(conv_channels)
        self.gru_hidden_dim = gru_hidden_dim
        self.gru_num_layers = gru_num_layers
        self.attn_num_heads = attn_num_heads
        self.attn_num_layers = attn_num_layers
        self.dropout = dropout
        
        # Handle latent_dims as int or list
        if isinstance(latent_dims, int):
            self.latent_dims = [latent_dims] * (self.n_modalities + 1)
        else:
            self.latent_dims = latent_dims
        
        print(f"Creating AdaptiveRankReducedAE_MultiBench ({self.model_type}) for {self.n_modalities} modalities:")
        print(f"  Input shapes: {input_shapes}")
        print(f"  Latent dims: {self.latent_dims} (last is shared space)")
        
        # Create encoders and decoders
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.adaptive_layers = nn.ModuleList()
        
        self.encoder_output_shapes = []
        self.flattened_dims = [] 
        
        for m in range(self.n_modalities):
            seq_len, n_features = input_shapes[m]
            latent_dim = self.latent_dims[m]
            
            # Build the specified encoder
            encoder, enc_output_shape, flattened_dim = self._build_encoder(
                n_features, seq_len
            )
            self.encoders.append(encoder)
            self.encoder_output_shapes.append(enc_output_shape)
            
            # Verify flattened_dim with a dummy forward pass
            with torch.no_grad():
                dummy_in = torch.zeros(1, n_features, seq_len) # (N, D, T)
                #dummy_in = torch.zeros(1, seq_len, n_features)
                enc_out = encoder(dummy_in)
                #enc_out = dummy_in
                actual_flattened = int(enc_out.shape[1])
                print(f"  Modality {m} ({self.model_type}): {input_shapes[m]} -> flattened {actual_flattened}")
            
            self.flattened_dims.append(actual_flattened)
            
            # Build the specified decoder
            decoder = self._build_decoder(
                latent_dim, enc_output_shape, n_features
            )
            self.decoders.append(decoder)
        
        # Add output activation if specified
        if self.activation == 'tanh':
            for m in range(self.n_modalities):
                self.decoders[m].append(nn.Tanh())
        elif self.activation == 'sigmoid':
            for m in range(self.n_modalities):
                self.decoders[m].append(nn.Sigmoid())
        elif self.activation == 'softmax':
            for m in range(self.n_modalities):
                self.decoders[m].append(nn.Softmax(dim=-1))

        # --- Adaptive Layers (Unchanged) ---
        # Shared layer takes concatenation of all modality embeddings
        shared_layer = AdaptiveRankReducedLinear(
            sum(self.flattened_dims), self.latent_dims[-1],
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        self.adaptive_layers.append(shared_layer)
        
        # Modality-specific layers
        for i in range(self.n_modalities):
            specific_layer = AdaptiveRankReducedLinear(
                self.flattened_dims[i], self.latent_dims[i],
                initial_rank_ratio=initial_rank_ratio,
                min_rank=min_rank
            )
            self.adaptive_layers.append(specific_layer)
        
        print(f"  Shared layer: {sum(self.flattened_dims)} -> {self.latent_dims[-1]} with min, max rank {self.adaptive_layers[0].min_rank}, {self.adaptive_layers[0].max_rank}")
        for i in range(self.n_modalities):
            print(f"  Specific layer {i}: {self.flattened_dims[i]} -> {self.latent_dims[i]} with min, max rank {self.adaptive_layers[i+1].min_rank}, {self.adaptive_layers[i+1].max_rank}")
    
    def _build_encoder(self, n_features, seq_len):
        """Builds an encoder stack based on self.model_type."""
        layers = nn.ModuleList()
        
        if self.model_type == 'conv1d':
            current_channels = n_features
            current_seq_len = seq_len
            
            for i, (out_channels, kernel_size) in enumerate(zip(self.conv_channels, self.kernel_sizes)):
                layers.append(nn.Conv1d(
                    in_channels=current_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=2,
                    padding=kernel_size // 2
                ))
                layers.append(nn.BatchNorm1d(out_channels))
                layers.append(nn.ReLU())
                if self.dropout > 0:
                    layers.append(nn.Dropout(self.dropout))
                
                current_seq_len = (current_seq_len + 2 * (kernel_size // 2) - kernel_size) // 2 + 1
                current_channels = out_channels
            
            layers.append(nn.Flatten())
            encoder_output_shape = (current_channels, current_seq_len)
            flattened_dim = current_channels * current_seq_len

        elif self.model_type == 'gru':
            layers.append(_TransposeForRNN()) # (N, D, T) -> (N, T, D)
            layers.append(nn.GRU(
                input_size=n_features,
                hidden_size=self.gru_hidden_dim,
                num_layers=self.gru_num_layers,
                batch_first=True,
                dropout=self.dropout if self.gru_num_layers > 1 else 0
            ))
            #layers.append(_SelectGRUOutput()) # Get output sequence, not hidden state
            layers.append(_SelectLastHiddenState())  # Get last hidden state
            #layers.append(_TransposeForConv1d()) # (N, T, H) -> (N, H, T)
            #layers.append(nn.Flatten())
            
            encoder_output_shape = (self.gru_hidden_dim, seq_len)
            flattened_dim = self.gru_hidden_dim# * seq_len
            
        elif self.model_type == 'attn':
            layers.append(_TransposeForRNN()) # (N, D, T) -> (N, T, D)
            # Note: For Transformer, d_model (embed dim) must match n_features
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=n_features,
                nhead=self.attn_num_heads,
                dropout=self.dropout,
                batch_first=True
            )
            layers.append(nn.TransformerEncoder(encoder_layer, num_layers=self.attn_num_layers))
            layers.append(_TransposeForConv1d()) # (N, T, D) -> (N, D, T)
            layers.append(nn.Flatten())
            
            encoder_output_shape = (n_features, seq_len)
            flattened_dim = n_features * seq_len

        else:
            raise ValueError(f"Unknown model_type: {self.model_type}")
            
        return nn.Sequential(*layers), encoder_output_shape, flattened_dim
    
    def _build_decoder(self, latent_dim, encoder_output_shape, n_features):
        """Builds a symmetric decoder stack based on self.model_type."""
        layers = nn.ModuleList()
        
        channels_out, seq_len_out = encoder_output_shape
        flattened_dim = channels_out * seq_len_out
        
        # Linear layer to unflatten
        if self.model_type == 'gru':
            flattened_dim = self.gru_hidden_dim # * seq_len
        layers.append(nn.Linear(latent_dim + self.latent_dims[-1], flattened_dim))
        layers.append(nn.ReLU())
        
        # Unflatten to the shape expected by the sequence model
        if self.model_type in {'conv1d', 'attn'}:
            layers.append(nn.Unflatten(1, (channels_out, seq_len_out)))

        if self.model_type == 'conv1d':
            current_channels = channels_out
            for i, (out_channels, kernel_size) in enumerate(zip(
                list(reversed(self.conv_channels))[1:] + [n_features], 
                list(reversed(self.kernel_sizes))
            )):
                layers.append(nn.ConvTranspose1d(
                    in_channels=current_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=2,
                    padding=kernel_size // 2,
                    output_padding=1
                ))
                if i < self.n_conv_layers - 1: # No activation on final layer
                    layers.append(nn.BatchNorm1d(out_channels))
                    layers.append(nn.ReLU())
                    if self.dropout > 0:
                        layers.append(nn.Dropout(self.dropout))
                current_channels = out_channels

        elif self.model_type == 'gru':
            #layers.append(_TransposeForRNN()) # (N, H, T) -> (N, T, H)
            layers.append(_RepeatVector(seq_len_out)) # (N, H) -> (N, T, H)
            layers.append(nn.GRU(
                input_size=self.gru_hidden_dim,
                hidden_size=self.gru_hidden_dim, # Hidden-to-hidden
                num_layers=self.gru_num_layers,
                batch_first=True,
                dropout=self.dropout if self.gru_num_layers > 1 else 0
            ))
            #layers.append(_SelectGRUOutput())
            layers.append(_SelectGRUSequence())
            # Final linear layer to map from hidden dim to feature dim
            layers.append(nn.Linear(self.gru_hidden_dim, n_features))
            layers.append(_TransposeForConv1d()) # (N, T, D) -> (N, D, T)
            
        elif self.model_type == 'attn':
            layers.append(_TransposeForRNN()) # (N, D, T) -> (N, T, D)
            decoder_layer = nn.TransformerEncoderLayer(
                d_model=n_features,
                nhead=self.attn_num_heads,
                dropout=self.dropout,
                batch_first=True
            )
            layers.append(nn.TransformerEncoder(decoder_layer, num_layers=self.attn_num_layers))
            layers.append(_TransposeForConv1d()) # (N, T, D) -> (N, D, T)
            # No final linear layer needed, Transformer preserves feature dim D

        else:
            raise ValueError(f"Unknown model_type: {self.model_type}")

        return nn.Sequential(*layers)
    
    def encode(self, x_list):
        """
        Encode list of modality inputs to latent representations.
        
        Args:
            x_list: List of tensors [(batch, seq_len, n_features), ...]
        
        Returns:
            Tuple of (h_shared, [h_specific_0, h_specific_1, ...])
        """
        h_concat = []
        for m, x in enumerate(x_list):
            # Transpose to (batch, channels, seq_len) for Conv1D/our setup
            x = x.transpose(1, 2)
            h = self.encoders[m](x)
            h_concat.append(h)
                
        h_all = torch.cat(h_concat, dim=1)
        
        # Apply adaptive layers
        h_shared = self.adaptive_layers[0](h_all)  # Shared layer
        
        specific_outputs = []
        for i in range(self.n_modalities):
            specific_output = self.adaptive_layers[i + 1](h_concat[i])
            specific_outputs.append(specific_output)
        
        return (h_shared, specific_outputs)
    
    def decode(self, h):
        """
        Decode latent representations to reconstructions.
        
        Args:
            h: Tuple of (h_shared, [h_specific_list])
        
        Returns:
            List of reconstructed tensors [(batch, seq_len, n_features), ...]
        """
        h_shared, h_specific = h
        reconstructions = []
        
        for m, h_m in enumerate(h_specific):
            h_concat = torch.cat([h_shared, h_m], dim=1)
            
            # Pass through the whole decoder sequential
            x = self.decoders[m](h_concat)
            
            # Transpose back to (batch, seq_len, channels)
            x = x.transpose(1, 2)
            
            # Trim to original sequence length if needed
            original_seq_len = self.input_shapes[m][0]
            if x.shape[1] > original_seq_len:
                x = x[:, :original_seq_len, :]
            
            reconstructions.append(x)
        
        return reconstructions
    
    def forward(self, x_list):
        """
        Forward pass: encode and decode.
        
        Args:
            x_list: List of input tensors [(batch, seq_len, n_features), ...]
        
        Returns:
            Tuple of (reconstructions, latents)
            where latents is (h_shared, [h_specific_list])
        """
        latents = self.encode(x_list)
        reconstructions = self.decode(latents)
        
        return reconstructions, latents
    
    def encode_modalities(self, x_list):
        """
        Encode to combined modality representations (shared + specific concatenated).
        
        Args:
            x_list: List of input tensors [(batch, seq_len, n_features), ...]
        
        Returns:
            List of combined latent tensors [(batch, shared_dim + specific_dim), ...]
        """
        h_shared, h_specific = self.encode(x_list)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def get_ranks(self):
        """Get current ranks of all adaptive layers."""
        ranks = []
        for adaptive_layer in self.adaptive_layers:
            ranks.append(adaptive_layer.active_dims)
        return ranks
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def get_r_squares(self):
        """Get R² values for all adaptive layers."""
        r_squares = []
        for adaptive_layer in self.adaptive_layers:
            # R² is calculated during rank reduction, return None if not available
            r_squares.append(None)
        return r_squares
    
    def reduce_rank(self, reduction_ratio=0.8, threshold=0.01, layer_ids=[], dim=0):
        """
        Reduce rank of all adaptive layers based on singular value importance.
        
        Args:
            reduction_ratio: Ratio to reduce rank by (default 0.8)
            threshold: Energy threshold for rank reduction (default 0.01)
            layer_ids: List of layer indices to reduce (empty = all layers)
            dim: Dimension along which to reduce (default 0)
        
        Returns:
            Boolean indicating if any changes were made
        """
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            # if layer_ids is specified, only reduce rank for those layers
            if layer_ids and i not in layer_ids:
                continue
            
            # Get singular values
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)

            # Find the rank that preserves specified energy threshold
            # Make sure we don't go below the minimum rank
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            which_dims = None

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the target rank from energy threshold
            new_rank = max(target_rank, ratio_rank)
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim, which_dims=which_dims)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """
        Increase rank of all adaptive layers by specified increment.
        
        Args:
            increment: Fixed increment to add (default None, uses increase_ratio)
            increase_ratio: Ratio to increase rank by (default 1.1)
            layer_ids: List of layer indices to increase (empty = all layers)
            dim: Dimension along which to increase (default 0)
        
        Returns:
            Boolean indicating if any changes were made
        """
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            previous_rank = layer.active_dims
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                # set the min rank to the previous active rank
                layer.min_rank = previous_rank
                
        return changes_made

##########################
# old version from copilot (bad)
##########################

class AdaptiveRankReducedAE_MultiBench_OLD(nn.Module):
    """
    CNN-based multimodal autoencoder for MultiBench datasets.
    
    Supports 3 modalities with temporal structure:
    - Text: (batch, seq_len, 300) - GloVe embeddings
    - Vision: (batch, seq_len, 35/709) - FACET or OpenFace features
    - Audio: (batch, seq_len, 74) - COVAREP features
    
    Architecture:
    - Modality-specific CNN encoders (1D convolutions over time)
    - Adaptive low-rank layers for dimension reduction
    - Modality-specific CNN decoders with transposed convolutions
    """
    
    def __init__(self, input_shapes, latent_dims, conv_channels=[64, 128, 256], 
                 kernel_sizes=[3, 3, 3], dropout=0.1, 
                 initial_rank_ratio=1.0, min_rank=10):
        """
        Args:
            input_shapes: List of tuples [(seq_len, n_features), ...] for each modality
            latent_dims: Latent dimension (int or list of ints per modality)
            conv_channels: List of channel sizes for conv layers (default [64, 128, 256])
            kernel_sizes: List of kernel sizes for conv layers (default [3, 3, 3])
            dropout: Dropout rate (default 0.1)
            initial_rank_ratio: Initial rank ratio for adaptive layers (default 1.0)
            min_rank: Minimum rank for adaptive layers (default 10)
        """
        super(AdaptiveRankReducedAE_MultiBench_OLD, self).__init__()
        
        self.n_modalities = len(input_shapes)
        self.input_shapes = input_shapes
        
        # Handle latent_dims as int or list
        # Last latent dim is for shared space, others are modality-specific
        if isinstance(latent_dims, int):
            # Create list: [mod1, mod2, ..., shared]
            self.latent_dims = [latent_dims] * (self.n_modalities + 1)
        else:
            self.latent_dims = latent_dims
        
        self.conv_channels = conv_channels
        self.kernel_sizes = kernel_sizes
        self.n_conv_layers = len(conv_channels)
        
        print(f"Creating AdaptiveRankReducedAE_MultiBench for {self.n_modalities} modalities:")
        print(f"  Input shapes: {input_shapes}")
        print(f"  Latent dims: {self.latent_dims} (last is shared space)")
        print(f"  Conv channels: {conv_channels}")
        print(f"  Kernel sizes: {kernel_sizes}")
        print(f"  Dropout: {dropout}")
        print(f"  Initial rank ratio: {initial_rank_ratio}, Min rank: {min_rank}")
        
        # Create encoders and decoders
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.adaptive_layers = nn.ModuleList()
        
        # Track output shapes after convolutions for each modality
        self.encoder_output_shapes = []
        self.flattened_dims = []  # Track flattened dimensions before adaptive layers
        
        for m in range(self.n_modalities):
            seq_len, n_features = input_shapes[m]
            latent_dim = self.latent_dims[m]
            
            # Build CNN encoder (without final adaptive layer)
            encoder, enc_output_shape, flattened_dim = self._build_encoder(
                n_features, seq_len, 
                conv_channels, kernel_sizes, dropout
            )
            self.encoders.append(encoder)
            self.encoder_output_shapes.append(enc_output_shape)
            # Compute actual flattened dim by running a dummy forward through the encoder
            # This guarantees the flattened dimension matches what the encoder will produce
            # at runtime for the given seq_len and n_features.
            with torch.no_grad():
                # encoder expects input shape (batch, channels, seq_len) -> we provide (1, n_features, seq_len)
                dummy_in = torch.zeros(1, n_features, seq_len)
                enc_out = encoder(dummy_in)
                # encoder ends with Flatten -> enc_out shape is (1, flattened_dim)
                actual_flattened = int(enc_out.shape[1])
            self.flattened_dims.append(actual_flattened)
            
            # Build CNN decoder
            decoder = self._build_decoder(
                latent_dim, enc_output_shape, n_features,
                conv_channels, kernel_sizes, dropout
            )
            self.decoders.append(decoder)
            
            print(f"  Modality {m}: {input_shapes[m]} -> flattened {flattened_dim} -> latent {latent_dim}")
        
        # Now create the adaptive layers: shared + modality-specific
        # Shared layer takes concatenation of all modality embeddings
        shared_layer = AdaptiveRankReducedLinear(
            sum(self.flattened_dims), self.latent_dims[-1],  # last latent dim is for shared space
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        self.adaptive_layers.append(shared_layer)
        
        # Modality-specific layers
        for i in range(self.n_modalities):
            specific_layer = AdaptiveRankReducedLinear(
                self.flattened_dims[i], self.latent_dims[i],
                initial_rank_ratio=initial_rank_ratio,
                min_rank=min_rank
            )
            self.adaptive_layers.append(specific_layer)
        
        print(f"  Shared layer: {sum(self.flattened_dims)} -> {self.latent_dims[-1]}")
        for i in range(self.n_modalities):
            print(f"  Specific layer {i}: {self.flattened_dims[i]} -> {self.latent_dims[i]}")
    
    def _build_encoder(self, n_features, seq_len, 
                      conv_channels, kernel_sizes, dropout):
        """
        Build CNN encoder with Conv1D layers (without final adaptive layer).
        
        Input: (batch, seq_len, n_features)
        Output: (batch, flattened_dim)
        """
        layers = nn.ModuleList()
        
        current_channels = n_features
        current_seq_len = seq_len
        
        # Conv1D layers expect (batch, channels, seq_len)
        # We'll transpose in forward pass
        
        for i, (out_channels, kernel_size) in enumerate(zip(conv_channels, kernel_sizes)):
            # Conv1D with stride=2 for downsampling
            layers.append(nn.Conv1d(
                in_channels=current_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=2,
                padding=kernel_size // 2
            ))
            layers.append(nn.BatchNorm1d(out_channels))
            layers.append(nn.ReLU())
            
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            
            # Update dimensions
            current_seq_len = (current_seq_len + 2 * (kernel_size // 2) - kernel_size) // 2 + 1
            current_channels = out_channels
        
        # Flatten 
        flattened_dim = current_channels * current_seq_len
        layers.append(nn.Flatten())
        
        encoder_output_shape = (current_channels, current_seq_len)
        
        return nn.Sequential(*layers), encoder_output_shape, flattened_dim
    
    def _build_decoder(self, latent_dim, encoder_output_shape, n_features,
                      conv_channels, kernel_sizes, dropout):
        """
        Build CNN decoder with ConvTranspose1D layers.
        Decoder receives concatenated [shared, specific] latent.
        
        Input: (batch, latent_dim + shared_latent_dim)
        Output: (batch, seq_len, n_features)
        """
        layers = nn.ModuleList()
        
        channels_out, seq_len_out = encoder_output_shape
        flattened_dim = channels_out * seq_len_out
        
        # Linear layer to unflatten
        # Input will be latent_dim + shared_latent_dim (concatenated in decode)
        # For now, we'll set this up to accept latent_dim + shared, updated in __init__
        # We'll use latent_dim here and adjust input in forward
        layers.append(nn.Linear(latent_dim + self.latent_dims[-1], flattened_dim))
        layers.append(nn.ReLU())
        
        # Store shape for unflatten
        self.decoder_unflatten_shape = (channels_out, seq_len_out)
        
        # Reverse conv layers with transposed convolutions
        reversed_channels = list(reversed(conv_channels))
        reversed_kernels = list(reversed(kernel_sizes))
        
        current_channels = channels_out
        
        for i, (out_channels, kernel_size) in enumerate(zip(reversed_channels[1:] + [n_features], reversed_kernels)):
            # ConvTranspose1D for upsampling
            layers.append(nn.ConvTranspose1d(
                in_channels=current_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=2,
                padding=kernel_size // 2,
                output_padding=1
            ))
            
            # Don't add activation/norm/dropout on final layer
            if i < len(reversed_channels) - 1:
                layers.append(nn.BatchNorm1d(out_channels))
                layers.append(nn.ReLU())
                if dropout > 0:
                    layers.append(nn.Dropout(dropout))
            
            current_channels = out_channels
        
        return nn.Sequential(*layers)
    
    def encode(self, x_list):
        """
        Encode list of modality inputs to latent representations.
        Returns shared and modality-specific latents like larrp_multimodal.
        
        Args:
            x_list: List of tensors [(batch, seq_len, n_features), ...]
        
        Returns:
            Tuple of (h_shared, [h_specific_0, h_specific_1, ...])
        """
        # First pass through CNN encoders to get flattened embeddings
        h_concat = []
        for m, x in enumerate(x_list):
            print(f" Modality {m} input shape: {x.shape}")
            # Transpose to (batch, channels, seq_len) for Conv1D
            x = x.transpose(1, 2)
            
            # Pass through encoder (CNN + flatten)
            #h = self.encoders[m](x)
            for layer in self.encoders[m]:
                x = layer(x)
                print(f"  After layer {layer}: shape {x.shape}")
            print(f"  Encoded modality {m} shape: {x.shape}")
            #h_concat.append(h)
            h_concat.append(x)
        exit()
        
        # Concatenate all modality embeddings for shared space
        h_all = torch.cat(h_concat, dim=1)
        
        # Apply adaptive layers
        h_shared = self.adaptive_layers[0](h_all)  # Shared layer
        
        # Apply modality-specific layers
        specific_outputs = []
        for i in range(self.n_modalities):
            specific_output = self.adaptive_layers[i + 1](h_concat[i])
            specific_outputs.append(specific_output)
        
        return (h_shared, specific_outputs)
    
    def decode(self, h):
        """
        Decode latent representations to reconstructions.
        Expects tuple of (h_shared, [h_specific_0, h_specific_1, ...])
        
        Args:
            h: Tuple of (h_shared, [h_specific_list])
        
        Returns:
            List of reconstructed tensors [(batch, seq_len, n_features), ...]
        """
        h_shared, h_specific = h
        reconstructions = []
        
        for m, h_m in enumerate(h_specific):
            # Concatenate shared and modality-specific latent
            h_concat = torch.cat([h_shared, h_m], dim=1)
            
            # Get decoder for this modality
            decoder = self.decoders[m]
            
            # Linear projection
            x = decoder[0](h_concat)  # Linear layer
            x = decoder[1](x)  # ReLU
            
            # Unflatten to (batch, channels, seq_len)
            channels_out, seq_len_out = self.encoder_output_shapes[m]
            x = x.view(-1, channels_out, seq_len_out)
            
            # Pass through transposed conv layers
            for layer in decoder[2:]:
                x = layer(x)
            
            # Transpose back to (batch, seq_len, channels)
            x = x.transpose(1, 2)
            
            # Trim to original sequence length if needed
            original_seq_len = self.input_shapes[m][0]
            if x.shape[1] > original_seq_len:
                x = x[:, :original_seq_len, :]
            
            reconstructions.append(x)
        
        return reconstructions
    
    def forward(self, x_list):
        """
        Forward pass: encode and decode.
        
        Args:
            x_list: List of input tensors [(batch, seq_len, n_features), ...]
        
        Returns:
            Tuple of (reconstructions, latents)
            where latents is (h_shared, [h_specific_list])
        """
        latents = self.encode(x_list)
        reconstructions = self.decode(latents)
        
        return reconstructions, latents
    
    def encode_modalities(self, x_list):
        """
        Encode to combined modality representations (shared + specific concatenated).
        
        Args:
            x_list: List of input tensors [(batch, seq_len, n_features), ...]
        
        Returns:
            List of combined latent tensors [(batch, shared_dim + specific_dim), ...]
        """
        h_shared, h_specific = self.encode(x_list)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def get_ranks(self):
        """Get current ranks of all adaptive layers."""
        ranks = []
        for adaptive_layer in self.adaptive_layers:
            ranks.append(adaptive_layer.active_dims)
        return ranks
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def get_r_squares(self):
        """Get R² values for all adaptive layers."""
        r_squares = []
        for adaptive_layer in self.adaptive_layers:
            # R² is calculated during rank reduction, return None if not available
            r_squares.append(None)
        return r_squares
    
    def reduce_rank(self, reduction_ratio=0.8, threshold=0.01, layer_ids=[], dim=0):
        """
        Reduce rank of all adaptive layers based on singular value importance.
        
        Args:
            reduction_ratio: Ratio to reduce rank by (default 0.8)
            threshold: Energy threshold for rank reduction (default 0.01)
            layer_ids: List of layer indices to reduce (empty = all layers)
            dim: Dimension along which to reduce (default 0)
        
        Returns:
            Boolean indicating if any changes were made
        """
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            # if layer_ids is specified, only reduce rank for those layers
            if layer_ids and i not in layer_ids:
                continue
            
            # Get singular values
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)

            # Find the rank that preserves specified energy threshold
            # Make sure we don't go below the minimum rank
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            which_dims = None

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the target rank from energy threshold
            new_rank = max(target_rank, ratio_rank)
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim, which_dims=which_dims)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """
        Increase rank of all adaptive layers by specified increment.
        
        Args:
            increment: Fixed increment to add (default None, uses increase_ratio)
            increase_ratio: Ratio to increase rank by (default 1.1)
            layer_ids: List of layer indices to increase (empty = all layers)
            dim: Dimension along which to increase (default 0)
        
        Returns:
            Boolean indicating if any changes were made
        """
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made


# ============================================================================
# Standard Multimodal Autoencoder (for non-sequential data like MM-IMDb)
# ============================================================================

class AdaptiveRankReducedAE_Static(nn.Module):
    """
    Standard multimodal autoencoder for non-sequential (static) features.
    
    Designed for datasets like MM-IMDb where features are already aggregated
    and don't have temporal structure (e.g., text embeddings, image features).
    
    Supports multiple modalities with:
    - Modality-specific encoders (feedforward networks)
    - Adaptive low-rank layers for dimension reduction
    - Shared and modality-specific latent representations
    - Modality-specific decoders (feedforward networks)
    """
    
    def __init__(self, input_dims, latent_dims, depth=2, width=0.5, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=10, activation=None):
        """
        Args:
            input_dims: List of input dimensions for each modality
            latent_dims: List of latent dimensions [mod1_dim, mod2_dim, ..., shared_dim]
            depth: Depth of encoder/decoder networks
            width: Width multiplier for hidden layers
            dropout: Dropout rate
            initial_rank_ratio: Initial rank ratio for adaptive layers
            min_rank: Minimum rank for adaptive layers
            activation: Optional output activation ('sigmoid', 'softmax', or None)
        """
        super(AdaptiveRankReducedAE_Static, self).__init__()
        
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_dims))])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_dims))])
        self.adaptive_layers = nn.ModuleList()
        self.input_dims = input_dims
        self.n_modalities = len(input_dims)
        
        # Handle latent_dims as int or list
        # Last latent dim is for shared space, others are modality-specific
        if isinstance(latent_dims, int):
            self.latent_dims = [latent_dims] * len(input_dims) + [latent_dims]
        else:
            self.latent_dims = latent_dims
        
        ff_input_dims = [input_dim for input_dim in input_dims]
        self.convolution = [False for _ in range(len(input_dims))]

        print(f"Creating AdaptiveRankReducedAE_Static for {len(input_dims)} modalities with")
        print(f"   input_dims={input_dims}, latent_dims={self.latent_dims}")
        print(f"   depth={depth}, width={width}, dropout={dropout}")
        print(f"   initial_rank_ratio={initial_rank_ratio}, min_rank={min_rank}")

        # Build encoder/decoder for each modality
        for m in range(len(input_dims)):
            input_dim = input_dims[m]
            ff_input_dim = ff_input_dims[m]
            
            # Handle very large input dimensions with convolution
            if input_dim > 100000:
                print(f"Input dimension {input_dim} is too large, using convolutional block.")
                padding = 0
                kernel_size = 3
                stride = 3
                self.encoders[m].append(nn.Conv1d(in_channels=1, out_channels=1, 
                                                   kernel_size=kernel_size, stride=stride, padding=padding))
                self.encoders[m].append(nn.Flatten())
                reduced_dim = int((input_dim + 2 * padding - kernel_size) / stride + 1)
                print(f"Reduced input dimension from {input_dim} to {reduced_dim}")
                ff_input_dim = reduced_dim
                self.convolution[m] = True
            
            hidden_dim = max(int(width * ff_input_dim), 100)
            print(f"Modality {m}: input_dim={ff_input_dim}, hidden_dim={hidden_dim}, latent_dim={self.latent_dims[m]}")
                
            # Build feedforward encoder
            for i in range(depth):
                if i == (depth - 1):
                    # Final encoder layer -> modality latent
                    encoder_layer = nn.Linear(hidden_dim, self.latent_dims[m])
                    self.encoders[m].append(encoder_layer)
                    
                    # Final decoder layer -> reconstruction
                    decoder_layer = nn.Linear(hidden_dim, ff_input_dim)
                    self.decoders[m].append(decoder_layer)
                else:
                    if i == 0:
                        # First encoder layer
                        encoder_layer = nn.Linear(ff_input_dim, hidden_dim)
                        self.encoders[m].append(encoder_layer)
                        
                        # First decoder layer (from shared + specific latent)
                        decoder_layer = nn.Linear(self.latent_dims[m] + self.latent_dims[-1], hidden_dim)
                        self.decoders[m].append(decoder_layer)
                    else:
                        # Middle layers
                        encoder_layer = nn.Linear(hidden_dim, hidden_dim)
                        self.encoders[m].append(encoder_layer)
                        
                        decoder_layer = nn.Linear(hidden_dim, hidden_dim)
                        self.decoders[m].append(decoder_layer)
                    
                    # Add activations and dropout
                    self.encoders[m].append(nn.ReLU())
                    self.decoders[m].append(nn.ReLU())
                    
                    if dropout > 0.0:
                        self.encoders[m].append(nn.Dropout(dropout))
                        self.decoders[m].append(nn.Dropout(dropout))
                        
            # Add transpose convolution if used convolution
            if input_dim > 100000:
                self.decoders[m].append(nn.ConvTranspose1d(in_channels=1, out_channels=1, 
                                                           kernel_size=kernel_size, stride=stride, padding=padding))
                self.decoders[m].append(nn.Flatten())
        
        # Optional output activation
        if activation == 'sigmoid':
            for m in range(len(input_dims)):
                self.decoders[m].append(nn.Sigmoid())
        elif activation == 'softmax':
            for m in range(len(input_dims)):
                self.decoders[m].append(nn.Softmax(dim=1))
            
        # Shared latent space: concatenation of all modality latents -> shared latent
        shared_layer = AdaptiveRankReducedLinear(
            sum(self.latent_dims[:len(input_dims)]), self.latent_dims[-1],
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        self.adaptive_layers.append(shared_layer)
        
        # Modality-specific latent refinement layers
        for i in range(len(input_dims)):
            specific_layer = AdaptiveRankReducedLinear(
                self.latent_dims[i], self.latent_dims[i],
                initial_rank_ratio=initial_rank_ratio,
                min_rank=min_rank
            )
            self.adaptive_layers.append(specific_layer)
    
        self.modality_weights = nn.Parameter(torch.ones(len(input_dims)), requires_grad=True)
    
    def encode(self, x):
        """
        Encode list of modality inputs to latent representations.
        
        Args:
            x: List of tensors [(batch, n_features), ...]
        
        Returns:
            Tuple of (h_shared, [h_specific_0, h_specific_1, ...])
        """
        h_concat = []
        for m, x_m in enumerate(x):
            if self.convolution[m]:
                x_m = x_m.view(x_m.shape[0], 1, -1)
            for layer in self.encoders[m]:
                x_m = layer(x_m)
            h_concat.append(x_m)
        
        # Concatenate all modality latents
        h = torch.cat(h_concat, dim=1)
        
        # Apply adaptive layers
        h_shared = self.adaptive_layers[0](h)
        
        specific_outputs = []
        for i, layer in enumerate(self.adaptive_layers[1:]):
            specific_output = layer(h_concat[i])
            specific_outputs.append(specific_output)
        
        return (h_shared, specific_outputs)

    def decode(self, h):
        """
        Decode latent representations to reconstructions.
        
        Args:
            h: Tuple of (h_shared, [h_specific_list])
        
        Returns:
            List of reconstructed tensors [(batch, n_features), ...]
        """
        h_shared, h_specific = h
        x_hat = []
        
        for m, h_m in enumerate(h_specific):
            h_concat = torch.cat([h_shared, h_m], dim=1)
            for layer in self.decoders[m]:
                if self.convolution[m] and isinstance(layer, nn.ConvTranspose1d):
                    h_concat = h_concat.view(h_concat.shape[0], 1, -1)
                h_concat = layer(h_concat)
            x_hat.append(h_concat)
        
        return x_hat

    def forward(self, x):
        """
        Forward pass: encode and decode.
        
        Args:
            x: List of input tensors [(batch, n_features), ...]
        
        Returns:
            Tuple of (reconstructions, latents)
        """
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h
    
    def encode_modalities(self, x):
        """
        Encode to combined modality representations (shared + specific concatenated).
        
        Args:
            x: List of input tensors [(batch, n_features), ...]
        
        Returns:
            List of combined latent tensors [(batch, shared_dim + specific_dim), ...]
        """
        h_shared, h_specific = self.encode(x)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def get_ranks(self):
        """Get current ranks of all adaptive layers."""
        ranks = []
        for adaptive_layer in self.adaptive_layers:
            ranks.append(adaptive_layer.active_dims)
        return ranks
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers."""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def get_r_squares(self):
        """Get R² values for all adaptive layers."""
        r_squares = []
        for adaptive_layer in self.adaptive_layers:
            r_sq = adaptive_layer.r_squared if hasattr(adaptive_layer, 'r_squared') else None
            r_squares.append(r_sq)
        return r_squares
    
    def reduce_rank(self, reduction_ratio=0.8, threshold=0.01, layer_ids=[], dim=0):
        """
        Reduce rank of all adaptive layers based on singular value importance.
        
        Args:
            reduction_ratio: Ratio to reduce rank by (default 0.8)
            threshold: Energy threshold for rank reduction (default 0.01)
            layer_ids: List of layer indices to reduce (empty = all layers)
            dim: Dimension along which to reduce (default 0)
        
        Returns:
            Boolean indicating if any changes were made
        """
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue
                
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)

            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            which_dims = None

            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            new_rank = max(target_rank, ratio_rank)
            
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim, which_dims=which_dims)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """
        Increase rank of all adaptive layers by specified increment.
        
        Args:
            increment: Fixed increment to add (default None, uses increase_ratio)
            increase_ratio: Ratio to increase rank by (default 1.1)
            layer_ids: List of layer indices to increase (empty = all layers)
            dim: Dimension along which to increase (default 0)
        
        Returns:
            Boolean indicating if any changes were made
        """
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if layer_ids and i not in layer_ids:
                continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made