import numpy as np
import optax
import chex
import jax.numpy as jnp
from typing import NamedTuple
from scipy.spatial.distance import cdist


class EigGroup(NamedTuple):
    eigvals: np.ndarray
    multiplicity: int
    # indices: list[int] | np.ndarray

    @property
    def median(self):
        return np.median(self.eigvals)


class EigPair(NamedTuple):
    g1: EigGroup
    g2: EigGroup
    dist: float | np.ndarray
    idx: tuple[int, int]


def analyze_repeated_eigvals(
    evals, tol: float = 1e-10, rel_tol: bool = False, sort: bool = True
) -> list[EigGroup]:
    if sort:
        idx = np.argsort(evals)
        evals = evals[idx]

    groups = []
    i = 0

    tols = np.array((tol,) * len(evals))
    if rel_tol:
        tols = tol * np.abs(evals)

    while i < len(evals):
        group = [i]
        current_val = evals[i]

        j = i + 1
        while j < len(evals) and abs(evals[j] - current_val) <= tols[i]:
            group.append(j)
            j += 1

        group = np.array(group)
        groups.append(EigGroup(evals[group], len(group)))

        i = j

    return groups


def align_eigval(groups: list[EigGroup], eigvals):
    centers = np.array([g.median for g in groups])
    distances = cdist(eigvals.reshape(-1, 1), centers.reshape(-1, 1))
    assignments = np.argmin(distances, axis=1)
    values = [eigvals[assignments == i] for i in range(len(centers))]
    new_groups = [EigGroup(v, len(v)) for v in values]
    return new_groups


def align_eigval_groups(
    groups1: list[EigGroup], groups2: list[EigGroup], tol: float = 1e-4
) -> list[EigPair]:
    if len(groups1) == 0 or len(groups2) == 0:
        return []

    aligned_pairs = []
    for i, (g1, g2) in enumerate(zip(groups1, groups2)):
        aligned_pairs.append(
            EigPair(g1, g2, dist=np.abs(g1.median - g2.median), idx=(i, i))
        )

    return aligned_pairs


def projector_info(P):
    mask = jnp.abs(P @ P - P) > 0
    num_non_zero = jnp.sum(mask)
    min_elem = P[mask].min()
    max_elem = P[mask].max()
    return num_non_zero, min_elem, max_elem


# def align_eigval_groups(
#     groups1: list[EigGroup], groups2: list[EigGroup], tol: float = 1e-4
# ) -> list[EigPair]:
#     if len(groups1) == 0 or len(groups2) == 0:
#         return []

#     n1, n2 = len(groups1), len(groups2)
#     cost_matrix = np.zeros((n1, n2))

#     for i, g1 in enumerate(groups1):
#         for j, g2 in enumerate(groups2):
#             cost_matrix[i, j] = abs(g1.median - g2.median)
#     row_ind, col_ind = linear_sum_assignment(cost_matrix)

#     aligned_pairs = []
#     for i, j in zip(row_ind, col_ind):
#         if cost_matrix[i, j] < tol:
#             aligned_pairs.append(
#                 EigPair(
#                     groups1[i],
#                     groups2[j],
#                     dist=cost_matrix[i, j],
#                     idx=(i, j),
#                 )
#             )

#     return aligned_pairs

