import numpy as np
import gym
import torch

from gym_minigrid.wrappers import *


def get_grid(N):
    return [(x, y)
            for x in range(1, N+1)
            for y in range(1, N+1)]


def get_universe(quads):
    quad_displacement = {
        '1': [0, 0],
        '2': [9, 0],
        '3': [0, 9],
        '4': [9, 9],
    }
    all_spots = []
    for quad in quads:
        offset = quad_displacement[quad]
        grid = get_grid(8)
        grid = [(x+offset[0], y+offset[1]) for x, y in grid]
        all_spots.extend(grid)

    # Set to a fixed seed so we have reproducable results
    rng = np.random.RandomState(42)
    rng.shuffle(all_spots)
    return all_spots


def get_universes(gw_agent_quad, gw_goal_quad, gw_cover, gw_compl):
    all_quads = '1,2,3,4'
    agent_quads = (gw_agent_quad or all_quads).split(',')
    goal_quads = (gw_goal_quad or all_quads).split(',')

    agent_u = get_universe(agent_quads)
    goal_u = get_universe(goal_quads)

    if gw_cover != 1.0:
        num_a = int(len(agent_u) * gw_cover)
        num_g = int(len(goal_u) * gw_cover)

        if gw_compl:
            agent_u = agent_u[num_a+1:]
            goal_u = goal_u[num_g+1:]
        else:
            agent_u = agent_u[:num_a]
            goal_u = goal_u[:num_g]
    return agent_u, goal_u


def choice_except(options, bad_val, rnd):
    choose_idx = rnd.choice([i for i, x in enumerate(options) if not (x == bad_val).all()])
    return options[choose_idx]


def convert_to_graph(grid_state):
    _, width, height = grid_state.shape
    graph = np.zeros((width, height))
    agent_pos = None
    goal_pos = None
    for i in range(width):
        for j in range(height):
            grid_val = grid_state[:, i, j]
            node_type = torch.argmax(grid_val).item()
            if node_type == 2:
                # Goal
                goal_pos = (i, j)
                node_val = 1
            elif node_type == 3:
                # Agent start
                agent_pos = (i, j)
                node_val = 1
            elif node_type == 1:
                # Wall
                node_val = 0
            elif node_type == 0:
                # Empty
                node_val = 1
            else:
                print(grid_val)
                raise ValueError('Unrecognized grid val')
            graph[j, i] = node_val
    return graph, agent_pos, goal_pos


def get_env_for_pos(agent_pos, goal_pos, args):
    return GoalCheckerWrapper(FullyObsWrapper(gym.make(args.env.name,
        agent_pos=agent_pos, goal_pos=goal_pos)), args)

def get_grid_obs_for_env(env):
    grid = env._get_obs(env.observation(env.gen_obs()))
    return grid

def disp_gw_state(state):
    for j in range(state.shape[2]):
        row_str = ''
        for i in range(state.shape[1]):
            grid_val = state[:, i, j]
            node_type = torch.argmax(grid_val).item()
            if node_type == 2:
                # Goal
                row_str += 'g'
            elif node_type == 3:
                # Agent start
                row_str += 'a'
            elif node_type == 1:
                # Wall
                row_str += 'x'
            elif node_type == 0:
                # Empty
                row_str += ' '
        print(row_str)


def sample_range_no_replace(cache_vals, N, avoid_coords):
    if 'all_starts' not in cache_vals:
        cache_vals['all_starts'] = [(x, y)
                                    for x in range(1, N-1)
                                    for y in range(1, N-1) if (x, y) not in avoid_coords]
        cache_vals['sample_idx'] = list(range(len(cache_vals['all_starts'])))

    sample_idx = cache_vals['sample_idx']
    all_starts = cache_vals['all_starts']

    if len(sample_idx) == 0:
        cache_vals['sample_idx'] = list(range(len(all_starts)))
    idx = np.random.choice(sample_idx)
    del cache_vals['sample_idx'][sample_idx.index(idx)]
    start = all_starts[idx]
    return start


def gw_empty_spawn(env, cache_vals, args, N=8):
    if args.env.gw_rand_pos:
        start = sample_range_no_replace(cache_vals, N, [(6, 6)])
        env.env.agent_start_pos = start


