import torch
import numpy as np
from torch import optim 
import pdb
from scipy.interpolate import interpn

def find_nearest(array, value):
    return (np.abs(array - value)).argmin()

class DubinsCar():
    def __init__(self, margin = 0.5):
        self.dt = 0.05
        self.n_state = 3 # state = [x, y, theta]
        self.n_obs = 4   # pass obs = [x, y, sin(theta), cos(theta)] to agent
        self.n_ctrl = 1  # action = [delta]
        self.v = 1   # constant speed
        
        self.u_max = 1  # action
        self.u_min = -1
        
        self.x = np.zeros(self.n_state)
        
        ## Use Jacobian to determine optimal safety action
        self.J_theta = np.load('envs/J_theta.npy')
        # self.J_theta = self.J_theta[::6, ::6, ::6]
        ## For indexing the Jacobian
        nx, ny, ntheta = self.J_theta.shape
        self.x_axis = np.linspace(-3.0, 3.0, nx)
        self.y_axis = np.linspace(-3.0, 3.0, ny)
        self.theta_axis = np.linspace(-np.pi, np.pi, ntheta)
        #self.grid = [x_axis, y_axis, theta_axis]
        self.margin = margin
        
    def lx(self, state):
        return np.linalg.norm(state[..., :2], axis = -1) - 1

    def reset(self):
        self.x = np.random.uniform([-3, -3, -np.pi], [3, 3, np.pi])
        obs = np.array([self.x[0], self.x[1], np.sin(self.x[2]), np.cos(self.x[2])])
        return obs

    def step(self, u, x_init=None):
        if x_init == None:
            x_init = self.x
        x = x_init[..., 0]
        y = x_init[..., 1]
        theta = x_init[..., 2]

        x_dot = np.zeros_like(x_init)
        x_dot[..., 0] = self.v * np.cos(theta)
        x_dot[..., 1] = self.v * np.sin(theta)
        x_dot[..., 2] = u
        self.x = x_init + self.dt * x_dot
        
        ## Make sure angle is in correct range
        self.x[..., 2][self.x[..., 2]>np.pi] -= 2*np.pi
        self.x[..., 2][self.x[..., 2]<-np.pi] += 2*np.pi
        if (np.linalg.norm(self.x[..., :2], axis = -1)<1-self.margin) | np.any(np.abs(self.x[..., :2])>3.2):
            ## Already reached the circle | Too far from the area of interest
            is_terminal = True
            if (np.linalg.norm(self.x[..., :2], axis = -1)<1-self.margin):
                print("Reached Circle")
            elif np.any(np.abs(self.x[..., :2])>3.2):
                print("Out of Bound")
        else:
            is_terminal = False
        #pdb.set_trace()
        obs = np.array([self.x[0], self.x[1], np.sin(self.x[2]), np.cos(self.x[2])])
        
        ## Convention: C(s)=1 if reach goal else 0
        if (np.linalg.norm(self.x[..., :2], axis = -1)<1):
            c = 1
        else:
            c = 0
        return obs, c, is_terminal

    def safe_policy(self, obs):
        ## obs = [x, y, sin(theta), cos(theta)]
        optimal_u = np.zeros((len(obs), self.n_ctrl))
        for i, (x, y, sin, cos) in enumerate(zip(obs[..., 0], obs[..., 1], obs[..., 2], obs[..., 3])):
            if cos > 0:
                theta = np.arcsin(sin) ## in [-pi/2, pi/2]
            elif sin > 0:
                theta = np.arccos(cos) ## in [pi/2, pi]
            else:
                theta = -np.arccos(cos) ## in [-pi, -pi/2]
            
            x_idx = find_nearest(self.x_axis, x)
            y_idx = find_nearest(self.y_axis, y)
            theta_idx = find_nearest(self.theta_axis, theta)
            J_theta = self.J_theta[x_idx, y_idx, theta_idx]
 
            if J_theta > 0:
                optimal_u[i] = self.u_min
            elif J_theta < 0:
                optimal_u[i] = self.u_max
            else:
                optimal_u[i] = np.random.uniform(self.u_min, self.u_max)
        return optimal_u
