import random

import gym
import gym_minigrid.minigrid as minigrid
import networkx as nx
from networkx import grid_graph
import numpy as np

from . import multigrid
from . import register


class BinaryStochasticChoice(multigrid.MultiGridEnv):
  """
  Grid world in which red and blue objects are placed in either 
  random or fixed positions. Depending on the seed, the red or blue
  object acts as the true goal, which the agent must pick up to
  receive a potentially sparse, stochastic reward. Picking up the wrong
  object leads to 0 reward. The episode terminates upon picking up 
  either object.
  """

  def __init__(self,
               p=0.5,
               rewards=[0.75, 1.0],
               reward_spreads=[10, 0],
               # p=0.9,
               # rewards=[2, 40],
               # reward_spreads=[2,0],
               size=15,
               use_walls=False, 
               agent_view_size=5, 
               max_steps=250,
               n_clutter=15,
               n_agents=1,
               obl_correction=False,
               unexclusive_goals=False,
               fixed_environment=False,
               seed=0):
    """Initializes environment in which adversary places goal, agent, obstacles.

    Args:
      p: The probability the goal is the red object.
      rewards: The reward received for picking up each object when it is the goal.
      reward_var: The reward variance for picking up each object when it is the goal.
      size: The number of tiles across one side of the grid; i.e. make a
        size x size grid.
      agent_view_size: The number of tiles in one side of the agent's partially
        observed view of the grid.
      max_steps: The maximum number of steps that can be taken before the
        episode terminates.
      n_clutter: The maximum number of obstacles.
    """
    # print(f'Making env with params: {p}, {rewards}, {reward_spreads}', flush=True)
    self.agent_start_dir = 0
    self.agent_start_pos = None
    self.red_pos = None
    self.blue_pos = None
    self.goal_pos = None
    self.size = size

    self.use_walls = use_walls
    self.n_clutter = n_clutter

    self.p = p
    self.rewards = rewards
    self.reward_spreads = reward_spreads
    self.obl_correction = obl_correction
    if self.obl_correction:
      self.unseeded_np_random,_ = gym.utils.seeding.np_random()

    self.unexclusive_goals = unexclusive_goals

    self.wall_locs = set()

    self.world = multigrid.World

    super().__init__(
        n_agents=1,
        minigrid_mode=True,
        grid_size=size,
        max_steps=max_steps,
        agent_view_size=agent_view_size,
        see_through_walls=False,  # Set this to True for maximum speed
        competitive=True,
        fixed_environment=fixed_environment,
        seed=seed
    )

    self.action_space = gym.spaces.Discrete(4)

    # Metrics
    self.reset_metrics()

    # NetworkX graph used for computing shortest path
    self.graph = grid_graph(dim=[size-2, size-2])

  def seed(self, seed):
    super().seed(seed)
    self.level_seed = seed
    obs = self.reset()

    # print(f'Seed to {seed}, goal_pos is {self.goal_pos}, rewards are r: {self.red_reward}, b: {self.blue_reward}', flush=True)
    return obs

  def _gen_grid(self, width, height):
    """Grid is initially empty, because adversary will create it."""
    # Create an empty grid
    self.grid = multigrid.Grid(width, height)

    # Create surrounding walls
    self.grid.wall_rect(0, 0, width, height)

  def step(self, actions):
    obs, r, done, info = super().step(actions)

    if done:
      info['target'] = self.target

    return obs, r, done, info

  def _handle_pickup(self, agent_id, reward, fwd_pos, fwd_cell):
    self.done[agent_id] = True

    fwd_tuple = tuple(fwd_pos)

    if fwd_tuple == tuple(self.red_pos):
      self.target = 'red'
    elif fwd_tuple == tuple(self.blue_pos):
      self.target = 'blue'
    else:
      self.target = None

  def reset_metrics(self):
    self.distance_to_goal = -1
    self.n_clutter_placed = 0
    self.passable = -1
    self.shortest_path_length = 0

  def reset(self):
    """Use domain randomization to create the environment."""
    self.graph = grid_graph(dim=[self.width-2, self.height-2])
    self.wall_locs.clear()

    self.step_count = 0

    # Current position and direction of the agent
    self.reset_agent_status()

    self.agent_start_pos = None
    self.goal_pos = None

    # Extra metrics
    self.reset_metrics()

    # Create empty grid
    self._gen_grid(self.width, self.height)

    # Place agents at bottom of level
    # self.agent_start_dir = self._rand_int(0, 4)
    self.agent_start_dir = 3 # Face north
    self.agent_start_pos = self.place_agent_at_pos(
      0, np.array([int(self.width/2.), self.height-2]), rand_dir=False)

    # Place red and blue objects  
    self.red_pos = np.array([1,1])
    self.blue_pos = np.array([self.width-2, 1])

    self.empty_shortest_path_red = \
      nx.shortest_path(self.graph, tuple(self.agent_start_pos-1), tuple(self.red_pos-1))
    self.empty_shortest_path_blue = \
      nx.shortest_path(self.graph, tuple(self.agent_start_pos-1), tuple(self.blue_pos-1))
    
    # Place walls
    if self.use_walls:
      # Randomly place walls
      key_pos = (tuple(self.red_pos), tuple(self.blue_pos), tuple(self.agent_start_pos))
      for _ in range(int(self.n_clutter)):
        wall_pos_x = self._rand_int(1,self.width-1)
        wall_pos_y = self._rand_int(1,self.height-1)
        wall_pos = (wall_pos_x, wall_pos_y)
        if wall_pos in self.wall_locs or wall_pos in key_pos:
          continue

        self.place_obj(minigrid.Wall(), top=wall_pos, size=(1,1), max_tries=100)
        self.wall_locs.add(wall_pos)
  
    # Randomly choose goal
    self.red_reward, self.blue_reward = 0,0

    if self.obl_correction:
      red_outcome = self.unseeded_np_random.rand() # Maintain uniform prior over levels
    else:
      red_outcome = self.np_random.rand()

    if red_outcome < self.p:
      self.goal_pos = self.red_pos
      self.red_reward = self.np_random.normal(self.rewards[0], self.reward_spreads[0])
      if self.unexclusive_goals:
        self.blue_reward = self.np_random.normal(self.rewards[1], self.reward_spreads[1])
      # self.red_reward = self.np_random.rand()*(2*self.reward_spreads[0]) + self.rewards[0] - self.reward_spreads[0]
      # print('Red is goal')
    else:
      self.goal_pos = self.blue_pos
      self.blue_reward = self.np_random.normal(self.rewards[1], self.reward_spreads[1])
      if self.unexclusive_goals:
        self.red_reward = self.np_random.normal(self.rewards[0], self.reward_spreads[0])
      # self.blue_reward = self.np_random.rand()*(2*self.reward_spreads[1]) + self.rewards[1] - self.reward_spreads[1]
      # print('Blue is goal')

    self.put_obj(
        multigrid.Ball(self.world, 
          index=minigrid.COLOR_TO_IDX['red'], 
          reward=self.red_reward), 
        *self.red_pos)

    self.put_obj(
        multigrid.Ball(self.world, 
          index=minigrid.COLOR_TO_IDX['blue'], 
          reward=self.blue_reward), 
        *self.blue_pos)

    self.target = None

    for (x, y) in self.wall_locs:
      self.graph.remove_node((x-1,y-1))

    self.compute_shortest_path()

    self.n_clutter_placed = len(self.wall_locs)

    # Return first observation
    obs = self.gen_obs()

    return obs

  def reset_agent_status(self):
    """Reset the agent's position, direction, done, and carrying status."""
    self.agent_pos = [None] * self.n_agents
    self.agent_dir = [self.agent_start_dir] * self.n_agents
    self.done = [False] * self.n_agents
    self.carrying = [None] * self.n_agents

  def reset_agent(self):
    """Resets the agent's start position, but leaves goal and walls."""
    # Remove the previous agents from the world
    for a in range(self.n_agents):
      if self.agent_pos[a] is not None:
        self.grid.set(self.agent_pos[a][0], self.agent_pos[a][1], None)

    # Current position and direction of the agent
    self.reset_agent_status()

    if self.agent_start_pos is None:
      raise ValueError('Trying to place agent at empty start position.')
    else:
      self.place_agent_at_pos(0, self.agent_start_pos, rand_dir=False)

    for a in range(self.n_agents):
      assert self.agent_pos[a] is not None
      assert self.agent_dir[a] is not None

      # Check that the agent doesn't overlap with an object
      start_cell = self.grid.get(*self.agent_pos[a])
      if not (start_cell.type == 'agent' or
              start_cell is None or start_cell.can_overlap()):
        raise ValueError('Wrong object in agent start position.')

    # Step count since episode start
    self.step_count = 0

    # Return first observation
    obs = self.gen_obs()

    return obs

  # For introducing obstacles
  def remove_wall(self, x, y):
    if (x, y) in self.wall_locs:
      self.wall_locs.remove((x, y))
    obj = self.grid.get(x, y)
    if obj is not None and obj.type == 'wall':
      self.grid.set(x, y, None)

  def compute_shortest_path(self):
    if self.agent_start_pos is None or self.goal_pos is None:
      return

    self.distance_to_goal = abs(
        self.goal_pos[0] - self.agent_start_pos[0]) + abs(
            self.goal_pos[1] - self.agent_start_pos[1])

    # Check if there is a path between agent start position and goal. Remember
    # to subtract 1 due to outside walls existing in the Grid, but not in the
    # networkx graph.
    self.passable = nx.has_path(
        self.graph,
        source=(self.agent_start_pos[0] - 1, self.agent_start_pos[1] - 1),
        target=(self.goal_pos[0]-1, self.goal_pos[1]-1))

    if tuple(self.goal_pos) == tuple(self.red_pos):
      self.passable_red = self.passable
      self.passable_blue = nx.has_path(
          self.graph,
          source=(self.agent_start_pos[0] - 1, self.agent_start_pos[1] - 1),
          target=(self.blue_pos[0]-1, self.blue_pos[1]-1))
    else:
      self.passable_blue = self.passable
      self.passable_red = nx.has_path(
          self.graph,
          source=(self.agent_start_pos[0] - 1, self.agent_start_pos[1] - 1),
          target=(self.red_pos[0]-1, self.red_pos[1]-1))

    if not self.passable_red:
      for (x,y) in self.empty_shortest_path_red[1:-1]:
        self.remove_wall(x+1,y+1)

    if not self.passable_blue:
      for (x,y) in self.empty_shortest_path_blue[1:-1]:
        self.remove_wall(x+1,y+1)

    self.graph = grid_graph(dim=[self.size-2, self.size-2])
    for (x, y) in self.wall_locs:
      self.graph.remove_node((x-1,y-1))

    self.passable, self.passable_red, self.passable_blue = True,True,True

    if self.passable:
      # Compute shortest path
      self.shortest_path_length = nx.shortest_path_length(
          self.graph,
          source=(self.agent_start_pos[0]-1, self.agent_start_pos[1]-1),
          target=(self.goal_pos[0]-1, self.goal_pos[1]-1))
    else:
      # Impassable environments have a shortest path length 1 longer than
      # longest possible path
      self.shortest_path_length = 0

  def goal_color(self):
    goal_pos_tuple = tuple(self.goal_pos)
    if goal_pos_tuple == tuple(self.red_pos):
      return 'red'
    else:
      return 'blue'
      

