import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Callable, Dict, Any
import random
from enum import Enum
import time

class ProblemType(Enum):
    ZDT1 = "ZDT1"
    ZDT2 = "ZDT2"
    ZDT3 = "ZDT3"
    ZDT4 = "ZDT4"
    ZDT6 = "ZDT6"
    DTLZ1 = "DTLZ1"
    DTLZ2 = "DTLZ2"
    DTLZ3 = "DTLZ3"
    CUSTOM = "CUSTOM"

class SelectionMethod(Enum):
    TOURNAMENT = "tournament"
    RANDOM = "random"
    ROULETTE = "roulette"

class CrossoverMethod(Enum):
    SBX = "sbx"
    UNIFORM = "uniform"
    BLEND = "blend"

class CLMEA:
    """
    Constrained Learning-Augmented Multi-objective Evolutionary Algorithm
    """
    
    def __init__(self, 
                 problem_type: ProblemType = ProblemType.ZDT1,
                 population_size: int = 100,
                 num_generations: int = 250,
                 num_variables: int = 30,
                 num_objectives: int = 2,
                 crossover_rate: float = 0.9,
                 mutation_rate: float = 0.1,
                 selection_method: SelectionMethod = SelectionMethod.TOURNAMENT,
                 crossover_method: CrossoverMethod = CrossoverMethod.SBX):
        
        self.problem_type = problem_type
        self.population_size = population_size
        self.num_generations = num_generations
        self.num_variables = num_variables
        self.num_objectives = num_objectives
        self.crossover_rate = crossover_rate
        self.mutation_rate = mutation_rate
        self.selection_method = selection_method
        self.crossover_method = crossover_method
        
        # Algorithm parameters
        self.eta_c = 20  # SBX crossover parameter
        self.eta_m = 20  # Polynomial mutation parameter
        
        # Results storage
        self.population = None
        self.objectives = None
        self.pareto_front = None
        self.history = []
        
    def initialize_population(self) -> np.ndarray:
        """Initialize the population with random values"""
        return np.random.random((self.population_size, self.num_variables))
    
    def evaluate_objectives(self, population: np.ndarray) -> np.ndarray:
        """Evaluate objectives for the given population"""
        objectives = np.zeros((population.shape[0], self.num_objectives))
        
        if self.problem_type == ProblemType.ZDT1:
            for i, individual in enumerate(population):
                objectives[i, 0] = individual[0]
                g = 1 + 9 * np.sum(individual[1:]) / (self.num_variables - 1)
                objectives[i, 1] = g * (1 - np.sqrt(individual[0] / g))
                
        elif self.problem_type == ProblemType.ZDT2:
            for i, individual in enumerate(population):
                objectives[i, 0] = individual[0]
                g = 1 + 9 * np.sum(individual[1:]) / (self.num_variables - 1)
                objectives[i, 1] = g * (1 - (individual[0] / g) ** 2)
                
        elif self.problem_type == ProblemType.ZDT3:
            for i, individual in enumerate(population):
                objectives[i, 0] = individual[0]
                g = 1 + 9 * np.sum(individual[1:]) / (self.num_variables - 1)
                objectives[i, 1] = g * (1 - np.sqrt(individual[0] / g) - 
                                       (individual[0] / g) * np.sin(10 * np.pi * individual[0]))
                
        elif self.problem_type == ProblemType.ZDT4:
            for i, individual in enumerate(population):
                objectives[i, 0] = individual[0]
                g = 1 + 10 * (self.num_variables - 1) + np.sum(
                    individual[1:] ** 2 - 10 * np.cos(4 * np.pi * individual[1:]))
                objectives[i, 1] = g * (1 - np.sqrt(individual[0] / g))
                
        elif self.problem_type == ProblemType.DTLZ1:
            for i, individual in enumerate(population):
                g = 100 * (self.num_variables - self.num_objectives + 1 + 
                          np.sum((individual[self.num_objectives-1:] - 0.5) ** 2 - 
                                np.cos(20 * np.pi * (individual[self.num_objectives-1:] - 0.5))))
                objectives[i, 0] = 0.5 * individual[0] * individual[1] * (1 + g)
                objectives[i, 1] = 0.5 * individual[0] * (1 - individual[1]) * (1 + g)
                if self.num_objectives > 2:
                    objectives[i, 2] = 0.5 * (1 - individual[0]) * (1 + g)
                    
        elif self.problem_type == ProblemType.DTLZ2:
            for i, individual in enumerate(population):
                g = np.sum((individual[self.num_objectives-1:] - 0.5) ** 2)
                objectives[i, 0] = (1 + g) * np.cos(individual[0] * np.pi / 2) * np.cos(individual[1] * np.pi / 2)
                objectives[i, 1] = (1 + g) * np.cos(individual[0] * np.pi / 2) * np.sin(individual[1] * np.pi / 2)
                if self.num_objectives > 2:
                    objectives[i, 2] = (1 + g) * np.sin(individual[0] * np.pi / 2)
        
        return objectives
    
    def non_dominated_sorting(self, objectives: np.ndarray) -> List[List[int]]:
        """Fast non-dominated sorting"""
        num_individuals = objectives.shape[0]
        dominates = [[] for _ in range(num_individuals)]
        dominated_by = [0] * num_individuals
        fronts = [[]]
        
        # Calculate domination relationships
        for i in range(num_individuals):
            for j in range(i + 1, num_individuals):
                if self.dominates(objectives[i], objectives[j]):
                    dominates[i].append(j)
                    dominated_by[j] += 1
                elif self.dominates(objectives[j], objectives[i]):
                    dominates[j].append(i)
                    dominated_by[i] += 1
        
        # Find first front
        for i in range(num_individuals):
            if dominated_by[i] == 0:
                fronts[0].append(i)
        
        # Find subsequent fronts
        current_front = 0
        while fronts[current_front]:
            next_front = []
            for i in fronts[current_front]:
                for j in dominates[i]:
                    dominated_by[j] -= 1
                    if dominated_by[j] == 0:
                        next_front.append(j)
            current_front += 1
            if next_front:
                fronts.append(next_front)
        
        return fronts
    
    def dominates(self, a: np.ndarray, b: np.ndarray) -> bool:
        """Check if solution a dominates solution b"""
        # a dominates b if a is not worse than b in all objectives and better in at least one
        not_worse = np.all(a <= b)
        better = np.any(a < b)
        return not_worse and better
    
    def crowding_distance(self, objectives: np.ndarray, front: List[int]) -> np.ndarray:
        """Calculate crowding distance for individuals in a front"""
        num_individuals = len(front)
        distances = np.zeros(num_individuals)
        
        if num_individuals == 0:
            return distances
        
        # For each objective
        for obj_idx in range(self.num_objectives):
            # Get objective values for this front
            obj_values = [objectives[i, obj_idx] for i in front]
            sorted_indices = np.argsort(obj_values)
            
            # Boundary points get infinite distance
            distances[sorted_indices[0]] = float('inf')
            distances[sorted_indices[-1]] = float('inf')
            
            # Calculate distances for intermediate points
            min_val = min(obj_values)
            max_val = max(obj_values)
            
            if max_val - min_val < 1e-10:
                continue
                
            for i in range(1, num_individuals - 1):
                idx = sorted_indices[i]
                next_idx = sorted_indices[i + 1]
                prev_idx = sorted_indices[i - 1]
                
                distances[idx] += (obj_values[next_idx] - obj_values[prev_idx]) / (max_val - min_val)
        
        return distances
    
    def selection(self, population: np.ndarray, objectives: np.ndarray) -> np.ndarray:
        """Select parents using tournament selection"""
        selected_indices = []
        
        if self.selection_method == SelectionMethod.TOURNAMENT:
            for _ in range(self.population_size):
                # Tournament size of 2
                idx1, idx2 = random.sample(range(len(population)), 2)
                
                # Check domination
                if self.dominates(objectives[idx1], objectives[idx2]):
                    selected_indices.append(idx1)
                elif self.dominates(objectives[idx2], objectives[idx1]):
                    selected_indices.append(idx2)
                else:
                    # If neither dominates, choose randomly
                    selected_indices.append(random.choice([idx1, idx2]))
                    
        elif self.selection_method == SelectionMethod.RANDOM:
            selected_indices = random.choices(range(len(population)), k=self.population_size)
            
        return population[selected_indices]
    
    def crossover(self, parents: np.ndarray) -> np.ndarray:
        """Perform crossover to create offspring"""
        offspring = np.zeros_like(parents)
        
        if self.crossover_method == CrossoverMethod.SBX:
            for i in range(0, len(parents), 2):
                if i + 1 < len(parents) and random.random() < self.crossover_rate:
                    parent1, parent2 = parents[i], parents[i+1]
                    child1, child2 = np.copy(parent1), np.copy(parent2)
                    
                    for j in range(self.num_variables):
                        if random.random() <= 0.5:
                            continue
                            
                        u = random.random()
                        if u <= 0.5:
                            beta = (2 * u) ** (1 / (self.eta_c + 1))
                        else:
                            beta = (1 / (2 * (1 - u))) ** (1 / (self.eta_c + 1))
                            
                        child1[j] = 0.5 * ((1 + beta) * parent1[j] + (1 - beta) * parent2[j])
                        child2[j] = 0.5 * ((1 - beta) * parent1[j] + (1 + beta) * parent2[j])
                        
                        # Ensure bounds
                        child1[j] = np.clip(child1[j], 0, 1)
                        child2[j] = np.clip(child2[j], 0, 1)
                    
                    offspring[i] = child1
                    offspring[i+1] = child2
                else:
                    offspring[i] = parents[i]
                    if i + 1 < len(parents):
                        offspring[i+1] = parents[i+1]
                        
        return offspring
    
    def mutation(self, population: np.ndarray) -> np.ndarray:
        """Perform polynomial mutation"""
        mutated = np.copy(population)
        
        for i in range(len(mutated)):
            for j in range(self.num_variables):
                if random.random() < self.mutation_rate:
                    u = random.random()
                    if u <= 0.5:
                        delta = (2 * u) ** (1 / (self.eta_m + 1)) - 1
                    else:
                        delta = 1 - (2 * (1 - u)) ** (1 / (self.eta_m + 1))
                    
                    mutated[i, j] += delta
                    mutated[i, j] = np.clip(mutated[i, j], 0, 1)
        
        return mutated
    
    def environmental_selection(self, combined_pop: np.ndarray, 
                               combined_obj: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Select next generation using NSGA-II environmental selection"""
        fronts = self.non_dominated_sorting(combined_obj)
        selected_pop = []
        selected_obj = []
        
        current_front = 0
        while len(selected_pop) + len(fronts[current_front]) <= self.population_size:
            selected_pop.extend(combined_pop[fronts[current_front]])
            selected_obj.extend(combined_obj[fronts[current_front]])
            current_front += 1
        
        # If we need more individuals, use crowding distance
        if len(selected_pop) < self.population_size:
            last_front = fronts[current_front]
            crowding_distances = self.crowding_distance(combined_obj, last_front)
            
            # Sort by crowding distance (descending)
            sorted_indices = np.argsort(crowding_distances)[::-1]
            needed = self.population_size - len(selected_pop)
            
            for i in range(needed):
                if i < len(sorted_indices):
                    idx = last_front[sorted_indices[i]]
                    selected_pop.append(combined_pop[idx])
                    selected_obj.append(combined_obj[idx])
        
        return np.array(selected_pop), np.array(selected_obj)
    
    def optimize(self) -> Dict[str, Any]:
        """Main optimization loop"""
        # Initialize population
        self.population = self.initialize_population()
        self.objectives = self.evaluate_objectives(self.population)
        
        # Main loop
        for generation in range(self.num_generations):
            # Selection
            parents = self.selection(self.population, self.objectives)
            
            # Crossover and mutation
            offspring = self.crossover(parents)
            offspring = self.mutation(offspring)
            
            # Evaluate offspring
            offspring_obj = self.evaluate_objectives(offspring)
            
            # Combine parent and offspring populations
            combined_pop = np.vstack([self.population, offspring])
            combined_obj = np.vstack([self.objectives, offspring_obj])
            
            # Environmental selection
            self.population, self.objectives = self.environmental_selection(combined_pop, combined_obj)
            
            # Store history
            if generation % 10 == 0:
                self.history.append({
                    'generation': generation,
                    'population': np.copy(self.population),
                    'objectives': np.copy(self.objectives)
                })
            
            # Print progress
            if generation % 50 == 0:
                print(f"Generation {generation}/{self.num_generations}")
        
        # Extract Pareto front
        fronts = self.non_dominated_sorting(self.objectives)
        self.pareto_front = self.population[fronts[0]]
        pareto_front_obj = self.objectives[fronts[0]]
        
        return {
            'population': self.population,
            'objectives': self.objectives,
            'pareto_front': self.pareto_front,
            'pareto_front_objectives': pareto_front_obj,
            'history': self.history
        }
    
    def plot_pareto_front(self, save_path: str = None):
        """Plot the Pareto front"""
        if self.objectives is None:
            print("Please run optimization first")
            return
        
        plt.figure(figsize=(10, 8))
        
        # Plot all solutions
        plt.scatter(self.objectives[:, 0], self.objectives[:, 1], 
                   alpha=0.5, label='All Solutions', color='blue')
        
        # Plot Pareto front
        fronts = self.non_dominated_sorting(self.objectives)
        pareto_obj = self.objectives[fronts[0]]
        plt.scatter(pareto_obj[:, 0], pareto_obj[:, 1], 
                   alpha=0.8, label='Pareto Front', color='red', s=50)
        
        plt.xlabel('Objective 1')
        plt.ylabel('Objective 2')
        plt.title(f'Pareto Front - {self.problem_type.value}')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
    
    def plot_convergence(self, save_path: str = None):
        """Plot convergence metrics"""
        if not self.history:
            print("No history data available")
            return
        
        generations = [h['generation'] for h in self.history]
        hypervolume = []
        spread = []
        
        for history in self.history:
            obj = history['objectives']
            # Simple hypervolume approximation (for 2D)
            if obj.shape[1] == 2:
                ref_point = [1.1 * np.max(obj[:, 0]), 1.1 * np.max(obj[:, 1])]
                sorted_obj = obj[np.argsort(obj[:, 0])]
                hv = 0
                for i in range(len(sorted_obj) - 1):
                    hv += (sorted_obj[i+1, 0] - sorted_obj[i, 0]) * (ref_point[1] - sorted_obj[i, 1])
                hypervolume.append(hv)
            
            # Spread metric
            extremes = obj[np.argmin(obj, axis=0)]
            spread.append(np.mean(np.std(obj, axis=0)))
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        if hypervolume:
            ax1.plot(generations, hypervolume)
            ax1.set_xlabel('Generation')
            ax1.set_ylabel('Hypervolume (approx)')
            ax1.set_title('Hypervolume Convergence')
            ax1.grid(True, alpha=0.3)
        
        ax2.plot(generations, spread)
        ax2.set_xlabel('Generation')
        ax2.set_ylabel('Spread')
        ax2.set_title('Diversity Metric')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()

# Example usage and testing
if __name__ == "__main__":
    # Create and run CLMEA algorithm
    algorithm = CLMEA(
        problem_type=ProblemType.ZDT1,
        population_size=100,
        num_generations=100,
        num_variables=30,
        crossover_rate=0.9,
        mutation_rate=0.1
    )
    
    print("Running CLMEA optimization...")
    start_time = time.time()
    results = algorithm.optimize()
    end_time = time.time()
    
    print(f"Optimization completed in {end_time - start_time:.2f} seconds")
    print(f"Final population size: {results['population'].shape[0]}")
    print(f"Pareto front size: {results['pareto_front'].shape[0]}")
    
    # Plot results
    algorithm.plot_pareto_front()
    algorithm.plot_convergence()