import gymnasium as gym
import numpy as np
import pygame
import cv2
from overcooked_ai_py.mdp.actions import Action
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
import random


class SingleAgentOvercooked(gym.Env):
    """
    A Gymnasium wrapper for a single-agent Overcooked environment.
    Keeps original RGB-based observation while allowing Gym compatibility.
    """

    def __init__(self, layout_name="cramped_room", horizon=400, reward_shaping_horizon=1000000,
                 random_layout=False, random_recipe=False, force_ingredients=False):
        super().__init__()
        self.layout_name = layout_name
        self.horizon = horizon
        self.random_layout = random_layout
        self.random_recipe = random_recipe
        self.force_ingredients = force_ingredients

        # Load Overcooked MDP
        mdp = OvercookedGridworld.from_layout_name(layout_name, old_dynamics=True)
        self.terrain_mtx = mdp.terrain_mtx
        self.start_player_positions = mdp.start_player_positions

        # print(f"random layout: {self.random_layout}, random_recipe: {self.random_recipe}")
        assert isinstance(mdp, OvercookedGridworld), f"Expected OvercookedGridworld, got {type(mdp)}"



        # the default env
        self.base_env = OvercookedEnv.from_mdp(mdp, horizon=self.horizon)
        self.possible_all_orders = self.base_env.state.all_orders  # [('onion', 'onion', 'onion'), ('tomato', 'tomato', 'tomato')]
        # print(f"all_orders possible {self.possible_all_orders}")


        # customized requirement
        custom_params = {}
        # Randomize layout if requested
        if self.random_layout:
            self.shuffle_non_empty_tiles_in_place()
            custom_params["terrain"] = self.terrain_mtx
            custom_params["start_player_positions"] = self.start_player_positions

        # Randomize recipe if requested
        if self.random_recipe:
            assert self.possible_all_orders, "No available recipes to choose from"
            chosen = random.choice(self.possible_all_orders)
            custom_params["start_all_orders"] = [{"ingredients": list(chosen)}]
            custom_params["possible_all_orders"] = self.possible_all_orders

        # If either condition is true, regenerate the MDP and base_env
        if self.random_layout or self.random_recipe:
            mdp = OvercookedGridworld.from_layout_name(
                self.layout_name,
                old_dynamics=True,
                **custom_params
            )
            self.base_env = OvercookedEnv.from_mdp(mdp, horizon=self.horizon)

        # Define discrete action space
        self.action_space = gym.spaces.Discrete(len(Action.ALL_ACTIONS))

        self.observation_space = gym.spaces.Dict({
            "state": gym.spaces.Discrete(1),  # Placeholder for state
            "info": gym.spaces.Dict({
                "shaped_r_by_agent": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.float32),
                "sparse_r_by_agent": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.float32)
            })
        })
        # State visualizer
        self.visualizer = StateVisualizer(is_rendering_hud=True)
        # Initialize state
        self.state = None
        self.info = None

        # for reward score
        self.score = 0 # current frame reward
        self.shaped_reward = 0 # cumulative shaped reward
        self.sparse_reward = 0 # cumulative sparse reward
        self.reward = 0 # cumulative sparse reward + decay * shaped reward

        # for reward shaping
        self._initial_reward_shaping_factor = 1
        self.reward_shaping_horizon = reward_shaping_horizon
        self.reward_shaping_factor = self._initial_reward_shaping_factor

    def get_sparse_reward(self):
        return self.sparse_reward

    def shuffle_non_empty_tiles_in_place(self):
        height = len(self.terrain_mtx)
        width = len(self.terrain_mtx[0])

        non_empty_positions = []
        non_empty_values = []
        empty_positions = []

        corner_coords = {(0, 0), (0, width - 1), (height - 1, 0), (height - 1, width - 1)}
        if self.layout_name.startswith("m_room"):
            skip_positions = [(0, 2)]
        elif self.layout_name.startswith("forced_room"):
            skip_positions = [(0, 2), (4, 2)]
        else:
            skip_positions = []

        for y in range(height):
            for x in range(width):
                is_boundary = (y == 0 or y == height - 1 or x == 0 or x == width - 1)
                is_corner = (y, x) in corner_coords
                val = self.terrain_mtx[y][x]

                if is_boundary and not is_corner:
                    if (y, x) not in skip_positions and val != ' ':
                        non_empty_positions.append((y, x))
                        non_empty_values.append(val)

                if val == ' ' and not is_corner:
                    empty_positions.append((x, y))  # Overcooked expects (x, y)

        # Shuffle non-corner boundary non-empty values
        random.shuffle(non_empty_values)

        for (y, x), new_val in zip(non_empty_positions, non_empty_values):
            self.terrain_mtx[y][x] = new_val

        # Random empty tile (not in corners) for agent start
        assert empty_positions, "No empty space available for agent start"
        self.start_player_positions = [random.choice(empty_positions)]

    def step(self, action):
        action_str = Action.INDEX_TO_ACTION[action]  # Convert action index to action string
        # print(self.reward_shaping_factor)
        self.state, reward, done, info = self.base_env.step((action_str,))

        self.shaped_reward += info["shaped_r_by_agent"][0]
        self.sparse_reward += info["sparse_r_by_agent"][0]
        self.reward = self.sparse_reward + self.reward_shaping_factor * self.shaped_reward
        self.score =  info["sparse_r_by_agent"][0] + self.reward_shaping_factor * info["shaped_r_by_agent"][0]

        info = {
                "shaped_r_by_agent": self.shaped_reward,
                "sparse_r_by_agent": self.sparse_reward,
            }
        self.info = info

        return {
            "state": self.state,
            "info": info
        }, self.score, done, False, info

    def get_state(self):
        return self.base_env


    def get_frame(self, color_mode="BGR", reward_mode="merged"):
        """
        Returns a processed RGB frame of the current state.
        """
        if reward_mode == "merged":
            pygame_surface = self.visualizer.render_state(
                state=self.base_env.state,
                grid=self.base_env.mdp.terrain_mtx,
                hud_data=StateVisualizer.default_hud_data(self.base_env.state, score=self.reward)
            )
        else:
            pygame_surface = self.visualizer.render_state(
                state=self.base_env.state,
                grid=self.base_env.mdp.terrain_mtx,
                hud_data=StateVisualizer.default_hud_data(self.base_env.state, score=self.sparse_reward)
            )

        # Convert Pygame surface to NumPy array (RGB)
        buffer = pygame.surfarray.array3d(pygame_surface)
        image = np.rot90(buffer, 1)  # Rotate correctly
        image = np.flip(image, axis=0)  # Flip for correct orientation
        if color_mode == "BGR":
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)  # Convert to BGR

        # Resize and return
        # image = cv2.resize(image, (528, 464))
        return image

    def close(self):
        """
        Cleanup the environment.
        """
        cv2.destroyAllWindows()

    def reset(self, seed=None, options=None):
        """
        Reset the environment to the initial state.
        """
        custom_params = {}

        # Randomize layout if requested
        if self.random_layout:
            self.shuffle_non_empty_tiles_in_place()
            custom_params["terrain"] = self.terrain_mtx
            custom_params["start_player_positions"] = self.start_player_positions

        chosen = None
        # Randomize recipe if requested
        if self.random_recipe:
            assert self.possible_all_orders, "No available recipes to choose from"
            chosen = random.choice(self.possible_all_orders)
            custom_params["start_all_orders"] = [{"ingredients": list(chosen)}]
            custom_params["possible_all_orders"] = self.possible_all_orders
            # if ingredient is in chosen and self.force_ingredients
            # If onion is in chosen, replace t in terrain mix by o, if tomato is in chosen, replace o in mix by tomato

        if self.force_ingredients and chosen:
            for y in range(len(self.terrain_mtx)):
                for x in range(len(self.terrain_mtx[0])):
                    val = self.terrain_mtx[y][x]
                    if "onion" in chosen and val == 'T':
                        self.terrain_mtx[y][x] = 'O'
                    elif "tomato" in chosen and val == 'O':
                        self.terrain_mtx[y][x] = 'T'
            custom_params["terrain"] = self.terrain_mtx # Overwrite




        # If either condition is true, regenerate the MDP and base_env
        if self.random_layout or self.random_recipe or self.force_ingredients:
            mdp = OvercookedGridworld.from_layout_name(
                self.layout_name,
                old_dynamics=True,
                **custom_params
            )
            self.base_env = OvercookedEnv.from_mdp(mdp, horizon=self.horizon)

        # Reset environment state
        self.base_env.reset()
        self.state = self.base_env.state
        self.score = 0
        self.reward = 0
        self.shaped_reward = 0
        self.sparse_reward = 0
        self.info = {
            "shaped_r_by_agent": self.shaped_reward,
            "sparse_r_by_agent": self.sparse_reward
        }


        return {
            "state": self.state,
            "info": self.info
        }, {}

    def _get_observation(self):
        """
        Returns a properly formatted RGB image representation of the environment.
        """
        return self.get_frame()

    def _anneal(self, start_v, curr_t, end_t, end_v=0, start_t=0):
        """
        Linearly anneal from start_v at time start_t to end_v at time end_t.
        """
        if end_t == start_t:
            return start_v
        fraction = max(1 - (curr_t - start_t) / (end_t - start_t), 0)
        return fraction * start_v + (1 - fraction) * end_v

    def anneal_reward_shaping_factor(self, timesteps: int):
        """
        Update the reward shaping factor using the current training timesteps.
        """
        # print(f"call with timesteps: {timesteps}")
        new_factor = self._anneal(
            self._initial_reward_shaping_factor,
            timesteps,
            self.reward_shaping_horizon,
            end_v=0  # You can change this if you want a nonzero floor.
        )
        self.set_reward_shaping_factor(new_factor)

    def set_reward_shaping_factor(self, factor: float):
        self.reward_shaping_factor = factor


def main():
    env = SingleAgentOvercooked(layout_name="cramped_room_ot_co", horizon=20, random_layout=True, random_recipe=True)
    # print(env.base_env)

    # Open a window to display frames
    cv2.namedWindow("Overcooked-AI", cv2.WINDOW_NORMAL)
    # Take a few random steps
    for step in range(1000000):  # Run for 100 steps
        action = env.action_space.sample()  # Sample random action
        # print("the choosen action is {}".format(Action.ACTION_TO_CHAR[Action.INDEX_TO_ACTION[int(action)]]))
        state, reward, done, _, info = env.step(action)
        # print(env.base_env)
        # print(info)
        # Get the current frame
        frame = env.get_frame()
        cv2.imshow("Overcooked-AI", frame)

        # Exit if 'q' is pressed
        if cv2.waitKey(500) & 0xFF == ord('q'):
            break

        if done:
            print("Episode finished, resetting environment.")
            env.reset()
            # print(env.base_env)



    # Close all OpenCV windows when done
    env.close()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()

