from operator import index

import numpy as np
from matplotlib.pyplot import thetagrids
from scipy.stats import alpha

from LS_PLUS_LINES import kmeans_plus
from set_of_lines import SetOfLines
from set_of_points import SetOfPoints
from itertools import product
import time
from shapely.geometry import LineString
import osmnx as ox
from sklearn.preprocessing import StandardScaler


def max_line_distance_per_cluster(spans, displacements, labels, centers):
    """
    计算每个簇中最远的线到簇中心的距离
    """
    results = {}
    unique_labels = np.unique(labels)
    for lbl in unique_labels:
        cluster_indices = np.where(labels == lbl)[0]
        cluster_center = centers[lbl]

        dists = [
            line_to_point_distance(spans[i], displacements[i], cluster_center)
            for i in cluster_indices
        ]

        results[lbl] = np.max(dists)
    return results


def compute_cost_to_centers(spans, displacements, weights, centers):
    """
    Computes the total weighted cost from each line to its nearest center.

    Args:
        spans (np.ndarray): shape (n, d), unit direction vectors of lines.
        displacements (np.ndarray): shape (n, d), closest point on line to origin.
        weights (np.ndarray): shape (n,), weights for each line.
        centers (np.ndarray): shape (k, d), center points.

    Returns:
        float: total cost (weighted sum of squared distances).
    """
    n, d = spans.shape
    k = centers.shape[0]

    # Expand dims to broadcast (n, d) vs (k, d)
    disp_expand = displacements[:, np.newaxis, :]  # (n, 1, d)
    span_expand = spans[:, np.newaxis, :]  # (n, 1, d)
    centers_expand = centers[np.newaxis, :, :]  # (1, k, d)

    diff = centers_expand - disp_expand  # (n, k, d)
    norm_sq = np.sum(diff ** 2, axis=2)  # (n, k)
    proj = np.sum(diff * span_expand, axis=2)  # (n, k)
    proj_sq = proj ** 2  # (n, k)

    s_norm_sq = np.sum(span_expand * span_expand, axis=2)
    perp_sq = norm_sq - proj_sq / np.maximum(s_norm_sq, 1e-12)
    perp_sq = np.maximum(perp_sq, 0.0)
    distances = np.sqrt(perp_sq)  # (n, k)
    min_distances = np.min(distances, axis=1)  # (n,)

    total_cost = np.sum(weights * min_distances)
    return total_cost


def line_to_point_distance(span, displacement, point):
    """
    计算一条直线(由方向向量 span 和位移 displacement 定义)到点 point 的最短距离
    """
    diff = point - displacement
    proj = np.dot(diff, span) * span   # 投影到直线方向
    perp = diff - proj                 # 垂直分量
    return np.linalg.norm(perp)


def assign_labels(spans, displacements, center_set):
    num_lines = spans.shape[0]
    k = center_set.shape[0]
    labels = np.zeros(num_lines, dtype=int)

    for i in range(num_lines):
        distances = [line_to_point_distance(spans[i], displacements[i], center_set[j])
                     for j in range(k)]
        labels[i] = np.argmin(distances)
    return labels


def corrupt_labels(labels, k, error_rate=0.2, seed=None):
    rng = np.random.default_rng(seed)
    corrupted = labels.copy()

    for cluster in range(k):
        # 找到当前簇的索引
        indices = np.where(labels == cluster)[0]
        n = len(indices)
        if n == 0:
            continue
        # 随机选择 20% 的索引
        n_corrupt = int(np.floor(error_rate * n))
        corrupt_indices = rng.choice(indices, size=n_corrupt, replace=False)

        for idx in corrupt_indices:
            # 从剩下的簇里随机选一个不同的标签
            new_label = rng.integers(low=0, high=k)
            while new_label == labels[idx]:
                new_label = rng.integers(low=0, high=k)
            corrupted[idx] = new_label

    return corrupted


