import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.larrp_unimodal import AdaptiveRankReducedLinear


class AdaptiveRankReducedVAE(torch.nn.Module):
    """
    Multi-modal Variational Autoencoder with adaptive rank reduction.
    
    Architecture:
    - Each modality has its own encoder that produces a latent representation
    - Adaptive layers predict:
        1. Shared latent space (mean and logvar)
        2. Modality-specific latent spaces (mean and logvar for each modality)
    - Decoders reconstruct from concatenation of shared + modality-specific latents
    """
    def __init__(self, input_dims, latent_dims, depth=2, width=0.5, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=10):
        super(AdaptiveRankReducedVAE, self).__init__()
        
        self.n_modalities = len(input_dims)
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(self.n_modalities)])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(self.n_modalities)])
        self.adaptive_layers = nn.ModuleList()  # Track adaptive rank layers for rank reduction
        
        hidden_dims = [int(width * input_dims[i]) for i in range(self.n_modalities)]
        ff_input_dims = [input_dim for input_dim in input_dims]
        self.convolution = [False for _ in range(self.n_modalities)]  # Track if convolutional block is used

        print(f"Creating AdaptiveRankReducedVAE for {self.n_modalities} modalities with\n   input_dims={input_dims}, latent_dims={latent_dims}, "
              f"depth={depth}, width={width}, dropout={dropout}")
        print(f"   hidden_dims: {hidden_dims}, ff_input_dims: {ff_input_dims}")
        print(f"   initial_rank_ratio: {initial_rank_ratio}, min_rank: {min_rank}")

        # Build encoders for each modality (up to pre-latent representation)
        for m in range(self.n_modalities):
            input_dim = input_dims[m]
            hidden_dim = hidden_dims[m]
            ff_input_dim = ff_input_dims[m]
            
            # Large input dimension handling with convolutional block
            if input_dim > 100000:
                print(f"Input dimension {input_dim} is too large, using convolutional block to reduce it.")
                padding = 0
                kernel_size = 3
                stride = 2
                self.encoders[m].append(torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
                self.encoders[m].append(torch.nn.Flatten())
                reduced_dim = int((input_dim + 2 * padding - kernel_size) / stride + 1)
                print(f"Reduced input dimension from {input_dim} to {reduced_dim} using convolutional block.")
                hidden_dim = int(width * reduced_dim)
                ff_input_dim = reduced_dim
                self.convolution[m] = True
                
            # Build encoder layers (all standard, no adaptive layers here)
            for i in range(depth):
                if i == (depth - 1):
                    # Final encoder layer outputs pre-latent representation
                    encoder_layer = nn.Linear(hidden_dim, latent_dims[m])
                    self.encoders[m].append(encoder_layer)
                else:
                    if i == 0:
                        encoder_layer = nn.Linear(ff_input_dim, hidden_dim)
                    else:
                        encoder_layer = nn.Linear(hidden_dim, hidden_dim)
                    self.encoders[m].append(encoder_layer)
                    self.encoders[m].append(nn.ReLU())
                    
                    if dropout > 0.0:
                        self.encoders[m].append(nn.Dropout(dropout))

        # Adaptive layers for latent spaces
        # Shared space: deterministic (no sampling)
        # Modality-specific spaces: variational (with mean and logvar)
        
        # Shared space: maps from concatenated pre-latent to deterministic shared representation
        concat_dim = sum(latent_dims[:self.n_modalities])
        shared_dim = latent_dims[-1]  # Last latent dim is for the shared space
        
        # Shared layer (deterministic, no VAE parameters)
        self.shared_layer = AdaptiveRankReducedLinear(
            concat_dim, shared_dim,
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        self.adaptive_layers.append(self.shared_layer)
        
        # Modality-specific spaces: each maps from its own pre-latent to mu
        self.specific_mu_layers = nn.ModuleList()
        self.specific_logvars = nn.ParameterList()
        
        for i in range(self.n_modalities):
            # Mean
            specific_mu_layer = AdaptiveRankReducedLinear(
                latent_dims[i], latent_dims[i],
                initial_rank_ratio=initial_rank_ratio,
                min_rank=min_rank
            )
            self.specific_mu_layers.append(specific_mu_layer)
            self.adaptive_layers.append(specific_mu_layer)
            
            # Logvar (single scalar per modality)
            specific_logvar = nn.Parameter(torch.zeros(1))
            self.specific_logvars.append(specific_logvar)

        # Build decoders for each modality
        for m in range(self.n_modalities):
            input_dim = input_dims[m]
            hidden_dim = hidden_dims[m]
            ff_input_dim = ff_input_dims[m]
            
            # Recompute dimensions if convolution was used
            if self.convolution[m]:
                padding = 0
                kernel_size = 3
                stride = 2
                reduced_dim = int((input_dim + 2 * padding - kernel_size) / stride + 1)
                hidden_dim = int(width * reduced_dim)
                ff_input_dim = reduced_dim
            
            for i in range(depth):
                if i == (depth - 1):
                    # Final decoder layer
                    decoder_layer = nn.Linear(hidden_dim, ff_input_dim)
                    self.decoders[m].append(decoder_layer)
                else:
                    if i == 0:
                        # First decoder layer: from concatenated latent (shared + specific)
                        decoder_layer = nn.Linear(latent_dims[m] + latent_dims[-1], hidden_dim)
                    else:
                        decoder_layer = nn.Linear(hidden_dim, hidden_dim)
                    self.decoders[m].append(decoder_layer)
                    self.decoders[m].append(nn.ReLU())
                    
                    if dropout > 0.0:
                        self.decoders[m].append(nn.Dropout(dropout))
                        
            if self.convolution[m]:
                # Add transpose conv to upsample back
                padding = 0
                kernel_size = 3
                stride = 2
                self.decoders[m].append(torch.nn.ConvTranspose1d(in_channels=1, out_channels=1, 
                                                                   kernel_size=kernel_size, stride=stride, padding=padding))
                self.decoders[m].append(torch.nn.Flatten())
        
        # Initialize modality weights for balanced training
        self.modality_weights = nn.Parameter(torch.ones(self.n_modalities), requires_grad=True)
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick for VAE"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def encode(self, x, compute_jacobian=False):
        """
        Encode input data to latent representations.
        
        Returns:
            - z_shared: Deterministic shared latent representation
            - specific_mus, specific_logvars: Lists of parameters for modality-specific distributions
            - h_concat: Concatenated pre-latent representations (for potential use)
        """
        # Get pre-latent representations from each encoder
        h_list = []
        for m, x_m in enumerate(x):
            if self.convolution[m]:
                x_m = x_m.view(x_m.shape[0], 1, -1)
            for layer in self.encoders[m]:
                x_m = layer(x_m)
            h_list.append(x_m)
        
        # Concatenate for shared space
        h_concat = torch.cat(h_list, dim=1)
        
        # Predict deterministic shared representation (no sampling)
        z_shared = self.shared_layer(h_concat)
        
        # Get batch size for broadcasting logvars
        batch_size = z_shared.size(0)
        
        # Predict modality-specific means
        specific_mus = []
        specific_logvars = []
        for i in range(self.n_modalities):
            specific_mu = self.specific_mu_layers[i](h_list[i])
            # Broadcast specific logvar to match batch size and dimension
            specific_logvar = self.specific_logvars[i].expand(batch_size, specific_mu.size(1))
            specific_mus.append(specific_mu)
            specific_logvars.append(specific_logvar)
        
        if compute_jacobian:
            # Compute contractive losses if requested
            contractive_losses = []
            
            # Shared layer (deterministic)
            weight = self.shared_layer.get_weights()
            activation = z_shared
            batch_size = activation.shape[0]
            derivative = (activation > 0).float()
            w_squared = torch.sum(weight**2, dim=1)
            d_squared = derivative**2
            contractive_loss_layer = torch.sum(d_squared * w_squared.unsqueeze(0)) / batch_size
            contractive_losses.append(contractive_loss_layer.detach().cpu().item())
            
            # Specific mu layers
            for i in range(self.n_modalities):
                weight = self.specific_mu_layers[i].get_weights()
                activation = specific_mus[i]
                derivative = (activation > 0).float()
                w_squared = torch.sum(weight**2, dim=1)
                d_squared = derivative**2
                contractive_loss_layer = torch.sum(d_squared * w_squared.unsqueeze(0)) / batch_size
                contractive_losses.append(contractive_loss_layer.detach().cpu().item())
            
            return (z_shared, specific_mus, specific_logvars, h_concat), contractive_losses
        
        return z_shared, specific_mus, specific_logvars, h_concat
    
    def decode(self, z_shared, z_specifics):
        """
        Decode from latent representations to reconstructions.
        
        Args:
            z_shared: Sampled shared latent (batch_size, shared_dim)
            z_specifics: List of sampled modality-specific latents
        
        Returns:
            List of reconstructed outputs for each modality
        """
        x_hat = []
        for m in range(self.n_modalities):
            # Concatenate shared and specific latents
            h_concat = torch.cat([z_shared, z_specifics[m]], dim=1)
            
            for layer in self.decoders[m]:
                if self.convolution[m] and isinstance(layer, nn.ConvTranspose1d):
                    h_concat = h_concat.view(h_concat.shape[0], 1, -1)
                h_concat = layer(h_concat)
            h_concat = torch.relu(h_concat)  # Final activation
            x_hat.append(h_concat)
        return x_hat
    
    def forward(self, x):
        """
        Forward pass through VAE.
        
        Returns:
            - x_hat: Reconstructed outputs
            - vae_params: Dictionary containing all VAE parameters for loss computation
        """
        # Encode
        z_shared, specific_mus, specific_logvars, h_concat = self.encode(x)
        
        # Sample modality-specific latents using reparameterization trick
        # Shared latent is deterministic (no sampling)
        z_specifics = [self.reparameterize(mu, logvar) for mu, logvar in zip(specific_mus, specific_logvars)]
        
        # Decode
        x_hat = self.decode(z_shared, z_specifics)
        
        # Package parameters for loss computation
        vae_params = {
            'z_shared': z_shared,  # Deterministic
            'specific_mus': specific_mus,
            'specific_logvars': specific_logvars,
            'z_specifics': z_specifics
        }
        
        return x_hat, vae_params
    
    def encode_modalities(self, x):
        """
        Encode and sample latents, then return combined latents for each modality.
        Useful for downstream analysis.
        """
        
        z_shared, specific_mus, specific_logvars, _ = self.encode(x)
        
        # Sample modality-specific latents
        z_specifics = [self.reparameterize(mu, logvar) for mu, logvar in zip(specific_mus, specific_logvars)]
        
        # Concatenate for each modality
        h_combined = []
        for i in range(self.n_modalities):
            h_combined.append(torch.cat([z_shared, z_specifics[i]], dim=1))
        return h_combined
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids:
                continue
            
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)
            
            # Find the rank that preserves specified energy threshold
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            which_dims = None
            
            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the energy-based target
            new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim, which_dims=which_dims)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers by specified increment"""
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids:
                continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
