import numpy as np
import gym
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import os.path as osp
import wandb
from PIL import Image
from collections import defaultdict

from gym_minigrid.wrappers import *
from goal_prox.envs.debug_viz import DebugViz
import rlf.rl.utils as rutils
import random


def get_grid(N):
    return [(x, y)
            for x in range(1, N+1)
            for y in range(1, N+1)]


def get_universe(quads):
    quad_displacement = {
        '1': [0, 0],
        '2': [9, 0],
        '3': [0, 9],
        '4': [9, 9],
    }
    all_spots = []
    for quad in quads:
        offset = quad_displacement[quad]
        grid = get_grid(8)
        grid = [(x+offset[0], y+offset[1]) for x, y in grid]
        all_spots.extend(grid)

    # Set to a fixed seed so we have reproducable results
    rng = np.random.RandomState(42)
    rng.shuffle(all_spots)
    return all_spots


def get_universes(gw_agent_quad, gw_goal_quad, gw_cover, gw_compl):
    all_quads = '1,2,3,4'
    agent_quads = (gw_agent_quad or all_quads).split(',')
    goal_quads = (gw_goal_quad or all_quads).split(',')

    agent_u = get_universe(agent_quads)
    goal_u = get_universe(goal_quads)

    if gw_cover != 1.0:
        num_a = int(len(agent_u) * gw_cover)
        num_g = int(len(goal_u) * gw_cover)

        if gw_compl:
            agent_u = agent_u[num_a+1:]
            goal_u = goal_u[num_g+1:]
        else:
            agent_u = agent_u[:num_a]
            goal_u = goal_u[:num_g]
    return agent_u, goal_u


def choice_except(options, bad_val, rnd):
    choose_idx = rnd.choice([i for i, x in enumerate(options) if not (x == bad_val).all()])
    return options[choose_idx]


def convert_to_graph(grid_state):
    _, width, height = grid_state.shape
    graph = np.zeros((width, height))
    agent_pos = None
    goal_pos = None
    for i in range(width):
        for j in range(height):
            grid_val = grid_state[:, i, j]
            node_type = torch.argmax(grid_val).item()
            if node_type == 2:
                # Goal
                goal_pos = (i, j)
                node_val = 1
            elif node_type == 3:
                # Agent start
                agent_pos = (i, j)
                node_val = 1
            elif node_type == 1:
                # Wall
                node_val = 0
            elif node_type == 0:
                # Empty
                node_val = 1
            else:
                print(grid_val)
                raise ValueError('Unrecognized grid val')
            graph[j, i] = node_val
    return graph, agent_pos, goal_pos


def get_env_for_pos(agent_pos, goal_pos, args):
    return GoalCheckerWrapper(FullyObsWrapper(gym.make(args.env_name,
        agent_pos=agent_pos, goal_pos=goal_pos)), args)

def get_grid_obs_for_env(env):
    grid = env._get_obs(env.observation(env.gen_obs()))
    return grid

def plot_prox_heatmap(get_prox_func_fn, save_dir, iter_count,
                      name, args, obs_shape, with_compl):
    agent_u, goal_u = get_universes(args.gw_agent_quad, args.gw_goal_quad,
                                    args.gw_cover, with_compl)

    if args.gw_goal_pos is not None:
        fixed_goal = tuple([int(x) for x in args.gw_goal_pos.split(',')])
    else:
        fixed_goal = goal_u[0]
    agent_u = [x for x in agent_u if x != fixed_goal]

    states = []
    # The door positions
    agent_u = [(4, 9), (11, 9), (9, 4), (9, 14), *agent_u]
    for agent_pos in agent_u:
        env = get_env_for_pos(agent_pos, fixed_goal, args)
        # Need to get the partial observation, then get the full observation,
        # then get the one hot encoding
        grid = get_grid_obs_for_env(env)
        grid = grid.transpose(2, 0, 1)
        states.append(grid)

    states = torch.FloatTensor(states).to(args.device)
    proximities = get_prox_func_fn(states, action=None)
    # rscale proximities to be between 0.1 and 1
    # proximities = (proximities - proximities.min()) / (proximities.max() - proximities.min())
    # proximities = proximities * 0.9 + 0.1
    
    p_vals = np.zeros((obs_shape[1], obs_shape[2]))
    for i, (x, y) in enumerate(agent_u):
        p_vals[x, y] = proximities[i].item()

    ax = sns.heatmap(p_vals, linewidth=0.5, annot=False, xticklabels=False,
            yticklabels=False)
    ax.tick_params(left=False, bottom=False)
    import os
    # if save_dir don't exist, create it
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    save_path = osp.join(save_dir, '%s_heat_%i.png' % (name, iter_count))
    print('Saved proximity heatmap to %s' % save_path)
    plt.savefig(save_path)
    plt.clf()

    if not args.no_wb:
        wandb.log({name: [wandb.Image(Image.open(save_path))]})

