"""
Racer_RRT_STAR 2D
@author: *Anonymous*
"""

import os
import sys
import math
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# from scipy.spatial.transform import Rotation as Rot

sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
                "/../../Sampling_based_Planning/")

from Sampling_based_Planning.rrt_2D import env, utils #,  plotting
# import CurvesGenerator.dubins_path as dubins
import CurvesGenerator.draw as draw


class Node:
    def __init__(self, x, y, v=None, friction=1):
        self.x = x
        self.y = y
        self.u = 0.
        self.v = v
        self.lqrCost = 0.
        self.parent = None
        self.cost = 0.0
        self.path_x = []
        self.path_y = []
        self.friction = friction


class RacerRRTStar:
    def __init__(self, sx, sy, gx, gy, vehicle_radius, step_len,
                 goal_sample_rate, search_radius, iter_max, v0):
        self.s_start = Node(sx, sy)
        self.s_start.v = v0
        self.s_goal = Node(gx, gy)
        self.vr = vehicle_radius
        self.step_len = step_len
        self.goal_sample_rate = goal_sample_rate
        self.search_radius = search_radius
        self.iter_max = iter_max

        self.env = env.Env()
        self.utils = utils.Utils()

        self.fig, self.ax = plt.subplots()
        self.delta = self.utils.delta
        self.x_range = self.env.x_range
        self.y_range = self.env.y_range
        self.obs_circle = self.env.obs_circle
        self.obs_circle_array = np.array(self.env.obs_circle[0])
        self.obs_boundary = self.env.obs_boundary
        self.ice_rectangle = self.env.ice_rectangle
        self.ice_rect = self.env.ice_rect
        self.mu_ice = 0.35 
        self.utils.update_obs(self.obs_circle, self.obs_boundary, [])
        self.umax = 0.75
        
        self.V = [self.s_start]
        self.path = None
    
    def set_mu(self, node):
        if node.x >= self.ice_rect[0] and node.x <=self.ice_rect[1]:
            if node.y >= self.ice_rect[2] and node.y <= self.ice_rect[3]:
                node.friction = self.mu_ice
        return node
    
    
    def planning(self, yRef, plotpath=True, underway=False, s_underway_x = None, s_underway_y=None, s_underway_v=None):
        
        if yRef <= 3.0:
            minSteps = 1000
            if underway:
                minSteps = 200
            maxCounts = 11
        else: 
            minSteps = int(450*(self.s_goal.y+1.-self.s_start.y)/30.)
            maxCounts = 21
        
        if underway: 
            self.s_start = Node(s_underway_x, s_underway_y, v=s_underway_v)
            self.V[0] = self.s_start
            if self.is_collision(self.V[0], singleton=True):
                return False, False, False, False, False, False, False
            
            self.path = None
            for k in range(1, len(self.V)):
                testNode = self.V[k]
                if k < len(self.V)-1:
                    near_indexes = self.Near((self.V[:k]+self.V[k+1:]), testNode)
                else:
                    near_indexes = self.Near(self.V[:k], testNode)
                
                testNode = self.choose_parent(testNode, near_indexes)
                if testNode:
                    testNode = self.set_mu(testNode)
                    self.V[k] = testNode
                    self.rewire(testNode, near_indexes)
                    
            for i in range(minSteps):
                rnd = self.Sample()
                node_nearest, dist_near = self.Nearest(self.V, rnd)
                new_node, _ = self.Steer(node_nearest, rnd, dist_near)
                
                if new_node:
                    if not self.is_collision(new_node):
                        near_indexes = self.Near(self.V, new_node)
                        new_node = self.choose_parent(new_node, near_indexes)
                        if new_node:
                            self.V.append(new_node)
                            self.rewire(new_node, near_indexes)
                
            last_index = self.search_best_goal_node()
            counter = 1
            
        else:
            last_index = None
            counter = 0
            if self.is_collision(self.V[0], singleton=True):
                return False, False, False, False, False, False, False
        
        
        while last_index is None:        
            for i in range(minSteps):
                if i % 50 == 0:
                    print("Iter:", i+counter*minSteps, ", number of nodes:", len(self.V))
                rnd = self.Sample()
                #print('Sample Point:')
                #print(rnd.x, rnd.y)
                
                node_nearest, dist_near = self.Nearest(self.V, rnd)
                new_node, _ = self.Steer(node_nearest, rnd, dist_near)
                if i % minSteps == 0:
                    print('Distance: ')
                    print(dist_near)
                    print('Nearest Node: ')
                    print(node_nearest.x, node_nearest.y)
                    print('New Node:')
                    print(new_node.x, new_node.y)
                    print()
                    # breakpoint()
                
                if new_node and not self.is_collision(new_node):
                    near_indexes = self.Near(self.V, new_node)
                    #if i % 1 == 0:
                    #    print(len(near_indexes))
                    new_node = self.choose_parent(new_node, near_indexes)
                    #print('New Node:')
                    #print(new_node.x, new_node.y)
                    
                    if new_node:
                        self.V.append(new_node)
                        self.rewire(new_node, near_indexes)
                        #print(len(self.V))
                        #print(self.V[-1].x, self.V[-1].y)
                if 0:
                    if i % 5 == 0:
                        self.draw_graph()
                if 0:
                    if i % minSteps == 0:
                        plt.cla()
                        for node in self.V:
                            plt.plot(node.x, node.y, "r*")
                        
                        plt.show()
                        breakpoint()
                        
    
            last_index = self.search_best_goal_node()
            counter +=1 
            
            if counter == maxCounts and last_index is None:
                print('Exceeded ', counter*minSteps,' iterations; Path likely infeasible')
                return False, False, 1, 0., False
        
        if last_index is None:
            return False, False, False, False, False
        
        path = self.generate_final_course(last_index)
        print("get!")
        px = [s[0] for s in path]
        py = [s[1] for s in path]
        
        if plotpath:
            fig3, ax3 = plt.subplots(1, 1)
            for (ox, oy, w, h) in self.ice_rectangle:
                ax3.add_patch(
                    patches.Rectangle(
                        (ox, oy), w, h,
                        edgecolor='blue',
                        facecolor='blue',
                        fill=True
                    )
                )
            for (ox, oy, w, h) in self.obs_boundary:
                ax3.add_patch(
                    patches.Rectangle(
                        (ox, oy), w, h,
                        edgecolor='black',
                        facecolor='black',
                        fill=True
                    )
                )
        
            for (ox, oy, r) in self.obs_circle:
                ax3.add_patch(
                    patches.Circle(
                        (ox, oy), r,
                        edgecolor='black',
                        facecolor='gray',
                        fill=True
                    )
                )
            ax3.set_ylim([-1., 11.])
            ax3.set_xlim([-5., 5.])
            ax3.plot(px, py, '-r')
            ax3.set_title("Actual Path")
            fig3.savefig('Sampling_based_Planning/rrt_2D/InitialPath.png')
            plt.show()
            plt.pause(2)
            
        YY = np.zeros(len(py))
        for k in range(len(py)):
            YY[k] = np.abs(np.sqrt((py[k]-self.s_start.y)**2 + (px[k] - self.s_start.x)**2) - self.step_len)
        yIDX = int(np.argmin(YY))
        if (yIDX == len(py)-1):
            yIDX += -1
        
        actualnodeEst = Node(px[yIDX], py[yIDX])
        actualnodeReal, uStar = self.Steer(self.s_start, actualnodeEst)
        if actualnodeReal is None:
            print(self.s_start.x, self.s_start.y)
            print(actualnodeEst.x, actualnodeEst.y)
            print(len(py))
            print(YY)
            
        
        print('uStar for y in [', yRef-0.5,', ', yRef,'] = ', uStar)
        print('velocity at new node: ', actualnodeReal.v)
        print(actualnodeReal.x, actualnodeReal.y)
        
        px0 = [s for s in actualnodeReal.path_x]
        py0 = [s for s in actualnodeReal.path_y]
        
        print('Ascending Time')
        return px0, py0, actualnodeReal.v, uStar, True, px, py
        '''
        if py0[-1] > py0[0]:
            print('Ascending Time')
            return px0, py0, actualnodeReal.v, uStar, True, px, py
        else:
            print('Reverse Time')
            return px0, py0, actualnodeReal.v, uStar, True, px, py
        '''
        

    def draw_graph(self, rnd=None):
        plt.cla()
        # for stopping simulation with the esc key.
        plt.gcf().canvas.mpl_connect('key_release_event',
                                     lambda event: [exit(0) if event.key == 'escape' else None])
        for node in self.V:
            if node.parent:
                plt.plot(node.path_x, node.path_y, "-g")

        self.plot_grid("dubins rrt*")
        plt.plot(self.s_start.x, self.s_start.y, "xr")
        plt.plot(self.s_goal.x, self.s_goal.y, "xr")
        plt.grid(True)
        self.plot_start_goal_arrow()
        plt.pause(0.01)

    def plot_start_goal_arrow(self):
        draw.Arrow(self.s_start.x, self.s_start.y, self.s_start.yaw, 2, "darkorange")
        draw.Arrow(self.s_goal.x, self.s_goal.y, self.s_goal.yaw, 2, "darkorange")

    def generate_final_course(self, goal_index):
        print("final")
        path = [[self.s_goal.x, self.s_goal.y]]
        node = self.V[goal_index]
        while node.parent:
            for (ix, iy) in zip(reversed(node.path_x), reversed(node.path_y)):
                path.append([ix, iy])
            node = node.parent
        path.append([self.s_start.x, self.s_start.y])
        return path

    def calc_dist_to_goal(self, x, y):
        dx = x - self.s_goal.x
        dy = y - self.s_goal.y
        return math.hypot(dx, dy)

    def search_best_goal_node(self):
        dist_to_goal_list = [self.calc_dist_to_goal(n.x, n.y) for n in self.V]
        print(np.min(np.array(dist_to_goal_list)))
        print(self.step_len)
        goal_inds = [dist_to_goal_list.index(i) for i in dist_to_goal_list if i <= self.step_len]

        safe_goal_inds = []
        for goal_ind in goal_inds:
            t_node, _ = self.Steer(self.V[goal_ind], self.s_goal)
            if t_node and not self.is_collision(t_node):
                safe_goal_inds.append(goal_ind)

        if not safe_goal_inds:
            return None

        min_cost = min([self.V[i].cost for i in safe_goal_inds])
        for i in safe_goal_inds:
            if self.V[i].cost == min_cost:
                return i

        return None

    def rewire(self, new_node, near_inds):
        for i in near_inds:
            near_node = self.V[i]
            edge_node, uS = self.Steer(new_node, near_node)
            if not edge_node:
                continue
            edge_node.cost, edge_node.lqrCost = self.calc_new_cost(new_node, near_node, uS)

            no_collision = ~self.is_collision(edge_node)
            improved_cost = near_node.cost > edge_node.cost

            if no_collision and improved_cost:
                self.V[i] = edge_node
                self.propagate_cost_to_leaves(new_node)

    def choose_parent(self, new_node, near_inds):
        if not near_inds:
            return None

        costs = []
        for i in near_inds:
            near_node = self.V[i]
            t_node, near_node.u = self.Steer(near_node, new_node)
            if t_node and not self.is_collision(t_node):
                tmp, _ = self.calc_new_cost(near_node, new_node)
                costs.append(tmp)
            else:
                costs.append(float("inf"))  # the cost of collision node
        min_cost = min(costs)

        if min_cost == float("inf"):
            print("There is no good path.(min_cost is inf)")
            return None

        min_ind = near_inds[costs.index(min_cost)]
        new_node, _ = self.Steer(self.V[min_ind], new_node)

        return new_node

    def calc_new_cost(self, from_node, to_node, uStar=None):
        d, _ = self.get_distance_and_angle(from_node, to_node)
        dt = to_node.y - from_node.y
        
        cx = 0.15*to_node.x**2
        #if uStar is None:
        #    cu = 0.025*(from_node.u**2)
        #else:
        #    cu = 0.025*(uStar**2)
        cu = 0.005*np.linalg.norm(from_node.v - to_node.v)**2
        # print(to_node.x)
        
        doo = 1000.
        roo = self.obs_circle_array[2]
        for (ox, oy, r) in self.obs_circle:
            tmpDist = np.linalg.norm(np.array([ox-to_node.x, oy-to_node.y])) - r
            if tmpDist <= doo:
                # print('Found a nearer obstacle')
                doo = tmpDist
                roo = r
        # doo = np.maximum(np.linalg.norm(np.array([self.obs_circle_array[0]-to_node.x, self.obs_circle_array[1]-to_node.y+1.5*float(roo)])), 0.)

        # co = 5.0*np.exp(roo - doo)
        co = 0.
        if doo <= 2.5*roo:
            co = (roo+1.)/np.sqrt(1+(roo+doo)**2)
        
        return from_node.cost + (cu + cx + co)*dt, cx + cu

    def propagate_cost_to_leaves(self, parent_node):
        for node in self.V:
            if node.parent == parent_node:
                node.cost, node.lqrCost = self.calc_new_cost(parent_node, node)
                self.propagate_cost_to_leaves(node)

    @staticmethod
    def get_distance_and_angle(node_start, node_end):
        dx = node_end.x - node_start.x
        dy = node_end.y - node_start.y
        return math.hypot(dx, dy), math.atan2(dy, dx)

    def Near(self, nodelist, node):
        n = len(nodelist) + 1
        r = min(self.search_radius * math.sqrt((math.log(n)) / n), 1.1*self.step_len)

        dist_table = [(nd.x - node.x) ** 2 + (nd.y - node.y) ** 2 for nd in nodelist]
        dist_table2 = [nd.y<node.y for nd in nodelist]
        node_near_ind = [ind for ind in range(len(dist_table)) if (dist_table[ind] <= r ** 2 and dist_table2[ind])]

        return node_near_ind

    def Steer(self, node_start, node_end, d=0):
        
        if d < -0.5:
            return False, False
        
        sx, sy = float(node_start.x), float(node_start.y)
        gx, gy = float(node_end.x), float(node_end.y)
        
        d = np.sqrt((gx-sx)**2 + (gy-sy)**2)
        #print(d)
        if d != 0: 
            # Then we cut off the node_end to be closer to node_start
            v0 = np.zeros(2)
            v0[0] = gx-sx
            v0[1] = gy-sy
            v0 = v0/(np.linalg.norm(v0)+0.00000001)
            uStar = [v0]
            
            if d >= self.step_len:
                d = self.step_len

            v1 = d*v0
            v2 = d*v0/self.step_len
            gx = sx + v1[0]
            gy = sy + v1[1]
            #print(gx, gy)
         
        if node_start.v is not None:
            vxres = node_start.v
        else: # node_start.v is None #and node_start.parent is None:
            vxres = np.zeros(2)
            vxres[1] = 1.0
        #else:
        #    if len(node_start.path_x) >=3:
        #        vxres = (node_start.path_x[-1]-node_start.path_x[-3])/((node_start.path_y[-1]-node_start.path_y[-3])+0.0001)
        #    elif len(node_start.path_x) ==2:
        #        vxres = (node_start.path_x[-1]-node_start.path_x[-2])/((node_start.path_y[-1]-node_start.path_y[-2])+0.0001)
        #    else:
        #        vxres = (sx - node_start.parent.x)/(sy - node_start.parent.y + 0.0001)
        #    # px, py = float(node_start.parent.x), float(node_start.parent.y)
        
        if np.linalg.norm(vxres) > 1.:
            vxres = vxres*1./np.absolute(vxres)
        dt = d # float(gy-sy)
        # dtp = float(sy-py)
        if dt <= 0.:
            # print('Got nonpositive dt!')
            return None, 0.

        # vxres = (sx-px)/dtp
        # Acceleration control [1D]
        if 0:
            uStar, xlist, ylist, vP = self.IntDynamicsAndGetU(sx, vxres, dt, sy, gx)
        else:
            # Velocity Control
            ss = np.arange(1, 11)/10.
            # print(np.max(ss))
            xlist = [sx]
            ylist = [sy]
            for k in range(10):
                xlist.append(v1[0]*ss[k]+sx)
                ylist.append(v1[1]*ss[k]+sy)
            vP = v0
        '''
        gxhat = dt*vxres + sx
        xres = gx-gxhat
        uhat = 0.3*vxres + 0.1*(sx+px)/2.
        ures = 2.*xres/(dt**2)
        u = (uhat + ures)
        u = self.umax*u/(np.abs(u)+0.000000001)

        pathLength = int(np.floor(dt/0.1)+1)
        deltaT = dt/float(pathLength)
        v = vxres
        ylist = [np.arange(node_start.y, node_end.y, pathLength)]
        xlist = [node_start.x]
        x = sx
        for k in range(1, pathLength):
            vPre = v
            vPost = vPre*(1.0-0.3*deltaT) + u*deltaT
            
            x += 0.5*(vPre+vPost) * deltaT
            xlist.append(x)
            v = v*(1.0-0.3*deltaT) + u*deltaT

        # path = dubins.calc_dubins_path(sx, sy, syaw, gx, gy, gyaw, maxc)
        '''
        
        if len(xlist) < 1:
            # print('Got null xlist!')
            return None, 0.
        #print(uStar)
        #print(xlist[-1])
        #print(ylist[-1])
        
        node_new = Node(xlist[-1], ylist[-1])
        node_new.path_x = xlist
        node_new.path_y = ylist
        node_start.u = uStar
        node_new.v = vP
        # node_new.path_yaw = path.yaw
        node_new.cost, node_new.lqrCost = self.calc_new_cost(node_start, node_new)
        node_new.parent = node_start

        return node_new, uStar

    def Sample(self):
        delta = self.utils.delta

        if random.random() > self.goal_sample_rate:
            return Node(random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
                        random.uniform(self.s_start.y, self.y_range[1] - delta))
        else:
            return self.s_goal
        
    def IntDynamicsAndGetU(self, x0, v0, dt, y0, xdes):
        LL = int(np.floor(dt/0.05)+1)
        if LL == 1:
            deltaT = dt
        else:
            deltaT = dt/float(LL)

        xlist = [x0]
        ylist = [y0]
        
        x = x0
        vPre = v0
        u = self.umax
        for tt in range(LL):
            vPost = vPre*(1-0.3*deltaT) + u*deltaT - 0.1*deltaT*x
            x += 0.5*(vPre+vPost)*deltaT
            
            ylist.append(y0+float(tt)*deltaT)
            xlist.append(x)
            vPre = vPost
        
        xHatMax = x
        x = x0
        vPre = v0
        if xHatMax < xdes:
            return u, xlist, ylist, vPost
        
        u = -self.umax
        xlist = [x0]
        
        for tt in range(LL):
            vPost = vPre*(1-0.3*deltaT) + u*deltaT - 0.1*deltaT*x
            x += 0.5*(vPre+vPost)*deltaT
            
            xlist.append(x)
            vPre = vPost
        
        xHatMin = x
        x = x0
        if xHatMin > xdes:
            return u, xlist, ylist, vPost
        else: 
            xHat = xHatMin
            while np.absolute(xHat-xdes) >= 0.05:
                if (xdes-xHat > 0):
                    u += np.minimum(np.absolute(xHat-xdes), 10.)*.01
                else:
                    u += -np.minimum(np.absolute(xHat-xdes), 10.)*.01
                
                xlist = [x0]
                x = x0
                vPre = v0
                for tt in range(LL):
                    vPost = vPre*(1.-(0.3*deltaT)) + u*deltaT - 0.1*deltaT*x
                    x += 0.5*(vPre+vPost)*deltaT
                    
                    xlist.append(x)
                    vPre = vPost
                
                xHat = x
            
            return u, xlist, ylist, vPost
        
    @staticmethod
    def Nearest(nodelist, n):
        nd20 = np.zeros(len(nodelist))
        count = int(0)
        for nd in nodelist:
            if nd.y < n.y:
                nd20[count] = 1.0
            count += 1
        
        nd2i = np.where(nd20 > 0.5)
        nd2i2 = np.array([item for sublist in nd2i for item in sublist]).astype(int)
        if nd2i2.size == 0:
            return False, -1.
        tmp = np.zeros(nd2i2.size)
        for k in range(nd2i2.size):
            nd = nodelist[int(nd2i2[k])]
            tmp[k] = np.sqrt((nd.x - n.x) ** 2 + (nd.y - n.y) ** 2)
            
        vali = int(np.argmin(tmp))
        
        dist = np.min(tmp)
        #print(nd2i2)
        # breakpoint()
        #print(len(nodelist))
        #nd2 = nodelist[nd2i2]
        #vali = int(np.argmin([(nd.x - n.x) ** 2 + (nd.y - n.y) ** 2
                                       # for nd in nd2]))

        return nodelist[int(nd2i2[vali])], dist

    def is_collision(self, node, singleton=False):
        if singleton:
            for ox, oy, r in self.obs_circle:
                dx = [ox - node.x]
                dy = [oy - node.y]
                
                dist = np.hypot(dx, dy)
                
                if min(dist) < r + self.delta:
                    return True

        else:
            for ox, oy, r in self.obs_circle:
                dx = [ox - x for x in node.path_x]
                dy = [oy - y for y in node.path_y]
                
                dist = np.hypot(dx, dy)
    
                if min(dist) < r + self.delta:
                    return True

        return False

    def animation(self):
        self.plot_grid("dubins rrt*")
        self.plot_arrow()
        plt.show()

    def plot_arrow(self):
        draw.Arrow(self.s_start.x, self.s_start.y, self.s_start.yaw, 2.5, "darkorange")
        draw.Arrow(self.s_goal.x, self.s_goal.y, self.s_goal.yaw, 2.5, "darkorange")

    def plot_grid(self, name):

        for (ox, oy, w, h) in self.obs_boundary:
            self.ax.add_patch(
                patches.Rectangle(
                    (ox, oy), w, h,
                    edgecolor='black',
                    facecolor='black',
                    fill=True
                )
            )

        for (ox, oy, r) in self.obs_circle:
            self.ax.add_patch(
                patches.Circle(
                    (ox, oy), r,
                    edgecolor='black',
                    facecolor='gray',
                    fill=True
                )
            )

        plt.plot(self.s_start.x, self.s_start.y, "bs", linewidth=3)
        plt.plot(self.s_goal.x, self.s_goal.y, "gs", linewidth=3)
        plt.plot([self.s_start.x, self.s_goal.x], [self.s_start.y, self.s_goal.y], 'k--')

        plt.title(name)
        plt.axis("equal")

    @staticmethod
    def obs_circle():
        obs_cir = [
            [10, 10, 3],
            [15, 22, 3],
            [22, 8, 2.5],
            [26, 16, 2],
            [37, 10, 3],
            [37, 23, 3],
            [45, 15, 2]
        ]

        return obs_cir


