from __future__ import annotations

import math
from typing import Any, Optional

import numpy as np


def make_non_dominated_sorter():
    from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting

    return NonDominatedSorting()


def mutate_one_step_remove_redundance(
    compute_graph_list,
    num_mantain: int,
    best,
    train_dataset,
    num_mutate: Optional[int] = 2,
    nds: Optional[Any] = None,
    num_steps: int = 20,
    eps: float = 1e-6,
    eval_dtype: str = "float32",
    metric_format: str = "fp32",
    mutation_cfg: Optional[Any] = None,
    optimizer_cfg: Optional[Any] = None,
    rng: Optional[np.random.Generator] = None,
    population_size: Optional[int] = None,
    metric_type: str = "rel",
):
    """Run the NSGA-II-style DAG mutation/selection loop used by workers.

    The objectives are expression cost and fitting error. The function keeps
    a diverse working population and a non-dominated best-of set while
    removing redundant best candidates with identical compute cost.
    """
    nds = nds or make_non_dominated_sorter()
    if num_mutate is None:
        if population_size:
            num_mutate = max(1, int(math.ceil(population_size / max(len(compute_graph_list), 1))))
        else:
            num_mutate = 2

    for _ in range(num_steps):
        tmp_res_list = []
        tmp_compute_cost_list = []
        tmp_graph_list = []
        for graph in compute_graph_list:
            for _ in range(num_mutate):
                tmp_graph = graph.copy()
                tmp_graph.mutate(mutation_cfg=mutation_cfg, rng=rng)
                tmp_graph_list.append(tmp_graph)
                tmp_res_list.append(
                    tmp_graph.optimize(
                        train_dataset,
                        method="LM_nelder_mead",
                        eps=eps,
                        ULP=False,
                        eval_dtype=eval_dtype,
                        metric_format=metric_format,
                        optimizer_cfg=optimizer_cfg,
                        metric_type=metric_type,
                        rng=rng,
                    )
                )
                tmp_compute_cost_list.append(tmp_graph.compute_cost())

        compute_graph_list = []
        res_list = []
        compute_cost_list = []
        for idx, cost in enumerate(tmp_compute_cost_list):
            graph = tmp_graph_list[idx]
            result = tmp_res_list[idx]
            graph.optimization_error = result
            compute_graph_list.append(graph)
            res_list.append(result)
            compute_cost_list.append(cost)

        obj = np.array([[compute_cost_list[i], res_list[i]] for i in range(len(compute_graph_list))])
        fronts = nds.do(obj)

        selected = []
        selected_count = 0
        for front_idx, front in enumerate(fronts):
            if front_idx == 0:
                best = best + [
                    (compute_graph_list[front[j]], res_list[front[j]], compute_cost_list[front[j]])
                    for j in range(len(front))
                ]

            seen_costs = set()
            unique_front = []
            for item in front:
                cost = compute_cost_list[item]
                if cost in seen_costs:
                    continue
                seen_costs.add(cost)
                unique_front.append(item)
            front = unique_front

            if selected_count + len(front) > num_mantain:
                while selected_count < num_mantain and front:
                    best_distance = -1.0
                    best_item = front[0]
                    for item in front:
                        min_distance = float("inf")
                        for existing in selected:
                            distance = (
                                (res_list[existing] - res_list[item]) ** 2
                                + (compute_cost_list[existing] - compute_cost_list[item]) ** 2
                            )
                            min_distance = min(min_distance, distance)
                        if min_distance > best_distance:
                            best_distance = min_distance
                            best_item = item
                    selected.append(best_item)
                    selected_count += 1
            else:
                selected.extend(front)
                selected_count += len(front)

        compute_graph_list = [compute_graph_list[idx] for idx in selected]

        best_obj = np.array([[entry[1], entry[2]] for entry in best])
        best_fronts = nds.do(best_obj)
        best = [best[best_fronts[0][j]] for j in range(len(best_fronts[0]))] if len(best_fronts) else []

        filtered_best = []
        seen_costs = set()
        for graph, result, cost in best:
            if math.isnan(result) or cost in seen_costs:
                continue
            seen_costs.add(cost)
            filtered_best.append((graph, result, cost))
        best = filtered_best

    return compute_graph_list, best