def plot_prox_heatmap_full(get_prox_func_fn, save_dir, iter_count,
                      name, args, obs_shape, with_compl):
    agent_u, goal_u = get_universes(args.gw_agent_quad, args.gw_goal_quad,
                                    args.gw_cover, with_compl)

    if args.gw_goal_pos is not None:
        fixed_goal = tuple([int(x) for x in args.gw_goal_pos.split(',')])
    else:
        fixed_goal = goal_u[0]
    agent_u = [x for x in agent_u if x != fixed_goal]

    states = []
    # The door positions
    agent_u = [(4, 9), (11, 9), (9, 4), (9, 14), *agent_u]
    for agent_pos in agent_u:
        for dir in range(4):
            env = get_env_for_pos(agent_pos, fixed_goal, args)
            env.agent_dir = dir
            # Need to get the partial observation, then get the full observation,
            # then get the one hot encoding
            grid = get_grid_obs_for_env(env)
            grid = grid.transpose(2, 0, 1)
            states.append(grid)

    states = torch.FloatTensor(states).to(args.device)
    proximities = get_prox_func_fn(states, action=None)
    proximities = proximities.view(len(agent_u), 4)
    # rscale proximities to be between 0.1 and 1
    # proximities = (proximities - proximities.min()) / (proximities.max() - proximities.min())
    # proximities = proximities * 0.9 + 0.1
    
    p_vals = np.zeros((obs_shape[1], obs_shape[2]))
    for i, (x, y) in enumerate(agent_u):
        p_vals[x, y] = proximities[i].mean().item()

    ax = sns.heatmap(p_vals, linewidth=0.5, annot=False, xticklabels=False,
            yticklabels=False)
    ax.tick_params(left=False, bottom=False)
    import os
    # if save_dir don't exist, create it
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    save_path = osp.join(save_dir, '%s_heat_%i.png' % (name, iter_count))
    print('Saved proximity heatmap to %s' % save_path)
    plt.savefig(save_path)
    plt.clf()

    if not args.no_wb:
        wandb.log({name: [wandb.Image(Image.open(save_path))]})


def disp_gw_state(state):
    for j in range(state.shape[2]):
        row_str = ''
        for i in range(state.shape[1]):
            grid_val = state[:, i, j]
            node_type = torch.argmax(grid_val).item()
            if node_type == 2:
                # Goal
                row_str += 'g'
            elif node_type == 3:
                # Agent start
                row_str += 'a'
            elif node_type == 1:
                # Wall
                row_str += 'x'
            elif node_type == 0:
                # Empty
                row_str += ' '
        print(row_str)