def main():
    nRuns = 1
    nFail = 0
    LQcosts = np.zeros((2, nRuns))
    for tn in range(nRuns):
        sx, sy = 0., 2.5
        gx, gy = 0, 47.5
        goal_sample_rate = 0.000
        search_radius = 25.0
        step_len = 10.0
        #np.savetxt('data/Adv/dumb.csv', np.arange(10).reshape(10, 1), delimiter=',')
        #np.savetxt('images/Adv/dumb.csv', np.arange(10).reshape(10, 1), delimiter=',')
        iter_max = 2500
        vehicle_radius = 0.0
        fullPathX = []
        fullPathY = []
        fullTime = []
        fullU = []
        WX = []
        wx = 0.
        vNew = 0.
        Rrrtstar = RacerRRTStar(sx, sy, gx, gy, vehicle_radius, step_len,
                                 goal_sample_rate, search_radius, iter_max, vNew)
    
        kStep = 0.5
        nStep = int(np.round_((gy-sy)/kStep))
        print(kStep)
        print(nStep)
    
        for t in range(nStep-1):
            NX, NY, vNew, US, timeBool, _, _ = Rrrtstar.planning(sy+float(t+1)*kStep, plotpath=False)
            
            if NX is False and NY is False:
                nx = False 
                ny = False
                if vNew == 1:
                    # Then we hit the obstacle == not great
                    print('Hit the obstacle, add to failed trajectories')
                    nFail += 1
                    break
                    
            else:
                if timeBool is True:
                    nx = NX[-1]
                    ny = NY[-1]
                else:
                    NX.reverse()
                    NY.reverse()
                    nx = NX[-1]
                    ny = NY[-1]
            
            print(len(NX))
            for jj in range(len(NX)):
                fullTime.append(float(kStep)/float(len(NX)))
                fullU.append(US)
                WX.append(wx)
                

            fullPathX.append(NX)
            fullPathY.append(NY)
            plt.pause(2)
            print(nx, ny, vNew)
            if np.absolute(ny-gy)<0.025:
                break 
            # print(ny)
            # breakpoint()
            if nx is False:
                break
            else:
                doo = 1000
                xoo = 1000
                for (ox, oy, r) in Rrrtstar.obs_circle:
                    tmpDist = np.linalg.norm(np.array([ox - nx, oy - ny]))-r
                    tmpdx = np.absolute(ox - nx)
                    if (tmpdx < np.absolute(xoo)):
                        xoo = ox-nx
                    if tmpDist <= doo:
                        doo = tmpDist
    
                if 1:
                    # Zero disturbances
                    wx = 0.*(np.random.rand(1)-0.5)
                elif 0:
                    # Stochastic disturbances
                    wx = (np.random.rand(1)-0.5)
                else:
                    # Adversarial disturbances
                    timeEst = float(t)*kStep + sy
                    wx = 0.
                    if 1:
                        if timeEst>=8.4:
                            if (np.abs(float(nx))<=0.5):
                                wx = -float(nx)
                            else:
                                wx = 0.5*(-float(nx)/np.abs(float(nx)))
                    else:
                        # Sinusoid disturbances
                        tmp = timeEst % 50.
                        if 7.49 <= tmp and 13.51>= tmp:
                            wx = 0.5
                        elif 13.51< tmp and 18.51 >= tmp:
                            wx = -0.5
                        else:
                            wx = np.random.rand(1) - 0.5
                
                WX.append(wx)
                
                # X[t+1,:] = np.array([nx+wx, ny])
                Rrrtstar = RacerRRTStar(float(nx)+wx, ny, gx, gy, vehicle_radius, step_len,
                                        goal_sample_rate, search_radius, iter_max, float(vNew))
    
        # X[-1, :] = np.array([gx, gy])
        # print(X)
        bigLen = 0
        for i in range(len(fullPathX)):
            bigLen += len(fullPathX[i])
        
        X = np.zeros((bigLen, 2))    
        countLen = 0
        costLQR = 0.
        for i in range(len(fullPathX)):
            count0 = countLen
            countLen += len(fullPathX[i])
            X[count0:countLen,0] = np.array(fullPathX[i])
            X[count0:countLen,1] = np.array(fullPathY[i])
            for kk in range(count0, countLen):
                costLQR += (fullTime[kk]*(0.05*X[kk, 0]**2 + 0.025*float(fullU[kk])**2))
        
        costLQR = costLQR/(gy-sy)
        
        LQcosts[0, tn] = float(costLQR)
        LQcosts[1, tn] = 1.0
        fig3, ax3 = plt.subplots(1, 1)
        for (ox, oy, w, h) in Rrrtstar.obs_boundary:
            ax3.add_patch(
                patches.Rectangle(
                    (ox, oy), w, h,
                    edgecolor='black',
                    facecolor='black',
                    fill=True
                )
            )
    
        for (ox, oy, r) in Rrrtstar.obs_circle:
            ax3.add_patch(
                patches.Circle(
                    (ox, oy), r,
                    edgecolor='black',
                    facecolor='gray',
                    fill=True
                )
            )
    
        ax3.plot(X[0,0], X[0,1], "bs", linewidth=3)
        ax3.plot(gx, gy, "gs", linewidth=3)
        ax3.plot(np.array([0., 0.]), np.array([sy, gy]), 'k--')
        ax3.plot(X[:,0], X[:,1], 'r')
        ax3.plot(2.0*np.cos(np.arange(360)*3.1415916/180.0), 12.5+2.0*np.sin(np.arange(360)*3.1415916/180.0), 'b')
        ax3.set_xlim([-10., 10.])
        ax3.set_ylim([0., 50.])
        ax3.set_title("Actual Path")
        fig3.savefig('images/Adv/Path'+str(tn+1)+'.png')
        # ax3.show()
    
        np.savetxt('data/Adv/X'+str(tn+1)+'.csv', X, delimiter=',')
        np.savetxt('data/Adv/W'+str(tn+1)+'.csv', WX, delimiter=',')
        
    np.savetxt('data/Adv/Cost.csv', LQcosts, delimiter=',')
        

