import math

import numpy as np

from coreset_for_k_means_for_lines_median import CorsetForKMeansForLines
from main_LS import generate_n_nonparallel_lines
from parameters_config import ParameterConfig
from set_of_lines_median import SetOfLines
import time
from shapely.geometry import LineString
import osmnx as ox
from sklearn.preprocessing import StandardScaler

def remove_parallel_lines(spans, displacements, weights, tol=1e-6):
    """
    移除方向向量几乎完全相同（即平行）的重复直线。
    Args:
        spans (np.ndarray): shape = (n, d)，每条直线的方向向量
        displacements (np.ndarray): shape = (n, d)，每条直线的最近点
        weights (np.ndarray): shape = (n,)，每条直线的权重
        tol (float): 角度容忍度，用于判断是否平行（越小越严格）

    Returns:
        spans_new, displacements_new, weights_new: 过滤后的直线集合
    """
    n, d = spans.shape
    norm_spans = spans / np.linalg.norm(spans, axis=1, keepdims=True)

    keep_indices = []
    used = np.zeros(n, dtype=bool)

    for i in range(n):
        if used[i]:
            continue
        keep_indices.append(i)
        dot_prods = np.dot(norm_spans, norm_spans[i])
        parallels = np.abs(dot_prods) > 1 - tol  # 方向夹角近似为0或π
        used[parallels] = True  # 把所有平行的标记掉

    return spans[keep_indices], displacements[keep_indices], weights[keep_indices]


def compute_cost_to_centers(spans, displacements, weights, centers):
    """
    Computes the total weighted cost from each line to its nearest center.

    Args:
        spans (np.ndarray): shape (n, d), unit direction vectors of lines.
        displacements (np.ndarray): shape (n, d), closest point on line to origin.
        weights (np.ndarray): shape (n,), weights for each line.
        centers (np.ndarray): shape (k, d), center points.

    Returns:
        float: total cost (weighted sum of squared distances).
    """
    n, d = spans.shape
    k = centers.shape[0]

    # Expand dims to broadcast (n, d) vs (k, d)
    disp_expand = displacements[:, np.newaxis, :]  # (n, 1, d)
    span_expand = spans[:, np.newaxis, :]  # (n, 1, d)
    centers_expand = centers[np.newaxis, :, :]  # (1, k, d)

    diff = centers_expand - disp_expand  # (n, k, d)
    norm_sq = np.sum(diff ** 2, axis=2)  # (n, k)
    proj = np.sum(diff * span_expand, axis=2)  # (n, k)
    proj_sq = proj ** 2  # (n, k)

    s_norm_sq = np.sum(span_expand * span_expand, axis=2)
    perp_sq = norm_sq - proj_sq / np.maximum(s_norm_sq, 1e-12)
    perp_sq = np.maximum(perp_sq, 0.0)
    distances = np.sqrt(perp_sq)  # (n, k)
    min_distances = np.min(distances, axis=1)  # (n,)

    total_cost = np.sum(weights * min_distances)
    return total_cost





