import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.larrp_unimodal import AdaptiveRankReducedLinear

class AdaptiveRankReducedAE(torch.nn.Module):
    def __init__(self, input_dims, latent_dims, depth=2, width=0.5, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=10, activation=None):
        super(AdaptiveRankReducedAE, self).__init__()
        
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_dims))])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_dims))])
        self.adaptive_layers = nn.ModuleList()  # Track adaptive rank layers for rank reduction
        self.input_dims = input_dims
        
        #hidden_dims = [int(width * input_dims[i]) for i in range(len(input_dims))]
        ff_input_dims = [input_dim for input_dim in input_dims]
        self.convolution = [False for _ in range(len(input_dims))]  # Track if convolutional block is used

        print(f"Creating AdaptiveRankReducedAE for {len(input_dims)} modalities with\n   input_dims={input_dims}, latent_dims={latent_dims}, "
              f"depth={depth}, width={width}, dropout={dropout}")
        #print(f"   hidden_dims: {hidden_dims}, ff_input_dims: {ff_input_dims}")
        print(f"   initial_rank_ratio: {initial_rank_ratio}, min_rank: {min_rank}")

        # Large input dimension handling with convolutional block
        for m in range(len(input_dims)):
            input_dim = input_dims[m]
            #hidden_dim = hidden_dims[m]
            ff_input_dim = ff_input_dims[m]
            if input_dim > 100000:
                print(f"Input dimension {input_dim} is too large, using convolutional block to reduce it.")
                padding = 0
                #padding = 1
                kernel_size = 3
                #stride = 2
                stride = 3
                # Use a 1D convolutional layer to reduce the input dimension
                self.encoders[m].append(torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
                self.encoders[m].append(torch.nn.Flatten())
                reduced_dim = int((input_dim + 2 * padding - kernel_size) / stride + 1)
                print(f"Reduced input dimension from {input_dim} to {reduced_dim} using convolutional block.")
                #hidden_dim = int(width * reduced_dim)
                ff_input_dim = reduced_dim
                self.convolution[m] = True
            hidden_dim = max(int(width * ff_input_dim), 100)  # Ensure hidden dim is at least 100
            print(f"Modality {m}: ff_input_dim={ff_input_dim}, hidden_dim={hidden_dim}")
                
            for i in range(depth):
                if i == (depth - 1):
                    # Bottleneck layer - THIS is the only place to use AdaptiveRankReducedLinear
                    #encoder_layer = AdaptiveRankReducedLinear(
                    #    hidden_dim, latent_dim, 
                    #    initial_rank_ratio=initial_rank_ratio,
                    #    min_rank=min_rank
                    #)
                    #self.encoders[m].append(encoder_layer)
                    #self.adaptive_layers.append(encoder_layer)
                    encoder_layer = nn.Linear(hidden_dim, latent_dims[m])
                    self.encoders[m].append(encoder_layer)
                    
                    # Final decoder layer - standard linear
                    decoder_layer = nn.Linear(hidden_dim, ff_input_dim)
                    self.decoders[m].append(decoder_layer)
                else:
                    if i == 0:
                        # First encoder layer - input to hidden (standard linear)
                        encoder_layer = nn.Linear(ff_input_dim, hidden_dim)
                        self.encoders[m].append(encoder_layer)
                        
                        # First decoder layer - latent to hidden (standard linear)
                        decoder_layer = nn.Linear(latent_dims[m]+latent_dims[-1], hidden_dim)
                        self.decoders[m].append(decoder_layer)
                    else:
                        # Middle layers - all standard linear
                        encoder_layer = nn.Linear(hidden_dim, hidden_dim)
                        self.encoders[m].append(encoder_layer)
                        
                        decoder_layer = nn.Linear(hidden_dim, hidden_dim)
                        self.decoders[m].append(decoder_layer)
                    
                    # Add activation
                    self.encoders[m].append(nn.ReLU())
                    self.decoders[m].append(nn.ReLU())
                    
                    # Add dropout if specified
                    if dropout > 0.0:
                        self.encoders[m].append(nn.Dropout(dropout))
                        self.decoders[m].append(nn.Dropout(dropout))
                        
            if input_dim > 100000:
                # Add a final convolutional layer to upsample back to the original input dimension
                self.decoders[m].append(torch.nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
                self.decoders[m].append(torch.nn.Flatten())
        
        # append sigmoids to the decoders (only for NiNFEA data)
        #for m in range(len(input_dims)):
        #    self.decoders[m].append(nn.Sigmoid())
        if activation == 'sigmoid':
            for m in range(len(input_dims)):
                self.decoders[m].append(nn.Sigmoid())
        elif activation == 'softmax':
            for m in range(len(input_dims)):
                self.decoders[m].append(nn.Softmax(dim=1))
        
        # try max pooling instead of concatenation
        #self.latent_pooling = nn.AdaptiveMaxPool1d(1)
            
        # now for the integral part where we learn the separate spaces
        shared_layer = AdaptiveRankReducedLinear(
            sum(latent_dims[:len(input_dims)]), latent_dims[-1], # last latent dim is for the shared space
            #latent_dims[0], latent_dims[-1], # try max pooling
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        self.adaptive_layers.append(shared_layer)
        for i in range(len(input_dims)):
            specific_layer = AdaptiveRankReducedLinear(
                #sum(latent_dims[:len(input_dims)]), latent_dims[i],
                latent_dims[i], latent_dims[i],
                #latent_dims[0], latent_dims[i], # try max pooling
                initial_rank_ratio=initial_rank_ratio,
                min_rank=min_rank
            )
            self.adaptive_layers.append(specific_layer)
    
        # Initialize log variance parameters for loss balancing
        #self.log_vars = nn.Parameter(torch.zeros(len(input_dims)), requires_grad=True)
        # Initialize modality weights for balanced training
        self.modality_weights = nn.Parameter(torch.ones(len(input_dims)), requires_grad=True)
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            # if layer_ids is specified, only reduce rank for those layers
            if i not in layer_ids:
                continue
            # Get singular values
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)
            #print(f"Layer {i}: singular values = {S.cpu().numpy()}")
            #print(f"Layer {i}: cumulative energy = {cumulative_energy.cpu().numpy()}")

            # Find the rank that preserves specified energy threshold
            # Make sure we don't go below the minimum rank
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            # get the indices of the dims where the cum energy is below the threshold ### we don't need this because the svd decomp makes the dims sorted by energy from left to right ###
            #if target_rank > layer.min_rank:
            #    which_dims = torch.where(cumulative_energy < (1.0 - threshold))[0]
            #else:
            #    which_dims = torch.where(cumulative_energy < (1.0 - threshold))[0][:layer.min_rank]
            # test if which_dims includes all dims from left to the target_rank
            #n_left_of_target = torch.sum(which_dims < target_rank).item()
            #n_right_of_target = torch.sum(which_dims >= target_rank).item()
            #print(f"Layer {i}: target_rank = {target_rank}, left of target = {n_left_of_target}, right of target = {n_right_of_target}")
            which_dims = None

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the larger of the two approaches
            new_rank = max(target_rank, ratio_rank)
            #new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim, which_dims=which_dims)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers by specified increment"""
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids:
                continue
            #print(f"Increasing rank for layer {i}")
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def encode(self, x, compute_jacobian=False):
        h_concat = []
        for m, x_m in enumerate(x):
            if self.convolution[m]:
                x_m = x_m.view(x_m.shape[0], 1, -1)
            for layer in self.encoders[m]:
                #print(next(layer.parameters()).device, x_m.device)
                x_m = layer(x_m)
            h_concat.append(x_m)
        h = torch.cat(h_concat, dim=1)
        # try max pooling
        #h = self.latent_pooling(torch.stack(h_concat, dim=2)).squeeze(-1)

        if not compute_jacobian:
            h_shared = self.adaptive_layers[0](h)
            specific_outputs = []
            for i, layer in enumerate(self.adaptive_layers[1:]):
                #specific_output = layer(h)
                specific_output = layer(h_concat[i])
                specific_outputs.append(specific_output)
            return (h_shared, specific_outputs)
        else:
            h_split = [self.adaptive_layers[0](h)]
            weights = [self.adaptive_layers[0].get_weights()]
            for i, layer in enumerate(self.adaptive_layers[1:]):
                #specific_output = layer(h)
                specific_output = layer(h_concat[i])
                h_split.append(specific_output)
                weights.append(layer.get_weights())
            
            contractive_losses = []
            for i, (activation, weight) in enumerate(zip(h_split, weights)):
                # Approximate Jacobian as outer product of activations and weights
                batch_size = activation.shape[0]
                derivative = (activation > 0).float() # for ReLU only, for sigmoid would be h * (1-h)
                # This is the key vectorized calculation for the penalty of one layer
                # It calculates sum_j ( (h'(a)_j)^2 * sum_i (W_ji)^2 ) efficiently
                w_squared = torch.sum(weight**2, dim=1)
                d_squared = derivative**2
                contractive_loss_layer = torch.sum(d_squared * w_squared.unsqueeze(0)) # Broadcasting over batch
                contractive_losses.append((contractive_loss_layer / batch_size).detach().cpu().item())  # Average over batch
            return h_split, contractive_losses

    def decode(self, h):
        h_shared, h_specific = h
        x_hat = []
        for m, h_m in enumerate(h_specific):
            h_concat = torch.cat([h_shared, h_m], dim=1)
            #print(f"Decoding modality {m} with shape {h_concat.shape}")
            for layer in self.decoders[m]:
                if self.convolution[m] and isinstance(layer, nn.ConvTranspose1d):
                    h_concat = h_concat.view(h_concat.shape[0], 1, -1)
                h_concat = layer(h_concat)
            x_hat.append(h_concat)
        return x_hat

    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h
    
    def encode_modalities(self, x):
        h_shared, h_specific = self.encode(x)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined