import numpy as np
import os
import cv2
import torch
import gymnasium as gym
from gymnasium import ObservationWrapper, spaces
from torchrl.envs import GymWrapper, TransformedEnv, Compose
from torchrl.envs.transforms import StepCounter, ToTensorImage
from tensordict import TensorDict

class FullyObservableOvercookedWrapper(ObservationWrapper):
    """
    Fully observable Overcooked environment using a structured feature encoding for a single agent.
    """

    def __init__(self, env):
        super().__init__(env)
        # Get a sample encoded state from OvercookedGridworld to determine the observation space shape
        dummy_state = self.env.base_env.mdp.get_standard_start_state()
        sample_obs, sample_recipe = self.env.base_env.mdp.lossless_state_encoding_single_agent(dummy_state)

        obs_shape = sample_obs.shape  # (width, height, num_features)
        recipe_shape = sample_recipe.shape  # (2,)

        self.observation_space = spaces.Dict({
            "image": spaces.Box(low=-np.inf, high=np.inf, shape=obs_shape, dtype=np.float32),
            "recipe": spaces.Box(low=0, high=1, shape=recipe_shape, dtype=np.int32),
            "info": gym.spaces.Dict({
                "shaped_r_by_agent": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.float32),
                "sparse_r_by_agent": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.float32)
            })
        })

    def observation(self, obs):
        env = self.unwrapped
        if env.base_env.state is None:
            env.base_env.reset()

        # Convert state to lossless encoding
        state_encoding, curr_recipe = env.base_env.mdp.lossless_state_encoding_single_agent(env.base_env.state)
        # print(f"the curr_recipe is: {curr_recipe}")

        assert state_encoding is not None, "state_encoding is None"
        assert curr_recipe is not None, "curr_recipe is None"

        return {
            "image": state_encoding.astype(np.float32),
            "recipe": curr_recipe.astype(np.int32),
            "info": self.info
        }


class EnvMethodWrapper(gym.Wrapper):
    def env_method(self, method_name, *args, **kwargs):
        method = getattr(self.env, method_name)
        return method(*args, **kwargs)


class ActionMasking_cook(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)

        # The action mask sets a value for each action of either 0 (invalid) or 1 (valid).
        self.observation_space = spaces.Dict({
            **self.observation_space.spaces,
            "action_mask": spaces.Box(0.0, 1.0, shape=(self.action_space.n,))
        })

    def get_terrain_type_at_pos(self, pos):
        x, y = pos
        return self.terrain_mtx[y][x]

    def get_action_mask(self, state, player_idx):
        player = state.players[player_idx]
        action_mask = np.ones(self.action_space.n)
        # Get the position and orientation of the player
        pos = player.position
        orientation = player.orientation

        # Calculate the position in front of the player
        front_pos = Action.move_in_direction(pos, orientation)

        front_terrain = self.base_env.mdp.get_terrain_type_at_pos(front_pos)

        #Determine valid actions based on the player's state and the terrain in front

        # block front movement
        if front_terrain in ["O", "T", "D","P","S","X","M","C"]:
            action_mask[Direction.DIRECTION_TO_INDEX[orientation]] = 0

        # block front interact
        if front_terrain == " ":
                action_mask[Action.ACTION_TO_INDEX["interact"]] = 0.0

        # if player.has_object():
        #     object_type = player.get_object().name
        #     if object_type in {"onion", "dish", "tomato","beef","chicken"} and front_terrain in {"O", "T", "D", "S"}:
        #         action_mask[Action.ACTION_TO_INDEX["interact"]] = 0.0
        #     if object_type in {"soup"} and front_terrain in {"O", "T", "D", "P"}:
        #         action_mask[Action.ACTION_TO_INDEX["interact"]] = 0.0
            # object_type = player.get_object().name
            # if object_type in {"onion", "dish", "tomato","beef","chicken"} and front_terrain not in {"P"}:
            #     action_mask[Action.ACTION_TO_INDEX["interact"]] = 0.0
            # if object_type in {"soup"} and front_terrain not in {"S"}:
            #     action_mask[Action.ACTION_TO_INDEX["interact"]] = 0.0

        # print(action_mask)
        return action_mask

    def observation(self, obs):
        state = self.base_env.state
        action_mask= self.get_action_mask(state, player_idx=0)
        return {**obs, "action_mask": action_mask.astype(np.int8)}


import cv2
from overcooked_ai_py.mdp.actions import Direction, Action
# Define action mappings (keys → actions)
KEY_TO_ACTION = {
    ord('w'): Direction.NORTH,   # Move up
    ord('s'): Direction.SOUTH,   # Move down
    ord('d'): Direction.EAST,    # Move right
    ord('a'): Direction.WEST,    # Move left
    ord(' '): Action.STAY,       # Stay in place
    ord('e'): Action.INTERACT    # Interact
}
# Define action-to-integer mapping (actions → integers 0-5)
ACTION_TO_INT = {
    Direction.NORTH: 0,
    Direction.SOUTH: 1,
    Direction.EAST: 2,
    Direction.WEST: 3,
    Action.STAY: 4,
    Action.INTERACT: 5
}
# Human-readable action names
ACTION_TO_CHAR = {
    Direction.NORTH: "↑",
    Direction.SOUTH: "↓",
    Direction.EAST: "→",
    Direction.WEST: "←",
    Action.STAY: "stay",
    Action.INTERACT: "interact"
}

def main():
    import cv2
    from single_agent_overcooked import SingleAgentOvercooked
    from wrappers import FullyObservableOvercookedWrapper  # Ensure correct import

    # Create the base environment
    layout_name = "cramped_room"
    base_env = SingleAgentOvercooked(layout_name=layout_name, horizon=10, random_layout=True, random_recipe=True)

    # Wrap with fully observable single-agent encoding
    wrapped_env = FullyObservableOvercookedWrapper(base_env)
    obs, _ = wrapped_env.reset()

    window_name = "Overcooked Human Play"
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)

    while True:
        frame = base_env.get_frame()
        cv2.imshow(window_name, frame)
        cv2.imwrite(f"frame_{layout_name}.png", frame)
        # state = base_env.get_state()
        # print(state)

        key = cv2.waitKey(0)  # Wait for user input

        if key in KEY_TO_ACTION:
            action = KEY_TO_ACTION[key]  # Convert keypress to action
            action_int = ACTION_TO_INT[action]  # Convert action to integer (0-5)
            action_symbol = ACTION_TO_CHAR[action]  # Get human-readable symbol
        elif key == ord('q'):  # Quit the game
            print("Exiting game.")
            break
        else:
            print("Invalid key. Use W/A/S/D for movement, E for interact, and Space to stay.")
            continue

        obs, reward, done, truncated, info = wrapped_env.step(action_int)
        print(obs['recipe'])
        print(obs['image'].shape)
        print(f"reward: {reward} and info: {info}")
        print(f"Action: {action_symbol}")



        if done:
            print("Episode finished, resetting environment.")
            obs, _ = wrapped_env.reset()

        fish_layer_names = [
            "fish_disp_loc",  # index 10
            "fish_in_pot",  # index 14
            "fish_in_soup",  # index 17
            "fish"  # index 21
        ]

        # These indices correspond to the layers as defined in your encoding
        fish_layer_indices = [9, 14, 17, 23]

        # for name, idx in zip(fish_layer_names, fish_layer_indices):
        #     print(f"\n{name} (layer {idx}):")
        #     print(obs["image"][:, :, idx])

    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()



