import numpy as np
import torch


''''
From https://newbedev.com/find-the-shortest-distance-between-a-point-and-line-segments-not-line

'''

def length(v):
    return torch.sqrt(torch.sum(torch.square(v), axis=1) )

def vector(b,e):
    return e - b

def distance(p0,p1):
    return length(vector(p0, p1))

def add(v,w):
    return v + w

def unit(v):
    mag = length(v)
    # https://stackoverflow.com/questions/19602187/numpy-divide-each-row-by-a-vector-element
    return (v.T / mag).T

def torch_2d_cross(a, b):
    return (a[:,0] * b[:,1]) - (b[:,0] * a[:,1])

def get_segment_to_point_dist(start, end, points, verbose=False):
    # Convert the line segment to a vector
    line_vec = vector(start, end)
    # Create a vector connecting start to pnt
    pnt_vec = vector(start, points) 
    # Find the length of the line vector
    line_len = length(line_vec)
    #  Convert line_vec to a unit vector
    line_unitvec = unit(line_vec)
    # Scale pnt_vec by line_len
    s = 1.0/line_len
    pnt_vec_scaled = (pnt_vec.T * s).T
    
    # Get the dot product of line_unitvec and pnt_vec_scaled
    # Use row wise dot product
    t = torch.sum(line_unitvec*pnt_vec_scaled, axis=1)
    
    # Ensure t is in the range 0 to 1.
    t = torch.clip(t, 0, 1)
    # Use t to get the nearest location on the line to the end
    #    of vector pnt_vec_scaled ('nearest').
    # scale line_vec to t
    nearest =  (line_vec.T * t).T
    # Calculate the distance from nearest to pnt_vec_scaled.
    dist = distance(nearest, pnt_vec)
    # Translate nearest back to the start/end line. 
    nearest = add(nearest, start)
    
    angle = torch_2d_cross(line_vec, pnt_vec)
    angle = torch.sign(angle)
    
    if verbose:
        for i in range(len(dist)):
            print("dist %.4f" % dist[i], "closest point:", nearest[i], "  t", t[i], 'angle', angle[i])    
    return dist, nearest, t, angle  # (N,) (N,2) (N,)

'''
Algorithm:

    -   calculate the points and dist for the two consecutive segments (s1 and s2) that connect the nearest point (rl_point)
    -   depending on t. We have 4 cases:
            1) 0 < t1 < 1
                the point is closer and perpendicular to s1 but not s2. Clear, use the closest point to s2
            2) 0 < t2 < 1
                the point is closer and perpendicular to s1 but not s1. Clear, use the closest point to s2
            3) no 1 and 2
                use s1
            4) 1 and 2 together
                use the segment whith the distance to the point is smaller
                
    -  use the cross product to get the sign (below or above). Use the one of s00, shoudn't change for s01
'''
def get_gap(points, racing_line):
    '''
    points are shape (N, 2) -> 2 dimensions (x,y)
                'VehiclePositionX', 'VehiclePositionY'
    racing_line (M, 2) -> M number of points to consider, 2 dimensions  (x,y) 
                'VehiclePositionX', 'VehiclePositionY', 'VehicleHeading'


    returns distance
    '''    
    xcoords =  points[:,0].reshape(1,-1).T - racing_line[:,0]
    ycoords =  points[:,1].reshape(1,-1).T - racing_line[:,1]
    dist = torch.sqrt( torch.square(xcoords) + torch.square(ycoords) )


    # 8us

    # 2)
    # get index in the racing line to the minimum euclidian distance for each point
    rl_point = torch.argmin(dist, axis=1)#.reshape(1,-1)   # shpae (9,)

    # 9 us


    # if closest point is in position 0 change to 1. Needed to calculate the segment later
    rl_point[torch.where(rl_point > (len(racing_line) - 2)  )] = len(racing_line) -2

    # 12us
    # calculate the points and dist for the two segment that connect the nearest point
    start = racing_line[rl_point-1]
    end = racing_line[rl_point+0]   
    #dist_s0, points_s0, t_s0, angle_s0 = get_segment_to_point_dist(start, end, points)

    dist_s0, points_s0, t_s0, angle_s0 = get_segment_to_point_dist(start, end, points)
    
    # 92 us
    start = racing_line[rl_point+0]
    end = racing_line[rl_point+1]   
    dist_s1, points_s1, t_s1, angle_s1 = get_segment_to_point_dist(start, end, points)  

    # 172

    # start with zero
    closest_points = points_s1.clone()
    #closest_points = np.zeros_like(points_s1)    

    # # write s1 first 
    # ok_idx_s1 = np.where((t_s1 > 0.0) & (t_s1 < 1.0))
    # closest_points[ok_idx_s1] = points_s1[ok_idx_s1]
    # closest_points

    # overwrite with s0 (means that s0 has priority)
    ok_idx_s0 = torch.where((t_s0 > 0.0) & (t_s0 < 1.0))
    closest_points[ok_idx_s0] = points_s0[ok_idx_s0]
    closest_points    



    # 182

    # ----
    # without this block from 238us to 196 µs 
    # overlapping case
    # get index to overlapping entries
    ok_idx_s0_and_s1 = torch.where((t_s0 > 0.0) & (t_s0 < 1.0) & (t_s1 > 0.0) & (t_s1 < 1.0))

    # get which segment has a lower distance (N,)
    dist_min = torch.argmax(torch.cat((dist_s0.reshape(-1,1) , dist_s1.reshape(-1,1) ) ,dim=1), axis=1)
    dist_min

    # overwrite cloestpoint with overlapping indexes
    closest_points[ok_idx_s0_and_s1] = (points_s0[ok_idx_s0_and_s1].T * ((dist_min == 0) * 1.0)[ok_idx_s0_and_s1]).T + \
                                    (points_s1[ok_idx_s0_and_s1].T * ((dist_min == 1) * 1.0)[ok_idx_s0_and_s1]).T


    dist = distance(closest_points, points)
    dist = dist * angle_s0 * (-1)

    return dist, closest_points