import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Autoencoder(nn.Module):
    def __init__(self, input_dim, architecture_config):
        """
        Configurable autoencoder with skip connections, supporting asymmetric architectures
        
        Args:
            input_dim: Dimension of input data
            architecture_config: Dict containing:
                - encoder_depth: Number of layers in encoder
                - encoder_width: List of neurons per encoder layer
                - decoder_depth: Number of layers in decoder
                - decoder_width: List of neurons per decoder layer
                - latent_dim: Dimension of latent space
                - l1_weight: L1 regularization strength
                - skip_connections: List of tuples (from_idx, to_idx) for decoder skip connections
        """
        super(Autoencoder, self).__init__()
        
        self.input_dim = input_dim
        self.encoder_depth = architecture_config['encoder_depth']
        self.encoder_width = architecture_config['encoder_width']
        self.decoder_depth = architecture_config['decoder_depth']
        self.decoder_width = architecture_config['decoder_width']
        self.latent_dim = architecture_config['latent_dim']
        self.l1_weight = architecture_config['l1_weight']
        self.skip_connections = architecture_config['skip_connections']
        
        # Build encoder
        self.encoder_layers = nn.ModuleList()
        
        prev_dim = input_dim
        for i in range(self.encoder_depth):
            self.encoder_layers.append(nn.Linear(prev_dim, self.encoder_width[i]))
            prev_dim = self.encoder_width[i]
        
        self.latent_layer = nn.Linear(prev_dim, self.latent_dim)
        
        # Build decoder
        self.decoder_layers = nn.ModuleList()
        
        prev_dim = self.latent_dim
        for i in range(self.decoder_depth):
            # Account for potential skip connections as input to this layer
            #extra_dims = self._calculate_skip_dims_for_layer(i)
            self.decoder_layers.append(nn.Linear(prev_dim, self.decoder_width[i]))
            prev_dim = self.decoder_width[i]
        
        self.skip_layers = nn.ModuleList()
        start_layers = np.array([x[0] for x in self.skip_connections])
        end_layers = np.array([x[1] for x in self.skip_connections])
        for i in range(self.decoder_depth):
            if i in end_layers:
                where_i = np.where(end_layers == i)[0]
                for j in where_i:
                    from_idx = start_layers[j]
                    to_idx = end_layers[j]
                    if from_idx == -1:
                        self.skip_layers.append(nn.Linear(self.latent_dim, self.decoder_width[to_idx]))
                    else:
                        self.skip_layers.append(nn.Linear(self.decoder_width[from_idx], self.decoder_width[to_idx]))
        
        # Output layer
        self.output_layer = nn.Linear(self.decoder_width[-1], input_dim)

        self.n_params = self._count_parameters()
    
    def _count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def forward(self, x):
        # Encoding
        for layer in self.encoder_layers:
            x = F.relu(layer(x))
        
        # Latent representation
        z = self.latent_layer(x)
        
        # Calculate L1 regularization loss for decoder weights instead of latents
        l1_loss = 0.0
        for layer in self.decoder_layers:
            l1_loss += torch.sum(torch.abs(layer.weight))
        for layer in self.skip_layers:
            l1_loss += torch.sum(torch.abs(layer.weight))
        l1_loss += torch.sum(torch.abs(self.output_layer.weight))
        l1_loss = l1_loss * self.l1_weight / self.n_params

        x = z
        activations = []
        
        start_skip = np.array([x[0] for x in self.skip_connections])
        end_skip = np.array([x[1] for x in self.skip_connections])
        skip_count = 0
        for i, layer in enumerate(self.decoder_layers):
            # Gather inputs from skip connections
            skip_inputs = []
            if i in end_skip:
                where_i = np.where(end_skip == i)[0]
                for j in where_i:
                    from_idx = start_skip[j]
                    if from_idx == -1:  # -1 represents latent layer
                        skip_inputs.append(self.skip_layers[skip_count](z))
                    else:
                        skip_inputs.append(self.skip_layers[skip_count](activations[from_idx]))
                    skip_count += 1
            
            if skip_inputs:
                skip = torch.sum(torch.cat(skip_inputs, dim=0), dim=0)
                x = F.relu(layer(x) + skip)
            else:
                x = F.relu(layer(x))
            activations.append(x)
        
        output = self.output_layer(x)
        
        return output, l1_loss
