import torch
import torch.nn as nn
import torch.nn.functional as F

class AdaptiveRankReducedLinear(nn.Module):
    """
    Linear layer with adaptive rank reduction as described in the paper:
    "Rank-Reduced Neural Networks for Data Compression" (https://arxiv.org/pdf/2405.13980)
    """
    def __init__(self, in_features, out_features, initial_rank_ratio=1.0, min_rank=10, bias=True):
        super(AdaptiveRankReducedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.min_rank = max(1, min_rank)

        # Calculate maximum possible rank
        self.max_rank = min(in_features, out_features)
        
        # Start with full rank or specified initial rank
        self.current_rank = max(1, int(self.max_rank * initial_rank_ratio))
        
        # Create factorized weight matrices at full dimension
        self.U = nn.Parameter(torch.Tensor(out_features, self.max_rank))
        self.V = nn.Parameter(torch.Tensor(self.max_rank, in_features))
        
        # Keep track of active dimensions
        self.active_dims = self.current_rank
        
        # Keep track of singular values for adaptive rank reduction
        self.register_buffer('singular_values', torch.ones(self.max_rank))
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
            
        self.init_parameters()

        # initialize a vector that keeps track of which dimensions are active (to not just go from right to left)
        #self.active_dims_vector = nn.Parameter(torch.ones(self.max_rank), requires_grad=False)

    def init_parameters(self):
        # Initialize using Xavier initialization
        nn.init.xavier_uniform_(self.U)
        nn.init.xavier_uniform_(self.V)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def reduce_rank(self, new_rank, dim=0, which_dims=None):
        """Reduce the effective rank by zeroing out smallest singular values"""
        #self.max_rank = min(self.max_rank, self.active_dims+1) # so we know we shouldn't go above that later on
        if self.active_dims == self.min_rank:
            #print(f"Warning: Attempting to reduce rank below minimum rank {self.min_rank}")
            return False
        if new_rank < self.min_rank:
            new_rank = self.min_rank
        with torch.no_grad():
            # Compute current weight matrix
            W = torch.matmul(self.U, self.V)
            
            # Perform SVD
            U, S, V = torch.svd(W)
            #print(f"Shapes: U={U.shape}, S={S.shape}, V={V.shape}")

            # Zero out smallest singular values but keep matrix dimensions
            # Set all singular values after new_rank to zero
            zeroing_mask = torch.ones_like(S)
            if which_dims is not None:
                #print(f"Reducing rank from {self.active_dims} to {new_rank} (dims {which_dims.tolist()})")
                dims_to_remove = list(set(range(len(S))) - set(which_dims.tolist()))
                if len(which_dims) != new_rank:
                    print(f"Warning: which_dims length {len(which_dims)} does not match new_rank {new_rank}")
                if len(dims_to_remove) != (len(S) - new_rank):
                    print(f"Warning: dims_to_remove length {len(dims_to_remove)} does not match expected {len(S) - new_rank}")
                #print(f"Removing dims {dims_to_remove}")
                #zeroing_mask[which_dims] = 0
                # reorder S, U, and V so that the to-be-kicked-out dims are last
                U_new = torch.zeros_like(U)
                V_new = torch.zeros_like(V)
                S_new = torch.zeros_like(S)
                U_new[:new_rank,:] = U[which_dims,:]
                U_new[new_rank:,:] = U[dims_to_remove,:]
                U = U_new
                S_new[:new_rank] = S[which_dims]
                S_new[new_rank:] = S[dims_to_remove]
                S = S_new
                V_new[:,:new_rank] = V[:,which_dims]
                V_new[:,new_rank:] = V[:,dims_to_remove]
                V = V_new
            #if dim == 0:
            zeroing_mask[new_rank:] = 0
            #elif dim == 1:
            #    zeroing_mask[:,new_rank:] = 0
            #else:
            #    raise ValueError("Invalid dimension for rank reduction")
            S_reduced = S * zeroing_mask
            
            # Store singular values for monitoring
            self.singular_values = S.detach().clone()
            
            # Reconstruct U and V with reduced effective rank
            # U_reduced will have zeros in columns beyond the new rank
            # V_reduced will have zeros in rows beyond the new rank
            sqrt_S = torch.sqrt(S_reduced)
            
            # Prepare scaled U and V matrices
            U_scaled = U * sqrt_S.unsqueeze(0)
            V_scaled = torch.matmul(torch.diag(sqrt_S), V.t())
            
            # Update parameters while maintaining original dimensions
            self.U.data.copy_(U_scaled)
            self.V.data.copy_(V_scaled)
            
            # Update current rank (for tracking)
            self.active_dims = new_rank
            
        return True
    
    def increase_rank(self, increment=None, increase_ratio=1.1, dim=0, mode='unimodal'):
        """Increase the effective rank by activating more singular values because it went too low"""
        #print(f"Increasing rank from {self.active_dims}. Max rank is {self.max_rank}")
        if increment is None:
            increment = max(1, int(self.active_dims * (increase_ratio - 1)))
        if self.active_dims >= self.max_rank:
            #print(f"Warning: Attempting to increase rank beyond max rank {self.max_rank}")
            return False
        #if mode == 'unimodal':
        #    self.min_rank = min(self.active_dims, self.max_rank) # setting the minimum rank to current active dimensions + 1 to prevent going too low again
        with torch.no_grad():
            # Calculate the new rank (ensuring we don't exceed max_rank)
            new_rank = min(self.active_dims + increment, self.max_rank)
            
            # If we're already at max rank, no change needed
            if new_rank <= self.active_dims:
                #print(f"Rank is already at maximum or cannot be increased further.")
                return False
                
            # Compute current weight matrix
            W = torch.matmul(self.U, self.V)
            
            # Perform SVD
            U, S, V = torch.svd(W)
            
            # Create mask for active singular values (including newly activated ones)
            zeroing_mask = torch.ones_like(S)
            if dim == 0:
                # Increase along output dimension
                zeroing_mask[new_rank:] = 0
            elif dim == 1:
                # Increase along input dimension  
                zeroing_mask[:,new_rank:] = 0
            else:
                raise ValueError("Invalid dimension for rank increase. Use 0 for output dim, 1 for input dim")
                
            S_increased = S * zeroing_mask
            
            # Store singular values for monitoring
            self.singular_values = S.detach().clone()
            
            # Reconstruct U and V with increased effective rank
            sqrt_S = torch.sqrt(S_increased)
            
            # Prepare scaled U and V matrices
            U_scaled = U * sqrt_S.unsqueeze(0)
            V_scaled = torch.matmul(torch.diag(sqrt_S), V.t())
            
            # Update parameters while maintaining original dimensions
            self.U.data.copy_(U_scaled)
            self.V.data.copy_(V_scaled)
            
            # Update current rank (for tracking)
            self.active_dims = min(new_rank, self.max_rank)
            
            #print(f"Increased rank to {self.active_dims}")
            
        return True
    
    def get_rank_reduction_info(self):
        """Return information about singular values for making rank reduction decisions"""
        # Calculate full SVD if needed
        with torch.no_grad():
            W = torch.matmul(self.U, self.V)
            _, S, _ = torch.svd(W)
            return S
    
    def forward(self, x):
        # Compute W = U * V on the fly
        # Use matmul for better efficiency with low-rank matrices
        # For effective rank reduction, we only use the active dimensions
        U_active = self.U[:, :self.active_dims]
        V_active = self.V[:self.active_dims, :]

        # try forcing the weight matrix to be of this rank only
        W = torch.matmul(U_active, V_active)
        W[self.active_dims:] = 0  # Remove out dimensions beyond active rank
        #W[:,self.active_dims:] = 0  # Zero out dimensions beyond active rank
        with torch.no_grad():
            self.bias[self.active_dims:] = 0  # Zero out bias beyond active rank
        
        return F.linear(x, W, self.bias)

        # testing without a bias
        #return F.linear(x, W)
    
        #return F.linear(x, torch.matmul(U_active, V_active), self.bias)
    
    def extra_repr(self):
        return f'in_features={self.in_features}, out_features={self.out_features}, current_rank={self.active_dims}'
    
    def get_weights(self):
        U_active = self.U[:, :self.active_dims]
        V_active = self.V[:self.active_dims, :]
        
        return torch.matmul(U_active, V_active)

class AdaptiveRankReducedAE(torch.nn.Module):
    def __init__(self, input_dim, latent_dim, depth=2, width=0.5, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=10):
        super(AdaptiveRankReducedAE, self).__init__()
        
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.adaptive_layers = []  # Track adaptive rank layers for rank reduction
        
        hidden_dim = int(width * input_dim)
        ff_input_dim = input_dim
        self.convolution = False
        
        #print(f"Creating AdaptiveRankReducedAE with\n   input_dim={input_dim}, latent_dim={latent_dim}, "
        #      f"depth={depth}, width={width}, dropout={dropout}")
        #print(f"   hidden_dim: {hidden_dim}, ff_input_dim: {ff_input_dim}")
        #print(f"   initial_rank_ratio: {initial_rank_ratio}, min_rank: {min_rank}")

        # Large input dimension handling with convolutional block
        if input_dim > 100000:
            print(f"Input dimension {input_dim} is too large, using convolutional block to reduce it.")
            padding = 0
            #kernel_size = 3
            #stride = 2
            kernel_size = 5
            stride = 4
            # Use a 1D convolutional layer to reduce the input dimension
            self.encoder.append(torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
            #self.encoder.append(torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
            self.encoder.append(torch.nn.Flatten())
            reduced_dim = int((input_dim + 2 * padding - kernel_size) / stride + 1)
            #reduced_dim = int((reduced_dim + 2 * padding - kernel_size) / stride + 1)
            print(f"Reduced input dimension from {input_dim} to {reduced_dim} using convolutional block.")
            hidden_dim = int(width * reduced_dim)
            ff_input_dim = reduced_dim
            self.convolution = True
            
        for i in range(depth):
            if i == (depth - 1):
                # Bottleneck layer - THIS is the only place to use AdaptiveRankReducedLinear
                encoder_layer = AdaptiveRankReducedLinear(
                    hidden_dim, latent_dim, 
                    initial_rank_ratio=initial_rank_ratio,
                    min_rank=min_rank
                )
                self.encoder.append(encoder_layer)
                self.adaptive_layers.append(encoder_layer)
                
                # Final decoder layer - standard linear
                decoder_layer = nn.Linear(hidden_dim, ff_input_dim)
                self.decoder.append(decoder_layer)
            else:
                if i == 0:
                    # First encoder layer - input to hidden (standard linear)
                    encoder_layer = nn.Linear(ff_input_dim, hidden_dim)
                    self.encoder.append(encoder_layer)
                    
                    # First decoder layer - latent to hidden (standard linear)
                    decoder_layer = nn.Linear(latent_dim, hidden_dim)
                    self.decoder.append(decoder_layer)
                else:
                    # Middle layers - all standard linear
                    encoder_layer = nn.Linear(hidden_dim, hidden_dim)
                    self.encoder.append(encoder_layer)
                    
                    decoder_layer = nn.Linear(hidden_dim, hidden_dim)
                    self.decoder.append(decoder_layer)
                
                # Add activation
                self.encoder.append(nn.ReLU())
                self.decoder.append(nn.ReLU())
                
                # Add dropout if specified
                if dropout > 0.0:
                    self.encoder.append(nn.Dropout(dropout))
                    self.decoder.append(nn.Dropout(dropout))
                    
        if input_dim > 100000:
            # Add a final convolutional layer to upsample back to the original input dimension
            self.decoder.append(torch.nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
            #self.decoder.append(torch.nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
            self.decoder.append(torch.nn.Flatten())
    
    def reduce_rank(self, reduction_ratio=0.8, threshold=0.01, dim=0):
        """Reduce rank of all adaptive layers based on singular value importance
        
        Args:
            reduction_ratio: Ratio to reduce rank by (default 0.9 = 10% reduction)
            threshold: Energy threshold for rank reduction
            dim: Dimension along which to reduce rank (0=output, 1=input)
        """
        changes_made = False
        
        for layer in self.adaptive_layers:
            # Get singular values
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)
            
            # Find the rank that preserves specified energy threshold
            # Make sure we don't go below the minimum rank
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            
            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the larger of the two approaches
            #new_rank = max(target_rank, ratio_rank)
            new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim)
                changes_made = True
                
        return changes_made
    
    def increase_rank(self, increment=None, increase_ratio=1.1, dim=0):
        """Increase rank of all adaptive layers by specified increment
        
        Args:
            increment: Number of dimensions to add (if None, calculated from increase_ratio)
            increase_ratio: Ratio to increase rank by (default 1.1 = 10% increase)
            dim: Dimension along which to increase rank (0=output, 1=input)
        """
        changes_made = False
        
        for layer in self.adaptive_layers:
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def encode(self, x):
        if self.convolution:
            x = x.view(x.shape[0], 1, -1)
        for layer in self.encoder:
            x = layer(x)
        return x
    
    def decode(self, x):
        for layer in self.decoder:
            if self.convolution and isinstance(layer, nn.ConvTranspose1d):
                x = x.view(x.shape[0], 1, -1)
            x = layer(x)
        return x
    
    def forward(self, x):
        x = self.encode(x)
        x = self.decode(x)
        return x