import numpy as np
import random
from src.models.autoencoder import Autoencoder
from src.utils.training import train_autoencoder, evaluate_autoencoder
from src.utils.objective_functions import calculate_fitness_score
import torch
import copy

class GeneticAlgorithm:
    def __init__(self, search_space, train_data, val_data, 
                 population_size=20, num_generations=20,
                 fitness_type="mdl", precision_bits=7):
        """
        Genetic Algorithm for Neural Architecture Search
        
        Args:
            search_space: AutoencoderSearchSpace instance
            train_data: Training data as (features, targets)
            val_data: Validation data as (features, targets)
            population_size: Size of population in genetic algorithm
            num_generations: Number of generations to evolve
            fitness_type: Type of fitness function ('negative_loss' or 'mdl')
            precision_bits: Number of bits per parameter for MDL calculation
        """
        self.search_space = search_space
        self.train_data = train_data
        self.val_data = val_data
        self.population_size = population_size
        self.num_generations = num_generations
        self.input_dim = train_data.features.shape[1]
        self.n_samples = train_data.features.shape[0]
        self.fitness_type = fitness_type
        self.precision_bits = precision_bits
        
        self.population = []
        self.fitness_scores = []
        self.history = []
        
    def _initialize_population(self):
        """Initialize random population"""
        self.population = [self.search_space.sample_random_architecture() 
                         for _ in range(self.population_size)]
        
    def _evaluate_population(self, num_epochs_per_arch=5, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Evaluate fitness of all architectures in population"""
        self.fitness_scores = []
        
        for i, arch_config in enumerate(self.population):
            print(f"Evaluating architecture {i+1}/{self.population_size}")
            
            # Create and train autoencoder
            model = Autoencoder(self.input_dim, arch_config).to(device)
            train_loss = train_autoencoder(model, self.train_data, num_epochs=num_epochs_per_arch, device=device)
            val_loss = evaluate_autoencoder(model, self.val_data, device=device)

            # Nn - Nm (Schuster and Krogh 2021)
            decoder_min_capacity = self.n_samples * (self.input_dim - arch_config['latent_dim'])
            
            # Calculate fitness score (higher is better)
            fitness = calculate_fitness_score(
                model, 
                val_loss, 
                decoder_min_capacity,
                fitness_type=self.fitness_type,
                precision_bits=self.precision_bits
            )
            self.fitness_scores.append(fitness.cpu().item())
            
            # Display additional info based on fitness type
            if self.fitness_type == "mdl":
                print(f"Architecture {i+1}: Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, "
                      f"Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}, "
                      f"MDL Score: {-fitness:.2f}")
            else:
                print(f"Architecture {i+1}: Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
            
    def _select_parents(self, k=3):
        """Tournament selection of parents"""
        parents = []
        
        for _ in range(self.population_size):
            # Select k individuals at random
            tournament_idx = np.random.choice(self.population_size, k, replace=False)
            tournament_fitness = [self.fitness_scores[i] for i in tournament_idx]
            
            # Select the best individual from tournament
            winner_idx = tournament_idx[np.argmax(tournament_fitness)]
            parents.append(self.population[winner_idx])
            
        return parents
    
    def _crossover(self, parent1, parent2):
        """Perform crossover between two parent architectures"""
        
        child = {}
        
        # Defensively copy all keys from both parents
        for key in set(list(parent1.keys()) + list(parent2.keys())):
            if 'depth' in key:
                continue
            # For each key in either parent, randomly choose from which parent to inherit
            if key in parent1 and key in parent2:
                child[key] = parent1[key] if random.random() < 0.5 else parent2[key]
            elif key in parent1:
                child[key] = parent1[key]
            elif key in parent2:
                child[key] = parent2[key]
        child['encoder_depth'] = len(child['encoder_width'])
        child['decoder_depth'] = len(child['decoder_width'])
        
        # Handle specific keys that might need special treatment
        #if 'width' in child and 'depth' in child:
        #    # Ensure width array matches depth
        #    child['width'] = child['width'][:child['depth']]
            
        # Special handling for skip connections if they exist
        #if 'skip_connections' in parent1 and 'skip_connections' in parent2:
        if len(parent1['skip_connections']) > 0 or len(parent2['skip_connections']) > 0:
            # Mix skip connections from both parents
            all_skip_connections = parent1['skip_connections'] + parent2['skip_connections']
            # Remove duplicates
            all_skip_connections = list(set((from_idx, to_idx) for from_idx, to_idx in all_skip_connections))
            # remove invalid skip connections (going beyond depth)
            
            # Filter connections that are valid for child's depth
            if 'decoder_depth' in child:
                child['skip_connections'] = self.search_space._ensure_skip_connection_limits(child['skip_connections'], child['decoder_depth'])
                #child['skip_connections'] = [conn for conn in all_skip_connections 
                #                           if (conn[0] == -1 or conn[0] < child['decoder_depth']) 
                #                           and conn[1] <= child['decoder_depth']]
            else:
                child['skip_connections'] = all_skip_connections
                
        return child
    
    def search(self, num_epochs_per_arch=5, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Execute genetic neural architecture search"""
        # Initialize population
        self._initialize_population()
        
        best_arch = None
        best_fitness = float('-inf')  # Higher fitness (negative loss) is better
        
        for generation in range(self.num_generations):
            print(f"Generation {generation+1}/{self.num_generations}")
            
            # Evaluate current population
            self._evaluate_population(num_epochs_per_arch, device)
            
            # Record best architecture in this generation
            gen_best_idx = np.argmax(self.fitness_scores)
            gen_best_arch = self.population[gen_best_idx]
            gen_best_fitness = self.fitness_scores[gen_best_idx]
            gen_best_val_loss = -gen_best_fitness
            
            print(f"Generation {generation+1} best: Validation Loss = {gen_best_val_loss:.6f}")
            
            # Update overall best
            if gen_best_fitness > best_fitness:
                best_fitness = gen_best_fitness
                best_arch = gen_best_arch
                print(f"New best architecture found! Validation loss: {gen_best_val_loss:.6f}")
            
            # Record history
            generation_info = {
                'generation': generation,
                'best_architecture': gen_best_arch,
                'best_val_loss': gen_best_val_loss,
                'population': self.population.copy(),
                'fitness_scores': self.fitness_scores.copy()
            }
            self.history.append(generation_info)
            
            # Early stopping at last generation
            if generation == self.num_generations - 1:
                break
                
            # Create new generation through selection, crossover, and mutation
            parents = self._select_parents()
            new_population = []
            
            for i in range(0, self.population_size, 2):
                if i + 1 < self.population_size:
                    # Crossover and mutation
                    child1 = copy.deepcopy(self._crossover(copy.deepcopy(parents[i]), copy.deepcopy(parents[i+1]))) # copying to avoid link between child1 and child2
                    child1 = copy.deepcopy(self.search_space.mutate_architecture(child1))
                    child2 = copy.deepcopy(self._crossover(copy.deepcopy(parents[i+1]), copy.deepcopy(parents[i])))
                    child2 = copy.deepcopy(self.search_space.mutate_architecture(child2))
                    
                    new_population.extend([child1, child2])
                    print(f"New architectures: {child1}, {child2}")
                else:
                    # If odd population size, just mutate the last parent
                    child = self.search_space.mutate_architecture(copy.deepcopy(parents[i]))
                    new_population.append(child)
                    print(f"New architecture: {child}")
            
            
            self.population = new_population
            
        print("Search completed!")
        best_val_loss = -best_fitness
        print(f"Best architecture: {best_arch}")
        print(f"Best validation loss: {best_val_loss:.6f}")
        
        return best_arch, self.history
