import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
import copy

import matplotlib.pyplot as plt

import warnings

class stateHistory():
    
    def __init__(self):
        self.prev_states = None
        self.states = None
        self.actions = None
        self.goals = None

    def add(self, prev_state, current_action, current_goal, new_state):
        if self.prev_states is None: 
            self.prev_states = prev_state.copy()
        else: 
            self.prev_states = np.vstack((self.prev_states,prev_state))

        if self.actions is None: 
            self.actions = current_action.copy()
        else: 
            self.actions = np.vstack((self.actions,current_action))

        if self.goals is None: 
            self.goals = current_goal.copy()
        else:
            self.goals = np.vstack((self.goals,current_goal))

        if self.states is None:
            self.states = new_state.copy()
        else:
            self.states = np.vstack((self.states,new_state))

    def get_all(self):
        return self.prev_states, self.states, self.actions, self.goals

    def reset(self):
        self.prev_states = None
        self.states = None
        self.actions = None
        self.goals = None


def close_plot():
    plt.close()

class PathFollowBaseEnv(gym.Env):
    """
    A simple environment for path following where the path can be programattical defined
    """
    def __init__(self, phase_step=10/180*np.pi, shape_gen=None, max_steps=36, seed=None, init_choices=False, include_goal=False, **kwargs):
        """
        Inputs:
        -------
        phase_step : step size for each step in an episode
        shape_gen: object defining the desired shape with a state-phase mapping.
                   - Compnents should include:
                     .get_state(phase): returns a 2D np array [x,y] appropriate for the input phase
                     .get_bounds(): get the min and max components of state space as a pair of np arrays [xmin, ymin], [xmax, ymax]
                     .sample_demo_phase(): get a phase corresponding to a region with demonstration (if one exists) - else returns None
                     .get_RL_reward(phase, state): returns a scalar reward given an input state - here, rewards are tied to the shape function
                     .get_goal_and_reward(phase, state): combines get_state and get_RL_reward
                     .set_rand_seed(seed): set the random seed
        max_steps: max episode steps
        render: render flag (bool)
        seed: random seed
        rsi_crit: rate of setting initial state close to output from 
        early_termination: terminate if too much deviation? (bool)
        early_termination_crit: max heading deviation from where it should be
        """
        self.max_steps = max_steps
        self.direction = 1
        self.phase_step = phase_step
        self.shape_gen = copy.deepcopy(shape_gen)

        self.init_choices = init_choices
        self.include_goal = include_goal
        
        self.seed(seed)
        self.shape_gen.set_rand_seed(seed)

        self.state_obs_low, self.state_obs_high = self.shape_gen.get_bounds()
        self.norm_factor = self.state_obs_high - self.state_obs_low
        self.max_act = (self.state_obs_high - self.state_obs_low)*0.1

        if self.include_goal:
            self.observation_space = spaces.Box(low=np.concatenate((self.state_obs_low, [0, -1.], self.state_obs_low)), high=np.concatenate((self.state_obs_high, [1., 1.], self.state_obs_high)))
        else:
            self.observation_space = spaces.Box(low=np.concatenate((self.state_obs_low, [0, -1.])), high=np.concatenate((self.state_obs_high, [1., 1.])))
        self.action_space = spaces.Box(low=-self.max_act, high=self.max_act)

        self.current_step = 0
        self.current_phase = 0.
        self.current_state = np.array([0., 0.])
        self.state_history = stateHistory()
        self.done = False

        self.fig = None
        self.ax1 = None
        self.a_graph = None
        self.g_graph = None

    def normStateObs(self):
        """ 
        Normalize state observation
        Outpus a 3D vector -> [x, y, phi] representing the current (x,y) coordinate of the agent and the current path phase
        Outputs are on the range [0,1]
        """
        normalized_state = (self.current_state - self.state_obs_low)/self.norm_factor
        if self.include_goal:
            goal_state, _ = self.shape_gen.get_state(self.current_phase)
            return np.concatenate((normalized_state, [np.cos(self.current_phase), np.sin(self.current_phase)], goal_state))
        else:
            return np.concatenate((normalized_state, [np.cos(self.current_phase), np.sin(self.current_phase)]))

    def step(self, action):
        """
        Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` to reset this environment's state.
        Accepts an action and returns a tuple (observation, reward, done, info).
        Args:
            action (object): an action provided by the agent
        Returns:
            observation (object): agent's observation of the current environment
            reward (float) : amount of reward returned after previous action
            done (bool): whether the episode has ended, in which case further step() calls will return undefined results
            info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)
        """
        if self.done:
            warnings.warn('Termination criteria was reached but episode is still running without reset')
        
        action = np.clip(action, -self.max_act, self.max_act)
        previous_state = self.current_state.copy()
        self.current_state += action
        
        self.current_state = np.clip(self.current_state, self.state_obs_low, self.state_obs_high)

        # print(self.current_phase/np.pi * 180)
        current_desired, reward, _ = self.shape_gen.get_goal_and_reward(self.current_phase, self.current_state)
        error = current_desired - self.current_state

        true_desired, _ = self.shape_gen.get_true_state(self.current_phase)
        
        # self.state_history.add(previous_state, action, current_desired, self.current_state)
        self.state_history.add(previous_state, action, true_desired, self.current_state)


        # If revolution complete, restart
        self.current_phase += self.direction * self.phase_step
        if self.current_phase >= 2*np.pi:
            self.current_phase -= 2*np.pi

        self.current_step += 1
        if self.current_step >= self.max_steps:
            self.done = True

        return self.normStateObs(), reward, self.done, {'prevState': previous_state, 'state':self.current_state, 'goalState':current_desired, 'actionTaken':action}

    def reset(self):
        """
        Resets the environment to an initial state and returns an initial observation.
        """
        self.state_history.reset()
        self.done = False
        self.current_step = 0
        # plt.ioff()
        plt.close(self.fig)
        self.fig, self.ax1, self.a_graph, self.g_graph = None, None, None, None

        init_phase = None
        init_phase = 0.

        if init_phase is None:
            if self.init_choices:
                init_options = np.arange(0, 1.+1e-5, 1./self.max_steps)
                self.current_phase = self.np_random.choice(init_options) * 2*np.pi
            else:
                self.current_phase = self.np_random.random() * 2*np.pi
        else:
            self.current_phase = init_phase
        
        self.current_state, _ = self.shape_gen.get_true_state(self.current_phase) 
        self.current_state += (self.np_random.random(2)-0.5)*2e-3 # Start near the desired state with small perturbation
        
        # self.current_state = self.np_random.random(2) * (self.state_obs_high - self.state_obs_low) # Start randomly in state space
        """ 
        ELSE TRY STARTING FROM SAME SET OF STATES
        """

        self.current_state = np.clip(self.current_state, self.state_obs_low, self.state_obs_high)

        self.terminate_count = 0

        return self.normStateObs()

    def close(self):
        """
        Closes out pending figures in the environment.
        """
        if self.fig:
            plt.close(self.fig)


    def seed(self, seed=None):
        """ 
        Seeds the environment
        """
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def set_shape_seed(self, seed=None):
        """
        Seed the shape generator with external command
        """
        s_ = self.shape_gen.set_rand_seed(seed)
        return [s_]

    def render(self, mode='human', **kwargs):
        """ Renders the environment """
        # For alternatiove approach: timed frame render code available from https://stackoverflow.com/a/30365738

        # Create a pyplot figure and hold it and update on call 
        if (self.fig is None):
            plt.ion()
            self.fig, self.ax1 = plt.subplots()

        _, states, _, goals = self.state_history.get_all()
        if (self.a_graph is None) or (self.g_graph is None):
            if (states is not None) and (states.shape[0] != 2): # (first step is ignored for convenience)
                self.a_graph = self.ax1.plot(states[:,0], states[:,1], '^b-', label='agent')[0]
                self.g_graph = self.ax1.plot(goals[:,0], goals[:,1], 'xg-', label='desired')[0]

                self.ax1.legend()

                self.ax1.set_xlim(self.state_obs_low[0]-0.1*self.norm_factor[0], self.state_obs_high[0]+0.1*self.norm_factor[0])
                self.ax1.set_ylim(self.state_obs_low[1]-0.1*self.norm_factor[1], self.state_obs_high[1]+0.1*self.norm_factor[1])

        else:
            self.a_graph.set_xdata(states[:,0])
            self.a_graph.set_ydata(states[:,1])

            self.g_graph.set_xdata(goals[:,0])
            self.g_graph.set_ydata(goals[:,1])

        plt.draw()
        plt.pause(3./self.max_steps)

    def get_expert_action(self):
        next_phase = self.current_phase + self.phase_step
        next_desired, _ = self.shape_gen.get_true_state(next_phase)
        best_action = np.clip(next_desired - self.current_state, -self.max_act, self.max_act)

        return best_action

    def set_direction(self, direction):
        self.direction = direction