class MiniBinaryChoice(BinaryStochasticChoice):
  def __init__(
      self, 
      seed=0, 
      fixed_environment=False, 
      p=0.5, 
      rewards=(0.75,1.0), 
      reward_spreads=(10,0), 
      use_walls=False,
      obl_correction=False):
    super().__init__(
      seed=seed,
      fixed_environment=fixed_environment,
      p=p, 
      rewards=rewards, 
      reward_spreads=reward_spreads,
      use_walls=use_walls, 
      n_clutter=15, 
      obl_correction=obl_correction,
      size=7, 
      agent_view_size=5, 
      max_steps=50)


class MiniBinaryChoiceUnexclusiveGoals(BinaryStochasticChoice):
  def __init__(
      self, 
      seed=0, 
      fixed_environment=False, 
      p=0.5, 
      rewards=(0.75,1.0), 
      reward_spreads=(10,0), 
      use_walls=False,
      obl_correction=False):
    super().__init__(
      seed=seed,
      fixed_environment=fixed_environment,
      p=p, 
      rewards=rewards, 
      reward_spreads=reward_spreads,
      use_walls=use_walls, 
      n_clutter=15, 
      obl_correction=obl_correction,
      unexclusive_goals=True,
      size=7, 
      agent_view_size=5, 
      max_steps=50)