'''
def get_all_intersection_points(spans,displacements):
    """
    this method returns n(n-1) points, where each n-1 points in the n-1 points on each line that are closest to the
    rest n-1 lines.

    Args:
        ~

    Returns:
        np.ndarray: all the "intersection" points
    """
    spans = spans
    displacements = displacements
    dim = np.shape(spans)[1]
    size = np.shape(spans)[0]



    t = range(size)
    indexes_repeat_all_but_one = np.array(
        [[x for i, x in enumerate(t) if i != j] for j, j in enumerate(t)]).reshape(-1)

    spans_rep_each = spans[
        indexes_repeat_all_but_one]  # repeat of the spans, each span[i] is being repeated size times in a sequance
    spans_rep_all = np.repeat(spans.reshape(1, -1), size - 1, axis=0).reshape(-1,
                                                                              dim)  # repeat of the spans, all the spans block is repeated size-1 times
    disp_rep_each = displacements[
        indexes_repeat_all_but_one]  # repeat of the displacements, each span[i] is being repeated size times in a sequance
    disp_rep_all = np.repeat(displacements.reshape(1, -1), size - 1, axis=0).reshape(-1,
                                                                                     dim)  # repeat of the displacements, all the spans block is repeated size-1 times

    W0 = disp_rep_each - disp_rep_all
    a = np.sum(np.multiply(spans_rep_each, spans_rep_each), axis=1)
    b = np.sum(np.multiply(spans_rep_each, spans_rep_all), axis=1)
    c = np.sum(np.multiply(spans_rep_all, spans_rep_all), axis=1)
    d = np.sum(np.multiply(spans_rep_each, W0), axis=1)
    e = np.sum(np.multiply(spans_rep_all, W0), axis=1)
    be = np.multiply(b, e)
    cd = np.multiply(c, d)
    be_minus_cd = be - cd
    ac = np.multiply(a, c)
    b_squared = np.multiply(b, b)
    ac_minus_b_squared = ac - b_squared
    s_c = be_minus_cd / ac_minus_b_squared
    """
    for i in range(len(s_c)):
        if np.isnan(s_c[i]):
            s_c[i] = 0
    """
    s_c_repeated = np.repeat(s_c.reshape(-1, 1), dim, axis=1)
    G = disp_rep_each + np.multiply(s_c_repeated, spans_rep_each)

    b = np.where(np.isnan(G))
    c = np.where(np.isinf(G))
    G2 = np.delete(G, np.concatenate((b[0], c[0]), axis=0), axis=0).reshape(-1, dim)

    if len(G2) == 0:  # that means all the lines are parallel, take k random points from the displacements set
        return displacements;

    b2 = np.where(np.isnan(G2))
    c2 = np.where(np.isinf(G2))
    d2 = np.sum(b2)
    e2 = np.sum(c2)
    f2 = d2 + e2
    if f2 > 0:
        x = 2

    return G2
'''

def get_all_intersection_points(spans, displacements, chunk_size=262144):
    """
    Drop-in replacement for the legacy implementation.

    - 输入与输出与原函数一致:
        spans: (n, d)
        displacements: (n, d)
        返回: (n*(n-1), d) 的数组，顺序与原代码一致，
             并会删除含 NaN/Inf 的行；若全被删，则返回 displacements。
    - 不再构造 O(n^2) 的中间数组，峰值内存 ~ O(chunk_size * d)。
    - 计算顺序严格复刻旧版：
        令总长度 L = n*(n-1)，对每个线性位置 pos = 0..L-1：
          j = pos // (n-1)
          r = pos %  (n-1)
          i = r if r < j else r + 1         # legacy 的 each 索引（跳过 j）
          k = pos % n                       # legacy 的 all 索引（完整 0..n-1 循环）
        这样配对 (i, k) 与旧代码完全一致。
    """
    spans = np.asarray(spans)
    displacements = np.asarray(displacements)
    n, d = spans.shape

    if n <= 1:
        # 退化情形，和原意保持一致：没有成对比较，返回 displacements
        return displacements

    # 预先分配输出缓冲区（和旧版一样，先生成全部点，再统一删 NaN/Inf 行）
    L = n * (n - 1)
    out = np.empty((L, d), dtype=spans.dtype)

    # 逐块计算，避免 O(n^2) 峰值内存
    write_lo = 0
    while write_lo < L:
        write_hi = min(write_lo + chunk_size, L)
        m = write_hi - write_lo

        # === 复刻 legacy 顺序的索引映射（不生成巨型 repeat/reshape）===
        pos = np.arange(write_lo, write_hi, dtype=np.int64)

        # legacy 的 j（对应“被排除”的索引）
        j = pos // (n - 1)
        r = pos %  (n - 1)
        # legacy 的 each 索引 i：把 r 映射到 [0..n-1]\{j}
        i = r + (r >= j)

        # legacy 的 all 索引 k：完整 0..n-1 循环
        k = pos % n

        # 取出对应向量/点
        u = spans[i]            # (m, d)
        v = spans[k]            # (m, d)
        p = displacements[i]    # (m, d)
        q = displacements[k]    # (m, d)

        # 对应旧代码的符号：
        # W0 = p - q
        # a = u·u, b = u·v, c = v·v, d = u·W0, e = v·W0
        # s = (b*e - c*d) / (a*c - b^2)
        W0 = p - q
        a = np.einsum('ij,ij->i', u, u)
        b = np.einsum('ij,ij->i', u, v)
        c = np.einsum('ij,ij->i', v, v)
        d_term = np.einsum('ij,ij->i', u, W0)
        e_term = np.einsum('ij,ij->i', v, W0)

        denom = a * c - b * b
        numer = b * e_term - c * d_term

        # 与旧版保持一致：不做 eps 夹逼，直接相除（平行/病态会产生 NaN/Inf，后面统一删除）
        s = numer / denom

        # 点在“each”那条线上：G = p + s * u
        out[write_lo:write_hi, :] = p + s[:, None] * u

        write_lo = write_hi

    # === 与旧版一致的行删除逻辑 ===
    bad = np.any(~np.isfinite(out), axis=1)  # 含 NaN/Inf 的整行丢弃
    G2 = out[~bad]

    # 旧版行为：如果全被删（例如全部平行），返回 displacements
    if G2.shape[0] == 0:
        return displacements

    # 旧版最后 reshape(-1, dim)，这里已经是 (m, d)，无需再 reshape
    return G2


