import torch
import numpy as np
from torch import optim
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

def find_nearest(array, value):
    return (np.abs(array - value)).argmin()
    
def VisualizeDI(value, fig, ax, saveName = None):
    '''
     Visualize the value function
    '''
    ## Plot the ground truth
    v1 = np.linspace(-2, 0, 21)
    v2 = np.linspace(0, 2, 21)
    ax.plot(v1, [-1 + v**2/2 for v in v1], 'k')
    ax.plot(v2, [ 1 - v**2/2 for v in v2], 'k')
    ax.plot([-2, 0], [1, 1], 'k')
    ax.plot([0, 2], [-1, -1], 'k')

    nv, nx = value.shape
    v_axis = np.linspace(-2.0, 2.0, nv)
    x_axis = np.linspace(-1.0, 1.0, nx)

    c= ax.imshow(value.T, extent=(-2, 2, -1, 1), vmin = -0.2, vmax = 1, aspect = 'auto', origin = 'lower')
    fig.colorbar(c)

    ax.set_xlim((-2.01, 2.01))
    ax.set_ylim((-1, 1.01))
    ax.set_xlabel('v', fontsize = 16)
    ax.set_ylabel('x', fontsize = 16)

    if saveName is not None:
        with PdfPages(saveName) as pdf:
            pdf.savefig(fig, bbox_inches = 'tight')
        print('Saved Image')
    plt.close()
    
class DoubleIntegrator():
    def __init__(self, margin = 0.4):
        self.dt = 0.1
        self.n_state = 2 # state = [x, v]
        self.n_obs = self.n_state
        self.n_ctrl = 1  # action = [a]
        self.x_max = 1
        self.x_min = -1        
        self.u_max = 1  # action
        self.u_min = -1

        self.x = np.zeros(self.n_state)
        self.action_space = np.linspace(-1, 1, 5)
        
        ## Use jacobian to determine optimal safety action
        self.Jv = np.load('envs/Jv.npy')
        ## For indexing the Jacobian
        nv, nx = self.Jv.shape
        self.v_axis = np.linspace(-2.0, 2.0, nv)
        self.x_axis = np.linspace(-1.0, 1.0, nx)
        
        self.margin = margin

    def lx(self, state):
        x = state[..., 0]
        return np.minimum(1-x, x+1)

    def reset(self):
        self.x = np.random.uniform([-1, -2], [1, 2])
        return self.x#, 0, False

    def step(self, u, x_init=None):
        if x_init == None:
            x_init = self.x
        x = x_init[..., 0]
        v = x_init[..., 1]

        x_dot = np.zeros_like(x_init)
        x_dot[..., 0] = v
        x_dot[..., 1] = u
        self.x = x_init + self.dt * x_dot
        if (self.x[..., 0]>self.x_max+self.margin) | (self.x[..., 0]<self.x_min-self.margin):
            is_terminal = True
        else:
            is_terminal = False
            
        ## Convention of CSC, C(s)=1 if unsafe else 0
        if (self.x[..., 0]>self.x_max) | (self.x[..., 0]<self.x_min):
            c = 1
        else:
            c = 0
        return self.x, c, is_terminal

    def safe_policy(self, state):
        optimal_u = np.zeros((len(state), self.n_ctrl))
        for i, (v, x) in enumerate(zip(state[..., 0], state[..., 1])):
            v_idx = find_nearest(self.v_axis, v)
            x_idx = find_nearest(self.x_axis, x)
            
            Jv = self.Jv[v_idx, x_idx]
            if Jv > 0:
                optimal_u[i] = self.u_max
            elif Jv < 0:
                optimal_u[i] = self.u_min
            else:
                optimal_u[i] = np.random.uniform(self.u_min, self.u_max)
        return optimal_u
