# -*- coding: utf-8 -*-
import numpy as np
import geatpy as ea
from typing import List, Dict, Callable, Optional, Any
import hvwfg
from copy import deepcopy


class NSGA2_TSP_Geatpy(ea.MoeaAlgorithm):
    """
    NSGA-II algorithm implementation based on Geatpy, specifically for Traveling Salesman Problem (TSP)
    Uses permutation encoding to represent city visiting order
    """

    def __init__(self, problem, population, **kwargs):
        # Call parent class constructor
        super().__init__(problem, population, **kwargs)

        self.name = "NSGA2-TSP"
        self.problem = problem
        self.population = population

        self.ref_point = self.problem.ref_points

        # Non-dominated sorting method selection[3](@ref)
        if problem.M < 10:
            self.ndSort = ea.ndsortESS  # Use ESS when number of objectives is small
        else:
            self.ndSort = ea.ndsortTNS  # Use TNS when number of objectives is large

        self.selFunc = "tour"  # Tournament selection

        # Algorithm parameter settings
        self.Opti_operators = {}
        self.kwargs = kwargs
        for operator_name in self.kwargs:
            operator = kwargs.get(operator_name, None)
            if operator:
                self.Opti_operators[operator_name] = TSPOperatorWrapper(operator)

        print("Algorithm initialization successful")
        # # Initialize TSP-specific operators
        # self.recOper = self._setup_crossover_operator()
        # self.mutOper = self._setup_mutation_operator()

    def _setup_crossover_operator(self):
        """Configure TSP crossover operator"""
        if self.tsp_crossover:
            print("Set tsp_crossover operator")
            return TSPCrossoverWrapper(self.tsp_crossover)
        else:
            print("Default tsp_crossover operator")
            return TSP_OX_Crossover()  # Order crossover operator

    def _setup_mutation_operator(self):
        """Configure TSP mutation operator"""
        if self.tsp_mutation:
            print("Set tsp_crossover operator")
            return TSPMutationWrapper(self.tsp_mutation)
        else:
            print("Default tsp_mutation operator")
            return TSP_Swap_Mutation()  # Swap mutation operator

    def run(self, prophetPop=None):

        # ========================== Initialization Configuration ===========================
        population = self.population
        NIND = population.sizes
        self.initialization()  # Initialize some dynamic parameters of the algorithm class

        # =========================== Prepare for Evolution =============================
        # Initialize population chromosome matrix (permutation encoding)
        population.initChrom()

        # Insert prior knowledge
        if prophetPop is not None:
            population = (prophetPop + population)[:NIND]

        # Calculate objective function values of the population
        self.call_aimFunc(population)

        # Initial non-dominated sorting[3](@ref)
        [levels, criLevel] = self.ndSort(
            population.ObjV, NIND, None, population.CV, self.problem.maxormins
        )
        population.FitnV = (1 / levels).reshape(-1, 1)

        # =========================== Start Evolution =============================
        while not self.terminated(population):
            try:
                # Select individuals to participate in evolution
                offspring = population[
                    ea.selecting(self.selFunc, population.FitnV, NIND)
                ]
                # Execute all operators in order
                for operator_name in self.Opti_operators:
                    operator = self.Opti_operators[operator_name]
                    parent_pops = deepcopy(offspring.Chrom)
                    try:
                        sub_pops = operator.do(parent_pops)
                    except Exception as e:
                        # offspring.Chrom
                        print(
                            f"Operator exception_{operator_name}_{self.currentGen}_{e}"
                        )
                    # if self.check_legal_pop(sub_pops):
                    offspring.Chrom = sub_pops

                # Perform crossover operation
                # offspring.Chrom = self.recOper.do(offspring.Chrom)

                # # self.check_legal_pop(offspring.Chrom)
                # # Mutation operation
                # offspring.Chrom = self.mutOper.do(offspring.Chrom)

                # Calculate objective function values of evolved individuals
                self.call_aimFunc(offspring)

                # Reinsertion to generate new generation population
                population = self.reinsertion(population, offspring, NIND)

                self.pop = population  # Attach current population to algorithm instance for callback access
                self.callback(self)

            except Exception as e:
                print(f"Warning: Error in generation {self.currentGen}: {e}")
                # Try to continue running when error occurs
                continue

        return self.finishing(population), True

    def reinsertion(self, population, offspring, NUM):
        """
        Reinsertion operation: parent-child combined selection[3](@ref)
        """
        # Merge parent and offspring generations
        population = population + offspring

        # Non-dominated sorting
        [levels, criLevel] = self.ndSort(
            population.ObjV, NUM, None, population.CV, self.problem.maxormins
        )

        # Calculate crowding distance
        dis = ea.crowdis(population.ObjV, levels)

        # Calculate fitness
        population.FitnV[:, 0] = np.argsort(
            np.lexsort(np.array([dis, -levels])), kind="mergesort"
        )

        # Select individuals to retain for next generation
        chooseFlag = ea.selecting("dup", population.FitnV, NUM)

        return population[chooseFlag]

    def check_legal_pop(self, pops):

        for i in range(len(pops)):
            pop = pops[i]
            if set(pop) != set(range(self.problem.instance.n)):
                raise False
        return True

    def cal_HV(self, PF, ref):
        ref_region = 1
        for i in range(ref.shape[0]):
            ref_region = ref_region * (ref[i] - 0)
            hv_val = hvwfg.wfg(
                np.array(PF).astype("float"),
                ref.astype("float"),
            )
            hv_val = hv_val / ref_region
        return hv_val