def remove_parallel_lines(spans, displacements, weights, tol=1e-6):
    """
    移除方向向量几乎完全相同（即平行）的重复直线。
    Args:
        spans (np.ndarray): shape = (n, d)，每条直线的方向向量
        displacements (np.ndarray): shape = (n, d)，每条直线的最近点
        weights (np.ndarray): shape = (n,)，每条直线的权重
        tol (float): 角度容忍度，用于判断是否平行（越小越严格）

    Returns:
        spans_new, displacements_new, weights_new: 过滤后的直线集合
    """
    n, d = spans.shape
    norm_spans = spans / np.linalg.norm(spans, axis=1, keepdims=True)

    keep_indices = []
    used = np.zeros(n, dtype=bool)

    for i in range(n):
        if used[i]:
            continue
        keep_indices.append(i)
        dot_prods = np.dot(norm_spans, norm_spans[i])
        parallels = np.abs(dot_prods) > 1 - tol  # 方向夹角近似为0或π
        used[parallels] = True  # 把所有平行的标记掉

    return spans[keep_indices], displacements[keep_indices], weights[keep_indices]

def convert_points_to_lines(data, missing_dim=0):
    """
    将一个任意的点数据集转换为直线数据集（每条线沿 missing_dim 方向延展）。

    参数:
        data (np.ndarray): 输入数据，shape = (n_samples, d)
        missing_dim (int): 指定哪个维度作为延展方向（生成直线的 span 方向）

    返回:
        spans (np.ndarray): 每条直线的方向向量，shape = (n_samples, d)
        displacements (np.ndarray): 每条直线距离原点最近的点，shape = (n_samples, d)
        weights (np.ndarray): 每条线的权重（默认为全 1），shape = (n_samples,)
    """
    assert isinstance(data, np.ndarray), "数据必须为 numpy.ndarray"
    assert len(data.shape) == 2, "输入数据必须是二维的 (n_samples, d)"
    n, d = data.shape
    assert 0 <= missing_dim < d, f"missing_dim 必须在 0 到 {d - 1} 之间"

    # 构造 span 向量（所有线都在 missing_dim 方向延展）
    spans = np.zeros((n, d))
    spans[:, missing_dim] = 1.0

    # displacement 是投影点：将 missing_dim 维度置 0
    displacements = data.copy()
    displacements[:, missing_dim] = 0.0

    # 默认每条线权重为 1
    weights = np.ones(n)

    return spans, displacements, weights