class GwProxPlotter(DebugViz):
    def __init__(self, save_dir, args, obs_shape):
        super().__init__(save_dir, args)
        self.obs_shape = obs_shape
        self.reset()

    def reset(self):
        self.sum_prox = defaultdict(lambda: np.zeros(self.obs_shape[1:]))
        self.count = defaultdict(lambda: np.zeros(self.obs_shape[1:]))

    def _should_process_batches(self):
        return (not self.args.gw_rand_pos) or (self.args.gw_goal_pos is not None)

    def add(self, batches):
        if not self._should_process_batches():
            return
        for name, batch in batches.items():
            state = batch['state'].cpu().numpy()
            proxs = batch['prox']
            _, agent_x, agent_y = np.where(state[:, 3] == 1)

            for x,y,prox in zip(agent_x, agent_y, proxs):
                self.sum_prox[name][x, y] += prox
                self.count[name][x, y] += 1

    def plot(self, iter_count, plot_names, plot_funcs):
        if self._should_process_batches():
            for plot_name in plot_names:
                avg_prox = self.sum_prox[plot_name] / self.count[plot_name]
                if np.isnan(avg_prox).all():
                    # We have nothing to render here.
                    continue

                #avg_prox[np.isnan(avg_prox)] = 0.0
                sns.heatmap(avg_prox)
                save_name = osp.join(self.save_dir, '%s_%i.png' % (plot_name, iter_count))
                plt.savefig(save_name)
                print(f"Saved to {save_name}")
                plt.clf()

        for func_name, plot_func in plot_funcs.items():
            # plot_prox_heatmap(plot_func, self.save_dir, iter_count, func_name,
            #                   self.args, self.obs_shape, False)
            # if self.args.gw_cover != 1.0:
            #     plot_prox_heatmap(plot_func, self.save_dir, iter_count, func_name + "_compl",
            #                       self.args, self.obs_shape, True)
            plot_prox_heatmap_full(plot_func, self.save_dir, iter_count, func_name,
                                  self.args, self.obs_shape, False)
            if self.args.gw_cover != 1.0:
                plot_prox_heatmap_full(plot_func, self.save_dir, iter_count, func_name + "_compl",
                                      self.args, self.obs_shape, True)

        self.reset()


def plot_iq_heatmap_full(get_prox_func_fn, save_dir, iter_count,
                      name, args, obs_shape, with_compl):
    agent_u, goal_u = get_universes(args.gw_agent_quad, args.gw_goal_quad,
                                    args.gw_cover, with_compl)

    if args.gw_goal_pos is not None:
        fixed_goal = tuple([int(x) for x in args.gw_goal_pos.split(',')])
    else:
        fixed_goal = goal_u[0]
    agent_u = [x for x in agent_u if x != fixed_goal]
    # The door positions
    agent_u = [(4, 9), (11, 9), (9, 4), (9, 14), *agent_u]

    possible_actions_len = 4
    possible_actions = list(range(possible_actions_len))
    # # make possible_actions into one-hot
    # possible_actions = np.eye(possible_actions_len)[possible_actions]
    # agent_u means the start point of the agent
    states = [] # input has: state, action, next_state, done
    proximities = []
    for agent_pos in agent_u:
        for dir in range(4):
            env = get_env_for_pos(agent_pos, fixed_goal, args)
            env.agent_dir = dir
            # Need to get the partial observation, then get the full observation,
            # then get the one hot encoding
            grid = get_grid_obs_for_env(env)
            grid = grid.transpose(2, 0, 1)
            states.append(grid)
    

    states = torch.FloatTensor(states).to(args.device)
    proximities = get_prox_func_fn(states)
    proximities = proximities.view(len(agent_u), 4)
    # get the mean of proximities
    proximities = proximities.mean(dim=1)
    # # rscale proximities to be between 0.1 and 1
    proximities = (proximities - proximities.min()) / (proximities.max() - proximities.min())
    proximities = proximities * 0.9 + 0.1

    p_vals = np.zeros((obs_shape[1], obs_shape[2]))
    for i, (x, y) in enumerate(agent_u):
        p_vals[y, x] = proximities[i].item() # because sns.heatmap is diagonally flipped
    


    # gen mean heatmap
    ax = sns.heatmap(p_vals, linewidth=0.5, annot=False, xticklabels=False,
            yticklabels=False)
    ax.tick_params(left=False, bottom=False)
    import os
    # if save_dir don't exist, create it
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    save_path = osp.join(save_dir, '%s_heat_%i_mean.png' % (name, iter_count))
    print('Saved proximity heatmap to %s' % save_path)
    plt.savefig(save_path)
    plt.clf()


    if not args.no_wb:
        wandb.log({name: [wandb.Image(Image.open(save_path))]})
        


