#toyenv similar to Derk
import gym
from gym.spaces import MultiDiscrete, Box
import numpy as np

class MultiDiscreteToyEnv(gym.Env):

    def __init__(self, goal_locs, goal_rads, reset_loc=None) -> None:
        super().__init__()
        #up-down-left-right movement + movement_fraction + act (0 or 1)
        self.n_dim = goal_locs.shape[1] #dimension of the space
        self.action_space = MultiDiscrete([2*self.n_dim, 4, 2])
        self.n_goals = goal_locs.shape[0]
        self.observation_space = Box(-1., 1., shape=(self.n_dim + self.n_goals*self.n_dim,))
        self.step_size = 0.1
        pos_actions = np.eye(self.n_dim)
        action_directions = list(np.stack((pos_actions, -pos_actions), axis=1).reshape(-1, self.n_dim))
        self.movement_vector_mapping = dict(zip(np.arange(2*self.n_dim), action_directions))
        self.goal_locs = goal_locs#n_goals x n_dims array
        self.goal_rads = goal_rads#n_goals x 1 array
        self.reset_loc = reset_loc
        self.viewer = None
        

    def step(self, action):
        next_loc = self.__get_next_loc(action)
        next_state = np.concatenate((next_loc, self.goal_locs.flatten()), dtype=np.float32)
        self.loc = next_loc
        self.t += 1
        return next_state.reshape((1, -1)), None, np.array(False).reshape((1,)), {}

    def reset(self):
        self.t = 0
        low = self.observation_space.low.copy()[:self.n_dim]
        high = self.observation_space.high.copy()[:self.n_dim]
        high[0] = low[0] + (high[0] - low[0])/2.0 #start from left half only
        #low[0] = low[0] + (high[0] - low[0])/2.0
        if self.reset_loc is None:
            reset_loc = np.random.uniform(low, high)
        else:
            reset_loc = self.reset_loc.copy()
        self.loc = reset_loc
        state = np.concatenate((self.loc, self.goal_locs.flatten()), dtype=np.float32)
        return state.reshape((1,-1))

    def __get_next_loc(self, action):
        movement_fraction = action[:,1] / (self.action_space.nvec[1] - 1)
        movement_fraction = movement_fraction.cpu().numpy()
        movement_direction = self.movement_vector_mapping[action[:,0].item()]
        proposed_next_loc = self.loc + movement_fraction * self.step_size * movement_direction
        next_loc = np.clip(proposed_next_loc, self.observation_space.low[:self.n_dim], self.observation_space.high[:self.n_dim])
        return next_loc



    def render(self, mode='human'):
        assert self.n_dim == 2, "Only available in 2dim"
            
        screen_width = 400
        screen_height = 400
        worldwidth = 2
        scale = screen_width / worldwidth

        def draw_goal(xy, rad):
            goal = self.viewer.draw_circle(radius=rad*scale, res=30)
            goal.set_color(1., 0., 0.)
            self.goaltrans = rendering.Transform()
            goal.add_attr(self.goaltrans)
            self.viewer.add_geom(goal)

            goaldx = xy[0].item() * scale + screen_width / 2.0
            goaldy = xy[1].item() * scale + screen_height / 2.0
            self.goaltrans.set_translation(goaldx, goaldy)

        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(screen_width, screen_height)
            
            #draw goals
            for goal_loc, goal_rad in zip(self.goal_locs, self.goal_rads):
                draw_goal(goal_loc, goal_rad)

            agent = self.viewer.draw_circle(radius=3, res=30)
            agent.set_color(.2, .2, .2)
            self.agenttrans = rendering.Transform()
            agent.add_attr(self.agenttrans)
            self.viewer.add_geom(agent)

            #axis for convenience
            a1 = self.viewer.draw_line((0, screen_height / 2.0), (screen_width, 0 + screen_height / 2.0))
            self.viewer.add_geom(a1)
            a2 = self.viewer.draw_line((screen_width / 2.0, 0.0), (screen_width / 2.0, screen_height))
            self.viewer.add_geom(a2)

        agentdx = self.loc[0].item() * scale + screen_width / 2.0
        agentdy = self.loc[1].item() * scale + screen_height / 2.0
        self.agenttrans.set_translation(agentdx, agentdy)

        return self.viewer.render(return_rgb_array=(mode == 'rgb_array'))
        
    def close(self):
        if self.viewer:
            self.viewer.close()