def main2():
    nRuns = 1
    nFail = 0
    LQcosts = np.zeros((2, nRuns))
    for tn in range(nRuns):
        sx, sy = 0., 0.5
        gx, gy = 0, 9.5
        goal_sample_rate = 0.005
        search_radius = 5.0
        step_len = 0.25
        #np.savetxt('data/Adv/dumb.csv', np.arange(10).reshape(10, 1), delimiter=',')
        #np.savetxt('images/Adv/dumb.csv', np.arange(10).reshape(10, 1), delimiter=',')
        iter_max = 2500
        vehicle_radius = 0.0
        fullPathX = []
        fullPathY = []
        fullTime = []
        fullU = []
        WX = []
        wx = 0.
        vNew = 0.
        Rrrtstar = RacerRRTStar(sx, sy, gx, gy, vehicle_radius, step_len,
                                 goal_sample_rate, search_radius, iter_max, vNew)
    
        kStep = 0.5
        nStep = 1 # int(np.round_((gy-sy)/kStep))
        print(kStep)
        print(nStep)
    
        #for t in range(nStep-1):
        nx0, ny0, vNew, US, timeBool, NX, NY = Rrrtstar.planning(sy+1.0*kStep, plotpath=False)
        
        TOT = np.concatenate((np.asarray(NX).reshape(-1, 1), np.asarray(NY).reshape(-1, 1)), axis=1)
        
        #np.savetxt('xWaypoints2.csv', 0.2*NX, delimiter=',')
        #np.savetxt('yWaypoints2.csv', 0.2*NY, delimiter=',')
        np.savetxt('Sampling_based_Planning/rrt_2D/waypoints.csv', TOT, delimiter=',')
        # np.savetxt('uWaypoints.csv', US, delimiter=',')

        fig3, ax3 = plt.subplots(1, 1)
        for (ox, oy, w, h) in Rrrtstar.obs_boundary:
            ax3.add_patch(
                patches.Rectangle(
                    (ox, oy), w, h,
                    edgecolor='black',
                    facecolor='black',
                    fill=True
                )
            )
    
        for (ox, oy, r) in Rrrtstar.obs_circle:
            ax3.add_patch(
                patches.Circle(
                    (ox, oy), r,
                    edgecolor='black',
                    facecolor='gray',
                    fill=True
                )
            )
        print(len(NX))
        print(len(NY))
        
        nx = np.asarray(NX)
        ny = np.asarray(NY)
        #print(nx)
        #print(ny)
        
        ax3.plot(NX[0], NY[0], "bs", linewidth=3)
        ax3.plot(gx, gy, "gs", linewidth=3)
        ax3.plot(np.array([0., 0.]), np.array([sy, gy]), 'k--')
        ax3.plot(np.asarray(NX)[:], np.asarray(NY)[:], 'r')
        # ax3.plot(2.0*np.cos(np.arange(360)*3.1415916/180.0), 12.5+2.0*np.sin(np.arange(360)*3.1415916/180.0), 'b')
        ax3.set_xlim([-6., 6.])
        ax3.set_ylim([-1., 11.])
        ax3.set_title("Actual Path")
        fig3.savefig('Sampling_based_Planning/rrt_2D/InitialPath.png')
        # ax3.show()
    
        #np.savetxt('data/Adv/X'+str(tn+1)+'.csv', X, delimiter=',')
        #np.savetxt('data/Adv/W'+str(tn+1)+'.csv', WX, delimiter=',')
        
    #np.savetxt('data/Adv/Cost.csv', LQcosts, delimiter=',')