def plot_iq_action_full(get_prox_func_fn, save_dir, iter_count,
                      name, args, obs_shape, with_compl):
    agent_u, goal_u = get_universes(args.gw_agent_quad, args.gw_goal_quad,
                                    args.gw_cover, with_compl)

    if args.gw_goal_pos is not None:
        fixed_goal = tuple([int(x) for x in args.gw_goal_pos.split(',')])
    else:
        fixed_goal = goal_u[0]
    agent_u = [x for x in agent_u if x != fixed_goal]
    # The door positions
    agent_u = [(4, 9), (11, 9), (9, 4), (9, 14), *agent_u]

    possible_actions_len = 4
    possible_actions = list(range(possible_actions_len)) # 0-right, 1-down, 2-left, 3-up
    optimal_actions = []
    flag_gen_maze = True
    for agent_pos in agent_u:
        proximities = torch.zeros(possible_actions_len)
        env = get_env_for_pos(agent_pos, fixed_goal, args)

        if flag_gen_maze:
            # render env and save an image
            env_obs = env.render('rgb_array')
            env_obs = Image.fromarray(env_obs)
            env_img_path = osp.join(save_dir, '%s_env_%i.png' % (name, iter_count))
            env_obs.save(env_img_path)
            print(f"Saved env image to {env_img_path}")
            flag_gen_maze = False

        # Need to get the partial observation, then get the full observation,
        # then get the one hot encoding
        grid = get_grid_obs_for_env(env)
        grid = grid.transpose(2, 0, 1)
        ac_index = get_prox_func_fn(grid)
        optimal_actions.append(ac_index)
    
    
    opt_ac_grid = np.zeros((obs_shape[1], obs_shape[2]))
    opt_ac_grid = opt_ac_grid - 1 # -1 means no optimal action
    for i, (x, y) in enumerate(agent_u):
        opt_ac_grid[x, y] = optimal_actions[i]

    # Define arrow properties based on action values
    ## because the grid render is upside down, so the corresponding action should be reversed
    # original action
    original_actions = {
        0: (1, 0),    # right
        1: (0, 1),   # down
        2: (-1, 0),   # left
        3: (0, -1)     # up
    }
    action_to_arrow = {
        0: (1, 0),    # right
        1: (0, -1),   # down
        2: (-1, 0),   # left
        3: (0, 1)     # up
    }
    
    # Create plot
    plt.figure(figsize=(6, 6))
    rows, cols = opt_ac_grid.shape

    # Draw each grid cell and add arrows
    for i in range(rows):
        for j in range(cols):
            action = opt_ac_grid[i, j]
            if action != -1:  # -1 means no movement
                dx, dy = action_to_arrow[action]
                plt.arrow(i, cols - j - 1, dx * 0.3, dy * 0.3, head_width=0.2, head_length=0.2, fc='blue', ec='blue')
            # Draw grid cell boundaries
            plt.gca().add_patch(plt.Rectangle((i - 0.5, cols - j - 1 - 0.5), 1, 1, fill=False, edgecolor='gray'))

    # Set plot limits and labels
    plt.xlim(-0.5, cols - 0.5)
    plt.ylim(-0.5, rows - 0.5)
    plt.gca().set_aspect('equal')
    plt.xticks(range(cols))
    plt.yticks(range(rows))

    # Save the plot to a file
    save_path = osp.join(save_dir, '%s_ac_dir_epoch%i.png' % (name, iter_count))
    print('Saved iq action map to %s' % save_path)
    plt.savefig(save_path)
    plt.close()



    if not args.no_wb:
        wandb.log({name: [wandb.Image(Image.open(save_path))]})
        