class BinaryChoice9x9(BinaryStochasticChoice):
  def __init__(self, seed=0, fixed_environment=False, 
      p=0.5, rewards=(0.75,1.0), reward_spreads=(10,0), use_walls=False,
      obl_correction=False):
    super().__init__(
      seed=seed,
      fixed_environment=fixed_environment,
      p=p, 
      rewards=rewards, 
      reward_spreads=reward_spreads,
      use_walls=use_walls, 
      n_clutter=20, 
      obl_correction=obl_correction,
      size=9, 
      agent_view_size=5, 
      max_steps=50)

class BinaryChoiceUnexclusiveGoals9x9(BinaryStochasticChoice):
  def __init__(self, seed=0, fixed_environment=False, 
      p=0.5, rewards=(0.75,1.0), reward_spreads=(10,0), use_walls=False,
      obl_correction=False):
    super().__init__(
      seed=seed,
      fixed_environment=fixed_environment,
      p=p, 
      rewards=rewards, 
      reward_spreads=reward_spreads,
      use_walls=use_walls, 
      n_clutter=20, 
      obl_correction=obl_correction,
      unexclusive_goals=True,
      size=9, 
      agent_view_size=5, 
      max_steps=50)


if hasattr(__loader__, 'name'):
  module_path = __loader__.name
elif hasattr(__loader__, 'fullname'):
  module_path = __loader__.fullname


register.register(
    env_id='MultiGrid-MiniBinaryChoice-v0',
    entry_point=module_path + ':MiniBinaryChoice',
    max_episode_steps=50,
)

register.register(
    env_id='MultiGrid-MiniBinaryChoiceUnexclusiveGoals-v0',
    entry_point=module_path + ':MiniBinaryChoiceUnexclusiveGoals',
    max_episode_steps=50,
)

register.register(
    env_id='MultiGrid-BinaryChoice9x9-v0',
    entry_point=module_path + ':BinaryChoice9x9',
    max_episode_steps=50,
)

register.register(
    env_id='MultiGrid-BinaryChoiceUnexclusiveGoals9x9-v0',
    entry_point=module_path + ':BinaryChoiceUnexclusiveGoals9x9',
    max_episode_steps=50,
)