# ========================== TSP-specific Operator Implementation ==========================


class TSP_OX_Crossover(ea.Recombination):
    """
    TSP Order Crossover Operator (Order Crossover)[1](@ref)
    Maintains genetic characteristics of city visiting order
    """

    def __init__(self, XOVR=0.8):
        self.XOVR = XOVR

    def do(self, OldChrom):
        """
        Execute order crossover operation
        """
        Nind, Lind = OldChrom.shape
        NewChrom = OldChrom.copy()

        # Ensure number of individuals is even
        if Nind % 2 != 0:
            NewChrom = NewChrom[:-1]  # Remove last individual
            Nind -= 1

        # Random pairing
        pairs = np.random.permutation(Nind).reshape(-1, 2)

        for i in range(pairs.shape[0]):
            p1_idx, p2_idx = pairs[i]
            parent1 = NewChrom[p1_idx].copy()
            parent2 = NewChrom[p2_idx].copy()

            if np.random.rand() < self.XOVR:
                # Select crossover segment
                cut_points = sorted(np.random.choice(Lind, 2, replace=False))
                start, end = cut_points

                # Create offspring
                child1 = self.ox_crossover(parent1, parent2, start, end)
                child2 = self.ox_crossover(parent2, parent1, start, end)

                # If illegal occurs, roll back crossover operation
                if set(child1) == set(range(len(parent1))):
                    NewChrom[p1_idx] = child1
                else:
                    NewChrom[p1_idx] = parent1
                if set(child2) == set(range(len(parent1))):
                    NewChrom[p2_idx] = child2
                else:
                    NewChrom[p2_idx] = parent2

        return NewChrom

    def ox_crossover(self, parent1, parent2, start, end):
        """
        Specific implementation of order crossover
        """
        child = -np.ones_like(parent1)

        # Copy crossover segment from parent1
        child[start : end + 1] = parent1[start : end + 1]

        # Fill remaining cities from parent2, maintaining order
        current_pos = (end + 1) % len(parent2)
        fill_pos = (end + 1) % len(parent2)

        for i in range(len(parent2)):
            city = parent2[(current_pos + i) % len(parent2)]
            if city not in child:
                child[fill_pos] = city
                fill_pos = (fill_pos + 1) % len(parent2)

        return child