def plot_iq_reward_heatmap(get_prox_func_fn, save_dir, iter_count,
                      name, args, obs_shape, with_compl):
    agent_u, goal_u = get_universes(args.gw_agent_quad, args.gw_goal_quad,
                                    args.gw_cover, with_compl)

    if args.gw_goal_pos is not None:
        fixed_goal = tuple([int(x) for x in args.gw_goal_pos.split(',')])
    else:
        fixed_goal = goal_u[0]
    agent_u = [x for x in agent_u if x != fixed_goal]
    # The door positions
    agent_u = [(4, 9), (11, 9), (9, 4), (9, 14), *agent_u]

    possible_actions_len = 4
    possible_actions = list(range(possible_actions_len))
    # # make possible_actions into one-hot
    # possible_actions = np.eye(possible_actions_len)[possible_actions]
    # agent_u means the start point of the agent
    proximities = []
    for agent_pos in agent_u:
        for action in possible_actions:
            env = get_env_for_pos(agent_pos, fixed_goal, args)
            env.agent_dir = dir
            # Need to get the partial observation, then get the full observation,
            # then get the one hot encoding
            obs = get_grid_obs_for_env(env)
            obs = obs.transpose(2, 0, 1)
            next_obs, reward, done, _ = env.step(action)
            next_obs = next_obs.transpose(2, 0, 1)

            # transform each obs, action, next_obs, done to torch tensor
            obs_th = torch.FloatTensor(obs).to(args.device)[None, :]
            action_th = torch.LongTensor([action]).to(args.device)[None, :]
            next_obs_th = torch.FloatTensor(next_obs).to(args.device)[None, :]
            done_th = torch.FloatTensor([done]).to(args.device)[None, :]

            # calculate the reward with get_prox_func_fn
            proximities.append(get_prox_func_fn(obs_th, action_th, next_obs_th, done_th))
    
    proximities = np.array(proximities)
    proximities = proximities.reshape(len(agent_u), 4)
    # rscale proximities to be between 0.1 and 1
    # proximities = (proximities - proximities.min()) / (proximities.max() - proximities.min())
    # proximities = proximities * 0.9 + 0.1

    p_vals = np.zeros((obs_shape[1], obs_shape[2], 4))
    for i, (x, y) in enumerate(agent_u):
        for j in range(4):
            p_vals[x, y, j] = proximities[i][j].item()
    
    import matplotlib.patches as patches

    # Define color normalization
    cmap = plt.cm.hot
    norm = plt.Normalize(np.min(p_vals), np.max(p_vals))

    # Create plot
    plt.figure(figsize=(6, 6))
    rows, cols, _ = p_vals.shape

    # Function to add a colored triangle
    def add_triangle(ax, x, y, reward, position):
        # Position determines which triangle to draw: 0=right, 1=down, 2=left, 3=up
        # original_actions = {
        #     0: (1, 0),    # right
        #     1: (0, 1),   # down
        #     2: (-1, 0),   # left
        #     3: (0, -1)     # up
        # }
        # action_to_arrow = {
        #     0: (1, 0),    # right
        #     1: (0, -1),   # down
        #     2: (-1, 0),   # left
        #     3: (0, 1)     # up
        # }
        if position == 0:  # Right triangle
            vertices = [(x, y), (x+0.5, y+0.5), (x+0.5, y-0.5)]
        elif position == 1:  # Down triangle
            vertices = [(x, y), (x-0.5, y-0.5), (x+0.5, y-0.5)]
        elif position == 2:  # Left triangle
            vertices = [(x, y), (x-0.5, y-0.5), (x-0.5, y+0.5)]
        elif position == 3:  # Up triangle
            vertices = [(x, y), (x-0.5, y+0.5), (x+0.5, y+0.5)]

        # Create polygon and add it to the plot
        triangle = patches.Polygon(vertices, color=cmap(norm(reward)), ec='black')
        ax.add_patch(triangle)

    # Draw each grid cell and add colored triangles
    for i in range(rows):
        for j in range(cols):
            # Coordinates transformation for plot (reversed y)
            if (i, j) not in agent_u:
                continue
            x, y = i, rows - 1 - j
            # x, y = j, rows - 1 - i
            for k in range(4):
                add_triangle(plt.gca(), x, y, p_vals[i, j, k], k)

    # Set plot limits and labels
    plt.xlim(-0.5, cols - 0.5)
    plt.ylim(-0.5, rows - 0.5)
    plt.gca().set_aspect('equal')
    plt.xticks(range(cols))
    plt.yticks(range(rows))
    plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Reward Intensity')

    import os
    # if save_dir don't exist, create it
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    save_path = osp.join(save_dir, f'{name}_state_action_reward_{iter_count}.png')
    print('Saved proximity heatmap to %s' % save_path)
    plt.savefig(save_path)
    plt.clf()


    if not args.no_wb:
        wandb.log({name: [wandb.Image(Image.open(save_path))]})
        



