import logging
from copy import deepcopy

import numpy as np
from pymoo.indicators.hv import Hypervolume
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
from scipy.linalg import null_space
from scipy.spatial import Delaunay
from scipy.linalg import orth

def compute_ref_point(_y: np.ndarray, _minimize=True):
    if _minimize:
        return np.max(_y, axis=0) + 0.1 * (np.max(_y, axis=0) - np.min(_y, axis=0))
    else:
        return np.min(_y, axis=0) - 0.1 * (np.max(_y, axis=0) - np.min(_y, axis=0))

def get_xy_based_on_hvc(_x: np.ndarray, _y: np.ndarray, _n: int):
    _dim = _x.shape[1]
    _obj = _y.shape[1]
    _x_out = np.zeros((0, _dim))
    _y_out = np.zeros((0, _obj))
    
    if _n > len(_x):
        raise ValueError("The number of points to select is greater than the number of points available.")
    elif _n == len(_x):
        return _x, _y
    else:
        _x_loop = deepcopy(_x)
        _y_loop = deepcopy(_y)
        while len(_x_out) < _n:
            nd_front = NonDominatedSorting().do(_y_loop, only_non_dominated_front=True)
            if len(nd_front) + len(_x_out) <= _n:
                _x_out = np.vstack((_x_out, _x_loop[nd_front]))
                _y_out = np.vstack((_y_out, _y_loop[nd_front]))
                _x_loop = np.delete(_x_loop, nd_front, axis=0)
                _y_loop = np.delete(_y_loop, nd_front, axis=0)
            else:
                _ref_point = compute_ref_point(_y_loop)
                _hv = Hypervolume(_ref_point)
                _hv_full = _hv(_y_loop)
                _hvc = []
                for id_front in nd_front:
                    _y_loop_wo_front = np.delete(_y_loop, id_front, axis=0)
                    _hv_wo_front = _hv(_y_loop_wo_front)
                    _hvc.append(_hv_full - _hv_wo_front)
                _hvc = np.array(_hvc)
                _hvc_argsort = np.argsort(_hvc)[::-1]
                _hvc_argsort = _hvc_argsort[:_n - len(_x_out)]
                _x_out = np.vstack((_x_out, _x_loop[nd_front[_hvc_argsort]]))
                _y_out = np.vstack((_y_out, _y_loop[nd_front[_hvc_argsort]]))
                _x_loop = np.delete(_x_loop, nd_front[_hvc_argsort], axis=0)
                _y_loop = np.delete(_y_loop, nd_front[_hvc_argsort], axis=0)
        return _x_out, _y_out
            
def batch_distance_from_m_point_to_n_line(__f, __point_on_line, __line_normal):
    _f = np.array(__f)
    _point_on_line = np.array(__point_on_line)
    _line_normal = np.array(__line_normal)
    assert _f.ndim == 2 and _point_on_line.ndim == 2 and _line_normal.ndim == 2
    m = _f.shape[0]
    n = _point_on_line.shape[0]
    assert _point_on_line.shape[0] == _line_normal.shape[0]
    # return a matrix of shape (m, n)
    pa = np.repeat(_f.reshape(m, 1, -1), n, axis=1) - np.repeat(_point_on_line.reshape(1, n, -1), m, axis=0)
    ba = np.repeat(_line_normal.reshape(1, n, -1), m, axis=0)
    t = (pa*ba).sum(2) / (ba*ba).sum(2)
    d = np.linalg.norm(pa - t.reshape(m, n, 1)*ba, axis=2)

    # loop way
    # d2 = np.zeros((m, n))
    # for i in range(m):
    #     for j in range(n):
    #         d2[i, j] = np.linalg.norm(_f[i] - _point_on_line[j] - t[i, j]*_line_normal[j])
    # assert np.allclose(d, d2)
    return d

# def nondominated_estimate(_f: np.ndarray, _f_pf: np.ndarray):
#     """
#     Check if points in _f is non-dominated by all points in _f_pf.
#     Parameters
#     ----------
#     _f : np.ndarray
#         The set of points to estimate.
#     _f_pf : np.ndarray
#         The Pareto front to compare against.
#     Returns
#     -------
#     np.ndarray
#         A boolean array indicating if each point in _f is non-dominated.
#     """
#     _n = _f.shape[0]
#     _n_pf = _f_pf.shape[0]
#     if _n == 0 or _n_pf == 0:
#         return 0
#     _f_pf_tile = np.tile(_f_pf, (_n, 1, 1))
#     _f_tile = np.tile(_f[:, np.newaxis, :], (1, _n_pf, 1))
#     _nondominated = np.all(np.all(_f_tile > _f_pf_tile, axis=-1),axis=-1)
#     return np.logical_not(_nondominated)

def _distance_to_line(_f, _point_on_line, _line_normal):
    _f = np.atleast_2d(_f)
    pa = _f - _point_on_line
    ba = np.atleast_2d(_line_normal)
    t = (pa*ba).sum(1) / (ba*ba).sum(1)
    d = np.linalg.norm(pa - t.reshape(-1, 1)*ba, axis=1)
    return d