class TSP_Swap_Mutation(ea.Mutation):
    """
    TSP Swap Mutation Operator[1](@ref)
    Introduces diversity by swapping positions of two cities
    """

    def __init__(self, Pm=0.1):
        self.Pm = Pm

    def do(self, Encoding, OldChrom, Field):
        Nind, Lind = OldChrom.shape
        NewChrom = OldChrom.copy()

        for i in range(Nind):
            if np.random.rand() < self.Pm:
                # Randomly select two different positions to swap
                pos1, pos2 = np.random.choice(Lind, 2, replace=False)
                NewChrom[i, pos1], NewChrom[i, pos2] = (
                    NewChrom[i, pos2],
                    NewChrom[i, pos1],
                )

        return NewChrom


# ========================== Custom Operator Wrapper ==========================
class TSPOperatorWrapper:
    def __init__(self, operator: callable):
        if operator:
            self.do = operator


class TSPCrossoverWrapper:
    def __init__(self, crossover_: callable):
        if crossover_:
            self.do = crossover_
        else:
            self.do = self.crossover_default

    def crossover_default(self, parent_chromosome):
        """Order Crossover (OX) Operator"""
        pop_size = parent_chromosome.shape[0]  # Population size
        chrom_length = parent_chromosome.shape[
            1
        ]  # Chromosome length (number of cities)
        offspring_chromosome = np.zeros_like(parent_chromosome)  # Store offspring

        # 1. Randomly shuffle population to form pairs (avoid self-mating)
        indices = np.random.permutation(pop_size)
        # If population size is odd, last individual does not participate in crossover (or use other strategies)
        if pop_size % 2 != 0:
            indices = indices[:-1]

        # 2. Traverse each pair of parents
        for i in range(0, len(indices), 2):  # Step size 2
            idx1, idx2 = indices[i], indices[i + 1]
            parent1 = parent_chromosome[idx1].copy()
            parent2 = parent_chromosome[idx2].copy()

            # 3. Generate two offspring for this pair of parents
            for child_index, (p1, p2) in enumerate(
                ((parent1, parent2), (parent2, parent1))
            ):
                # Randomly select crossover start and end points
                start, end = sorted(np.random.choice(chrom_length, 2, replace=False))

                # Initialize offspring with invalid values (e.g., -1) for debugging
                child = -np.ones(chrom_length, dtype=parent_chromosome.dtype)

                # 4. Inherit cities in middle segment [start, end] from parent p1
                child[start : end + 1] = p1[start : end + 1]

                # 5. Fill remaining cities from parent p2 (maintaining order)
                # 5.1 Determine fill starting position and p2 traversal starting position
                fill_pos = (end + 1) % chrom_length
                current_pos_in_p2 = (end + 1) % chrom_length

                # 5.2 Traverse p2 (starting from end+1, wrapping around), fill cities not in inheritance segment into offspring
                for _ in range(chrom_length - (end - start + 1)):
                    # Find next city not in offspring inheritance segment
                    while p2[current_pos_in_p2] in child:
                        current_pos_in_p2 = (current_pos_in_p2 + 1) % chrom_length
                    # Fill this city into offspring
                    child[fill_pos] = p2[current_pos_in_p2]
                    fill_pos = (fill_pos + 1) % chrom_length

                # 6. Place generated child into offspring population
                offspring_index = idx1 if child_index == 0 else idx2
                offspring_chromosome[offspring_index] = child

        # Handle individuals that may not have participated in crossover (e.g., odd population size), can retain as is or mutate
        if pop_size % 2 != 0:
            offspring_chromosome[indices[-1]] = parent_chromosome[indices[-1]]

        return offspring_chromosome


class TSPMutationWrapper:
    def __init__(self, mutation: callable):
        if mutation:
            self.do = mutation
        else:
            self.do = self.mutation_default

    def mutation_default(self, offspring_chromosome):
        """Swap Mutation Operator"""
        pop_size = offspring_chromosome.shape[0]  # Population size
        for i in range(pop_size):

            # Randomly select positions of two cities to swap
            swap_idx = np.random.choice(len(offspring_chromosome[i]), 2, replace=False)
            city1, city2 = swap_idx

            # Swap cities
            offspring_chromosome[i, city1], offspring_chromosome[i, city2] = (
                offspring_chromosome[i, city2],
                offspring_chromosome[i, city1],
            )

        return offspring_chromosome