def main3():
    nRuns = 1
    nFail = 0
    LQcosts = np.zeros((2, nRuns))
    for tn in range(nRuns):
        sx, sy = 0., 0.5
        gx, gy = 0, 5.5
        goal_sample_rate = 0.005
        search_radius = 5.0
        step_len = 0.25
        #np.savetxt('data/Adv/dumb.csv', np.arange(10).reshape(10, 1), delimiter=',')
        #np.savetxt('images/Adv/dumb.csv', np.arange(10).reshape(10, 1), delimiter=',')
        iter_max = 2500
        vehicle_radius = 0.0
        shortPathX = []
        shortPathY = []
        fullPathX = []
        fullPathY = []
        fullTime = []
        fullU = []
        WX = []
        fullV = []
        wx = 0.
        vNew = np.array((0., 1.))
        Rrrtstar = RacerRRTStar(sx, sy, gx, gy, vehicle_radius, step_len,
                                 goal_sample_rate, search_radius, iter_max, vNew)
        pReal_x = sx
        pReal_y = sy
        vNet = vNew
        kStep = 0.5
        nStep = 100
        print(kStep)
        print(nStep)
    
        for t in range(0, nStep-1):
            if t == 0:
                NX, NY, vNew, US, timeBool, _, _ = Rrrtstar.planning(1.0, plotpath=True)
            else:
                print()
                print('Planning Stage Active For pReal_x, pReal_y, vNet = ', pReal_x, pReal_y, vNet)
                print()
                # breakpoint()
                Rrrtstar = RacerRRTStar(pReal_x, pReal_y, gx, gy, vehicle_radius, step_len,
                                        goal_sample_rate, search_radius, iter_max, vNet)
                NX, NY, vNew, US, timeBool, _, _ = Rrrtstar.planning(pReal_y, plotpath=False, underway=False, s_underway_x=pReal_x, s_underway_y=pReal_y, s_underway_v=vNet)
                
            if NX is False and NY is False:
                nx = False 
                ny = False
                if vNew == 1:
                    # Then we hit the obstacle == not great
                    print('Hit the obstacle, add to failed trajectories')
                    nFail += 1
                
                break
                    
            else:
                if timeBool is True:
                    nx = NX[-1]
                    ny = NY[-1]
                else:
                    NX.reverse()
                    NY.reverse()
                    nx = NX[-1]
                    ny = NY[-1]
            
            print(len(NX))
            for jj in range(len(NX)):
                fullTime.append(float(kStep)/float(len(NX)))
                fullU.append(US)
                WX.append(wx)
            
            fullPathX.append(NX)
            fullPathY.append(NY)
            shortPathX.append(nx)
            shortPathY.append(ny)
            plt.pause(0.5)
            
            print(nx, ny, vNew)
            if np.absolute(ny-gy) + np.absolute(nx-gx) < 0.25:
                break 
            # print(ny)
            # breakpoint()
            if nx is False:
                break
            else:
                '''
                doo = 1000
                xoo = 1000
                for (ox, oy, r) in Rrrtstar.obs_circle:
                    tmpDist = np.linalg.norm(np.array([ox - nx, oy - ny]))-r
                    tmpdx = np.absolute(ox - nx)
                    if (tmpdx < np.absolute(xoo)):
                        xoo = ox-nx
                    if tmpDist <= doo:
                        doo = tmpDist
                '''
                if 0:
                    # Zero disturbances
                    wx = 0.*(np.random.rand(1)-0.5)
                elif 0:
                    # Stochastic disturbances
                    wx = (np.random.rand(1)-0.5)
                elif 1:
                    # Ice disturbances
                    ICE = 0
                    if nx >= Rrrtstar.ice_rect[0] and nx <= Rrrtstar.ice_rect[1]:
                        if ny >= Rrrtstar.ice_rect[2] and ny <= Rrrtstar.ice_rect[3]:
                            v0 = fullV[-1]
                            vNet = Rrrtstar.mu_ice*vNew + (1-Rrrtstar.mu_ice)*v0
                            dpReal = Rrrtstar.step_len*vNet
                            dpPlan = np.zeros(2)
                            dpPlan[0] = nx-NX[0]
                            dpPlan[1] = ny-NY[0]
                            wx = dpReal-dpPlan
                            pReal_x = nx + wx[0]
                            pReal_y = ny + wx[1]
                            ICE = 1
                            fullV.append(vNet)
                    if ICE == 0:
                        wx = np.zeros(2)
                        pReal_x = nx + wx[0]
                        pReal_y = ny + wx[1]
                        vNet = vNew
                        fullV.append(vNet)
                else:
                    # Adversarial disturbances
                    timeEst = float(t)*kStep + sy
                    wx = 0.
                    if 1:
                        if timeEst>=8.4:
                            if (np.abs(float(nx))<=0.5):
                                wx = -float(nx)
                            else:
                                wx = 0.5*(-float(nx)/np.abs(float(nx)))
                    else:
                        # Sinusoid disturbances
                        tmp = timeEst % 50.
                        if 7.49 <= tmp and 13.51>= tmp:
                            wx = 0.5
                        elif 13.51< tmp and 18.51 >= tmp:
                            wx = -0.5
                        else:
                            wx = np.random.rand(1) - 0.5
                
                WX.append(wx)
                
                # X[t+1,:] = np.array([nx+wx, ny])
                # Rrrtstar = RacerRRTStar(float(nx)+wx, ny, gx, gy, vehicle_radius, step_len,
                #                        goal_sample_rate, search_radius, iter_max, float(vNew))
        
        '''
        # X[-1, :] = np.array([gx, gy])
        # print(X)
        bigLen = 0
        for i in range(len(fullPathX)):
            bigLen += len(fullPathX[i])
        
        X = np.zeros((bigLen, 2))    
        countLen = 0
        costLQR = 0.
        for i in range(len(fullPathX)):
            count0 = countLen
            countLen += len(fullPathX[i])
            X[count0:countLen,0] = np.array(fullPathX[i])
            X[count0:countLen,1] = np.array(fullPathY[i])
            for kk in range(count0, countLen):
                costLQR += (fullTime[kk]*(0.05*X[kk, 0]**2 + 0.025*float(fullU[kk])**2))
        
        costLQR = costLQR/(gy-sy)
        
        LQcosts[0, tn] = float(costLQR)
        LQcosts[1, tn] = 1.0
        '''
        TOT = np.concatenate((np.asarray(fullPathX).reshape(-1, 1), np.asarray(fullPathY).reshape(-1, 1)), axis=1)
        
        #np.savetxt('xWaypoints2.csv', 0.2*NX, delimiter=',')
        #np.savetxt('yWaypoints2.csv', 0.2*NY, delimiter=',')
        np.savetxt('Sampling_based_Planning/rrt_2D/waypoints.csv', TOT, delimiter=',')
        
        fig3, ax3 = plt.subplots(1, 1)
        for (ox, oy, w, h) in Rrrtstar.ice_rectangle:
            ax3.add_patch(
                patches.Rectangle(
                    (ox, oy), w, h,
                    edgecolor='blue',
                    facecolor='blue',
                    fill=True
                )
            )
        for (ox, oy, w, h) in Rrrtstar.obs_boundary:
            ax3.add_patch(
                patches.Rectangle(
                    (ox, oy), w, h,
                    edgecolor='black',
                    facecolor='black',
                    fill=True
                )
            )
    
        for (ox, oy, r) in Rrrtstar.obs_circle:
            ax3.add_patch(
                patches.Circle(
                    (ox, oy), r,
                    edgecolor='black',
                    facecolor='gray',
                    fill=True
                )
            )
    
        ax3.plot(fullPathX[0], fullPathY[0], "bs", linewidth=3)
        ax3.plot(gx, gy, "gs", linewidth=3)
        ax3.plot(np.array([0., 0.]), np.array([sy, gy]), 'k--')
        ax3.plot(np.asarray(shortPathX)[:], np.asarray(shortPathY)[:], 'r', linewidth=1)
        # ax3.plot(2.0*np.cos(np.arange(360)*3.1415916/180.0), 12.5+2.0*np.sin(np.arange(360)*3.1415916/180.0), 'b')
        ax3.set_xlim([-6., 6.])
        ax3.set_ylim([-1., 11.])
        ax3.set_title("Actual Path")
        fig3.savefig('Sampling_based_Planning/rrt_2D/IteratedPath.png')
        # ax3.show()
    
        #np.savetxt('data/Adv/X'+str(tn+1)+'.csv', X, delimiter=',')
        #np.savetxt('data/Adv/W'+str(tn+1)+'.csv', WX, delimiter=',')
        
    #np.savetxt('data/Adv/Cost.csv', LQcosts, delimiter=',')


if __name__ == '__main__':
    main3()
