import numpy as np
from sklearn.neighbors import KDTree

class LineIndex:
    def __init__(self, lines):
        self.lines = lines
        self.directions = np.array([((end-start)/np.linalg.norm(end-start)) for start, end in lines])
        self.midpoints = np.array([(start+end)/2 for start, end in lines])

        self.direction_index = KDTree(self.directions)
        self.spatial_index = KDTree(self.midpoints)

    def find_nearest_line(self, query_start, query_end):
        query_direction = (query_end - query_start) / np.linalg.norm(query_end - query_start)
        _, dir_idxs = self.direction_index.query([query_direction], k=10)

        candidate_midpoints = self.midpoints[dir_idxs[0]]
        query_midpoint = (query_start + query_end) / 2
        _, spatial_idxs = KDTree(candidate_midpoints).query([query_midpoint], k=1)

        nearest_idx = dir_idxs[0][spatial_idxs[0][0]]
        return self.lines[nearest_idx]