from __future__ import annotations

from minigrid.core.constants import COLOR_NAMES
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Door, Goal, Key, Wall, Ball, Lava
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv
import gymnasium as gym
# for GymMDP
from simple_rl.tasks.gym.GymStateClass import GymState
from simple_rl.mdp.MDPClass import MDP
from collections import defaultdict
# wrappers
from minigrid.wrappers import FullyObsWrapper
from agents.FullyObsObjWrapperClass import FullyObsObjWrapper
import matplotlib.pyplot as plt

plt.rcParams["axes.grid"] = False

class TestEnvEasy(MiniGridEnv):
    def __init__(
        self,
        size=None, # Don't change this since it will change the size of the env but not the loc of objs and walls. 
        width=9,
        height=5,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 8 * size**2 if size else 8 * (width * height)**2

        super().__init__(
            mission_space=mission_space,
            grid_size=size, 
            # Set this to True for maximum speed
            see_through_walls=True,
            max_steps=max_steps,
            width=width,
            height=height,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "grand mission"

    def _gen_grid(self, width=9, height=5):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Generate verical separation wall
        for i in range(0, height):
            self.grid.set(4, i, Wall())
        
        # Place the door and key and ball
        # self.grid.set(1, 4, Door(COLOR_NAMES[0], is_locked=True))
        self.grid.set(4, 2, Door(COLOR_NAMES[1], is_locked=True))
        # self.grid.set(6, 4, Door(COLOR_NAMES[2], is_locked=True))

        self.grid.set(3, 2, Ball(COLOR_NAMES[0]))  
        self.grid.set(6, 3, Ball(COLOR_NAMES[1])) 
        # self.grid.set(1, 3, Ball(COLOR_NAMES[1]))
        # self.grid.set(3, 6, Ball(COLOR_NAMES[2]))
        # self.grid.set(3, 7, Ball(COLOR_NAMES[3]))


        # self.grid.set(2, 2, Key(COLOR_NAMES[0]))
        # self.grid.set(2, 2, Lava())
        self.grid.set(3, 3, Key(COLOR_NAMES[1]))
        self.grid.set(7, 3, Key(COLOR_NAMES[2]))


        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "grand mission"

class TestEnvMedium(MiniGridEnv):
    def __init__(
        self,
        size=9, # Don't change this since it will change the size of the env but not the loc of objs and walls. 
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 16 * size**2

        super().__init__(
            mission_space=mission_space,
            grid_size=size, 
            # Set this to True for maximum speed
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "grand mission"

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Generate verical separation wall
        for i in range(0, height):
            self.grid.set(4, i, Wall())
            # self.grid.set(10, i, Wall())
            self.grid.set(i, 4, Wall())
            # self.grid.set(i, 10, Wall())
        
        # Place the door and key and ball
        # self.grid.set(1, 4, Door(COLOR_NAMES[0], is_locked=True))
        self.grid.set(4, 2, Door(COLOR_NAMES[1], is_locked=True))
        self.grid.set(6, 4, Door(COLOR_NAMES[2], is_locked=True))

        self.grid.set(3, 2, Ball(COLOR_NAMES[0]))  
        self.grid.set(6, 3, Ball(COLOR_NAMES[1])) 
        # self.grid.set(1, 3, Ball(COLOR_NAMES[1]))
        # self.grid.set(3, 6, Ball(COLOR_NAMES[2]))
        # self.grid.set(3, 7, Ball(COLOR_NAMES[3]))


        # self.grid.set(2, 2, Key(COLOR_NAMES[0]))
        self.grid.set(3, 3, Key(COLOR_NAMES[1]))
        self.grid.set(7, 3, Key(COLOR_NAMES[2]))


        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "grand mission"

class TestEnvMediumLava(MiniGridEnv):
    def __init__(
        self,
        size=9, # Don't change this since it will change the size of the env but not the loc of objs and walls. 
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 64 * size**2

        super().__init__(
            mission_space=mission_space,
            grid_size=size, 
            # Set this to True for maximum speed
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "grand mission"

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # # Generate the surrounding walls
        # self.grid.wall_rect(0, 0, width, height)

        # generate the outer rim manually
        for i in [0, height -1]:
            for j in range(0,width):
                if j == 4:
                    self.grid.set(i, j, Wall())
                else:
                    self.grid.set(i, j, Lava())
        
        for i in [0, width -1]:
            for j in range(0,height):
                if j == 4:
                    self.grid.set(j, i, Wall())
                else:
                    self.grid.set(j, i, Lava())

        # Generate verical separation wall
        for i in range(0, height):
            self.grid.set(4, i, Wall())
            # self.grid.set(10, i, Wall())
            self.grid.set(i, 4, Wall())
            # self.grid.set(i, 10, Wall())
        
        # Place the door and key and ball
        # self.grid.set(1, 4, Door(COLOR_NAMES[0], is_locked=True))
        self.grid.set(4, 2, Door(COLOR_NAMES[1], is_locked=True))
        self.grid.set(6, 4, Door(COLOR_NAMES[2], is_locked=True))

        self.grid.set(3, 2, Ball(COLOR_NAMES[0]))  
        self.grid.set(6, 3, Ball(COLOR_NAMES[1])) 
        # self.grid.set(1, 3, Ball(COLOR_NAMES[1]))
        # self.grid.set(3, 6, Ball(COLOR_NAMES[2]))
        # self.grid.set(3, 7, Ball(COLOR_NAMES[3]))


        # self.grid.set(2, 2, Key(COLOR_NAMES[0]))
        self.grid.set(3, 3, Key(COLOR_NAMES[1]))
        self.grid.set(7, 3, Key(COLOR_NAMES[2]))


        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "grand mission"

class TestEnvLessBall(MiniGridEnv):
    def __init__(
        self,
        size=16, # Don't change this since it will change the size of the env but not the loc of objs and walls. 
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 16 * size**2

        super().__init__(
            mission_space=mission_space,
            grid_size=size, 
            # Set this to True for maximum speed
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "grand mission"

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Generate verical separation wall
        for i in range(0, height):
            self.grid.set(5, i, Wall())
            self.grid.set(10, i, Wall())
            self.grid.set(i, 5, Wall())
            self.grid.set(i, 10, Wall())
        
        # Place the door and key and ball
        self.grid.set(1, 5, Door(COLOR_NAMES[0], is_locked=True))
        self.grid.set(5, 6, Door(COLOR_NAMES[1], is_locked=True))
        self.grid.set(10, 8, Door(COLOR_NAMES[2], is_locked=True))
        self.grid.set(7, 10, Door(COLOR_NAMES[3], is_locked=True))
        self.grid.set(10, 13, Door(COLOR_NAMES[4], is_locked=True))
        self.grid.set(5, 14, Door(COLOR_NAMES[5], is_locked=True))
              
        self.grid.set(1, 4, Ball(COLOR_NAMES[1]))
        # self.grid.set(4, 6, Ball(COLOR_NAMES[2]))
        self.grid.set(9, 8, Ball(COLOR_NAMES[3]))
        # self.grid.set(7, 9, Ball(COLOR_NAMES[4]))
        # self.grid.set(9, 13, Ball(COLOR_NAMES[5]))
        # self.grid.set(6, 14, Ball(COLOR_NAMES[0]))

        self.grid.set(2, 2, Key(COLOR_NAMES[0]))
        self.grid.set(3, 4, Key(COLOR_NAMES[1]))
        self.grid.set(8, 8, Key(COLOR_NAMES[2]))
        self.grid.set(12, 6, Key(COLOR_NAMES[3]))
        self.grid.set(2, 13, Key(COLOR_NAMES[4]))
        self.grid.set(8, 11, Key(COLOR_NAMES[5]))


        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "grand mission"

class TestEnv(MiniGridEnv):
    def __init__(
        self,
        size=16, # Don't change this since it will change the size of the env but not the loc of objs and walls. 
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 16 * size**2

        super().__init__(
            mission_space=mission_space,
            grid_size=size, 
            # Set this to True for maximum speed
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "grand mission"

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Generate verical separation wall
        for i in range(0, height):
            self.grid.set(5, i, Wall())
            self.grid.set(10, i, Wall())
            self.grid.set(i, 5, Wall())
            self.grid.set(i, 10, Wall())
        
        # Place the door and key and ball
        self.grid.set(1, 5, Door(COLOR_NAMES[0], is_locked=True))
        self.grid.set(5, 6, Door(COLOR_NAMES[1], is_locked=True))
        self.grid.set(10, 8, Door(COLOR_NAMES[2], is_locked=True))
        self.grid.set(7, 10, Door(COLOR_NAMES[3], is_locked=True))
        self.grid.set(10, 13, Door(COLOR_NAMES[4], is_locked=True))
        self.grid.set(5, 14, Door(COLOR_NAMES[5], is_locked=True))
              
        self.grid.set(1, 4, Ball(COLOR_NAMES[1]))
        self.grid.set(4, 6, Ball(COLOR_NAMES[2]))
        self.grid.set(9, 8, Ball(COLOR_NAMES[3]))
        self.grid.set(7, 9, Ball(COLOR_NAMES[4]))
        self.grid.set(9, 13, Ball(COLOR_NAMES[5]))
        self.grid.set(6, 14, Ball(COLOR_NAMES[0]))

        self.grid.set(2, 2, Key(COLOR_NAMES[0]))
        self.grid.set(3, 4, Key(COLOR_NAMES[1]))
        self.grid.set(8, 8, Key(COLOR_NAMES[2]))
        self.grid.set(12, 6, Key(COLOR_NAMES[3]))
        self.grid.set(2, 13, Key(COLOR_NAMES[4]))
        self.grid.set(8, 11, Key(COLOR_NAMES[5]))


        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "grand mission"

# modified GymMDP
class GymMDP(MDP):
    ''' Class for Gym MDPs '''

    def __init__(self, env_name='CartPole-v0', render=False, render_every_n_episodes=0, wrapper=None, seed=None, **kwargs):
        '''
        Args:
            env_name (str)
            render (bool): If True, renders the screen every time step.
            render_every_n_epsiodes (int): @render must be True, then renders the screen every n episodes.
        '''
        # self.render_every_n_steps = render_every_n_steps
        self.render_every_n_episodes = render_every_n_episodes
        self.episode = 0
        self.env_name = env_name
        if 'customized_env' in kwargs:
            self.env = kwargs['customized_env']
        else:
            self.env = gym.make(env_name, **kwargs)
        if wrapper:
            self.env = wrapper(self.env)
        self.render = render
        if seed:
            init_state = GymState(self.env.reset(seed=seed)[0])
            self.env_seed = seed
        else:
            init_state = GymState(self.env.reset()[0])
            self.env_seed = None
        MDP.__init__(self, range(self.env.action_space.n), self._transition_func, self._reward_func, init_state=init_state)
    
    def get_parameters(self):
        '''
        Returns:
            (dict) key=param_name (str) --> val=param_val (object).
        '''
        param_dict = defaultdict(int)
        param_dict["env_name"] = self.env_name
   
        return param_dict

    def _reward_func(self, state, action, next_state):
        '''
        Args:
            state (AtariState)
            action (str)

        Returns
            (float)
        '''
        return self.prev_reward

    def _transition_func(self, state, action):
        '''
        Args:
            state (AtariState)
            action (str)

        Returns
            (State)
        '''
        obs, reward, is_terminal, truncated, info = self.env.step(action)

        if self.render and (self.render_every_n_episodes == 0 or self.episode % self.render_every_n_episodes == 0):
            self.env.render()

        self.prev_reward = reward
        self.next_state = GymState(obs, is_terminal=is_terminal)

        return self.next_state

    def reset(self, seed=None, options=None):
        if seed:
            if self.env_seed:
                print("Got a reset seed but the env already has a seed. Going with reset seed.")
            self.init_state = GymState(self.env.reset(seed=seed)[0])
            # print("Resetting with seed", seed)
        elif self.env_seed:
            self.init_state = GymState(self.env.reset(seed=self.env_seed)[0])
        else:
            self.init_state = GymState(self.env.reset()[0])
        # print(obs)
        self.episode += 1
        return self.init_state

    def __str__(self):
        return "gym-" + str(self.env_name)

def main():
    env = TestEnvEasy(render_mode="human")
    # env = FullyObsObjWrapper(env)
    env = GymMDP(env_name="nl2rlang", customized_env=env)
    breakpoint()

    
if __name__ == "__main__":
    main()
