# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp
import enum
import dm_env
from dm_env import specs
import numpy as np
import matplotlib.pyplot as plt
from url_benchmark.dmc import ExtendedTimeStep
import torch.nn as nn
import torch
import pickle


MAX_SEED = 20000

class ObservationType(enum.IntEnum):
    STATE_INDEX = enum.auto()
    AGENT_ONEHOT = enum.auto()
    GRID = enum.auto()
    AGENT_GOAL_POS = enum.auto()
    AGENT_POS = enum.auto()

class Reward:
    def __init__(self, rew):
        self.rew = rew
    def reward(self, obs):
        if self.rew == 'up':
            return -obs[:, 0]
        elif self.rew == 'right':
            return obs[:, 1]
        elif self.rew == 'down':
            return obs[:, 0]
        elif self.rew == 'left':
            return -obs[:, 1]


def build_gridworld_task(task,
                         discount=0.99,
                         penalty_for_walls=0,
                         observation_type=ObservationType.AGENT_ONEHOT,
                         max_episode_length=200):
    """Construct a particular Gridworld layout with start/goal states.

    Args:
      task: string name of the task to use. One of {'simple', 'obstacle',
        'random_goal'}.
      discount: Discounting factor included in all Timesteps.
      penalty_for_walls: Reward added when hitting a wall (should be negative).
      observation_type: Enum observation type to use. One of:
        * ObservationType.STATE_INDEX: int32 index of agent occupied tile.
        * ObservationType.AGENT_ONEHOT: NxN float32 grid, with a 1 where the
          agent is and 0 elsewhere.
        * ObservationType.GRID: NxNx3 float32 grid of feature channels.
          First channel contains walls (1 if wall, 0 otherwise), second the
          agent position (1 if agent, 0 otherwise) and third goal position
          (1 if goal, 0 otherwise)
        * ObservationType.AGENT_GOAL_POS: float32 tuple with
          (agent_y, agent_x, goal_y, goal_x).
      max_episode_length: If set, will terminate an episode after this many
        steps.
    """
    tasks_specifications = {
        'simple': {
            'layout': [
                [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
                [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, -1, -1, 0, 0, 0, -1],
                [-1, 0, 0, 0, -1, -1, 0, 0, 0, -1],
                [-1, 0, 0, 0, -1, -1, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
                [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
          ],
          'start_state': (2, 2),
          'randomize_goals': False,
          'pos_neg_goals': [(2, 7), (3, 7), (7, 7), (7, 4), (7, 2)],
          'pos_neg_obs': [(1, 5), (1, 5), (7, 5), (7, 2), (4, 2)],
          'rewards': ['up', 'right', 'down', 'left'],
          '_pos_neg_goals': True,
          'random_rewards': False,
          'goal_list': [(2, 7), (3, 7), (7, 2), (2, 2), (4, 2), (2, 6), (1, 2), (7, 5), (3, 6)]
          # 'goal_state': (7, 2)

      },
      'fourroom': {
            'layout': [
                [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
                [-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1],
                [-1, -1, 0, -1, -1, -1, -1, -1, 0, -1, -1],
                [-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1],
                [-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1],
                [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
          ],
          'start_state': (2, 2),
          'randomize_goals': True,
          'pos_neg_goals': [(6, 4), (6, 9), (4, 9), (1, 2), (9, 8)],
          'pos_neg_obs': [(7, 2), (8, 5), (2, 5), (5, 2), (5, 8)],
          'rewards': ['up', 'right', 'down', 'left'],
          '_pos_neg_goals': True,
          'random_rewards': False,
          'goal_list': [(3, 7), (2, 6), (7, 6), (8, 8), (7, 4), (7, 2), (2, 2), (4, 2), (4, 7), (1, 7)]
          # 'goal_state': (7, 2)

      },
      'obstacle': {
          'layout': [
              [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
              [-1, 0, 0, 0, 0, 0, -1, 0, 0, -1],
              [-1, 0, 0, 0, -1, 0, 0, 0, 0, -1],
              [-1, 0, 0, 0, -1, -1, 0, 0, 0, -1],
              [-1, 0, 0, 0, -1, -1, 0, 0, 0, -1],
              [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
              [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
              [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
              [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
          ],
          'start_state': (2, 2),
          'goal_state': (2, 8)
      },
      'random_goal': {
          'layout': [
              [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
              [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
              [-1, 0, 0, 0, -1, -1, 0, 0, 0, -1],
              [-1, 0, 0, 0, -1, -1, 0, 0, 0, -1],
              [-1, 0, 0, 0, -1, -1, 0, 0, 0, -1],
              [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
              [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
              [-1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
              [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
          ],
          'start_state': (2, 2),
          # 'randomize_goals': True
      },
    }
    return GridWorld(
        discount=discount,
        penalty_for_walls=penalty_for_walls,
        observation_type=observation_type,
        max_episode_length=max_episode_length,
        **tasks_specifications[task])

class GridWorld(dm_env.Environment):

    def __init__(self,
               layout,
               start_state,
               goal_state=None,
               pos_neg_goals=None,
                pos_neg_obs=None,
                rewards=None,
                _pos_neg_goals=True,
                random_rewards=False,
                goal_list=None,
               observation_type=ObservationType.STATE_INDEX,
               discount=1.0,
               penalty_for_walls=0,
               reward_goal=1,
               max_episode_length=None,
               randomize_goals=False) -> None:
        """Build a grid environment.

        Simple gridworld defined by a map layout, a start and a goal state.

        Layout should be a NxN grid, containing:
          * 0: empty
          * -1: wall
          * Any other positive value: value indicates reward; episode will terminate

        Args:
          layout: NxN array of numbers, indicating the layout of the environment.
          start_state: Tuple (y, x) of starting location.
          goal_state: Optional tuple (y, x) of goal location. Will be randomly
            sampled once if None.
          observation_type: Enum observation type to use. One of:
            * ObservationType.STATE_INDEX: int32 index of agent occupied tile.
            * ObservationType.AGENT_ONEHOT: NxN float32 grid, with a 1 where the
              agent is and 0 elsewhere.
            * ObservationType.GRID: NxNx3 float32 grid of feature channels.
              First channel contains walls (1 if wall, 0 otherwise), second the
              agent position (1 if agent, 0 otherwise) and third goal position
              (1 if goal, 0 otherwise)
            * ObservationType.AGENT_GOAL_POS: float32 tuple with
              (agent_y, agent_x, goal_y, goal_x)
          discount: Discounting factor included in all Timesteps.
          penalty_for_walls: Reward added when hitting a wall (should be negative).
          reward_goal: Reward added when finding the goal (should be positive).
          max_episode_length: If set, will terminate an episode after this many
            steps.
          randomize_goals: If true, randomize goal at every episode.
        """
        if observation_type not in ObservationType:
            raise ValueError('observation_type should be a ObservationType instace.')
        self._layout = np.array(layout)
        self._start_state = start_state
        self._state = self._start_state
        self._number_of_states = np.prod(np.shape(self._layout))
        self._discount = discount
        self._penalty_for_walls = penalty_for_walls
        self._reward_goal = reward_goal
        self._observation_type = observation_type
        self._layout_dims = self._layout.shape
        self._max_episode_length = max_episode_length
        self._num_episode_steps = 0
        self._randomize_goals = randomize_goals
        self._goal_state: tp.Tuple[int, int]
        self.goal_list = goal_list
        if goal_state is None:
            # Randomly sample goal_state if not provided
            goal_state = self._sample_goal()
        self._goal_state = goal_state
        # self.goal_state = goal_state
        self.pos_neg_goals = pos_neg_goals
        self.pos_neg_obs = pos_neg_obs
        self.rewards = rewards
        self._pos_neg_goals = _pos_neg_goals
        print('------')
        print(self._pos_neg_goals)
        self._random_rewards = random_rewards

        self.rni = nn.Sequential(
            nn.Linear(self.observation_spec().shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def sample_goal(self):
        if self._randomize_goals:
            return self._sample_goal()
        elif self._pos_neg_goals:
            rand_idx = np.random.randint(len(self.pos_neg_goals))
            return self.pos_neg_goals[rand_idx], self.pos_neg_obs[rand_idx]
        elif self._random_rewards:
            rand_idx = np.random.randint(len(self._rewards))
            rew = self._rewards[rand_idx]
            return Reward(rew)

    def _sample_goal(self):
        """Randomly sample reachable non-starting state."""
        # Sample a new goal
        # n = 0
        # max_tries = 1e5
        # while n < max_tries:
        #     goal_state = tuple(np.random.randint(d) for d in self._layout_dims)
        #     if goal_state != self._state and self._layout[goal_state] == 0:
        #         # Reachable state found!
        #         return goal_state
        #     n += 1
        rand_idx = np.random.randint(len(self.goal_list))
        return self.goal_list[rand_idx]
        # raise ValueError('Failed to sample a goal state.')

    @property
    def number_of_states(self):
        return self._number_of_states

    @property
    def goal_state(self):
        return self._goal_state

    @goal_state.setter
    def goal_state(self, new_goal):
        if self._layout[new_goal] < 0:
            raise ValueError('This is not a valid goal!')
        # Zero out any other goal
        self._layout[self._layout > 0] = 0
        # Setup new goal location
        self._layout[new_goal] = self._reward_goal
        self._goal_state = new_goal

    def set_state(self, x, y):
        self._state = (y, x)

    def observation_spec(self):
        if self._observation_type is ObservationType.AGENT_ONEHOT:
            return specs.Array(
                shape=(self._number_of_states, ),
                dtype=np.float32,
                name='observation_agent_onehot')
        elif self._observation_type is ObservationType.GRID:
            return specs.Array(
                shape=self._layout_dims + (3,),
                dtype=np.float32,
                name='observation_grid')
        elif self._observation_type is ObservationType.AGENT_POS:
            return specs.Array(
                shape=(2,), dtype=np.float32, name='observation_agent_pos')
        elif self._observation_type is ObservationType.AGENT_GOAL_POS:
            return specs.Array(
                shape=(4,), dtype=np.float32, name='observation_agent_goal_pos')
        elif self._observation_type is ObservationType.STATE_INDEX:
            return specs.DiscreteArray(
                self._number_of_states, dtype=int, name='observation_state_index')

    def action_spec(self):
        return specs.DiscreteArray(5, dtype=int, name='action')

    def get_state(self):
        return self._state

    def get_goal_obs(self):
        if self._observation_type is ObservationType.AGENT_ONEHOT:
            obs = np.zeros(self._layout.shape, dtype=np.float32)
            # Place agent
            obs[self._goal_state] = 1
            return obs.flatten()
        elif self._observation_type is ObservationType.AGENT_POS:
            return np.array(self._goal_state, dtype=np.float32) / np.array(self._layout.shape, dtype=np.float32)
        elif self._observation_type is ObservationType.STATE_INDEX:
            y, x = self._goal_state
            return y * self._layout.shape[1] + x

    def get_neg_goal_obs(self):
        if self._observation_type is ObservationType.AGENT_ONEHOT:
            obs = np.zeros(self._layout.shape, dtype=np.float32)
            # Place agent
            obs[self.neg_goal] = 1
            return obs.flatten()
        elif self._observation_type is ObservationType.AGENT_POS:
            return np.array(self.neg_goal, dtype=np.float32) / np.array(self._layout.shape, dtype=np.float32)
        elif self._observation_type is ObservationType.STATE_INDEX:
            y, x = self.neg_goal
            return y * self._layout.shape[1] + x

    def get_obs(self):
        if self._observation_type is ObservationType.AGENT_ONEHOT:
            obs = np.zeros(self._layout.shape, dtype=np.float32)
            # Place agent
            obs[self._state] = 1
            return obs.flatten()
        elif self._observation_type is ObservationType.GRID:
            obs = np.zeros(self._layout.shape + (3,), dtype=np.float32)
            obs[..., 0] = self._layout < 0
            obs[self._state[0], self._state[1], 1] = 1
            obs[self._goal_state[0], self._goal_state[1], 2] = 1
            return obs
        elif self._observation_type is ObservationType.AGENT_POS:
            return np.array(self._state, dtype=np.float32) / np.array(self._layout.shape, dtype=np.float32)
        elif self._observation_type is ObservationType.AGENT_GOAL_POS:
            return np.array(self._state + self._goal_state, dtype=np.float32)
        elif self._observation_type is ObservationType.STATE_INDEX:
            y, x = self._state
            return y * self._layout.shape[1] + x

    def reset(self):
        self._state = self._start_state
        self._num_episode_steps = 0
        if self._randomize_goals:
            # print('Randomizing goals')
            self.goal_state = self.sample_goal()
            self._goal_state = self.goal_state
        elif self._pos_neg_goals:
            self.goal_state, self.neg_goal = self.sample_goal()
        elif self._random_rewards:
            self.reward_func = self.sample_goal()
        return ExtendedTimeStep(
            step_type=dm_env.StepType.FIRST,
            action=0,
            reward=0.0,
            discount=1,
            observation=self.get_obs(),
            obs_hash=hash(str(self.get_obs()))%MAX_SEED)
    
    def reset_at_state(self, state):
        self.set_state(state[0], state[1])
        self._num_episode_steps = 0
        # if self._randomize_goals:
        #     self.goal_state = self._sample_goal()
        return ExtendedTimeStep(
            step_type=dm_env.StepType.FIRST,
            action=0,
            reward=0.0,
            discount=1,
            observation=self.get_obs(),
            obs_hash=hash(str(self.get_obs()))%MAX_SEED)

    def step(self, action):
        y, x = self._state
        if action == 0:  # up
          new_state = (y - 1, x)
        elif action == 1:  # right
          new_state = (y, x + 1)
        elif action == 2:  # down
          new_state = (y + 1, x)
        elif action == 3:  # left
          new_state = (y, x - 1)
        elif action == 4: # stay
          new_state = (y, x)
        else:
          raise ValueError(
              'Invalid action: {} is not 0, 1, 2, 3, or 4.'.format(action))

        new_y, new_x = new_state
        step_type = dm_env.StepType.MID
        if self._layout[new_y, new_x] == -1:  # wall
            reward = self._penalty_for_walls
            discount = self._discount
            new_state = (y, x)
        elif self._layout[new_y, new_x] == 0:  # empty cell
            reward = 0.
            discount = self._discount
        else:  # a goal
            reward = self._layout[new_y, new_x]
            ##  if we choose de terminate
            # discount = 0.
            # new_state = self._start_state
            # step_type = dm_env.StepType.LAST
            discount = self._discount

        self._state = new_state
        self._num_episode_steps += 1
        if (self._max_episode_length is not None and
            self._num_episode_steps >= self._max_episode_length):
          step_type = dm_env.StepType.LAST
        return ExtendedTimeStep(
            step_type=step_type,
            action=action,
            reward=np.float32(reward),
            discount=discount,
            observation=self.get_obs(),
            obs_hash = hash(str(self.get_obs()))%MAX_SEED)

    def get_state_list(self):
        # print('Storing states')
        state_list = []
        for y in range(self._layout.shape[0]):
            for x in range(self._layout.shape[1]):
                if self._layout[y, x] >= 0:
                    state_list.append((x, y))
                # print(x, y, self._layout[y, x])
        # print(state_list)
        # print(len(state_list))
        return state_list
    
    def get_action_list(self):
        return [0, 1, 2, 3, 4]
    
    def get_single_transition(self, action):
        # print(self._layout.shape)
        y, x = self._state
        if action == 0:  # up
          new_state = (y - 1, x)
        elif action == 1:  # right
          new_state = (y, x + 1)
        elif action == 2:  # down
          new_state = (y + 1, x)
        elif action == 3:  # left
          new_state = (y, x - 1)
        elif action == 4: # stay
          new_state = (y, x)
        else:
          raise ValueError(
              'Invalid action: {} is not 0, 1, 2, 3, or 4.'.format(action))

        new_y, new_x = new_state
        step_type = dm_env.StepType.LAST
        if self._layout[new_y, new_x] == -1:  # wall
            reward = self._penalty_for_walls
            discount = self._discount
            new_state = (y, x)
        elif self._layout[new_y, new_x] == 0:  # empty cell
            reward = 0.
            discount = self._discount
        else:  # a goal
            reward = self._layout[new_y, new_x]
            ##  if we choose de terminate
            # discount = 0.
            # new_state = self._start_state
            # step_type = dm_env.StepType.LAST
            discount = self._discount

        self._state = new_state

        return ExtendedTimeStep(
            step_type=step_type,
            action=action,
            reward=np.float32(reward),
            discount=discount,
            observation=self.get_obs(),
            obs_hash = hash(str(self.get_obs()))%MAX_SEED)
    
    def get_single_transition_state(self, state, action):
        # print(self._layout.shape)
        y, x = state
        if action == 0:  # up
          new_state = (y - 1, x)
        elif action == 1:  # right
          new_state = (y, x + 1)
        elif action == 2:  # down
          new_state = (y + 1, x)
        elif action == 3:  # left
          new_state = (y, x - 1)
        elif action == 4: # stay
          new_state = (y, x)
        else:
          raise ValueError(
              'Invalid action: {} is not 0, 1, 2, 3, or 4.'.format(action))

        new_y, new_x = new_state
        if self._layout[new_y, new_x] == -1:  # wall
            new_state = (y, x)
        return new_state

    def get_obs_from_state(self, state, obs_type=None):
        x,y = state
        state = y,x
        if obs_type is None:
            obs_type = self._observation_type
        if obs_type is ObservationType.AGENT_ONEHOT:
            obs = np.zeros(self._layout.shape, dtype=np.float32)
            # Place agent
            obs[state] = 1
            return obs.flatten()
        elif obs_type is ObservationType.GRID:
            obs = np.zeros(self._layout.shape + (3,), dtype=np.float32)
            obs[..., 0] = self._layout < 0
            obs[state[0], state[1], 1] = 1
            obs[self._goal_state[0], self._goal_state[1], 2] = 1
            return obs
        elif obs_type is ObservationType.AGENT_POS:
            return np.array(state, dtype=np.float32) / np.array(self._layout.shape, dtype=np.float32)
        elif obs_type is ObservationType.AGENT_GOAL_POS:
            return np.array(state + self._goal_state, dtype=np.float32)
        elif obs_type is ObservationType.STATE_INDEX:
            y, x = state
            return y * self._layout.shape[1] + x

        return obs
    
    def get_obs_from_state_xy(self, state, obs_type=None):
        if obs_type is None:
            obs_type = self._observation_type
        if obs_type is ObservationType.AGENT_ONEHOT:
            obs = np.zeros(self._layout.shape, dtype=np.float32)
            # Place agent
            obs[state] = 1
            return obs.flatten()
        elif obs_type is ObservationType.GRID:
            obs = np.zeros(self._layout.shape + (3,), dtype=np.float32)
            obs[..., 0] = self._layout < 0
            obs[state[0], state[1], 1] = 1
            obs[self._goal_state[0], self._goal_state[1], 2] = 1
            return obs
        elif obs_type is ObservationType.AGENT_POS:
            return np.array(state, dtype=np.float32) / np.array(self._layout.shape, dtype=np.float32)
        elif obs_type is ObservationType.AGENT_GOAL_POS:
            return np.array(state + self._goal_state, dtype=np.float32)
        elif obs_type is ObservationType.STATE_INDEX:
            y, x = state
            return y * self._layout.shape[1] + x

        return obs
    
    def get_state_from_obs(self, obs):
        if self._observation_type is ObservationType.AGENT_ONEHOT:
            state = np.unravel_index(np.argmax(obs), self._layout.shape)
        elif self._observation_type is ObservationType.GRID:
            state = np.unravel_index(np.argmax(obs[..., 1]), self._layout.shape)
        elif self._observation_type is ObservationType.AGENT_POS:
            # state = np.array(np.multiply(obs, np.array(self._layout.shape)), dtype=np.int32)
            state = obs * np.array(self._layout.shape) + 0.0001
            # print(obs, self._layout.shape, obs*np.array(self._layout.shape), state)
            state = state.astype(np.compat.long) if isinstance(state, np.ndarray) else state.long()
            # print('after: ', state)
            # print(obs, self._layout.shape, obs*np.array(self._layout.shape), state)
        elif self._observation_type is ObservationType.AGENT_GOAL_POS:
            state = np.array(obs[:2] * np.array(self._layout.shape), dtype=np.int32)
        elif self._observation_type is ObservationType.STATE_INDEX:
            y = obs // self._layout.shape[1]
            x = obs % self._layout.shape[1]
            state = (y, x)
        return state    

    def plot_grid(self, add_start=False):
        asbestos = (127 / 255, 140 / 255, 141 / 255, 0.8)
        dodger_blue = (25 / 255, 140 / 255, 255 / 255, 0.8)
        dodger_red = (255 / 255, 25 / 255, 25 / 255, 0.8)
        # carrot = (235 / 255, 137 / 255, 33 / 255, 0.8)
        grid_kwargs = {'color': (220 / 255, 220 / 255, 220 / 255, 0.5)}
        # marker_style = dict(linestyle=':', color=carrot, markersize=20)
        plt.figure(figsize=(4, 4))
        img = np.ones((self._layout.shape[0], self._layout.shape[1], 4))
        wall_y, wall_x = np.where(self._layout <= -1)
        for i in range(len(wall_y)):
            img[wall_y[i], wall_x[i]] = np.array(asbestos)

        plt.imshow(img, interpolation=None)
        # plt.imshow(self._layout <= -1, interpolation='nearest')
        ax = plt.gca()
        ax.grid(0)
        plt.xticks([])
        plt.yticks([])
        # Add start/goal
        if add_start:
            plt.text(
            self._start_state[1],
            self._start_state[0],
            r'$\mathbf{S}$',
            fontsize=16,
            ha='center',
            va='center')
        try:
            for state in self.pos_goals:
                plt.text(
                    state[1],
                    state[0],
                    r'$\mathbf{P}$',
                    fontsize=16,
                    ha='center',
                    va='center',
                    color='green')
            for state in self.neg_goals:
                plt.text(
                    state[1],
                    state[0],
                    r'$\mathbf{N}$',
                    fontsize=16,
                    ha='center',
                    va='center',
                    color='red')
        except:
            pass
        # plt.text(
        #     self._goal_state[1],
        #     self._goal_state[0],
        #     r'$\mathbf{G}$',
        #     fontsize=16,
        #     ha='center',
        #     va='center',
        #     color=dodger_blue)
        # if self._pos_neg_goals:
        #     plt.text(
        #         self.neg_goal[1],
        #         self.neg_goal[0],
        #         r'$\mathbf{N}$',
        #         fontsize=16,
        #         ha='center',
        #         va='center',
        #         color=dodger_red)
        h, w = self._layout.shape
        for y in range(h - 1):
            plt.plot([-0.5, w - 0.5], [y + 0.5, y + 0.5], **grid_kwargs)
        for x in range(w - 1):
            plt.plot([x + 0.5, x + 0.5], [-0.5, h - 0.5], **grid_kwargs)

    def render(self, return_rgb=True):
        carrot = (235 / 255, 137 / 255, 33 / 255, 0.8)
        self.plot_grid(add_start=False)
        # Add the agent location
        plt.text(
            self._state[1],
            self._state[0],
            u'😃',
            fontname='symbola',
            fontsize=18,
            ha='center',
            va='center',
            color=carrot)
        if return_rgb:
            fig = plt.gcf()
            plt.axis('tight')
            plt.subplots_adjust(0, 0, 1, 1, 0, 0)
            fig.canvas.draw()
            data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
            w, h = fig.canvas.get_width_height()
            data = data.reshape((h, w, 3))
            plt.close(fig)
            return data

    def plot_policy(self, policy):
        action_names = [
            r'$\uparrow$', r'$\rightarrow$', r'$\downarrow$', r'$\leftarrow$'
        ]
        self.plot_grid()
        plt.title('Policy Visualization')
        h, w = self._layout.shape
        for y in range(h):
            for x in range(w):
                # if ((y, x) != self._start_state) and ((y, x) != self._goal_state):
                if (y, x) != self._goal_state:
                    action_name = action_names[policy[y, x]]
                    plt.text(x, y, action_name, ha='center', va='center')

    def plot_greedy_policy(self, q):
        greedy_actions = np.argmax(q, axis=2)
        self.plot_policy(greedy_actions)

    def plot_policy_from_list(self, work_dir, obs_list, act_list, diversity=None, title=''):
        print('Plotting policy')
        # print(obs_list, act_list)
        action_names = [
            r'$\uparrow$', r'$\rightarrow$', r'$\downarrow$', r'$\leftarrow$', r'$\cdot$'
        ]
        self.plot_grid()
        # plt.title(title)
        h, w = self._layout.shape
        for i, obs in enumerate(obs_list):
            y, x = self.get_state_from_obs(obs)
            action_name = action_names[act_list[i]]
            plt.text(x, y, action_name, ha='center', va='center', fontsize='large', color='green')
        # if diversity is not None:
        #     plt.title(title + ' - Diversity: {:.2f}'.format(diversity))
        # else:
        #     plt.title(title)

        plt.savefig(str(work_dir)+'/'  + title + '.png', bbox_inches='tight')


    def plot_v_function(self, work_dir, obs_list, v_list, a_list, title=''):
        print('Plotting V function')
        action_names = [
            r'$\uparrow$', r'$\rightarrow$', r'$\downarrow$', r'$\leftarrow$', r'$\cdot$'
        ]
        self.plot_grid()
        # plt.title(title)
        h, w = self._layout.shape
        # VMIN = -1000
        # v_map = np.zeros((h, w)) + VMIN
        min_val = np.min(v_list.cpu().numpy())
        v_map = np.ones((h, w)) * (min_val-1)
        # print(obs_list, a_)
        # print(obs_list.shape)
        for i, obs in enumerate(obs_list):
            # print(obs)
            y, x = self.get_state_from_obs(obs)
            
            # print(y, x)
            action_name = action_names[a_list[i]]
            plt.text(x, y, action_name, ha='center', va='center', fontsize='large', color='green')
            v_map[y, x] = v_list[i]
            # v_min = np.min(v_list)
            # if y==0 or y==h or x==0 or x==h:
            #     v_map[y, x] = -1000
        
        # v_map[2, 2] = -10
        # v_map[self._goal_state] = -20

        plt.imshow(v_map, cmap='magma', interpolation='nearest')
        # plt.colorbar()
        # plt.title(title)
        plt.savefig(str(work_dir)+'/' + title + '.png', bbox_inches='tight')

    def plot_bf_function(self, work_dir, obs_list, v_list, a_list, title=''):
        print('Plotting V function')
        action_names = [
            r'$\uparrow$', r'$\rightarrow$', r'$\downarrow$', r'$\leftarrow$', r'$\cdot$'
        ]
        self.plot_grid()
        # plt.title(title)
        h, w = self._layout.shape
        # VMIN = -1000
        # v_map = np.zeros((h, w)) + VMIN
        min_val = np.min(v_list)
        v_map = np.ones((h, w)) * (-30)
        # print(obs_list, a_)
        # print(obs_list.shape)
        for i, obs in enumerate(obs_list):
            # print(obs)
            # y, x = self.get_state_from_obs(obs)
            # print(obs)
            y, x = obs[0], obs[1]
            
            # print(y, x)
            # print(a_list)
            # print(x, y)
            # print(a_list[(x,y)])
            actions = [action_names[name] for name in a_list[(x,y)]]
            for action_name in actions:
                plt.text(y, x, action_name, ha='center', va='center', fontsize='large', color='green')
            v_map[x, y] = v_list[x, y]
            # v_min = np.min(v_list)
            # if y==0 or y==h or x==0 or x==h:
            #     v_map[y, x] = -1000
        
        # v_map[2, 2] = -10
        # v_map[self._goal_state] = -20
        # print(v_map)
        plt.imshow(v_map, cmap='magma', interpolation='nearest')
        # plt.colorbar()
        # plt.title(title)
        plt.savefig(str(work_dir)+'/' + title + '.png', bbox_inches='tight')

    def setup_eval(self, eval_type, i=None):
        """
        Sets up evaluation for the given eval_type.
        Returns: pos_states, neg_states
        """
        import random
        if eval_type == 'goal':
            # Sample a random goal from the goal list
            goal_list = self.goal_list
            if goal_list is None or len(goal_list) == 0:
                raise ValueError("env must have a non-empty goal_list for 'goal' eval_type")
            goal_state = goal_list[i]
            self.goal_state = goal_state
            pos_states = {tuple(goal_state)}
            print('Goal state:', goal_state)
            neg_states = set()
        elif eval_type == 'pos_neg_goal':
            pos_neg_goals = self.pos_neg_goals
            if pos_neg_goals is None or len(pos_neg_goals) == 0:
                raise ValueError("env must have a non-empty pos_neg_goals for 'pos_neg_goal' eval_type")
            # idx = random.randint(0, len(pos_neg_goals) - 1)
            self.pos_goal, self.neg_goal = self.pos_neg_goals[i], self.pos_neg_obs[i]
            pos_states = {tuple(self.pos_goal)}
            neg_states = {tuple(self.neg_goal)}
        elif eval_type == 'rni2':
            # For rni, just return empty sets (or could be filled externally)
            self.rni = nn.Sequential(
                nn.Linear(2, 128),
                nn.ReLU(),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )
            for param in self.rni.parameters():
                param.requires_grad = False  # Freeze RNI parameters
            self.b1 = nn.Parameter(torch.tensor(0.0))  # Trainable parameter b1
            self.b2 = nn.Parameter(torch.tensor(0.0))  # Trainable parameter b2

            states = self.get_state_list() # state list is a list not tensor
            states_tensor = [self.get_obs_from_state(s, ObservationType.AGENT_POS) for s in states]
            self.train_b(self.rni, self.b1, states_tensor, p=0.2)
            self.train_b(self.rni, self.b2, states_tensor, p=0.85)

            # states with f(s) + b1 >= 0 are pos goals
            # states with f(s) + b2 <= 0 are neg goals
            
            with torch.no_grad():
                f_s = self.rni(torch.tensor(states_tensor, dtype=torch.float32)).squeeze()
                pos_states = {s for s, fs in zip(states, f_s) if fs + self.b1 >= 0}
                neg_states = {s for s, fs in zip(states, f_s) if fs + self.b2 <= 0}

        elif eval_type == 'rni3':
            # For rni, just return empty sets (or could be filled externally)
            self.rni = nn.Sequential(
                nn.Linear(2, 128),
                nn.ReLU(),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )
            for param in self.rni.parameters():
                param.requires_grad = False  # Freeze RNI parameters
            self.b1 = nn.Parameter(torch.tensor(0.0))  # Trainable parameter b1
            self.b2 = nn.Parameter(torch.tensor(0.0))  # Trainable parameter b2

            states = self.get_state_list() # state list is a list not tensor
            states_tensor = [self.get_obs_from_state(s, ObservationType.AGENT_POS) for s in states]
            self.train_b(self.rni, self.b1, states_tensor, p=0.2)
            self.train_b(self.rni, self.b2, states_tensor, p=0.4)

            # states with f(s) + b1 >= 0 are pos goals
            # states with f(s) + b2 <= 0 are neg goals
            
            with torch.no_grad():
                f_s = self.rni(torch.tensor(states_tensor, dtype=torch.float32)).squeeze()
                pos_states = {s for s, fs in zip(states, f_s) if fs + self.b1 >= 0}
                neg_states = {s for s, fs in zip(states, f_s) if (fs + self.b2 >= 0 and fs + self.b1 <= 0)}
        else:
            raise ValueError(f"Unknown eval_type: {eval_type}")
        
        self.reward_array = np.zeros(self._layout.shape)
        for s in pos_states:
            self.reward_array[s[0], s[1]] = 1.0
        for s in neg_states:
            self.reward_array[s[0], s[1]] = -1.0

        # define a reward function based on these states
        def reward_func(states):
            if isinstance(states, torch.Tensor):
                states = [self.get_state_from_obs(state.numpy()) for state in states]
                rewards = []
                for state in states:
                    if tuple(state) in pos_states:
                        rewards.append(1.0)
                    elif tuple(state) in neg_states:
                        rewards.append(-1.0)
                    else:
                        rewards.append(0.0)
                
                return torch.tensor(rewards, dtype=torch.float32) if isinstance(states, list) else rewards
            else:
                if tuple(states) in pos_states:
                    return 1.0
                elif tuple(states) in neg_states:
                    return -1.0
                else:
                    return 0.0
        
        self.pos_goals = pos_states
        self.neg_goals = neg_states
        pos_goals_set = [self.get_obs_from_state_xy(s) for s in pos_states]
        neg_goals_set = [self.get_obs_from_state_xy(s) for s in neg_states]
        return pos_states, pos_goals_set, neg_states, neg_goals_set, self.reward_array, reward_func
    
    def train_b(self, rni, b, states, p):
        def pinball(u, tau):
            return torch.maximum(tau*u, (tau-1)*u)
        
        optimizer = torch.optim.Adam([b], lr=0.01)
        optimizer.zero_grad()
        for _ in range(1000):  # Number of optimization steps
            states = torch.tensor(states, dtype=torch.float32)
            f_s = rni(states).squeeze()  # Shape: (N,)
            y = -f_s
            # f_s_b = f_s + b  # Shape: (N,)
            # sign_f_s_b = torch.sign(f_s_b)  # Shape: (N,)
            # target = (1 - p) - sign_f_s_b  # Shape: (N,)
            # print(f_s_b)
            # print(target)
            u = y - b
            loss = pinball(u, p).mean()
            # loss = torch.mean(f_s_b * target)  # Scalar
            # print('b:', b, 'loss:', loss.item())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        return b
    

    def plot_eval_plot(self, work_dir, title=''):
        print('Plotting eval plot')
        carrot = (235 / 255, 137 / 255, 33 / 255, 0.8)
        dodger_blue = (25 / 255, 140 / 255, 255 / 255, 0.8)
        dodger_red = (255 / 255, 25 / 255, 25 / 255, 0.8)
        self.plot_grid()
        h, w = self._layout.shape
        # for y in range(h):
        #     for x in range(w):
        #         if self.reward_array[y, x] > 0:
        #             plt.text(
        #                 x,
        #                 y,
        #                 r'$\mathbf{G}$',
        #                 fontsize=16,
        #                 ha='center',
        #                 va='center',
        #                 color=dodger_blue)
        #         elif self.reward_array[y, x] < 0:
        #             plt.text(
        #                 x,
        #                 y,
        #                 r'$\mathbf{N}$',
        #                 fontsize=16,
        #                 ha='center',
        #                 va='center',
        #                 color=dodger_red)

        plt.savefig(str(work_dir)+'/' + title + '.png', bbox_inches='tight')

    def create_evaluation_set(self, path):
        # Create a dictionary to hold the evaluation set, 5 random goals, 5 random pos/neg goals and 10 rni goals
        eval_set = {
            'goal': [],
            # 'pos_neg_goal': [],
            'rni2': [],
            'rni3': []
        }
        # 5 random goals
        for i in range(10):
            pos_states, pos_states_set, neg_states, neg_states_set, reward_array, reward_func = self.setup_eval('goal', i)
            eval_set['goal'].append({
                'pos_states': pos_states,
                'neg_states': neg_states,
                'reward_array': reward_array,
                # 'reward_func': reward_func
            })
            self.plot_eval_plot(work_dir='eval_plots2', title=f'eval_goal_{i}')

        # 5 random pos/neg goals
        # for i in range(5):
        #     pos_states, pos_states_set, neg_states, neg_states_set, reward_array, reward_func = self.setup_eval('pos_neg_goal', i)
        #     eval_set['pos_neg_goal'].append({
        #         'pos_states': pos_states,
        #         'neg_states': neg_states,
        #         'reward_array': reward_array,
        #         # 'reward_func': reward_func
        #     })
        #     self.plot_eval_plot(work_dir='eval_plots', title=f'eval_pos_neg_goal_{i}')

        # 10 rni goals
        for i in range(10):
            pos_states, pos_states_set, neg_states, neg_states_set, reward_array, reward_func = self.setup_eval('rni2')
            eval_set['rni2'].append({
                'pos_states': pos_states,
                'neg_states': neg_states,
                'reward_array': reward_array,
                # 'reward_func': reward_func
            })
            self.plot_eval_plot(work_dir='eval_plots2', title=f'eval_rni2_{i}')

        for i in range(10):
            pos_states, pos_states_set, neg_states, neg_states_set, reward_array, reward_func = self.setup_eval('rni3')
            eval_set['rni3'].append({
                'pos_states': pos_states,
                'neg_states': neg_states,
                'reward_array': reward_array,
                # 'reward_func': reward_func
            })
            self.plot_eval_plot(work_dir='eval_plots2', title=f'eval_rni3_{i}')

        # Save the evaluation set to the given path
        with open(path, 'wb') as f:
            pickle.dump(eval_set, f)

    def load_evaluation_set(self, path):
        with open(path, 'rb') as f:
            eval_set = pickle.load(f)
        self.eval_set = eval_set
    
    def sample_eval_task(self, eval_type, index):
        if eval_type not in ['goal', 'pos_neg_goal', 'rni']:
            raise ValueError(f"Unknown eval_type: {eval_type}")
        if index < 0 or index >= len(self.eval_set[eval_type]):
            raise IndexError(f"Invalid index: {index}")
        curr_eval = self.eval_set[eval_type][index]
        pos_states = curr_eval['pos_states']
        neg_states = curr_eval['neg_states']
        reward_array = curr_eval['reward_array']
        # reward_func = curr_eval['reward_func']
        # define a reward function based on these states
        def reward_func(states):
            if isinstance(states, torch.Tensor):
                states = [self.get_state_from_obs(state.numpy()) for state in states]
                rewards = []
                for state in states:
                    if tuple(state) in pos_states:
                        rewards.append(1.0)
                    elif tuple(state) in neg_states:
                        rewards.append(-1.0)
                    else:
                        rewards.append(0.0)
                
                return torch.tensor(rewards, dtype=torch.float32) if isinstance(states, list) else rewards
            else:
                if tuple(states) in pos_states:
                    return 1.0
                elif tuple(states) in neg_states:
                    return -1.0
                else:
                    return 0.0
        self.pos_goals = pos_states
        self.neg_goals = neg_states
        pos_goals_set = [self.get_obs_from_state_xy(s) for s in pos_states]
        neg_goals_set = [self.get_obs_from_state_xy(s) for s in neg_states]
        self.reward_array = reward_array
        self.reward_func = reward_func
        return pos_goals_set, neg_goals_set, reward_array, reward_func


if __name__ == '__main__':
    env = build_gridworld_task('fourroom', observation_type=ObservationType.AGENT_ONEHOT)
    env.reset()
    # # env.render()
    # pos, neg, arr, reward_func = env.setup_eval('rni')
    # env.plot_eval_plot(work_dir='.', title='eval_plot')
    # env.create_evaluation_set('eval_set2.pkl')
    # env.load_evaluation_set('eval_set.pkl')
    # pos, neg, arr, reward_func = env.sample_eval_task('rni', 0)
    # print('pos:', pos)
    # print('neg:', neg)
    # print('arr:', arr)
    # env.plot_eval_plot(work_dir='.', title='sampled_eval_plot')
    
    # load pickle file eval_set.pkl and print the keys
    # new_eval_set = {}
    # with open('eval_set2.pkl', 'rb') as f:
    #     eval_set = pickle.load(f)
    #     new_eval_set["goal"] = eval_set["goal"]
    
    # with open('eval_set.pkl', 'rb') as f:
    #     eval_set = pickle.load(f)
    #     new_eval_set["rni"] = eval_set["rni"]

    # with open('new_eval_set.pkl', 'wb') as f:
    #     pickle.dump(new_eval_set, f)