def gw_room_spawn(env, cache_vals, args):
    if not args.env.gw_rand_pos:
        # Always start at the top left corner.
        env.env._agent_default_pos = [1, 1]
    else:
        if 'universe' not in cache_vals:
            cache_vals['universe'] = get_universes(args.env.gw_agent_quad,
                                                   args.env.gw_goal_quad, args.env.gw_cover, args.env.gw_compl)
        univ = cache_vals['universe']
        univ = np.array(univ)

        if args.env.gw_goal_pos is None:
            # Sample a position from both.
            agent_start_idx = env.np_random.choice(len(univ[0]))
            agent_start = univ[0][agent_start_idx]
            # Sample from anywhere except for agent position.
            goal_start = choice_except(univ[1], agent_start, env.np_random)
        else:
            goal_start = tuple([int(x) for x in args.env.gw_goal_pos.split(',')])
            agent_start = choice_except(univ[0], goal_start, env.np_random)

        env.env._agent_default_pos = agent_start
        env.env._goal_default_pos = goal_start


NODE_TO_ONE_HOT = {
    # Empty square
    (1, 0, 0): [1, 0, 0, 0],
    (9, 0, 0): [1, 0, 0, 0], # make lava also stepable
    # Wall
    (2, 5, 0): [0, 1, 0, 0],
    # Goal
    (8, 1, 0): [0, 0, 1, 0],
    # Agent
    (10, 0, 0): [0, 0, 0, 1],
    (10, 0, 1): [0, 0, 0, 1],
    (10, 0, 2): [0, 0, 0, 1],
    (10, 0, 3): [0, 0, 0, 1],
}


def get_obs_shape(ob_space, k="observation"):
    if isinstance(ob_space, gym.spaces.Dict):
        return ob_space.spaces[k].shape
    else:
        return ob_space.shape

class GoalCheckerWrapper(gym.Wrapper):
    def __init__(self, env, args):
        super().__init__(env)
        ob_s = env.observation_space.spaces['image'].shape

        ob_shape = (ob_s[0], ob_s[1], 4)

        low = 0.0
        high = 1.0
        self.observation_space = gym.spaces.Box(shape=ob_shape,
                                                low=np.float32(low),
                                                high=np.float32(high),
                                                dtype=np.float32)

        # Transform the action space to the cardinal directions
        self.gw_diag_action_space = args.env.gw_diag_action_space
        if self.gw_diag_action_space:
            self.action_space = gym.spaces.Discrete(8)
        else:
            self.action_space = gym.spaces.Discrete(4)
        self.cache_vals = {}
        self.args = args
        self.set_cond = {
            'MiniGrid-Empty-8x8-v0': gw_empty_spawn,
            'MiniGrid-FourRooms-v0': gw_room_spawn,
            'MiniGrid-Deceptive-v0': gw_room_spawn,
            'MiniGrid-Deceptive-v1': gw_room_spawn,
        }

    def _get_obs(self, obs_dict):
        obs = obs_dict['image']

        obs = obs.reshape(-1, 3)
        obs = np.array(list(map(lambda x: NODE_TO_ONE_HOT[tuple(x)], obs)))
        obs = obs.reshape(get_obs_shape(self.observation_space))
        return obs

    def reset(self):
        # Call whatever setup is specific to this version of grid world.
        self.set_cond[self.args.env.name](self.env, self.cache_vals, self.args)
        self.found_goal = False
        reset_obs = self.env.reset()
        return self._get_obs(reset_obs)

    def step(self, a):
        if self.gw_diag_action_space:
            if a < 4:
                desired_dir = a
                self.env.env.agent_dir = desired_dir
                obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
            else:
                # for action = 4, 5, 6, 7, they are diagonal actions
                final_dir = self.env.env.agent_dir
                inter_dir = a % 4
                self.env.env.agent_dir = inter_dir
                _, _, _, _ = self.env.step(self.env.actions.forward)
                self.env.env.step_count -= 1
                _, _, _, _ = self.env.step(self.env.actions.right)
                self.env.env.step_count -= 1
                obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
                self.env.env.agent_dir = final_dir
        else:
            desired_dir = a
            self.env.env.agent_dir = desired_dir
            obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
            

        obs = self._get_obs(obs_dict)

        if done and reward > 0.0:
            self.found_goal = True
        info['ep_found_goal'] = float(self.found_goal)
        return obs, reward, done, info
