import math
import operator
from functools import reduce
from typing import Any

import gymnasium as gym
import numpy as np
from gymnasium import logger, spaces
from gymnasium.core import ObservationWrapper, ObsType, ActType, WrapperObsType, Wrapper

from minigrid.core.constants import COLOR_TO_IDX, IDX_TO_COLOR, STATE_TO_IDX, OBJECT_TO_IDX
from minigrid.core.world_object import Goal
from minigrid.core.grid import Grid


OBJECT_TO_IDX_TRACK = {
    "wall": 2,
    "door": 4,
    "key": 5,
    "ball": 6,
    "box": 7,
    "goal": 8,
    "lava": 9,
    "agent": 10,
}

IDX_TO_OBJECT_TRACK = dict(zip(OBJECT_TO_IDX_TRACK.values(), OBJECT_TO_IDX_TRACK.keys()))


class FullyObsObjWrapper(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)

        new_image_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width, self.env.height, 3),  # number of cells
            dtype="uint8",
        )

        # TODO: Augment this observation space thing to include a variable size dictionary of objects if possible

        self.observation_space = spaces.Dict(
            {**self.observation_space.spaces, "image": new_image_space}
        )

    def observation(self, obs):
        env = self.unwrapped
        # return {**obs, "image": env.grid.encode(), "carrying": env.carrying.encode() if env.carrying is not None else (1, 0, 0), "agent": (env.agent_pos[0], env.agent_pos[1], env.agent_dir), "grid": env.grid}
        return env.grid.encode(), env.carrying.encode() if env.carrying is not None else (1, 0, 0), (env.agent_pos[0], env.agent_pos[1], env.agent_dir)


class SkillsAndObjWrapper(Wrapper):
    def __init__(self, env, env_actions=[0, 1, 2, 3, 4, 5]):
        # Let's assume we have some open-loop skills given to us
        # Some assumptions:
        # No dying
        super().__init__(env)

        ############ Observation Space
        new_image_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width, self.env.height, 3),  # number of cells
            dtype="uint8",
        )
        self.observation_space = spaces.Dict(
            {**self.observation_space.spaces, "image": new_image_space}
        )
        #############################

        ############ Action Space
        self.env_actions = env_actions
        self.skill_dict = dict()

        ######## For keeping track of the state in case a macro-action does nothing
        self.prev_step = None
    
    def step(self, pre_action):
        # In the case that it's a macro-action, we execute repeatedly until it's None.
        # If it's none we return the most recent observation and stuff.
        action, is_skill = self.action(pre_action)
        # print(action, is_skill)
        if not is_skill:
            self.prev_step = self.env.step(action)
            observation, reward, terminated, truncated, info = self.prev_step
            return self.observation(observation), reward, terminated, truncated, info
        else:
            # Execute the skill until it's done
            if action is None:
                observation, reward, terminated, truncated, info = self.prev_step
            while action is not None:
                self.prev_step = self.env.step(action)
                observation, reward, terminated, truncated, info = self.prev_step
                action, is_skill = self.action(pre_action)
            return self.observation(observation), reward, terminated, truncated, info
            
    
    def reset(self, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        self.prev_step = (obs, 0, False, False, info)
        return self.observation(obs), info
    
    def observation(self, obs):
        env = self.unwrapped
        return env.grid.encode(), env.carrying.encode() if env.carrying is not None else (1, 0, 0), (env.agent_pos[0], env.agent_pos[1], env.agent_dir)
    
    def action(self, action):
        if action not in self.env_actions:
            skill = self.skill_dict[action]
            return skill(self.observation(self.prev_step[0])), True
        else:
            return action, False
