import math
import operator
from functools import reduce
from typing import Any

import gymnasium as gym
import numpy as np
from gymnasium import logger, spaces
from gymnasium.core import ActionWrapper, ObservationWrapper, ObsType, Wrapper

from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
from minigrid.core.world_object import Goal


class OneHotPartialImage(ObservationWrapper):
    """
    Wrapper to get a one-hot encoding of a partially observable
    agent view as observation.
    Args:
        env: The environment to apply the wrapper
    """
    def __init__(self, env, tile_size=8):
        """A wrapper that makes the image observation a one-hot encoding of a partially observable agent view.

        Args:
            env: The environment to apply the wrapper
        """
        super().__init__(env)

        self.tile_size = tile_size

        obs_shape = env.observation_space["image"].shape
        self.obs_shape = obs_shape
        
        self.obs_size = obs_shape[0] * obs_shape[1]

        # Number of bits per cell
        num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
        self.num_bits = num_bits

        new_image_space = spaces.Box(
            low=0, high=1.0, shape=(obs_shape[0] * obs_shape[1] * num_bits, ), dtype="float32"
        )

        self.observation_space = new_image_space
    
    def observation(self, obs):
        img = obs["image"]
        out = np.zeros((self.obs_shape[0], self.obs_shape[1], self.num_bits), dtype="float32")

        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                type = img[i, j, 0]
                color = img[i, j, 1]
                state = img[i, j, 2]

                out[i, j, type] = 1.
                out[i, j, len(OBJECT_TO_IDX) + color] = 1.
                out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1.

        return out.flatten()
    

class OneHotFullImage(ObservationWrapper):
    """
    Wrapper to get a one-hot encoding of a partially observable
    agent view as observation.
    Args:
        env: The environment to apply the wrapper
    """
    def __init__(self, env, tile_size=8):
        """A wrapper that makes the image observation a one-hot encoding of a fully observable agent view.

        Args:
            env: The environment to apply the wrapper
        """
        super().__init__(env)

        self.tile_size = tile_size

        obs_shape = (env.grid.width, env.grid.height, 3)
        self.obs_shape = obs_shape
        
        self.obs_size = obs_shape[0] * obs_shape[1]

        # Number of bits per cell
        num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX) + 4
        self.num_bits = num_bits

        new_image_space = spaces.Box(
            low=0, high=1.0, shape=(obs_shape[0] * obs_shape[1] * num_bits, ), dtype="float32"
        )

        self.observation_space = new_image_space
    
    def observation(self, obs):
        env = self.unwrapped
        img = env.grid.encode()

        out = np.zeros((self.obs_shape[0], self.obs_shape[1], self.num_bits), dtype="float32")

        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                type = img[i, j, 0]
                color = img[i, j, 1]
                state = img[i, j, 2]

                out[i, j, type] = 1.
                out[i, j, len(OBJECT_TO_IDX) + color] = 1.
                out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1.
                out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX) + env.agent_dir] = 1.

        return out.flatten()