def compute_map_utopia_observed_points(all_utopia_points_ref, all_utopia_normals_ref, observed_f_scaled):
    mean_neighbor_dists = np.zeros(len(all_utopia_points_ref))
    min_dists = np.zeros(len(all_utopia_points_ref))
    num_objectives = len(all_utopia_points_ref[0])    
    distss = batch_distance_from_m_point_to_n_line(
        observed_f_scaled,
        all_utopia_points_ref,
        all_utopia_normals_ref,
    )
    if num_objectives == 2:
        edges = np.array([[i, i + 1] for i in range(len(all_utopia_points_ref) - 1)])
        all_neighbors = [[1]]
        all_neighbors.extend([[i - 1, i + 1] for i in range(1, len(all_utopia_points_ref) - 1)])
        all_neighbors.extend([[len(all_utopia_points_ref) - 1]])
        ...
    else:
        points = np.array(all_utopia_points_ref)
        triangulation_success = False
        n_retry = 0
        while not triangulation_success:
            try:
                # Find d points to calculate the orthonormal basis
                idx = np.random.choice(np.arange(len(points)), num_objectives, replace=False)
                Ps = points[idx]
                P0 = Ps[0]
                uv = (Ps[1:] - P0).reshape(len(Ps) - 1, -1)
                bases = orth(uv.T)
                T = bases.T
                points_proj = np.dot(T, (points - P0).T).T
                tri = Delaunay(points_proj)
                triangulation_success = True
            except Exception as e:
                logging.info(f'Error in triangulation: {e}')
                n_retry += 1
                if n_retry > 10:
                    raise Exception('Triangulation failed after 10 retries. Error')
                continue
        indptr, indices = tri.vertex_neighbor_vertices
        all_neighbors = []
        for i in range(len(points)):
            neighbors = indices[indptr[i]:indptr[i + 1]]
            all_neighbors.append(neighbors)
        ...
    for i_neighbor, neighbors in enumerate(all_neighbors):
        these_dists = _distance_to_line(
            all_utopia_points_ref[i_neighbor], 
            [all_utopia_points_ref[i] for i in neighbors],
            [all_utopia_normals_ref[i] for i in neighbors], 
        )
        mean_neighbor_dists[i_neighbor] = np.mean(these_dists)/2.0
        min_dists[i_neighbor] = np.mean(these_dists)/2.0
    # min_dists += np.linalg.norm(all_utopia_points_ref[0] - all_utopia_points_ref[-1]) * 0.2
    idx_obs, idx_upoints = np.where(distss < mean_neighbor_dists)
    map_upoints_obs = {}
    list_recheck_pair = list(tuple(zip(idx_obs, idx_upoints)))
    while len(list_recheck_pair) > 0:
        idx_obs, idx_upoint = list_recheck_pair.pop(0)
        if idx_upoint not in map_upoints_obs.keys() and idx_obs not in map_upoints_obs.values():
            map_upoints_obs[idx_upoint] = idx_obs
        elif idx_upoint in map_upoints_obs.keys():
            prev_idx_obs = map_upoints_obs[idx_upoint]
            # find dominated point
            if np.all(observed_f_scaled[idx_obs] < observed_f_scaled[prev_idx_obs]):
                map_upoints_obs[idx_upoint] = idx_obs
                list_recheck_pair.append((prev_idx_obs, idx_upoint))
        elif idx_obs in map_upoints_obs.values():
            prev_idx_upoint = list(map_upoints_obs.keys())[list(map_upoints_obs.values()).index(idx_obs)]
            # find closest uline
            dist_prev_upoint = distss[idx_obs, prev_idx_upoint]
            dist_curr_upoint = distss[idx_obs, idx_upoint]
            if dist_curr_upoint < dist_prev_upoint:
                map_upoints_obs.pop(prev_idx_upoint)
                map_upoints_obs[idx_upoint] = idx_obs
                list_recheck_pair.append((idx_obs, prev_idx_upoint))
    return map_upoints_obs

# def line_surface_intersection(A, v, surface_points, tol=1e-8):
#     """
#     Find the intersection of a line and a hyperplane defined by points.
    
#     A : point on the line (n,)
#     v : direction vector of the line (n,)
#     surface_points : array of shape (n, n) - n points in n-dim space

#     Returns:
#     - Intersection point (numpy array)
#     - None if no intersection (parallel and not lying on surface)
#     """
#     n = A.shape[0]

#     if surface_points.shape != (n, n):
#         raise ValueError("Surface must be defined by n points in n-dimensional space.")
    
#     # Step 1: build U matrix
#     u_vectors = surface_points[1:] - surface_points[0]  # (n-1, n)
#     U = u_vectors.T  # shape (n, n-1)

#     # Step 2: find normal vector n (null space of U^T)
#     ns = null_space(U.T)  # shape (n, 1)
#     if ns.size == 0:
#         raise ValueError("Surface points are degenerate; no unique hyperplane.")

#     normal = ns[:,0]
    
#     # Step 3: find d
#     d = normal @ surface_points[0]

#     # Step 4: check if line and hyperplane are parallel
#     denom = normal @ v
#     if abs(denom) < tol:
#         # Parallel
#         if abs(normal @ A - d) < tol:
#             return A.copy()  # The line lies in the hyperplane
#         else:
#             return None      # No intersection
#     else:
#         # Solve for t
#         t = (d - normal @ A) / denom
#         intersection = A + t * v
#         return intersection