def blockshaped(arr, nrows, ncols):
    """
    Return an array of shape (n, nrows, ncols) where
    n * nrows * ncols = arr.size

    If arr is a 2D array, the returned array should look like n subblocks with
    each subblock preserving the "physical" layout of arr.
    """
    h, w = arr.shape
    return (arr.reshape(h // nrows, nrows, -1, ncols)
            .swapaxes(1, 2)
            .reshape(-1, nrows, ncols))


def compute_intersection(span1, disp1, span2, disp2):
    u = span1
    v = span2
    w0 = disp1 - disp2

    a = np.dot(u, u)
    b = np.dot(u, v)
    c = np.dot(v, v)
    d = np.dot(u, w0)
    e = np.dot(v, w0)

    denom = a * c - b * b
    if denom == 0:
        # Parallel lines: return midpoint between disp1 and projection of disp1 onto line 2
        t1 = 0
        t2 = d / b if b != 0 else 0
    else:
        t1 = (b * e - c * d) / denom
        t2 = (a * e - b * d) / denom

    p1 = disp1 + t1 * u
    p2 = disp2 + t2 * v
    midpoint = (p1 + p2) / 2

    return midpoint

def compute_line_to_point_distances(spans, displacements, point):
    """
    Compute the Euclidean distance from each line to a given point.

    Parameters:
        spans: (n, d) array of direction vectors for each line
        displacements: (n, d) array of displacement vectors for each line
        point: (d,) array, the given point in d-dimensional space

    Returns:
        distances: (n,) array of distances from each line to the point
    """
    vectors = point - displacements        # shape (n, d)
    dot = np.sum(vectors * spans, axis=1)  # (c - d) · v
    norm_sq = np.sum(spans * spans, axis=1)  # ||v||^2
    projection = (dot / norm_sq)[:, np.newaxis] * spans  # proj of c - d onto v
    residual = vectors - projection        # orthogonal part
    distances = np.linalg.norm(residual, axis=1)  # Euclidean norm

    return distances



def get_all_intersection_points_optimized(spans,displacements):
    """
    this method returns n(n-1) points, where each n-1 points in the n-1 points on each line that are closest to the
    rest n-1 lines.

    Args:
        ~

    Returns:
        np.ndarray: all the "intersection" points
    """
    #assert self.get_size() > 0, "set is empty"

    spans = spans
    displacements = displacements
    dim = np.shape(spans)[1]
    size = np.shape(spans)[0]

    spans_repeat_each_point = np.repeat(spans, size,
                                        axis=0)  # that is a repeat of the spans, each span[i] is being repeated size times
    identity = np.identity(dim)
    identity_repeat_rows_all = np.repeat(identity.reshape(1, -1), size, axis=0).reshape(-1, dim)
    I_final = np.repeat(identity_repeat_rows_all, size, axis=0).reshape(size * dim,
                                                                        size * dim)  # the final is an identity matrix that is duplicated in rows and cols in factor of size
    G_G_T_all_permutations = np.outer(spans,
                                      spans)  # this is a 2 dimensional matrix of blocks, where the (i,j)-th block is spans[i]*spans[j]^T
    I_minus_G_G_T_all_permutations = I_final - G_G_T_all_permutations
    I_minus_G_G_tag_blocks = blockshaped(I_minus_G_G_T_all_permutations, dim,
                                              dim)  # it will take the big matrix that is built from many clocks and returns a stack of blocks matrices
    I_minus_G_G_T_s = I_minus_G_G_tag_blocks[0:len(
        I_minus_G_G_tag_blocks):size + 1]  # this is a 1 dimensional matrix of blocks, where the i-th block is spans[i]*spans[i]^T
    I_minus_G_G_T_s_concatenated = I_minus_G_G_T_s.reshape(1, -1).T.reshape(-1,
                                                                            dim).T  # that is a 1 dimensional block matrix, where the i-th element is the matrix I-G_iG_i^T
    I_minus_G_G_T_s_F = np.dot(spans, I_minus_G_G_T_s_concatenated)
    I_minus_G_G_T_s_F = I_minus_G_G_T_s_F.reshape(-1, 1).reshape(-1,
                                                                 dim)  # in this matrix, the i-th index is the dot product of spans[j] and the k-th (I-GG^T), for j=i/size and k=i%size
    I_minus_G_G_T_s_F_inv = np.linalg.pinv(I_minus_G_G_T_s_F.reshape(size ** 2, dim,
                                                                     1))  # this matrix dimension is $size^2 \times dim$, where the i-th element is the point on the line j=i/size that are closest to the line m=i%size. that means: I_minus_G_G_T_s_F_inv[1] = ((I-G_1G_1^T)F_1)^+, I_minus_G_G_T_s_F_inv[2] = ((I-G_1G_1^T)F_2)^+,...,I_minus_G_G_T_s_F_inv[i] = ((I-G_jG_j^T)F_m)^+,
    I_minus_G_G_T_s_F_inv = I_minus_G_G_T_s_F_inv.reshape(size ** 2, dim)
    displacements_repeat_each_point = np.repeat(displacements, size, axis=0).reshape(size ** 2, dim)
    displacements_repeat_all = np.repeat(displacements.reshape(1, -1), size, axis=0).reshape(size ** 2, dim)
    f_minus_g = displacements_repeat_all - displacements_repeat_each_point  # this is a matrix where the i-th element is the substraction of g_j-f_m, where j=i/size and m=i%size
    I_minus_G_G_T_s_dot_f_minus_g = np.dot(f_minus_g, I_minus_G_G_T_s_concatenated)
    I_minus_G_G_T_s_dot_f_minus_g = I_minus_G_G_T_s_dot_f_minus_g.reshape(-1, 1).reshape(size, -1,
                                                                                         dim)  # this matrix contains more than it needs to contain. The i-th element is (I-G_iG_i^T)(f_k-g_l), and we do not need the cases where i!=l, that is why we take the right subset in the folowing two rows
    inner_steps = np.arange(0, size ** 2, size + 1)
    I_minus_G_G_T_s_dot_f_minus_g_s = I_minus_G_G_T_s_dot_f_minus_g[:, inner_steps]
    I_minus_G_G_T_s_dot_f_minus_g_s = I_minus_G_G_T_s_dot_f_minus_g_s.reshape(-1,
                                                                              dim)  # this matrix is the right matrix, where the i-th element is (I-G_iG_i^T)(f_j-g_i)
    final = np.multiply(I_minus_G_G_T_s_F_inv,
                        I_minus_G_G_T_s_dot_f_minus_g_s)  # each row in this matrix is ((I-G_iG_i^T)F_j)^{+}(I-G_iG_i^T)(f_j-g_i)
    final_x_stars = np.sum(final, axis=1)  # that yields the scalar the fits ti Fx-b in each line
    F_x_s = np.multiply(spans_repeat_each_point.T, final_x_stars)
    F_x_s_minus_b = F_x_s.T + displacements_repeat_each_point  # reconstruct points from all the x stars
    indices = np.arange(0, len(F_x_s_minus_b), size + 1)
    F_x_s_minus_b = np.delete(F_x_s_minus_b, indices,
                              axis=0)  # removing all the unnecessary "closest point on the i-th line in the set to the i-th line in the set"
    return F_x_s_minus_b


def get_4_approx_points(spans,displacements, k):

    """
    This method returns k points that minimizes the sum of squared distances to the lines in the set, up to factor
    of 4.

    Args:
        k (int) : the number of required centers.

    Returns:
        np.ndarray: a set of k points that minimizes the sum of squared distances to the lines in the set, up to
        a constant factor.
    """

    assert k > 0, "k <= 0"
    assert np.shape(spans)[0] > 0, "set is empty"

    dim = np.shape(spans)[1]
    size = np.shape(spans)[0]
    #displacements = self.displacements
    #spans = self.spans
    #weights = self.weights

    intersection_points_before_uniqe = get_all_intersection_points_optimized(spans,displacements)
    intersection_points = np.unique(intersection_points_before_uniqe,
                                    axis=0)  # that is n(n-1) points - the union of every n-1 points on each line in the set that are closest to the n-1 other lines
    number_of_intersection_points = np.shape(intersection_points.reshape(-1, dim))[0]
    if number_of_intersection_points <= k:
        P_4_approx = intersection_points_before_uniqe
    else:
        all_indices = np.asarray(range(len(intersection_points)))
        indices_sample = np.random.choice(all_indices, k, False)
        P_4_approx = intersection_points[indices_sample]
    if len(P_4_approx) == 0:
        x = 2
    P_4_approx = SetOfPoints(P_4_approx)
    if len(P_4_approx.indexes) == 0:
        x = 2
    return P_4_approx


def sample_k_points_from_centerset(centers_set, k, seed=None):
    """
    Randomly sample k points from a given centers_set.

    Parameters:
    -----------
    centers_set : np.ndarray
        A NumPy array of shape (n, d), representing n points in d-dimensional space.

    k : int
        The number of points to sample. Must satisfy k <= n.

    seed : int, optional
        A random seed for reproducibility.

    Returns:
    --------
    sampled_points : np.ndarray
        A NumPy array of shape (k, d), containing the sampled points.
    """
    if seed is not None:
        np.random.seed(seed)  # Set the random seed if provided

    n = centers_set.shape[0]
    if k > n:
        raise ValueError(f"Cannot sample {k} points from a set of only {n} points.")

    # Randomly select k unique indices
    indices = np.random.choice(n, size=k, replace=False)

    # Return the sampled points
    return centers_set[indices]


def generate_n_nonparallel_lines(n, d, seed=None):
    """
    Generate n lines in d-dimensional space such that no two lines are parallel.
    Each line is defined by a direction vector (span) and a displacement vector.

    Returns:
        spans: (n, d) array of unit direction vectors (no two are parallel)
        displacements: (n, d) array of random displacements
        weights: (n,) array of 1s (uniform weights)
    """
    if seed is not None:
        np.random.seed(seed)

    spans = []
    attempts = 0
    max_attempts = 1000 * n

    while len(spans) < n and attempts < max_attempts:
        # Generate a random direction vector and normalize it
        candidate = np.random.randn(d)
        candidate /= np.linalg.norm(candidate)

        # Check that it's not parallel to any existing span
        is_parallel = False
        for s in spans:
            cos_sim = np.abs(np.dot(candidate, s))  # cosine similarity
            if cos_sim > 0.999:  # nearly parallel
                is_parallel = True
                break

        if not is_parallel:
            spans.append(candidate)
        attempts += 1

    if len(spans) < n:
        raise RuntimeError(f"Failed to generate {n} non-parallel lines in {d}D after {max_attempts} attempts.")

    spans = np.array(spans)
    displacements = np.random.uniform(low=-5, high=5, size=(n, d))
    weights = np.ones(n)
    return spans, displacements, weights


def point_line_distance(point, line_point, line_dir):
    """
    Compute the squared distance from a point to a line defined by (line_point + t * line_dir).
    """
    diff = point - line_point
    proj = np.dot(diff, line_dir) / np.dot(line_dir, line_dir)
    closest_point = line_point + proj * line_dir
    return np.sum((point - closest_point) ** 2)





def gram_schmidt_complete_basis(v1, v2, d):
    basis = [v1 / np.linalg.norm(v1)]
    v2_proj = v2 - np.dot(v2, basis[0]) * basis[0]
    if np.linalg.norm(v2_proj) < 1e-10:
        raise ValueError("v1 and v2 are linearly dependent")
    basis.append(v2_proj / np.linalg.norm(v2_proj))

    while len(basis) < d:
        rand_vec = np.random.randn(d)
        for b in basis:
            rand_vec -= np.dot(rand_vec, b) * b
        if np.linalg.norm(rand_vec) < 1e-10:
            continue
        basis.append(rand_vec / np.linalg.norm(rand_vec))

    return np.stack(basis, axis=0)  # shape (d, d)


def generate_cube_points(span1, disp1, span2, disp2, r):
    d = span1.shape[0]
    center = compute_intersection(span1, disp1, span2, disp2)
    if center is None:
        raise ValueError("Lines do not intersect (are parallel or skewed).")

    basis = gram_schmidt_complete_basis(span1, span2, d)  # shape (d, d)

    offsets = [-r, 0, r]
    shifts = list(product(offsets, repeat=d))  # 9^d shifts
    cube_points = [center + np.dot(s, basis) for s in shifts]

    return np.array(cube_points)  # shape (9^d, d)

def max_distance_to_pointset(point, pointset):
    """
    Compute the maximum Euclidean distance between a point and a set of points.

    Parameters:
    - point: np.ndarray of shape (d,), a single point in d-dimensional space.
    - pointset: np.ndarray of shape (n, d), n points in d-dimensional space.

    Returns:
    - max_dist: float, the maximum distance from `point` to any point in `pointset`.
    """
    # Compute the Euclidean distances from `point` to all points in `pointset`
    distances = np.linalg.norm(pointset - point, axis=1)
    max_dist = np.max(distances)
    return max_dist

def sample_random_cube_points(center, basis, r, num_samples):
    d = basis.shape[0]
    offsets = np.random.choice([-r, 0, r], size=(num_samples, d))
    return center + offsets @ basis


def remove_farthest_lines(spans, displacements, point, alpha=0.2):
    """
    给定直线集合和一个点，去除距离 point 最远的 alpha 比例直线。

    spans: (n, d) 每条直线的方向向量
    displacements: (n, d) 每条直线的位移向量
    point: (d,) 给定点
    alpha: float, 比例 (0~1)

    return:
        spans_remain: 去除后的方向向量集合
        displacements_remain: 去除后的位移向量集合
        removed_idx: 被去除的直线索引
    """
    # 计算点到每条直线的最短距离
    diff = point[None, :] - displacements       # (n, d)
    dot  = np.sum(diff * spans, axis=1)
    proj = (dot / np.sum(spans * spans, axis=1))[:, None] * spans
    perp = diff - proj
    dists = np.linalg.norm(perp, axis=1)        # (n,)

    n = len(dists)
    m = max(1, int(np.ceil(alpha * n)))         # 要去掉的直线数
    idx_sorted_desc = np.argsort(dists)[::-1]   # 按距离从大到小排序
    removed_idx = idx_sorted_desc[:m]           # 最远的 α 比例
    keep_mask = np.ones(n, dtype=bool)
    keep_mask[removed_idx] = False

    return spans[keep_mask], displacements[keep_mask], removed_idx



def best_point_for_lines(points, spans, displacements, agg="sum", weights=None):
    """
    内存友好：逐个点计算其到所有直线的代价，避免构造 (n,m,d) 张量。
    points: (m,d)
    spans: (n,d)
    displacements: (n,d)
    agg: 'sum' | 'mean' | 'max'
    weights: (n,) or None
    """
    points = np.asarray(points)            # (m,d)
    spans  = np.asarray(spans)             # (n,d)
    disp   = np.asarray(displacements)     # (n,d)

    n, d = spans.shape
    m    = points.shape[0]

    # 预计算与直线无关的量
    s_norm2   = np.sum(spans * spans, axis=1)                  # ||s_i||^2, (n,)
    s_norm2   = np.maximum(s_norm2, 1e-12)                     # 数值稳定
    d_norm2   = np.sum(disp * disp, axis=1)                    # ||d_i||^2, (n,)
    d_dot_s   = np.sum(disp * spans, axis=1)                   # d_i · s_i,  (n,)
    w         = None if weights is None else np.asarray(weights).reshape(n)

    costs = np.empty(m, dtype=float)
    best_idx = 0
    best_cost = np.inf

    for j in range(m):
        p = points[j]                                          # (d,)

        # 计算向量化的 n 条直线到单个点 p 的距离：
        # v_i = p - d_i,  ||v_i||^2 = ||p||^2 + ||d_i||^2 - 2 d_i·p
        p_norm2 = np.dot(p, p)
        d_dot_p = disp @ p                                     # (n,)
        v_norm2 = p_norm2 + d_norm2 - 2.0 * d_dot_p            # (n,)

        # dot_i = (p - d_i) · s_i = p·s_i - d_i·s_i
        s_dot_p = spans @ p                                    # (n,)
        dot     = s_dot_p - d_dot_s                            # (n,)

        # 点到直线距离平方：||perp||^2 = ||v||^2 - (dot^2 / ||s||^2)
        dist2 = v_norm2 - (dot * dot) / s_norm2
        dist2 = np.maximum(dist2, 0.0)                         # 数值稳定
        dists = np.sqrt(dist2)                                 # (n,)

        if w is not None:
            dists = dists * w

        if   agg == "sum":  cost = dists.sum()
        elif agg == "mean": cost = dists.mean()
        elif agg == "max":  cost = dists.max()
        else:
            raise ValueError("agg must be one of {'sum','mean','max'}")

        costs[j] = cost
        if cost < best_cost:
            best_cost = cost
            best_idx  = j

    return best_idx, points[best_idx]




def learning_augmented_for_lines(spans,displacements,corrupted_labels,k,alpha,radis):
    center=[]
    for i in range(k):
        print(i)
        iter_num=5
        index_i=np.where(corrupted_labels == i)[0]
        spans_i=spans[index_i]
        displacements_i=displacements[index_i]
        num_i=np.shape(spans_i)[0]
        best_points=[]
        for j in range(iter_num):
            print(j)
            # get two lines
            idx_two_lines = np.random.choice(num_i, size=2, replace=False)
            span1=spans_i[idx_two_lines[0]]
            span2=spans_i[idx_two_lines[1]]
            disp1=displacements_i[idx_two_lines[0]]
            disp2 = displacements_i[idx_two_lines[1]]

            anchor_set=generate_cube_points(span1, disp1, span2, disp2, radis[i])
            idx_anchor = np.random.randint(anchor_set.shape[0])
            anchor_point = anchor_set[idx_anchor]
            spans_new, disps_new, removed_idx = remove_farthest_lines(spans, displacements, anchor_point, alpha)

            ################get 1-median############################
            #
            centroid_set = get_all_intersection_points(spans_new, disps_new)
            best_idx, best_point=best_point_for_lines(centroid_set,spans_new,disps_new,agg="sum", weights=None)
            best_points.append(best_point)
        _,final_point=best_point_for_lines(best_points, spans_i, displacements_i, agg="sum", weights=None)
        center.append(final_point)
    return  np.array(center)



def main():

        '''
        #road dataset
        north, south =32.070, 32.058
        east, west = 118.790, 118.770
        G = ox.graph_from_bbox(north, south, east, west, network_type='drive')
        G_proj = ox.project_graph(G)
        edges = ox.graph_to_gdfs(G_proj, nodes=False, edges=True)
        lines = edges["geometry"]



        # 3. 构建 spans, displacements, weights
        spans = []
        displacements = []
        weights = []

        for geom in lines:
            if isinstance(geom, LineString):
                coords = list(geom.coords)
                for i in range(len(coords) - 1):
                    p1 = np.array(coords[i])
                    p2 = np.array(coords[i + 1])
                    vec = p2 - p1
                    norm = np.linalg.norm(vec)
                    if norm == 0:  # 跳过重复点
                        continue
                    span = vec / norm
                    # displacement: 最近点 = p1 + proj(-p1 onto span)
                    proj_len = -np.dot(p1, span)
                    disp = p1 + proj_len * span

                    spans.append(span)
                    displacements.append(disp)
                    weights.append(1.0)

        # 4. 转为 numpy 数组
        spans = np.array(spans)
        displacements = np.array(displacements)
        weights = np.array(weights)

        spans, displacements, weights = remove_parallel_lines(spans, displacements, weights)
        spans = spans / np.linalg.norm(spans, axis=1, keepdims=True)
        scaler = StandardScaler()
        displacements = scaler.fit_transform(displacements)
        print(np.shape(spans)[0])

        #################data_set####################



        spans = np.loadtxt("spans_10000_5.txt")  # shape: (n, d)
        displacements = np.loadtxt("displacements_10000_5.txt")  # shape: (n, d)
        weights = np.loadtxt("weights_10000_5.txt")  # shape: (n, 1) or (n,)
        weights = weights.flatten()
        '''
        num_lines = 1000
        dim =5
        k = 3

        spans, displacements,weights=generate_n_nonparallel_lines(num_lines, dim, seed=2025)
        weights = np.ones(num_lines)  # 每条线的权重都为1
        spans = spans / np.linalg.norm(spans, axis=1, keepdims=True)
        scaler = StandardScaler()
        displacements = scaler.fit_transform(displacements)
       
        k=5
        alpha=0.5
        dim=np.shape(spans)[1]
        n=np.shape(spans)[0]


        dataset_cost10_list = []
        cost_list=[]
        total_list=[]

        tt = 0
        run_times=3
        for t in range(run_times):


            #####################obtain optimal solution#################
            #get the center set of lines
            centroid_set=get_all_intersection_points(spans,displacements)
            #get initial center sets
            center_set=sample_k_points_from_centerset(centroid_set,k)
            #get optimal labels
            opt_labels=assign_labels(spans, displacements, center_set)

            corrupted_labels = corrupt_labels(opt_labels, k=k, error_rate=alpha, seed=2025)

            t0 = time.time()
            radis1=max_line_distance_per_cluster(spans, displacements, corrupted_labels, center_set)

            radis = [x / 2 for x in radis1]

            print(radis)

            pre_center_set=learning_augmented_for_lines(spans.copy(),displacements.copy(),corrupted_labels,k,alpha,radis)

            data_cost=compute_cost_to_centers(spans, displacements, weights, pre_center_set)
            print(data_cost)
            t1 = time.time()
            dataset_cost10_list.append(data_cost)


            tt += t1 - t0


        tt=tt/run_times
        dataset_cost10_list = np.array(dataset_cost10_list)
        min_cost = np.min(dataset_cost10_list)
        max_cost = np.max(dataset_cost10_list)
        mean_cost = np.mean(dataset_cost10_list)
        std_cost = np.std(dataset_cost10_list)
        cost_list.append(min_cost)
        cost_list.append(max_cost)
        cost_list.append(mean_cost)
        cost_list.append(std_cost)
        cost_list.append(tt)
        total_list.append(cost_list)

        print('run times:' + str(run_times))
        print('n: ' + str(n) + ', d: ' + str(dim)+ ', k: ' + str(k))
        #print('north: ' + str(north) + ', south: ' + str(south))
        #print('east: ' + str(east) + ', west: ' + str(west))
        print('min_cost: ' + str(min_cost) + ', max_cost: ' + str(max_cost)+ ', mean_cost: ' + str(mean_cost)+ ', std_cost: ' + str(std_cost)+ ', time: ' + str(tt))
        print(total_list)



if __name__ == '__main__':
    main()
