import numpy as np
import torch
from minigrid.core.world_object import Lava, Wall
from minigrid.minigrid_env import MiniGridEnv
from tensordict import TensorDict


def get_wrapper_chain(env):
    """Get the chain of wrappers from outermost to innermost."""
    wrappers = []
    curr = env
    while hasattr(curr, "env"):
        wrappers.append(curr)
        curr = curr.env
    return wrappers


def generate_all_states(env: MiniGridEnv, device: torch.device) -> TensorDict:
    """Generate a tensordict containing all states in the environment.

    This function iterates through all possible agent positions and orientations,
    skips invalid states (e.g., walls), and generates the corresponding
    gridworld observation for each valid state. It respects environment wrappers.

    Args:
        env: The MiniGrid environment, possibly with wrappers.
        device: The device to store the data on.

    Returns:
        A TensorDict containing all possible observations.
    """
    observations = []
    env.reset()
    unwrapped_env = env.unwrapped

    num_width = unwrapped_env.width
    num_height = unwrapped_env.height
    num_orientations = 4  # 0: right, 1: down, 2: left, 3: up

    # Get the chain of wrappers to apply their observation transformations
    wrappers = get_wrapper_chain(env)
    # We want to apply from innermost to outermost, so we reverse
    wrappers.reverse()

    for x in range(num_width):
        for y in range(num_height):
            for orientation in range(num_orientations):
                tile = unwrapped_env.grid.get(x, y)

                # Skip walls and lava, as agent can't be there.
                if tile is not None and (isinstance(tile, Wall) or isinstance(tile, Lava)):
                    continue

                # Set agent state in the base environment
                unwrapped_env.agent_pos = (x, y)
                unwrapped_env.agent_dir = orientation

                # Generate observation from the base environment
                obs = unwrapped_env.gen_obs()

                # Apply observation transformations from wrappers
                for wrapper in wrappers:
                    if wrapper is not unwrapped_env and hasattr(wrapper, "observation"):
                        obs = wrapper.observation(obs)

                # The final observation might be a dict or just the image
                if isinstance(obs, dict) and "image" in obs:
                    observations.append(obs["image"])
                else:
                    # Assuming the observation is the image itself if not a dict
                    observations.append(obs)

    # Stack observations and create a TensorDict
    if not observations:
        # Handle case with no valid states
        return TensorDict({}, batch_size=[0], device=device)

    all_obs = np.stack(observations)
    data = {"observation": torch.tensor(all_obs, dtype=torch.float32, device=device)}

    return TensorDict(data, batch_size=[len(observations)], device=device)
