import numpy as np
from math import *
import torch
from gym.spaces import Box
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('tkagg')

class motion_planning_env():
    def __init__(self,x_lim,y_lim,initial_state,goal,goal_radius,is_goal_circle=False): #initial_state and goal are both lists
        self.x_lim=x_lim
        self.y_lim=y_lim
        self.initial_state=initial_state
        self.observation_space=Box(low=np.array([0.0,0.0]),high=np.array([x_lim,y_lim]),dtype=np.float64)
        self.action_space=Box(low=np.array([-1.0,-1.0]),high=np.array([1.0,1.0]),dtype=np.float64)
        self.interval=0.1
        self.goal=goal
        self.goal_radius=goal_radius
        self.is_goal_circle=is_goal_circle

    def seed(self,seed):
        np.random.seed(seed)
        torch.manual_seed(seed)

    def reset(self, state=None):
        if state is None: 
            x= np.random.uniform(0,self.x_lim)
            y= np.random.uniform(0,self.y_lim)
            self.state=np.array([x,y])
            return self.state
        else:
            self.state=state
            return self.state

    def is_in_goal(self,state=None):
        if self.is_goal_circle:
            if state is None:
                if sqrt((self.state[0]-self.goal[0])**2+(self.state[1]-self.goal[1])**2)<=self.goal_radius:
                    return 1
            else:
                if sqrt((state[0]-self.goal[0])**2+(state[1]-self.goal[1])**2)<=self.goal_radius:
                    return 1
        else:
            if state is None:
                if abs(self.state[0]-self.goal[0])<=self.goal_radius and abs(self.state[1]-self.goal[1])<=self.goal_radius:
                    return 1
            else: 
                if abs(state[0]-self.goal[0])<=self.goal_radius and abs(state[1]-self.goal[1])<=self.goal_radius:
                    return 1
        return 0

    def is_done(self):
        #if self.state[0]<=0.0 or self.state[0]>=self.x_lim or self.state[1]<=0.0 or self.state[1]>=self.y_lim:
        #    return 1
        #if self.is_in_goal():
        #    return 1
        return 0           

    def reward(self):
        if self.is_in_goal():
            return 100.0
        else:
            return 0.0

    def step(self, action):
        x=self.interval*(action[0]/sqrt(action[0]**2+action[1]**2))+self.state[0]
        y=self.interval*(action[1]/sqrt(action[0]**2+action[1]**2))+self.state[1]
        if x >= 0.0 and x <= self.x_lim and y >= 0.0 and y <= self.y_lim:
            self.state=np.array([x,y])        
        return self.state, self.reward(), self.is_done(), 0.0

    def draw_trajectory(self, trajectory, trajectory_a=None, iteration=None, learned_reward_function=None):
        fig,ax=plt.subplots()
        ax.axis('scaled')
        ax.axis([0,10*self.x_lim,0,10*self.y_lim])
        reward_map=np.zeros((10*self.y_lim,10*self.x_lim))
        if learned_reward_function is None:            
            for x in range(int(10*self.x_lim)):
                for y in range(int(10*self.y_lim)):
                    #if x>=int(10*(self.initial_state[0]-self.goal_radius)) and x<=int(10*(self.initial_state[0]+self.goal_radius)) and y>=int(10*(self.initial_state[1]-self.goal_radius)) and y<=int(10*(self.initial_state[1]+self.goal_radius)):
                    if self.is_in_goal([0.1*x,0.1*y]):
                        reward_map[y,x]=1.0
        else:
            for x in range(int(10*self.x_lim)):
                for y in range(int(10*self.y_lim)):
                    reward_map[y,x]=learned_reward_function(torch.tensor([0.1*x,0.1*y])).item()
        max_value=np.max(reward_map)
        min_value=np.min(reward_map)
        reward_map=(reward_map-min_value)/(max_value-min_value)
        plt.imshow(reward_map, cmap='viridis',origin='lower',extent=[0,10*int(self.x_lim),0,10*int(self.y_lim)])
        #plt.colorbar()
        plt.xticks([],[])
        plt.yticks([],[])
        x=[]
        y=[]
        for i in range(len(trajectory)):
            x.append(10*trajectory[i][0].item())
            y.append(10*trajectory[i][1].item())
        plt.plot(x,y,'tab:blue',linewidth=3)
        #if trajectory_a is not None:
        #    for i in [38,42,46,88]:
        #        circle = plt.Circle((10*trajectory[i][0].item(),10*trajectory[i][1].item()),1,color='r',zorder=3)
        #        arrow = plt.arrow(10*trajectory[i][0].item(),10*trajectory[i][1].item(),5*trajectory_a[i][0].item()/sqrt((trajectory_a[i][0].item())**2+(trajectory_a[i][1].item())**2),5*trajectory_a[i][1].item()/sqrt((trajectory_a[i][0].item())**2+(trajectory_a[i][1].item())**2),color='r',zorder=3,width=0.5)
        #        ax.add_patch(circle)
        #        ax.add_patch(arrow)
        #    circle = plt.Circle((10*trajectory[104][0].item(),10*trajectory[104][1].item()),1,color='r',zorder=3)
        #    arrow = plt.arrow(10*trajectory[104][0].item(),10*trajectory[104][1].item(),-5,0,color='r',zorder=3,width=0.5)
        #    ax.add_patch(circle)
        #    ax.add_patch(arrow)
        if iteration is None:
            plt.savefig('reward.pdf',bbox_inches = 'tight',pad_inches = 0)
        else:
            plt.savefig('reward'+str(iteration)+'.pdf',bbox_inches = 'tight',pad_inches = 0)




















