import gym
import numpy as np 
from gym_minigrid.minigrid import *
import operator
from collections import namedtuple
from functools import reduce

class RoverEnv(MiniGridEnv):
    """
        Rover enviroment with static obstacles. 
    """

    def __init__(
            self,
            size=16,
            agent_start_pos=(1, 1),
            agent_start_dir=0,
            n_obstacles=0
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        # Reduce obstacles if there are too many
        if n_obstacles <= size/2 + 1:
            self.n_obstacles = int(n_obstacles)
        else:
            self.n_obstacles = int(size/2)

        # Init super
        super().__init__(
            grid_size=size,
            max_steps=4 * size * size,
            # Set this to True for maximum speed
            see_through_walls=True,
        )

        imgShape= (self.width, self.height, 3)
        imgSize = reduce(operator.mul, imgShape, 1)

        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(imgSize,),
            dtype='float32'
        )

        # Allow only 3 actions permitted: left, right, forward
        self.action_space = spaces.Discrete(self.actions.forward + 1)
        self.reward_range = (0, 1)


    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Place a goal square in the bottom-right corner
        self.grid.set(width - 2, height - 2, Goal())

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        # Place obstacles
        self.obstacles = []
        for i_obst in range(self.n_obstacles):
            self.obstacles.append(Ball())
            self.place_obj(self.obstacles[i_obst], max_tries=100)

        self.mission = ""

    def step(self, action):
        # Invalid action
        if action >= self.action_space.n or action < 0:
            raise ValueError("Invalid action!")

        # Get cost 
        cost = self.get_cost(action)

        # Check if there is an obstacle in front of the agent
        front_cell = self.grid.get(*self.front_pos)
        not_clear = front_cell and front_cell.type != 'goal' and front_cell.type != 'wall'

        # Update the agent's position/direction        
        obs, reward, done, info = MiniGridEnv.step(self, action)

        # # If the agent tried to walk over an obstacle or wall
        # if action == self.actions.forward and not_clear:
        #     reward = 0
        #     done = True

        return self.observation(obs), [reward, -0.1], done, info

    def reset(self):
        # Reset super
        obs = MiniGridEnv.reset(self)

        return self.observation(obs)

    def observation(self, obs):
        env = self.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
            OBJECT_TO_IDX['agent'],
            COLOR_TO_IDX['red'],
            env.agent_dir
        ])

        full_grid = full_grid.flatten()
        return full_grid/ 1.

    def get_cost(self, action):
        """
            Returns the cost
        """
        front_cell = self.grid.get(*self.front_pos)
        not_clear = front_cell and front_cell.type != 'goal'

        if action == self.actions.forward and not_clear:
            return 1.
        else:
            return 0.


class RGBImgObsWrapper(gym.core.ObservationWrapper):
    """
    Wrapper to use fully observable RGB image as the only observation output,
    no language/mission. This can be used to have the agent to solve the
    gridworld in pixel space.
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        self.tile_size = tile_size

        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
            dtype='uint8'
        )

    def observation(self, obs):
        env = self.unwrapped

        rgb_img = env.render(
            mode='rgb_array',
            highlight=False,
            tile_size=self.tile_size
        )
        return rgb_img