from gym.utils import EzPickle
from gym import spaces
from gym import Env
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import torch

class PointDynamics(object):
    """
    State: position.
    Action: velocity.
    """
    def __init__(self, dim, sigma):
        self.dim = dim
        self.sigma = sigma
        self.s_dim = dim
        self.a_dim = dim

    def forward(self, state, action):
        mu_next = state + action
        state_next = mu_next + self.sigma * np.random.normal(size=self.s_dim)
        return state_next


class MultiGoalEnv(Env, EzPickle): 
    """
    Move a 2D point mass to one of the goal positions. Cost is the distance to
    the closest goal.

    State: position.
    Action: velocity.
    """
    def __init__(self, goal_reward=10, actuation_cost_coeff=30.0, distance_cost_coeff=1.0, init_sigma=0.05, max_steps=None, plot_format=None, env_name=None):
        EzPickle.__init__(**locals())

        self.dynamics = PointDynamics(dim=2, sigma=0)
        self.init_mu = np.zeros(2, dtype=np.float32)
        self.init_sigma = init_sigma
        self.max_steps = max_steps
        # goal
        self.goal_positions = np.array(((5, 0),(-5, 0),(0, 5),(0, -5)), dtype=np.float32)
        self.goal_names_plotting = np.array(['$G_1$', '$G_2$', '$G_3$', '$G_4$'])
        self.num_goals = len(self.goal_positions)
        self.goal_threshold = 0.5 #1.0
        self.goal_reward = goal_reward
        # reward
        self.action_cost_coeff = actuation_cost_coeff
        self.distance_cost_coeff = distance_cost_coeff
        # plotting
        self.xlim = (-7, 7)
        self.ylim = (-7, 7)
        self.vel_bound = 1.
        
        # logging
        self.episode_observations = [] 
        self.all_episode_observations = [] 
        self.ep_len = 0

        # Plotter params, to be cleaned tomorrow. 
        self.entropy_obs_names = np.array(['a', 'b', 'c', 'd', 'e', 'f'])
        self.entropy_obs_names_plotting = np.array(['$s_a$', '$s_b$', '$s_c$', '$s_d$', '$s_e$', '$s_f$'])

        self._obs_lst = np.array([
            [0,0],
            [1, 0],
            [2, 0],
            [2.5,2.5],
            [4, 0],
            [2.5,-2.5],
        ])
        self._n_samples = 10
        self.n_plots = len(self._obs_lst)
        self.x_size = (1.5 * self.n_plots + 1)
        self.y_size = 17 
        self.agent_failure = None
        self.plot_format = plot_format
        self.env_name = env_name
        self.intersection = False

        
    def reset(self, init_state=None):
        if init_state:
            unclipped_observation = init_state
        else: 
            unclipped_observation = self.init_mu

        self.observation = np.clip(unclipped_observation, self.observation_space.low, self.observation_space.high)
        self.all_episode_observations.append([list(self.observation)])
        self.ep_len = 0
        return self.observation

    @property
    def observation_space(self):
        return spaces.Box(
            low=np.array((self.xlim[0], self.ylim[0])),
            high=np.array((self.xlim[1], self.ylim[1])),
            dtype=np.float32,
            shape=(2,))

    @property
    def action_space(self):
        return spaces.Box(
            low=-self.vel_bound,
            high=self.vel_bound,
            shape=(self.dynamics.a_dim, ),
            dtype=np.float32)


    def step(self, action):
        # compute action
        action = action.ravel()
        action = np.clip(action, self.action_space.low, self.action_space.high).ravel()

        # compute observation
        self.observation = self.dynamics.forward(self.observation, action)
        self.observation = np.clip(self.observation, self.observation_space.low,self.observation_space.high)
        self.ep_len += 1

        # compute reward
        reward = self.compute_reward(self.observation, action)

        # compute done
        distance_to_goals = [np.linalg.norm(self.observation - goal_position)for goal_position in self.goal_positions]
        self.min_dist_index = np.argmin(distance_to_goals)
        done = distance_to_goals[self.min_dist_index] < self.goal_threshold
        
        # reward at the goal
        if done:
            if self.env_name == 'test_env':
                self.number_of_hits_mode_acc[self.min_dist_index] += 1
            reward += self.goal_reward

        if done or self.ep_len == self.max_steps:
            self.episode_observations.append(self.observation)
        self.all_episode_observations[-1].append(list(self.observation))
        return self.observation, reward, done, None
    

    def compute_reward(self, observation, action): 
        # penalize the L2 norm of acceleration
        # noinspection PyTypeChecker
        action_cost = np.sum(action ** 2) * self.action_cost_coeff

        # penalize squared dist to goal
        cur_position = self.observation

        # noinspection PyTypeChecker
        goal_cost = self.distance_cost_coeff * np.amin([
            np.sum((cur_position - goal_position) ** 2)
            for goal_position in self.goal_positions
        ])

        # penalize staying with the log barriers
        costs = [action_cost, goal_cost]
        reward = -np.sum(costs)
        return reward
    
    def reset_rendering(self):
        self.episode_observations = []
        self.all_episode_observations = []
        self.number_of_hits_mode = np.zeros(self.num_goals)
        self.number_of_hits_mode_acc = np.zeros(self.num_goals)

    
    def render(self, itr, fig_path, ac=None, paths=None):
        paths = self.all_episode_observations
        self._init_plot(self.x_size, self.y_size)
        for a in range(len(self.all_episode_observations)):
            positions = np.array(paths[a])
            self._ax_lst[0].plot(positions[:, 0], positions[:, 1], color='blue')
        

        if ac.pi.actor!='svgd_sql':
            if not ac.pi.actor=='sac':
                num_particles_tmp = ac.pi.num_particles
                ac.pi.num_particles = self._n_samples
                ac.pi.Kernel.num_particles = ac.pi.num_particles
                ac.pi.identity = torch.eye(ac.pi.num_particles).to(ac.pi.device)

            self.entropy_list = []
            for i in range(len(self._obs_lst)):
                self.entropy_tmp = []
                o = torch.as_tensor(self._obs_lst[i], dtype=torch.float32).to(ac.pi.device).view(-1,1,self.observation_space.shape[0]).repeat(1,ac.pi.num_particles,1).view(-1,self.observation_space.shape[0])
                for _ in range(100):
                    a, log_p = ac(o, action_selection=ac.pi.test_action_selection, with_logprob=True)
                    self.entropy_tmp.append(round(-log_p.detach().item(), 2))
                self.entropy_list.append(round(np.array(self.entropy_tmp).mean(), 2))
            
            
            if not ac.pi.actor=='sac':
                ac.pi.num_particles = num_particles_tmp
                ac.pi.Kernel.num_particles = ac.pi.num_particles
                ac.pi.identity = torch.eye(ac.pi.num_particles).to(ac.pi.device)


            for i in range(len(self._obs_lst)):
                self._ax_lst[0].scatter(self._obs_lst[i, 0], self._obs_lst[i, 1], c='#003f40', marker='x', s=50, zorder=2)
                self._ax_lst[0].annotate(self.entropy_obs_names_plotting[i], (self._obs_lst[i,0] - 0.25, self._obs_lst[i,1] + 0.2), fontsize=18, color='#003f40', zorder=2)
        

        self._plot_level_curves(self._obs_lst, ac)
        self._plot_action_samples(ac)
        plt.plot()
        plt.savefig(fig_path+ '/env_' + str(itr) + '.' + 'png')   
        plt.close()

        modes_dist = (((positions).reshape(-1,1,2) - np.expand_dims(self.goal_positions,0))**2).sum(-1)
        ind = modes_dist[np.where(modes_dist.min(-1)<1)[0]].argmin(-1)
        self.number_of_hits_mode[ind]+=1
    
    
    def _init_plot(self, x_size, y_size, grid_size=(5,3), debugging=False):
        self._ax_lst = []
        self.fig_env = plt.figure(figsize=(x_size, y_size), constrained_layout=True) 
        self._ax_lst.append(plt.subplot2grid(grid_size, (0,0), colspan=3, rowspan=3))
        self._ax_lst[0].axis('equal')
        self._ax_lst[0].set_xlim(self.xlim)
        self._ax_lst[0].set_ylim(self.xlim)
        self._ax_lst[0].set_title('Multigoal Environment', fontsize=20)

        self._ax_lst[0].xaxis.set_tick_params(labelsize=15)
        self._ax_lst[0].yaxis.set_tick_params(labelsize=15)
        x_min, x_max = tuple(1.1 * np.array(self.xlim))
        y_min, y_max = tuple(1.1 * np.array(self.ylim))
        X, Y = np.meshgrid(
            np.arange(x_min, x_max, 0.01),
            np.arange(y_min, y_max, 0.01)
        )
        goal_costs = np.amin([
            (X - goal_x) ** 2 + (Y - goal_y) ** 2
            for goal_x, goal_y in self.goal_positions
        ], axis=0)
        costs = goal_costs
        contours = self._ax_lst[0].contour(X, Y, costs, 20)
        self._ax_lst[0].clabel(contours, inline=1, fontsize=10, fmt='%.0f')
        self._ax_lst[0].set_xlim([x_min, x_max])
        self._ax_lst[0].set_ylim([y_min, y_max])
        self._ax_lst[0].plot(self.goal_positions[:, 0], self.goal_positions[:, 1], 'ro')


        goal_names_positions = np.array([[0.5, -0.2], [-1.2, -0.2], [-0.4, 0.5], [-0.4, -1]]) + self.goal_positions
        for i in range(len(self.goal_positions)):
                self._ax_lst[0].annotate(self.goal_names_plotting[i], goal_names_positions[i], fontsize=19, color='red', zorder=2)


        if not debugging:
            for i in range(self.n_plots):
                ax = plt.subplot2grid(grid_size, (3 + i//3,i%3))
                ax.set_xlim((-1, 1))
                ax.set_ylim((-1, 1))
                ax.xaxis.set_tick_params(labelsize=15)
                ax.yaxis.set_tick_params(labelsize=15)
                ax.grid(True)
                self._ax_lst.append(ax)
            self._line_objects = list()
        return self._ax_lst[0]

    def _plot_level_curves(self, _obs_lst, ac):
        # Create mesh grid.
        xs = np.linspace(-1, 1, 50)
        ys = np.linspace(-1, 1, 50)
        xgrid, ygrid = np.meshgrid(xs, ys)
        a = np.concatenate((np.expand_dims(xgrid.ravel(), -1), np.expand_dims(ygrid.ravel(), -1)), -1)
        a = torch.from_numpy(a.astype(np.float32)).to(ac.pi.device)
        for i in range(len(_obs_lst)):
            o = torch.Tensor(_obs_lst[i]).repeat([a.shape[0],1]).to(ac.pi.device)
            with torch.no_grad():
                qs = ac.q1(o.to(ac.pi.device), a).cpu().detach().numpy()
            qs = qs.reshape(xgrid.shape)
            cs = self._ax_lst[i+1].contour(xgrid, ygrid, qs, 20)
            self._line_objects += cs.collections
            self._line_objects += self._ax_lst[i+1].clabel(
                cs, inline=1, fontsize=10, fmt='%.2f')

    def _plot_action_samples(self, ac):
        if ac.pi.actor in  ['svgd_nonparam', 'svgd_p0_param']:
            num_particles_tmp = ac.pi.num_particles
            ac.pi.num_particles = self._n_samples
            ac.pi.Kernel.num_particles = ac.pi.num_particles
            ac.pi.identity = torch.eye(ac.pi.num_particles).to(ac.pi.device)
        for i in range(len(self._obs_lst)):
            if ac.pi.actor in  ['svgd_nonparam', 'svgd_p0_param', 'svgd_sql']:
                
                o = torch.as_tensor(self._obs_lst[i], dtype=torch.float32).view(-1,1,self.observation_space.shape[0]).repeat(1,ac.pi.num_particles,1).view(-1,self.observation_space.shape[0]).to(ac.pi.device)
            else:
                o = torch.as_tensor(self._obs_lst[i], dtype=torch.float32).repeat([self._n_samples,1]).to(ac.pi.device)
            actions, _ = ac(o, action_selection=None, with_logprob=False)
            actions = actions.cpu().detach().numpy().squeeze()
            
            x, y = actions[:, 0], actions[:, 1]
            if ac.pi.actor not in  ['svgd_sql']:
                self._ax_lst[i+1].set_title('Entr(' + self.entropy_obs_names_plotting[i] + ')=' + str(self.entropy_list[i]), fontsize=15)
            self._line_objects += self._ax_lst[i+1].plot(x, y, 'b*')
        if ac.pi.actor in  ['svgd_nonparam', 'svgd_p0_param']:
            ac.pi.num_particles = num_particles_tmp
            ac.pi.Kernel.num_particles = ac.pi.num_particles
            ac.pi.identity = torch.eye(ac.pi.num_particles).to(ac.pi.device)