import numpy as np
import torch
from .gap_torch import get_gap as get_gap_torch

''''
From https://newbedev.com/find-the-shortest-distance-between-a-point-and-line-segments-not-line

'''


def length(v):
    return np.sqrt(np.sum(np.square(v), axis=1) )

def vector(b,e):
    return e - b

def distance(p0,p1):
    return length(vector(p0, p1))

def add(v,w):
    return v + w

def unit(v):
    mag = length(v)
    # https://stackoverflow.com/questions/19602187/numpy-divide-each-row-by-a-vector-element
    return (v.T / mag).T



def get_segment_to_point_dist(start, end, points, verbose=False):
    # Convert the line segment to a vector
    line_vec = vector(start, end)
    # Create a vector connecting start to pnt
    pnt_vec = vector(start, points) 
    # Find the length of the line vector
    line_len = length(line_vec)
    #  Convert line_vec to a unit vector
    line_unitvec = unit(line_vec)
    # Scale pnt_vec by line_len
    s = 1.0/line_len
    pnt_vec_scaled = (pnt_vec.T * s).T
    
    # Get the dot product of line_unitvec and pnt_vec_scaled
    # Use row wise dot product
    t = np.sum(line_unitvec*pnt_vec_scaled, axis=1)
    
    # Ensure t is in the range 0 to 1.
    t = np.clip(t, 0, 1)
    # Use t to get the nearest location on the line to the end
    #    of vector pnt_vec_scaled ('nearest').
    # scale line_vec to t
    nearest =  (line_vec.T * t).T
    # Calculate the distance from nearest to pnt_vec_scaled.
    dist = distance(nearest, pnt_vec)
    # Translate nearest back to the start/end line. 
    nearest = add(nearest, start)
    
    angle = np.cross(line_vec, pnt_vec)
    angle = np.sign(angle)
    
    if verbose:
        for i in range(len(dist)):
            print("dist %.4f" % dist[i], "closest point:", nearest[i], "  t", t[i], 'angle', angle[i])    
    return dist, nearest, t, angle  # (N,) (N,2) (N,)


'''
Algorithm:

    -   calculate the points and dist for the two consecutive segments (s1 and s2) that connect the nearest point (rl_point)
    -   depending on t. We have 4 cases:
            1) 0 < t1 < 1
                the point is closer and perpendicular to s1 but not s2. Clear, use the closest point to s2
            2) 0 < t2 < 1
                the point is closer and perpendicular to s1 but not s1. Clear, use the closest point to s2
            3) no 1 and 2
                use s1
            4) 1 and 2 together
                use the segment whith the distance to the point is smaller
                
    -  use the cross product to get the sign (below or above). Use the one of s00, shoudn't change for s01
'''
def get_gap(points, racing_line):
    '''
    points are shape (N, 2) -> 2 dimensions (x,y)
                'VehiclePositionX', 'VehiclePositionY'
    racing_line (M, 2) -> M number of points to consider, 2 dimensions  (x,y) 
                'VehiclePositionX', 'VehiclePositionY', 'VehicleHeading'


    returns distance
    '''
    if isinstance(points, np.ndarray):
        xcoords =  points[:,0].reshape(1,-1).transpose() - racing_line[:,0]
        ycoords =  points[:,1].reshape(1,-1).transpose() - racing_line[:,1]
        # dist shape: N, M. distance of each point to each point in the racing line
        dist = np.sqrt( np.square(xcoords) + np.square(ycoords) )    

        # 8us

        # 2)
        # get index in the racing line to the minimum euclidian distance for each point
        rl_point = np.argmin(dist, axis=1)#.reshape(1,-1)   # shpae (9,)

        # 9 us
        

        # if closest point is in position 0 change to 1. Needed to calculate the segment later
        # TODO enable
        rl_point[np.where(rl_point > (len(racing_line) - 2)  )] = len(racing_line) -2

        # 12us
        
        # calculate the points and dist for the two segment that connect the nearest point
        start = racing_line[rl_point-1]
        end = racing_line[rl_point+0]    

        dist_s0, points_s0, t_s0, angle_s0 = get_segment_to_point_dist(start, end, points)

        # 92 us
        start = racing_line[rl_point+0]
        end = racing_line[rl_point+1]   
        dist_s1, points_s1, t_s1, angle_s1 = get_segment_to_point_dist(start, end, points)

        # 172

        # start with zero
        closest_points = points_s1.copy()
        #closest_points = np.zeros_like(points_s1)

        # # write s1 first 
        # ok_idx_s1 = np.where((t_s1 > 0.0) & (t_s1 < 1.0))
        # closest_points[ok_idx_s1] = points_s1[ok_idx_s1]
        # closest_points

        # overwrite with s0 (means that s0 has priority)
        ok_idx_s0 = np.where((t_s0 > 0.0) & (t_s0 < 1.0))
        closest_points[ok_idx_s0] = points_s0[ok_idx_s0]
        closest_points

        # 182

        # ----
        # without this block from 238us to 196 µs 
        # overlapping case
        # get index to overlapping entries
        ok_idx_s0_and_s1 = np.where((t_s0 > 0.0) & (t_s0 < 1.0) & (t_s1 > 0.0) & (t_s1 < 1.0))

        # get which segment has a lower distance (N,)
        dist_min = np.argmin( [dist_s0 , dist_s1 ], axis=0)

        # overwrite cloestpoint with overlapping indexes
        closest_points[ok_idx_s0_and_s1] = (points_s0[ok_idx_s0_and_s1].T * ((dist_min == 0) * 1.0)[ok_idx_s0_and_s1]).T + \
                                        (points_s1[ok_idx_s0_and_s1].T * ((dist_min == 1) * 1.0)[ok_idx_s0_and_s1]).T
        # ----

        # 207

        dist = distance(closest_points, points)
        dist = dist * angle_s0 * (-1)

    elif isinstance(points, torch.Tensor):
        dist, closest_points = get_gap_torch(points, racing_line)

    return dist, closest_points


