import random
import numpy as np

class AutoencoderSearchSpace:
    def __init__(self, input_dim, max_depth, min_depth, max_width, min_width, 
                 max_latent_dim, min_latent_dim, l1_range, train_samples=None):
        """
        Search space for autoencoder architectures with an overdetermined encoder
        
        Args:
            input_dim: Dimension of input data
            max_depth: Maximum number of layers in decoder
            min_depth: Minimum number of layers in decoder
            max_width: Maximum width of layers
            min_width: Minimum width of layers
            max_latent_dim: Maximum latent dimension
            min_latent_dim: Minimum latent dimension
            l1_range: Tuple of (min, max) for L1 regularization strength
            train_samples: Number of training samples (needed to calculate overdetermined encoder params)
        """
        self.input_dim = input_dim
        self.max_decoder_depth = max_depth
        self.min_decoder_depth = min_depth
        self.max_width = max_width
        self.min_width = min_width
        self.max_latent_dim = max_latent_dim
        self.min_latent_dim = min_latent_dim
        self.l1_min, self.l1_max = l1_range
        self.train_samples = train_samples if train_samples else 1000  # Default if not provided
        
        # Design an overdetermined encoder
        self.encoder_architecture = self._design_overdetermined_encoder()
        
    def _calculate_linear_layer_params(self, in_dim, out_dim):
        """Calculate number of parameters in a linear layer (weights + biases)"""
        return (in_dim * out_dim) + out_dim
        
    def _design_overdetermined_encoder(self):
        """Design an encoder with enough parameters to represent training data"""
        required_params = self.train_samples * self.input_dim
        
        # Start with a simple encoder architecture and add layers until we meet the requirement
        encoder_depth = 2  # Start with a reasonable depth
        encoder_width = []
        
        # First layer dimensions
        first_layer_width = min(max(self.input_dim * 2, self.min_width), self.max_width)
            
        encoder_width.append(first_layer_width)
        
        # Calculate initial parameters
        params_so_far = self._calculate_linear_layer_params(self.input_dim, first_layer_width)
        
        # Add more layers if needed to reach required parameters
        # first make sure the min depth is reached
        while len(encoder_width) < encoder_depth:
            next_width = encoder_width[-1] // 2
            if next_width < self.min_width:
                next_width = self.min_width
            
            encoder_width.append(next_width)
            params_so_far += self._calculate_linear_layer_params(encoder_width[-2], next_width)
        # Ensure we have enough parameters
        while params_so_far < required_params:
            # Find the layer with the smallest width and increase it
            min_idx = encoder_width.index(min(encoder_width))
            
            # If the previous layer exists, calculate parameter gain from increasing its width
            if min_idx > 0:
                prev_width = encoder_width[min_idx-1]
                new_width = min(encoder_width[min_idx] * 2, self.max_width)
                param_gain = self._calculate_linear_layer_params(prev_width, new_width) - \
                             self._calculate_linear_layer_params(prev_width, encoder_width[min_idx])
                
                # Only update if we gain parameters and don't exceed max width
                if param_gain > 0 and new_width <= self.max_width:
                    params_so_far += param_gain
                    encoder_width[min_idx] = new_width
                    continue
            
            # If we can't increase existing layers further, add a new layer
            if len(encoder_width) < 6:  # Limit total depth
                new_width = self.min_width
                params_so_far += self._calculate_linear_layer_params(encoder_width[-1], new_width)
                encoder_width.append(new_width)
            else:
                # If we can't add more layers, we've done our best
                break
        
        return {
            'encoder_depth': len(encoder_width),
            'encoder_width': encoder_width
        }
        
    def _sample_skip_connections(self, decoder_depth):
        """
        Sample random skip connections for decoder
        
        Args:
            decoder_depth: Number of decoder layers
            
        Returns:
            List of tuples (from_idx, to_idx) representing skip connections
        """
        skip_connections = []
        # add skip connections by sampling each possible combination
        for from_idx in range(decoder_depth):
            for to_idx in range(from_idx + 2, decoder_depth):
                skip_connections.append((from_idx, to_idx))
        
        # sample how many skip connections to keep
        num_skip_connections = random.randint(0, len(skip_connections))
        skip_connections = random.sample(skip_connections, num_skip_connections)
        # sort the skip connections by from_idx first and to_idx second
        skip_connections.sort(key=lambda x: (x[0], x[1]))
            
        return skip_connections
    
    def _enforce_non_decreasing_decoder(self, decoder_width):
        """Ensure decoder layers are non-decreasing in width (from latent to output)"""
        for i in range(1, len(decoder_width)):
            decoder_width[i] = max(decoder_width[i], decoder_width[i-1])
        return decoder_width
        
    def sample_random_architecture(self):
        """Sample a random architecture from the search space focusing on decoder and latent space"""
        # Fixed encoder from design
        encoder_depth = self.encoder_architecture['encoder_depth']
        encoder_width = self.encoder_architecture['encoder_width']
        
        # Random latent dimension
        latent_dim = 2 ** random.randint(int(np.log2(self.min_latent_dim)), int(np.log2(self.max_latent_dim)))
        
        # Random decoder
        decoder_depth = random.randint(self.min_decoder_depth, self.max_decoder_depth)
        decoder_width = []
        
        for i in range(decoder_depth):
            width = 2 ** random.randint(int(np.log2(latent_dim)), 
                                        int(np.log2(self.max_width)))
            decoder_width.append(width)
        
        # Ensure decoder widths are non-decreasing
        decoder_width = self._enforce_non_decreasing_decoder(decoder_width)
        
        # Random L1 weight
        l1_weight = random.uniform(self.l1_min, self.l1_max)
        
        # Sample random skip connections
        skip_connections = self._sample_skip_connections(decoder_depth)
        #skip_connections = []
        
        architecture = {
            'encoder_depth': encoder_depth,
            'encoder_width': encoder_width,
            'decoder_depth': decoder_depth,
            'decoder_width': decoder_width,
            'latent_dim': latent_dim,
            'l1_weight': l1_weight,
            'skip_connections': skip_connections
        }
        print(f"Sampled architecture: {architecture}")
        
        return architecture
    
    def mutate_architecture(self, architecture, mutation_rate=0.3):
        """Mutate an architecture focusing on decoder and latent space"""
        mutated = architecture.copy()
        
        # Mutation for latent dimension
        if random.random() < mutation_rate:
            delta = random.choice([-5, -2, 2, 5])
            mutated['latent_dim'] = max(self.min_latent_dim, 
                                       min(self.max_latent_dim, 
                                           mutated['latent_dim'] + delta))
        
        # Mutation for decoder depth
        if random.random() < mutation_rate:
            if random.random() < 0.5 and mutated['decoder_depth'] > self.min_decoder_depth:
                # Remove a layer
                mutated['decoder_depth'] -= 1
                mutated['decoder_width'].pop()
            elif mutated['decoder_depth'] < self.max_decoder_depth:
                # Add a layer
                mutated['decoder_depth'] += 1
                width = 2 ** random.randint(int(np.log2(self.min_width)), 
                                           int(np.log2(self.max_width)))
                mutated['decoder_width'].append(width)
        
        # Mutation for decoder widths
        for i in range(mutated['decoder_depth']):
            if random.random() < mutation_rate:
                width = 2 ** random.randint(int(np.log2(self.min_width)), 
                                           int(np.log2(self.max_width)))
                mutated['decoder_width'][i] = width
                
        # Ensure decoder widths are non-decreasing
        mutated['decoder_width'] = self._enforce_non_decreasing_decoder(mutated['decoder_width'])
        
        # Mutation for L1 weight
        if random.random() < mutation_rate:
            mutated['l1_weight'] = random.uniform(self.l1_min, self.l1_max)
        
        # Mutation for skip connections
        # skip this for now
        mutated['skip_connections'] = self._ensure_skip_connection_limits(mutated['skip_connections'], mutated['decoder_depth'])
        
        return mutated
    
    def crossover(self, parent1, parent2):
        """Crossover between two architectures focusing on decoder and latent space"""
        child = {}
        
        # Fixed encoder (no crossover needed)
        child['encoder_depth'] = self.encoder_architecture['encoder_depth']
        child['encoder_width'] = self.encoder_architecture['encoder_width'][:]
        
        # Crossover latent dimension
        child['latent_dim'] = random.choice([parent1['latent_dim'], parent2['latent_dim']])
        
        # Crossover decoder structure
        if random.random() < 0.5:
            child['decoder_depth'] = parent1['decoder_depth']
            child['decoder_width'] = parent1['decoder_width'][:]
        else:
            child['decoder_depth'] = parent2['decoder_depth']
            child['decoder_width'] = parent2['decoder_width'][:]
            
        # Ensure decoder widths are non-decreasing
        child['decoder_width'] = self._enforce_non_decreasing_decoder(child['decoder_width'])
        
        # Crossover L1 weight
        child['l1_weight'] = random.choice([parent1['l1_weight'], parent2['l1_weight']])
        
        # Crossover skip connections
        # Choose skip connections from either parent or combine them
        if random.random() < 0.7:
            # Take skip connections from one parent
            child['skip_connections'] = random.choice([
                parent1['skip_connections'][:],
                parent2['skip_connections'][:]
            ])
        else:
            # Combine skip connections from both parents
            child['skip_connections'] = []
            
            # Add some connections from parent1
            for conn in parent1['skip_connections']:
                if random.random() < 0.5 and conn not in child['skip_connections']:
                    child['skip_connections'].append(conn)
                    
            # Add some connections from parent2
            for conn in parent2['skip_connections']:
                if random.random() < 0.5 and conn not in child['skip_connections']:
                    child['skip_connections'].append(conn)
            
            # make the skip connections unique and ensure that they adhere to the max depth of the child
            child['skip_connections'] = self._ensure_skip_connection_limits(child['skip_connections'], child['decoder_depth'])
            #print(f"child: depth {child['decoder_depth']}, skip {child['skip_connections']}")
        
        return child
    
    def _ensure_skip_connection_limits(self, skip_connections, decoder_depth):
        """Ensure skip connections do not exceed decoder depth"""
        skip_connections = list(set(skip_connections))
        return [(from_idx, to_idx) for from_idx, to_idx in skip_connections 
                if (from_idx < decoder_depth) and (to_idx <= decoder_depth)]