class DirectionWrapper(gym.Wrapper):
    def __init__(self, env):
        super(DirectionWrapper, self).__init__(env)
        self.direction = +1
        self.env.set_direction(self.direction)
        observation_space = self.env.observation_space
        self.observation_space = spaces.Box(low=np.concatenate(([0., 0.], observation_space.low)), high=np.concatenate(([1., 1.], observation_space.high)))

    def reset(self, direction=None, **kwargs):
        if direction is None:
            direction = np.random.randint(2) * 2 - 1
        self.direction = direction
        self.env.set_direction(self.direction)
        obs = self.env.reset(**kwargs).copy()
        return np.concatenate(([(self.direction + 1)/2, (-self.direction + 1)/2], obs))

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        if self.direction == 1:
            if 'RewardBreakdown' in info.keys():
                info['RewardBreakdown'] = np.concatenate(([rew, 0.], info['RewardBreakdown']))
            else:
                info['RewardBreakdown'] = np.array([rew, 0.])
        else:
            if 'RewardBreakdown' in info.keys():
                info['RewardBreakdown'] = np.concatenate(([0., rew], info['RewardBreakdown']))
            else:
                info['RewardBreakdown'] = np.array([0., rew,])
        obs_return = np.concatenate(([(self.direction + 1)/2, (-self.direction + 1)/2], obs))
        return obs_return, rew, done, info

    def set_direction(self, direction, **kwargs):
        self.direction = direction
        self.env.set_direction(direction, **kwargs)


