import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.larrp_unimodal import AdaptiveRankReducedLinear

class AdaptiveRankReducedAE_CNN(torch.nn.Module):
    def __init__(self, input_dims, latent_dims, input_shapes, depth=2, width=0.5, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=10, activation=None):
        """
        CNN-based multimodal autoencoder with adaptive rank reduction.
        
        Parameters:
        - input_dims: list of input dimensions (for compatibility, will be ignored)
        - latent_dims: list of latent dimensions [modality1_latent, modality2_latent, ..., shared_latent]
        - input_shapes: list of tuples [(channels1, height1, width1), (channels2, height2, width2), ...]
        - depth: number of convolutional blocks (default: 2)
        - width: channel multiplier for conv layers (default: 0.5)
        - dropout: dropout rate (default: 0.0)
        - initial_rank_ratio: initial rank ratio for adaptive layers (default: 1.0)
        - min_rank: minimum rank for adaptive layers (default: 10)
        - activation: output activation function ('sigmoid', 'softmax', or None)
        """
        super(AdaptiveRankReducedAE_CNN, self).__init__()
        
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_shapes))])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_shapes))])
        self.adaptive_layers = nn.ModuleList()
        self.input_shapes = input_shapes
        self.latent_dims = latent_dims
        
        print(f"Creating AdaptiveRankReducedAE_CNN for {len(input_shapes)} modalities")
        print(f"   input_shapes={input_shapes}, latent_dims={latent_dims}")
        print(f"   depth={depth}, width={width}, dropout={dropout}")
        print(f"   initial_rank_ratio={initial_rank_ratio}, min_rank={min_rank}")
        
        # Build encoder and decoder for each modality
        for m in range(len(input_shapes)):
            in_channels, height, width_dim = input_shapes[m]
            
            # Calculate the feature map sizes after convolutions
            current_h, current_w = height, width_dim
            current_channels = in_channels
            
            # Base number of channels for this modality
            base_channels = max(16, int(32 * width))
            
            # Build encoder: series of Conv2d -> ReLU -> (optional Dropout)
            channels_list = [current_channels]
            for i in range(depth):
                out_channels = base_channels * (2 ** i)
                
                self.encoders[m].append(nn.Conv2d(
                    current_channels, out_channels, 
                    kernel_size=3, stride=2, padding=1
                ))
                self.encoders[m].append(nn.ReLU())
                if dropout > 0.0:
                    self.encoders[m].append(nn.Dropout2d(dropout))
                
                channels_list.append(out_channels)
                current_channels = out_channels
                current_h = (current_h + 2 * 1 - 3) // 2 + 1
                current_w = (current_w + 2 * 1 - 3) // 2 + 1
            
            # Flatten and project to latent space
            flattened_size = current_channels * current_h * current_w
            self.encoders[m].append(nn.Flatten())
            self.encoders[m].append(nn.Linear(flattened_size, latent_dims[m]))
            
            print(f"Modality {m}: {input_shapes[m]} -> flattened_size={flattened_size} -> latent={latent_dims[m]}")
            
            # Store decoder shape information
            self.encoders[m].final_conv_h = current_h
            self.encoders[m].final_conv_w = current_w
            self.encoders[m].final_conv_channels = current_channels
            
            # Build decoder: Linear -> Unflatten -> series of ConvTranspose2d -> ReLU -> (optional Dropout)
            # First layer: from latent back to flattened conv size
            self.decoders[m].append(nn.Linear(latent_dims[m] + latent_dims[-1], flattened_size))
            self.decoders[m].append(nn.ReLU())
            if dropout > 0.0:
                self.decoders[m].append(nn.Dropout(dropout))
            
            # Unflatten to conv shape
            self.decoders[m].append(nn.Unflatten(1, (current_channels, current_h, current_w)))
            
            # Transpose convolutions (reverse order of encoder)
            for i in range(depth - 1, -1, -1):
                in_ch = channels_list[i + 1]
                out_ch = channels_list[i]
                
                self.decoders[m].append(nn.ConvTranspose2d(
                    in_ch, out_ch,
                    kernel_size=3, stride=2, padding=1, output_padding=1
                ))
                
                if i > 0:  # Don't add ReLU/Dropout after final layer
                    self.decoders[m].append(nn.ReLU())
                    if dropout > 0.0:
                        self.decoders[m].append(nn.Dropout2d(dropout))
        
        # Add output activation if specified
        if activation == 'sigmoid':
            for m in range(len(input_shapes)):
                self.decoders[m].append(nn.Sigmoid())
        elif activation == 'softmax':
            for m in range(len(input_shapes)):
                self.decoders[m].append(nn.Softmax(dim=1))
        
        # Adaptive rank reduction layers for learning shared and modality-specific spaces
        shared_layer = AdaptiveRankReducedLinear(
            sum(latent_dims[:len(input_shapes)]), latent_dims[-1],
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        self.adaptive_layers.append(shared_layer)
        
        for i in range(len(input_shapes)):
            specific_layer = AdaptiveRankReducedLinear(
                latent_dims[i], latent_dims[i],
                initial_rank_ratio=initial_rank_ratio,
                min_rank=min_rank
            )
            self.adaptive_layers.append(specific_layer)
        
        # Initialize modality weights for balanced training
        self.modality_weights = nn.Parameter(torch.ones(len(input_shapes)), requires_grad=True)
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            # if layer_ids is specified, only reduce rank for those layers
            if i not in layer_ids:
                continue
            # Get singular values
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)
            #print(f"Layer {i}: singular values = {S.cpu().numpy()}")
            #print(f"Layer {i}: cumulative energy = {cumulative_energy.cpu().numpy()}")

            # Find the rank that preserves specified energy threshold
            # Make sure we don't go below the minimum rank
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            # get the indices of the dims where the cum energy is below the threshold ### we don't need this because the svd decomp makes the dims sorted by energy from left to right ###
            #if target_rank > layer.min_rank:
            #    which_dims = torch.where(cumulative_energy < (1.0 - threshold))[0]
            #else:
            #    which_dims = torch.where(cumulative_energy < (1.0 - threshold))[0][:layer.min_rank]
            # test if which_dims includes all dims from left to the target_rank
            #n_left_of_target = torch.sum(which_dims < target_rank).item()
            #n_right_of_target = torch.sum(which_dims >= target_rank).item()
            #print(f"Layer {i}: target_rank = {target_rank}, left of target = {n_left_of_target}, right of target = {n_right_of_target}")
            which_dims = None

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the larger of the two approaches
            new_rank = max(target_rank, ratio_rank)
            #new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim, which_dims=which_dims)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers by specified increment"""
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids:
                continue
            #print(f"Increasing rank for layer {i}")
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def encode(self, x, compute_jacobian=False):
        """
        Encode each modality through its CNN encoder.
        
        Parameters:
        - x: list of tensors, each with shape (B, C, H, W)
        - compute_jacobian: whether to compute Jacobian for contractive loss
        
        Returns:
        - (h_shared, h_specific) tuple of latent representations
        """
        #print("Encoding modalities...")
        h_concat = []
        for m, x_m in enumerate(x):
            #print(f"  Modality {m}: input shape = {x_m.shape}")
            # x_m should already be (B, C, H, W)
            for layer in self.encoders[m]:
                x_m = layer(x_m)
                #print(f"    After layer {layer.__class__.__name__}: shape = {x_m.shape}")
            #print(f"    Encoded latent shape: {x_m.shape}")
            h_concat.append(x_m)
        
        # Concatenate all modality-specific latent representations
        h = torch.cat(h_concat, dim=1)
        
        if not compute_jacobian:
            h_shared = self.adaptive_layers[0](h)
            specific_outputs = []
            for i, layer in enumerate(self.adaptive_layers[1:]):
                specific_output = layer(h_concat[i])
                specific_outputs.append(specific_output)
            return (h_shared, specific_outputs)
        else:
            h_split = [self.adaptive_layers[0](h)]
            weights = [self.adaptive_layers[0].get_weights()]
            for i, layer in enumerate(self.adaptive_layers[1:]):
                specific_output = layer(h_concat[i])
                h_split.append(specific_output)
                weights.append(layer.get_weights())
            
            contractive_losses = []
            for i, (activation, weight) in enumerate(zip(h_split, weights)):
                batch_size = activation.shape[0]
                derivative = (activation > 0).float()
                w_squared = torch.sum(weight**2, dim=1)
                d_squared = derivative**2
                contractive_loss_layer = torch.sum(d_squared * w_squared.unsqueeze(0))
                contractive_losses.append((contractive_loss_layer / batch_size).detach().cpu().item())
            return h_split, contractive_losses

    def decode(self, h):
        """
        Decode latent representations back to original modality shapes.
        
        Parameters:
        - h: tuple of (h_shared, h_specific) latent representations
        
        Returns:
        - list of reconstructed tensors, each with shape (B, C, H, W)
        """
        h_shared, h_specific = h
        x_hat = []
        #print("Decoding modalities...")
        #print(f"  Shared latent shape: {h_shared.shape}")
        for m, h_m in enumerate(h_specific):
            #print(f"  Modality {m}: specific latent shape = {h_m.shape}")
            h_concat = torch.cat([h_shared, h_m], dim=1)
            for layer in self.decoders[m]:
                h_concat = layer(h_concat)
                #print(f"    After layer {layer.__class__.__name__}: shape = {h_concat.shape}")
            #print(f"    Decoded output shape: {h_concat.shape}")
            x_hat.append(h_concat)
        #exit()
        return x_hat

    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h
    
    def encode_modalities(self, x):
        h_shared, h_specific = self.encode(x)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    

class MMSimData(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        self.n_modalities = len(data) if isinstance(data, list) else 1
        self.n_samples = data[0].shape[0] if isinstance(data, list) else data.shape[0]
        self.n_features = [mod.shape[1] for mod in data] if isinstance(data, list) else data.shape[1]
        self.total_features = sum(mod.shape[1] for mod in data) if isinstance(data, list) else data.shape[1]

    def __len__(self):
        return len(self.data[0]) if isinstance(self.data, list) else len(self.data)

    def __getitem__(self, idx):
        return [self.data[i][idx].float() for i in range(self.n_modalities)]