class GwIQPlotter(DebugViz):
    def __init__(self, save_dir, args, obs_shape):
        super().__init__(save_dir, args)
        self.obs_shape = obs_shape
        self.reset()

    def reset(self):
        self.sum_prox = defaultdict(lambda: np.zeros(self.obs_shape[1:]))
        self.count = defaultdict(lambda: np.zeros(self.obs_shape[1:]))

    def _should_process_batches(self):
        return (not self.args.gw_rand_pos) or (self.args.gw_goal_pos is not None)

    def add(self, batches):
        if not self._should_process_batches():
            return
        for name, batch in batches.items():
            state = batch['state'].cpu().numpy()
            proxs = batch['prox']
            _, agent_x, agent_y = np.where(state[:, 3] == 1)

            for x,y,prox in zip(agent_x, agent_y, proxs):
                self.sum_prox[name][x, y] += prox
                self.count[name][x, y] += 1

    def plot(self, iter_count, plot_names, plot_funcs):
        if self._should_process_batches():
            for plot_name in plot_names:
                avg_prox = self.sum_prox[plot_name] / self.count[plot_name]
                if np.isnan(avg_prox).all():
                    # We have nothing to render here.
                    continue

                #avg_prox[np.isnan(avg_prox)] = 0.0
                sns.heatmap(avg_prox)
                save_name = osp.join(self.save_dir, '%s_%i.png' % (plot_name, iter_count))
                plt.savefig(save_name)
                print(f"Saved to {save_name}")
                plt.clf()

        for func_name, plot_func in plot_funcs.items():
            if func_name == 'irlV':
                plot_iq_heatmap_full(plot_func, self.save_dir, iter_count, func_name,
                                      self.args, self.obs_shape, False)
            elif func_name == 'getAction':
                plot_iq_action_full(plot_func, self.save_dir, iter_count, func_name,
                                      self.args, self.obs_shape, False)
            elif func_name == 'getReward':
                plot_iq_reward_heatmap(plot_func, self.save_dir, iter_count, func_name,
                                      self.args, self.obs_shape, False)

        self.reset()



def sample_range_no_replace(cache_vals, N, avoid_coords):
    if 'all_starts' not in cache_vals:
        cache_vals['all_starts'] = [(x, y)
                                    for x in range(1, N-1)
                                    for y in range(1, N-1) if (x, y) not in avoid_coords]
        cache_vals['sample_idx'] = list(range(len(cache_vals['all_starts'])))

    sample_idx = cache_vals['sample_idx']
    all_starts = cache_vals['all_starts']

    if len(sample_idx) == 0:
        cache_vals['sample_idx'] = list(range(len(all_starts)))
    idx = np.random.choice(sample_idx)
    del cache_vals['sample_idx'][sample_idx.index(idx)]
    start = all_starts[idx]
    return start


def gw_empty_spawn(env, cache_vals, args, N=8):
    if args.gw_rand_pos:
        start = sample_range_no_replace(cache_vals, N, [(6, 6)])
        env.env.agent_start_pos = start


def gw_room_spawn(env, cache_vals, args):
    if not args.gw_rand_pos:
        # Always start at the top left corner.
        env.env._agent_default_pos = [1, 1]
    else:
        if 'universe' not in cache_vals:
            cache_vals['universe'] = get_universes(args.gw_agent_quad,
                                                   args.gw_goal_quad, args.gw_cover, args.gw_compl)
        univ = cache_vals['universe']
        univ = np.array(univ)

        if args.gw_goal_pos is None:
            # Sample a position from both.
            agent_start_idx = env.np_random.choice(len(univ[0]))
            agent_start = univ[0][agent_start_idx]
            # Sample from anywhere except for agent position.
            goal_start = choice_except(univ[1], agent_start, env.np_random)
        else:
            goal_start = tuple([int(x) for x in args.gw_goal_pos.split(',')])
            agent_start = choice_except(univ[0], goal_start, env.np_random)

        env.env._agent_default_pos = agent_start
        env.env._goal_default_pos = goal_start



def gw_room_spawn_long_hypotenuse(env, cache_vals, args):
    return


NODE_TO_ONE_HOT = {
    # Empty square
    (1, 0, 0): [1, 0, 0, 0],
    (9, 0, 0): [1, 0, 0, 0], # make lava also stepable
    # Wall
    (2, 5, 0): [0, 1, 0, 0],
    # Goal
    (8, 1, 0): [0, 0, 1, 0],
    # Agent
    (10, 0, 0): [0, 0, 0, 1],
    (10, 0, 1): [0, 0, 0, 1],
    (10, 0, 2): [0, 0, 0, 1],
    (10, 0, 3): [0, 0, 0, 1],
}