class ShapeChoiceWrapper(gym.Wrapper):
    def __init__(self, env1, env2):
        super(ShapeChoiceWrapper, self).__init__(env1)
        self.envs = [copy.deepcopy(env1), copy.deepcopy(env2)]
        self.env_choice = 0
        observation_space = self.env.observation_space
        self.observation_space = spaces.Box(low=np.concatenate(([0, 0], observation_space.low)), high=np.concatenate(([1., 1.], observation_space.high)))

    def reset(self, selection=None, **kwargs):
        if selection is None:
            selection = np.random.randint(2)
        self.env_choice = selection
        self.env = self.envs[self.env_choice]
        obs = self.env.reset(**kwargs).copy()
        return np.concatenate(([self.env_choice, 1-self.env_choice], obs))

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        if self.env_choice == 1:
            if 'RewardBreakdown' in info.keys():
                info['RewardBreakdown'] = np.concatenate(([rew, 0.], info['RewardBreakdown']))
            else:
                info['RewardBreakdown'] = np.array([rew, 0.])
        else:
            if 'RewardBreakdown' in info.keys():
                info['RewardBreakdown'] = np.concatenate(([0., rew], info['RewardBreakdown']))
            else:
                info['RewardBreakdown'] = np.array([0., rew,])
        obs_return = np.concatenate(([self.env_choice, 1-self.env_choice], obs))
        return obs_return, rew, done, info

    def set_direction(self, direction, **kwargs):
        self.env.set_direction(direction, **kwargs)

class TriShapeWrapper(gym.Wrapper):
    def __init__(self, env1, env2, env3):
        super(TriShapeWrapper, self).__init__(env1)
        self.envs = [copy.deepcopy(env1), copy.deepcopy(env2), copy.deepcopy(env3)]
        self.env_choice = 0
        observation_space = self.env.observation_space
        self.observation_space = spaces.Box(low=np.concatenate(([0, 0, 0], observation_space.low)), high=np.concatenate(([1., 1., 1.], observation_space.high)))

    def reset(self, selection=None, **kwargs):
        if selection is None:
            selection = np.random.randint(3)
        self.env_choice = selection
        indicator = [0, 0, 0]
        indicator[self.env_choice] = 1
        self.env = self.envs[self.env_choice]
        obs = self.env.reset(**kwargs).copy()
        return np.concatenate((indicator, obs))

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        if self.env_choice == 0:
            if 'RewardBreakdown' in info.keys():
                info['RewardBreakdown'] = np.concatenate(([rew, 0., 0.], info['RewardBreakdown']))
            else:
                info['RewardBreakdown'] = np.array([rew, 0., 0.])
        elif self.env_choice == 1:
            if 'RewardBreakdown' in info.keys():
                info['RewardBreakdown'] = np.concatenate(([0., rew, 0.], info['RewardBreakdown']))
            else:
                info['RewardBreakdown'] = np.array([0., rew, 0.])
        else:
            if 'RewardBreakdown' in info.keys():
                info['RewardBreakdown'] = np.concatenate(([0., 0., rew], info['RewardBreakdown']))
            else:
                info['RewardBreakdown'] = np.array([0., 0., rew])
        indicator = [0, 0, 0]
        indicator[self.env_choice] = 1
        obs_return = np.concatenate((indicator, obs))
        return obs_return, rew, done, info

    def set_direction(self, direction, **kwargs):
        self.env.set_direction(direction, **kwargs)


