"""Key-Lock environment with yellow and blue keys/doors."""

from __future__ import annotations

import time
import gymnasium as gym
import numpy as np
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Wall, Door, Key, Goal
from minigrid.minigrid_env import MiniGridEnv


class KeyLockEnv(MiniGridEnv):
    
    def __init__(
        self,
        size: int = 15,
        agent_start_pos: tuple[int, int] = (1, 1),
        agent_start_dir: int = 0,
        yellow_key_pos=(12, 3),
        yellow_door_pos=(3, 8),
        blue_key_pos=(12, 12),
        blue_door_pos=(9, 3),
        goal_pos=(3, 12),
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.size = size
        self.yellow_key_pos = yellow_key_pos
        self.yellow_door_pos = yellow_door_pos
        self.blue_key_pos = blue_key_pos
        self.blue_door_pos = blue_door_pos
        self.goal_pos = goal_pos
        self.yellow_key_on_map = 1
        self.blue_key_on_map = 1
        self.key_pos = yellow_key_pos
        self.door_pos = yellow_door_pos
        
        mission_space = MissionSpace(mission_func=self._gen_mission)
        
        if max_steps is None:
            max_steps = 4 * size**2
        
        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )
        
        self.action_space = gym.spaces.Discrete(6)
        self.observation_space = gym.spaces.Tuple((
            gym.spaces.Discrete(size),  # x
            gym.spaces.Discrete(size),  # y
            gym.spaces.Discrete(4),     # direction
            gym.spaces.Discrete(2),     # yellow_door_open
            gym.spaces.Discrete(2),     # blue_door_open
            gym.spaces.Discrete(2),     # yellow_key_on_map
            gym.spaces.Discrete(2),     # blue_key_on_map
        ))
    
    @staticmethod
    def _gen_mission():
        return "pick up yellow and blue keys, open both doors, and reach the goal"
    
    def reset(self, *, seed=None, options=None):
        self.yellow_key_on_map = 1
        self.blue_key_on_map = 1
        obs, info = super().reset(seed=seed, options=options)
        self.carrying = None
        self._place_keys_doors_goal()
        
        obs = self._get_obs()
        return obs, info
    
    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)
        for i in range(9, 15):
            self.grid.set(i, 5, Wall())
        for j in range(0, 3):
            self.grid.set(9, j, Wall())
        self.grid.set(9, 4, Wall())
        
        for i in range(9, 15):
            self.grid.set(i, 9, Wall())
        for j in range(10, 11):
            self.grid.set(9, j, Wall())
        self.grid.set(9, 13, Wall())
        
        for i in range(0, 3):
            self.grid.set(i, 9, Wall())
        for i in range(4, 6):
            self.grid.set(i, 9, Wall())
        for j in range(9, 15):
            self.grid.set(6, j, Wall())
        for j in range(7, 9):
            self.grid.set(2, j, Wall())
        for j in range(7, 9):
            self.grid.set(4, j, Wall())
        
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()
        self._place_keys_doors_goal()
        
        self.mission = "pick up yellow and blue keys, open both doors, and reach the goal"
    
    def _place_keys_doors_goal(self):
        key_x, key_y = self.yellow_key_pos
        key_cell = self.grid.get(key_x, key_y)
        if self.yellow_key_on_map:
            if key_cell is None or not isinstance(key_cell, Wall):
                yellow_key = Key('yellow')
                self.grid.set(key_x, key_y, yellow_key)
        else:
            if key_cell is not None:
                self.grid.set(key_x, key_y, None)
        
        # Place yellow locked door
        door_x, door_y = self.yellow_door_pos
        door_cell = self.grid.get(door_x, door_y)
        if door_cell is None or not isinstance(door_cell, Wall):
            yellow_door = Door('yellow', is_locked=True)
            self.grid.set(door_x, door_y, yellow_door)
        elif isinstance(door_cell, Door) and door_cell.color == 'yellow':
            door_cell.is_locked = True
        
        # Place blue key if flagged on map
        blue_key_x, blue_key_y = self.blue_key_pos
        blue_key_cell = self.grid.get(blue_key_x, blue_key_y)
        if self.blue_key_on_map:
            if blue_key_cell is None or not isinstance(blue_key_cell, Wall):
                blue_key = Key('blue')
                self.grid.set(blue_key_x, blue_key_y, blue_key)
        else:
            if blue_key_cell is not None:
                self.grid.set(blue_key_x, blue_key_y, None)
        
        # Place blue locked door
        blue_door_x, blue_door_y = self.blue_door_pos
        blue_door_cell = self.grid.get(blue_door_x, blue_door_y)
        if blue_door_cell is None or not isinstance(blue_door_cell, Wall):
            blue_door = Door('blue', is_locked=True)
            self.grid.set(blue_door_x, blue_door_y, blue_door)
        elif isinstance(blue_door_cell, Door) and blue_door_cell.color == 'blue':
            blue_door_cell.is_locked = True
        
        # Place goal
        goal_x, goal_y = self.goal_pos
        goal_cell = self.grid.get(goal_x, goal_y)
        if goal_cell is None or not isinstance(goal_cell, Wall):
            goal = Goal()
            self.grid.set(goal_x, goal_y, goal)
    
    def _get_obs(self):
        x, y = self.agent_pos
        dir = self.agent_dir
        yellow_door_cell = self.grid.get(*self.yellow_door_pos)
        yellow_door_open = 1 if (yellow_door_cell is not None and isinstance(yellow_door_cell, Door) and yellow_door_cell.is_open) else 0
        blue_door_cell = self.grid.get(*self.blue_door_pos)
        blue_door_open = 1 if (blue_door_cell is not None and isinstance(blue_door_cell, Door) and blue_door_cell.is_open) else 0
        yellow_key_cell = self.grid.get(*self.yellow_key_pos)
        yellow_key_on_map = 1 if (yellow_key_cell is not None and isinstance(yellow_key_cell, Key) and yellow_key_cell.color == 'yellow') else 0
        self.yellow_key_on_map = yellow_key_on_map
        blue_key_cell = self.grid.get(*self.blue_key_pos)
        blue_key_on_map = 1 if (blue_key_cell is not None and isinstance(blue_key_cell, Key) and blue_key_cell.color == 'blue') else 0
        self.blue_key_on_map = blue_key_on_map
        
        return (x, y, dir,
                yellow_door_open, blue_door_open,
                yellow_key_on_map, blue_key_on_map)
    
    def _get_front_pos(self):
        x, y = self.agent_pos
        if self.agent_dir == 0:
            return (x + 1, y)
        elif self.agent_dir == 1:
            return (x, y + 1)
        elif self.agent_dir == 2:
            return (x - 1, y)
        else:
            return (x, y - 1)
    
    def _can_pickup_key(self, key_color='yellow'):
        if self.carrying is not None:
            return False
        front_x, front_y = self._get_front_pos()
        if key_color == 'yellow':
            key_pos = self.yellow_key_pos
            key_on_map = self.yellow_key_on_map
        elif key_color == 'blue':
            key_pos = self.blue_key_pos
            key_on_map = self.blue_key_on_map
        else:
            return False
        if (front_x, front_y) == key_pos and key_on_map == 1:
            key_cell = self.grid.get(front_x, front_y)
            if key_cell is not None and isinstance(key_cell, Key) and key_cell.color == key_color:
                return True
        return False
    
    def _can_toggle_door(self, door_color='yellow'):
        front_x, front_y = self._get_front_pos()
        if door_color == 'yellow':
            door_pos, key_on_map = self.yellow_door_pos, self.yellow_key_on_map
        elif door_color == 'blue':
            door_pos, key_on_map = self.blue_door_pos, self.blue_key_on_map
        else:
            return False
        if (front_x, front_y) == door_pos:
            door_cell = self.grid.get(front_x, front_y)
            if door_cell is not None and isinstance(door_cell, Door) and door_cell.color == door_color:
                return key_on_map == 0
        return False
    
    def step(self, action):
        if action not in [0, 1, 2, 3, 4, 5]:
            raise ValueError(f"Invalid action: {action}. Must be 0-5.")
        at_goal_before = (self.agent_pos == self.goal_pos)
        if action == 4:
            front_x, front_y = self._get_front_pos()
            can_pickup_yellow = self._can_pickup_key('yellow')
            can_pickup_blue = self._can_pickup_key('blue')
            
            if can_pickup_yellow or can_pickup_blue:
                obs, reward, terminated, truncated, info = super().step(3)
                if can_pickup_yellow and self.carrying is not None and isinstance(self.carrying, Key) and self.carrying.color == 'yellow':
                    self.yellow_key_on_map = 0
                    self.grid.set(*self.yellow_key_pos, None)
                if can_pickup_blue and self.carrying is not None and isinstance(self.carrying, Key) and self.carrying.color == 'blue':
                    self.blue_key_on_map = 0
                    self.grid.set(*self.blue_key_pos, None)
                obs = self._get_obs()
            else:
                reward = -0.01
                terminated = False
                truncated = False
                info = {}
                obs = self._get_obs()
                
                at_goal_after = (self.agent_pos == self.goal_pos)
                if at_goal_after and not at_goal_before:
                    reward = 1.0
                    terminated = True
                return obs, reward, terminated, truncated, info
                at_goal_after = (self.agent_pos == self.goal_pos)
                if at_goal_after and not at_goal_before:
                    reward = 1.0
                    terminated = True
                return obs, reward, terminated, truncated, info
        
        # Handle toggle action
        elif action == 5:
            # Check which door is in front (yellow or blue)
            front_x, front_y = self._get_front_pos()
            can_toggle_yellow = self._can_toggle_door('yellow')
            can_toggle_blue = self._can_toggle_door('blue')

            if can_toggle_yellow or can_toggle_blue:
                # Store carrying state before toggle (to check if key was consumed)
                carrying_before = self.carrying
                carrying_color_before = carrying_before.color if carrying_before is not None and isinstance(carrying_before, Key) else None
                
                # Execute MiniGrid's toggle action (action 5)
                # MiniGrid will consume the key if door was locked and is now opened
                obs, reward, terminated, truncated, info = super().step(5)
                
                # After toggle, check if key was consumed (carrying changed from key to None)
                # and update key_on_map flags accordingly
                carrying_after = self.carrying
                
                # Check yellow door
                yellow_door_cell = self.grid.get(*self.yellow_door_pos)
                if yellow_door_cell is not None and isinstance(yellow_door_cell, Door):
                    if yellow_door_cell.is_open:
                        # Yellow door is now open
                        if carrying_color_before == 'yellow' and carrying_after is None:
                            # Yellow key was consumed (was carrying, now not carrying)
                            self.yellow_key_on_map = 0
                            # Ensure key is removed from grid
                            self.grid.set(*self.yellow_key_pos, None)
                        # Also ensure carrying is None (in case MiniGrid didn't clear it)
                        if carrying_color_before == 'yellow':
                            self.carrying = None
                
                # Check blue door
                blue_door_cell = self.grid.get(*self.blue_door_pos)
                if blue_door_cell is not None and isinstance(blue_door_cell, Door):
                    if blue_door_cell.is_open:
                        # Blue door is now open
                        if carrying_color_before == 'blue' and carrying_after is None:
                            # Blue key was consumed (was carrying, now not carrying)
                            self.blue_key_on_map = 0
                            # Ensure key is removed from grid
                            self.grid.set(*self.blue_key_pos, None)
                        # Also ensure carrying is None (in case MiniGrid didn't clear it)
                        if carrying_color_before == 'blue':
                            self.carrying = None
                
                # Update observation to reflect current state
                obs = self._get_obs()
            else:
                # Cannot toggle (not adjacent, not facing any door, or no matching key)
                reward = -0.01
                terminated = False
                truncated = False
                info = {}
                obs = self._get_obs()
                # Check goal even for no-op
                at_goal_after = (self.agent_pos == self.goal_pos)
                if at_goal_after and not at_goal_before:
                    reward = 1.0
                    terminated = True
                return obs, reward, terminated, truncated, info
        
        else:
            # Movement actions (0-3): up, down, left, right
            # Direction codes: 0=east, 1=south, 2=west, 3=north
            current_dir = self.agent_dir
            
            # Map global action to target facing direction
            if action == 0:        # up
                target_dir = 3
            elif action == 1:      # down
                target_dir = 1
            elif action == 2:      # left
                target_dir = 2
            elif action == 3:      # right
                target_dir = 0
            else:
                raise ValueError(f"Invalid action index: {action}")
            
            # Set direction directly (no step_count increment) to match transition matrix
            self.agent_dir = target_dir
            current_dir = target_dir
            
            # Compute forward cell and block if key on map or closed door/wall
            fx, fy = self._get_front_pos()
            cell = self.grid.get(fx, fy)
            blocked = False
            if cell is None:
                blocked = False
            elif isinstance(cell, Wall):
                blocked = True
            elif isinstance(cell, Door):
                # Block if door is locked OR if door is closed (not open)
                if cell.is_locked or not cell.is_open:
                    blocked = True
            elif isinstance(cell, Key):
                blocked = True
            if blocked:
                # No movement; mimic no-op with step penalty
                reward = -0.01
                terminated = False
                truncated = False
                info = {}
                obs = self._get_obs()
            else:
                # Move forward once
                obs, reward, terminated, truncated, info = super().step(2)
        
        # Check if agent reached the goal
        at_goal_after = (self.agent_pos == self.goal_pos)
        
        # Reward function:
        # - Base reward: -0.01 per time step
        # - Bonus: +1 if reached goal
        reward = -0.01
        if at_goal_after and not at_goal_before:
            reward = 1.0  # +1 for reaching the goal
            terminated = True
        
        # Return simplified observation
        obs = self._get_obs()
        
        return obs, reward, terminated, truncated, info