class GoalCheckerWrapper(gym.Wrapper):
    def __init__(self, env, args):
        super().__init__(env)
        assert args.gw_img
        ob_s = env.observation_space.spaces['image'].shape

        ob_shape = (ob_s[0], ob_s[1], 4)

        low = 0.0
        high = 1.0
        self.observation_space = gym.spaces.Box(shape=ob_shape,
                                                low=np.float32(low),
                                                high=np.float32(high),
                                                dtype=np.float32)

        # Transform the action space to the cardinal directions
        assert not (args.gw_larger_action_space and args.gw_rotation_action_space), 'gw_larger_action_space and gw_rotation_action_space can\'t be true at the same time'
        self.gw_rotation_action_space = args.gw_rotation_action_space
        self.gw_larger_action_space = args.gw_larger_action_space
        self.gw_diag_action_space = args.gw_diag_action_space
        self.gw_diag_action_dir_unchanged = args.gw_diag_action_dir_unchanged
        if args.gw_larger_action_space or self.gw_diag_action_space:
            self.action_space = gym.spaces.Discrete(8)
        else:
            self.action_space = gym.spaces.Discrete(4)
        self.cache_vals = {}
        self.args = args
        self.set_cond = {
            'MiniGrid-Empty-8x8-v0': gw_empty_spawn,
            'MiniGrid-FourRooms-v0': gw_room_spawn,
            'MiniGrid-Deceptive-v0': gw_room_spawn,
            'MiniGrid-Deceptive-v1': gw_room_spawn,
            'MiniGrid-FourRooms-Long-Hypotenuse-v0': gw_room_spawn_long_hypotenuse,
        }
        self.found_goal = False

    def _get_obs(self, obs_dict):
        obs = obs_dict['image']

        obs = obs.reshape(-1, 3)
        obs = np.array(list(map(lambda x: NODE_TO_ONE_HOT[tuple(x)], obs)))
        obs = obs.reshape(*rutils.get_obs_shape(self.observation_space))
        if self.args.gw_img:
            return obs
        else:
            return obs.reshape(-1)

    def reset(self):
        # Call whatever setup is specific to this version of grid world.
        self.set_cond[self.args.env_name](self.env, self.cache_vals, self.args)
        self.found_goal = False
        reset_obs = self.env.reset()
        return self._get_obs(reset_obs)

    def step(self, a):
        if self.gw_rotation_action_space:
            if a == 0:
                obs_dict, reward, done, info = self.env.step(self.env.actions.left)
            elif a == 1:
                obs_dict, reward, done, info = self.env.step(self.env.actions.right)
            elif a == 2:
                obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
            elif a == 3:
                obs_dict, reward, done, info = self.env.step(self.env.actions.done)
            else:
                raise ValueError('Invalid action')
        elif self.gw_larger_action_space:
            desired_dir = a % 4
            self.env.env.agent_dir = desired_dir
            obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
            if a >= 4:
                self.env.env.step_count -= 1
                obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
        elif self.gw_diag_action_space:
            if a < 4:
                desired_dir = a
                self.env.env.agent_dir = desired_dir
                obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
            else:
                # for action = 4, 5, 6, 7, they are diagonal actions
                final_dir = self.env.env.agent_dir
                inter_dir = a % 4
                self.env.env.agent_dir = inter_dir
                _, _, _, _ = self.env.step(self.env.actions.forward)
                self.env.env.step_count -= 1
                _, _, _, _ = self.env.step(self.env.actions.right)
                self.env.env.step_count -= 1
                obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
                if self.gw_diag_action_dir_unchanged:
                    self.env.env.agent_dir = final_dir
        else:
            desired_dir = a
            self.env.env.agent_dir = desired_dir
            obs_dict, reward, done, info = self.env.step(self.env.actions.forward)
            

        obs = self._get_obs(obs_dict)

        if done and reward > 0.0:
            self.found_goal = True
        info['ep_found_goal'] = float(self.found_goal)
        return obs, reward, done, info