def main():
    # 构造三条二维直线：方向向量和位移向量
    '''
    north, south =32.070, 32.058
    east, west = 118.790, 118.770

    G = ox.graph_from_bbox(north, south, east, west, network_type='drive')
    G_proj = ox.project_graph(G)
    edges = ox.graph_to_gdfs(G_proj, nodes=False, edges=True)
    lines = edges["geometry"]

    # 3. 构建 spans, displacements, weights
    spans = []
    displacements = []
    weights = []

    for geom in lines:
        if isinstance(geom, LineString):
            coords = list(geom.coords)
            for i in range(len(coords) - 1):
                p1 = np.array(coords[i])
                p2 = np.array(coords[i + 1])
                vec = p2 - p1
                norm = np.linalg.norm(vec)
                if norm == 0:  # 跳过重复点
                    continue
                span = vec / norm
                # displacement: 最近点 = p1 + proj(-p1 onto span)
                proj_len = -np.dot(p1, span)
                disp = p1 + proj_len * span

                spans.append(span)
                displacements.append(disp)
                weights.append(1.0)

    # 4. 转为 numpy 数组
    spans = np.array(spans)
    displacements = np.array(displacements)
    weights = np.array(weights)
    print(np.shape(spans)[0])
    spans, displacements, weights = remove_parallel_lines(spans, displacements, weights)

    spans, displacements, weights = remove_parallel_lines(spans, displacements, weights)
    spans = spans / np.linalg.norm(spans, axis=1, keepdims=True)
    scaler = StandardScaler()
    displacements = scaler.fit_transform(displacements)
    print(np.shape(spans)[0])

    '''
    num_lines = 1000
    dim = 5

    
    spans, displacements, weights = generate_n_nonparallel_lines(num_lines, dim, seed=2025)
    # weights = np.ones(num_lines)  # 每条线的权重都为1
    spans = spans / np.linalg.norm(spans, axis=1, keepdims=True)
    scaler = StandardScaler()
    displacements = scaler.fit_transform(displacements)
    print(np.shape(spans)[0])

    '''
    print('save data')
    np.savetxt("spans_5000_10.txt", spans, fmt="%.6f")
    np.savetxt("displacements_5000_10.txt", displacements, fmt="%.6f")
    np.savetxt("weights_5000_10.txt", weights.reshape(-1, 1), fmt="%.6f")  # reshape 使其成为二维列向量
  
    
    spans = np.loadtxt("spans_10000_5.txt")  # shape: (n, d)
    displacements = np.loadtxt("displacements_10000_5.txt")  # shape: (n, d)
    weights = np.loadtxt("weights_10000_5.txt")  # shape: (n, 1) or (n,)
    weights = weights.flatten()
    '''
    k = 3
    dim = np.shape(spans)[1]
    print(dim)
    num_lines=np.shape(spans)[0]
    dataset_cost10_list = []
    cost_list = []
    total_list = []

    tt = 0
    run_times=3
    error=0.86

    print('sampling number')
    n_sample=10
    for t in range(run_times):
        print(t)
        t0 = time.time()
        A = ParameterConfig(num_lines, k, dim)
        L = SetOfLines(spans, displacements, weights)
        constructor = CorsetForKMeansForLines(A)
        S=constructor.coreset(L, k, m=n_sample)
        print('coreset')

        approx_centers= S.get_4_approx_points(k)
        #approx_centers=S.get_4_approx_points_ex_search(k)
        data_cost=compute_cost_to_centers(S.spans, S.displacements, S.weights, approx_centers.points)
        print(data_cost)
        t1 = time.time()
        dataset_cost10_list.append(data_cost)

        tt += t1 - t0

    tt = tt / run_times
    dataset_cost10_list = np.array(dataset_cost10_list)
    min_cost = np.min(dataset_cost10_list)
    max_cost = np.max(dataset_cost10_list)
    mean_cost = np.mean(dataset_cost10_list)
    std_cost = np.std(dataset_cost10_list)
    cost_list.append(min_cost)
    cost_list.append(max_cost)
    cost_list.append(mean_cost)
    cost_list.append(std_cost)
    cost_list.append(tt)
    total_list.append(cost_list)


    print('coresets')
    print('run times:' + str(run_times))
    print('n: ' + str(num_lines) + ', d: ' + str(dim) + ', k: ' + str(k))
    #print('north: ' + str(north) + ', south: ' + str(south))
    #print('east: ' + str(east) + ', west: ' + str(west))
    print('min_cost: ' + str(min_cost) + ', max_cost: ' + str(max_cost) + ', mean_cost: ' + str(
        mean_cost) + ', std_cost: ' + str(std_cost) + ', time: ' + str(tt))
    print(total_list)


if __name__ == '__main__':
    main()
