import json
import os
import pickle
import random
from abc import ABC, abstractmethod
from queue import Empty, Full, LifoQueue, Queue
from threading import Lock, Thread
from time import time
from multiprocessing import Manager

import ray
from utils import DOCKER_VOLUME, create_dirs

# from human_aware_rl.rllib.rllib import load_agent
from overcooked_ai_py.mdp.actions import Action, Direction
# from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld, ObjectState
from overcooked_ai_py.planning.planners_simplified import (
    NO_COUNTERS_PARAMS,
    LLMActionManager
)
import tensorflow as tf 
import numpy as np 
from human_aware_rl.baselines_utils import get_agent_from_saved_model
# from human_aware_rl.rllib.rllib import load_agent
# 
# from human_aware_rl.rllib.rllib import (
#     RlLibAgent,
# )
# from human_aware_rl.imitation.behavior_cloning_tf2 import (
#     get_bc_params,
#     load_bc_model,
#     _get_base_ae,
#     BehaviorCloningPolicy
# )

from human_aware_rl.imitation.behavioural_cloning import load_bc_model_from_path, get_bc_agent_from_model

import pickle

LLM_PLAYER_IDX = 1
LAYOUT_MAP = {
    'cramped_room': 'simple',
    'asymmetric_advantages': 'unident_s',
    'coordination_ring': 'random1',
    'forced_coordination': 'random0',
    'counter_circuit_o_1order': 'random3'
}

def set_global_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_random_seed(seed)


INVERTED_LAYOUT_MAPS = {'simple': 'cramped_room', 
                        'unident_s': 'asymmetric_advantages', 
                        'random1': 'coordination_ring', 
                        'random0': 'forced_coordination', 
                        'random3': 'counter_circuit_o_1order'}

def fix_filetype(path, filetype):
    if path[-len(filetype) :] == filetype:
        return path
    else:
        return path + filetype
    
def load_pickle(filename):
    with open(fix_filetype(filename, ".pickle"), "rb") as f:
        return pickle.load(f) 
    
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

# Relative path to where all static pre-trained agents are stored on server
AGENT_DIR = None

# Maximum allowable game time (in seconds)
MAX_GAME_TIME = None


def _configure(max_game_time, agent_dir):
    global AGENT_DIR, MAX_GAME_TIME
    MAX_GAME_TIME = max_game_time
    AGENT_DIR = agent_dir


def fix_bc_path(path):
    """
    Loading a PPO agent trained with a BC agent requires loading the BC model as well when restoring the trainer, even though the BC model is not used in game
    For now the solution is to include the saved BC model and fix the relative path to the model in the config.pkl file
    """

    import dill

    # the path is the agents/Rllib.*/agent directory
    agent_path = os.path.dirname(path)
    with open(os.path.join(agent_path, "config.pkl"), "rb") as f:
        data = dill.load(f)
    bc_model_dir = data["bc_params"]["bc_config"]["model_dir"]
    last_dir = os.path.basename(bc_model_dir)
    bc_model_dir = os.path.join(agent_path, "bc_params", last_dir)
    data["bc_params"]["bc_config"]["model_dir"] = bc_model_dir
    with open(os.path.join(agent_path, "config.pkl"), "wb") as f:
        dill.dump(data, f)


class Game(ABC):

    """
    Class representing a game object. Coordinates the simultaneous actions of arbitrary
    number of players. Override this base class in order to use.

    Players can post actions to a `pending_actions` queue, and driver code can call `tick` to apply these actions.


    It should be noted that most operations in this class are not on their own thread safe. Thus, client code should
    acquire `self.lock` before making any modifications to the instance.

    One important exception to the above rule is `enqueue_actions` which is thread safe out of the box
    """

    # Possible TODO: create a static list of IDs used by the class so far to verify id uniqueness
    # This would need to be serialized, however, which might cause too great a performance hit to
    # be worth it

    EMPTY = "EMPTY"

    class Status:
        DONE = "done"
        ACTIVE = "active"
        RESET = "reset"
        INACTIVE = "inactive"
        ERROR = "error"

    def __init__(self, *args, **kwargs):
        """
        players (list): List of IDs of players currently in the game
        spectators (set): Collection of IDs of players that are not allowed to enqueue actions but are currently watching the game
        id (int):   Unique identifier for this game
        pending_actions List[(Queue)]: Buffer of (player_id, action) pairs have submitted that haven't been commited yet
        lock (Lock):    Used to serialize updates to the game state
        is_active(bool): Whether the game is currently being played or not
        """
        self.players = []
        self.spectators = set()
        self.pending_actions = []
        self.id = kwargs.get("id", id(self))
        self.lock = Lock()
        self._is_active = False

    @abstractmethod
    def is_full(self):
        """
        Returns whether there is room for additional players to join or not
        """
        pass

    @abstractmethod
    def apply_action(self, player_idx, action):
        """
        Updates the game state by applying a single (player_idx, action) tuple. Subclasses should try to override this method
        if possible
        """
        pass

    @abstractmethod
    def is_finished(self):
        """
        Returns whether the game has concluded or not
        """
        pass

    def is_ready(self):
        """
        Returns whether the game can be started. Defaults to having enough players
        """
        return self.is_full()

    @property
    def is_active(self):
        """
        Whether the game is currently being played
        """
        return self._is_active

    @property
    def reset_timeout(self):
        """
        Number of milliseconds to pause game on reset
        """
        return 3000

    def apply_actions(self):
        """
        Updates the game state by applying each of the pending actions in the buffer. Is called by the tick method. Subclasses
        should override this method if joint actions are necessary. If actions can be serialized, overriding `apply_action` is
        preferred
        """
        for i in range(len(self.players)):
            try:
                while True:
                    action = self.pending_actions[i].get(block=False)
                    self.apply_action(i, action)
            except Empty:
                pass

    def activate(self):
        """
        Activates the game to let server know real-time updates should start. Provides little functionality but useful as
        a check for debugging
        """
        self._is_active = True

    def deactivate(self):
        """
        Deactives the game such that subsequent calls to `tick` will be no-ops. Used to handle case where game ends but
        there is still a buffer of client pings to handle
        """
        self._is_active = False

    def reset(self):
        """
        Restarts the game while keeping all active players by resetting game stats and temporarily disabling `tick`
        """
        if not self.is_active:
            raise ValueError("Inactive Games cannot be reset")
        if self.is_finished():
            return self.Status.DONE
        self.deactivate()
        self.activate()
        return self.Status.RESET

    def needs_reset(self):
        """
        Returns whether the game should be reset on the next call to `tick`
        """
        return False

    def tick(self):
        """
        Updates the game state by applying each of the pending actions. This is done so that players cannot directly modify
        the game state, offering an additional level of safety and thread security.

        One can think of "enqueue_action" like calling "git add" and "tick" like calling "git commit"

        Subclasses should try to override `apply_actions` if possible. Only override this method if necessary
        """
        if not self.is_active:
            return self.Status.INACTIVE
        if self.needs_reset():
            self.reset()
            return self.Status.RESET

        self.apply_actions()
        return self.Status.DONE if self.is_finished() else self.Status.ACTIVE

    def enqueue_action(self, player_id, action):
        """
        Add (player_id, action) pair to the pending action queue, without modifying underlying game state

        Note: This function IS thread safe
        """
        if not self.is_active:
            # Could run into issues with is_active not being thread safe
            return
        if player_id not in self.players:
            # Only players actively in game are allowed to enqueue actions
            return
        try:
            player_idx = self.players.index(player_id)
            self.pending_actions[player_idx].put(action)
        except Full:
            pass

    def get_state(self):
        """
        Return a JSON compatible serialized state of the game. Note that this should be as minimalistic as possible
        as the size of the game state will be the most important factor in game performance. This is sent to the client
        every frame update.
        """
        return {"players": self.players}

    def to_json(self):
        """
        Return a JSON compatible serialized state of the game. Contains all information about the game, does not need to
        be minimalistic. This is sent to the client only once, upon game creation
        """
        return self.get_state()

    def is_empty(self):
        """
        Return whether it is safe to garbage collect this game instance
        """
        return not self.num_players

    def add_player(self, player_id, idx=None, buff_size=-1):
        """
        Add player_id to the game
        """
        if self.is_full():
            raise ValueError("Cannot add players to full game")
        if self.is_active:
            raise ValueError("Cannot add players to active games")
        if not idx and self.EMPTY in self.players:
            idx = self.players.index(self.EMPTY)
        elif not idx:
            idx = len(self.players)

        padding = max(0, idx - len(self.players) + 1)
        for _ in range(padding):
            self.players.append(self.EMPTY)
            self.pending_actions.append(self.EMPTY)

        self.players[idx] = player_id
        self.pending_actions[idx] = Queue(maxsize=buff_size)

    def add_spectator(self, spectator_id):
        """
        Add spectator_id to list of spectators for this game
        """
        if spectator_id in self.players:
            raise ValueError("Cannot spectate and play at same time")
        self.spectators.add(spectator_id)

    def remove_player(self, player_id):
        """
        Remove player_id from the game
        """
        try:
            idx = self.players.index(player_id)
            self.players[idx] = self.EMPTY
            self.pending_actions[idx] = self.EMPTY
        except ValueError:
            return False
        else:
            return True

    def remove_spectator(self, spectator_id):
        """
        Removes spectator_id if they are in list of spectators. Returns True if spectator successfully removed, False otherwise
        """
        try:
            self.spectators.remove(spectator_id)
        except ValueError:
            return False
        else:
            return True

    def clear_pending_actions(self):
        """
        Remove all queued actions for all players
        """
        for i, player in enumerate(self.players):
            if player != self.EMPTY:
                queue = self.pending_actions[i]
                queue.queue.clear()

    @property
    def num_players(self):
        return len([player for player in self.players if player != self.EMPTY])

    def get_data(self):
        """
        Return any game metadata to server driver.
        """
        return {}


class DummyGame(Game):

    """
    Standin class used to test basic server logic
    """

    def __init__(self, **kwargs):
        super(DummyGame, self).__init__(**kwargs)
        self.counter = 0

    def is_full(self):
        return self.num_players == 2

    def apply_action(self, idx, action):
        pass

    def apply_actions(self):
        self.counter += 1

    def is_finished(self):
        return self.counter >= 100

    def get_state(self):
        state = super(DummyGame, self).get_state()
        state["count"] = self.counter
        return state


class DummyInteractiveGame(Game):

    """
    Standing class used to test interactive components of the server logic
    """

    def __init__(self, **kwargs):
        super(DummyInteractiveGame, self).__init__(**kwargs)
        self.max_players = int(
            kwargs.get("playerZero", "human") == "human"
        ) + int(kwargs.get("playerOne", "human") == "human")
        self.max_count = kwargs.get("max_count", 30)
        self.counter = 0
        self.counts = [0] * self.max_players

    def is_full(self):
        return self.num_players == self.max_players

    def is_finished(self):
        return max(self.counts) >= self.max_count

    def apply_action(self, player_idx, action):
        if action.upper() == Direction.NORTH:
            self.counts[player_idx] += 1
        if action.upper() == Direction.SOUTH:
            self.counts[player_idx] -= 1

    def apply_actions(self):
        super(DummyInteractiveGame, self).apply_actions()
        self.counter += 1

    def get_state(self):
        state = super(DummyInteractiveGame, self).get_state()
        state["count"] = self.counter
        for i in range(self.num_players):
            state["player_{}_count".format(i)] = self.counts[i]
        return state



# -------------------------------------------------------------------------------------------------------------------------------
class OvercookedGame(Game):
    """
    Class for bridging the gap between Overcooked_Env and the Game interface

    Instance variable:
        - max_players (int): Maximum number of players that can be in the game at once
        - mdp (OvercookedGridworld): Controls the underlying Overcooked game logic
        - score (int): Current reward acheived by all players
        - max_time (int): Number of seconds the game should last
        - npc_policies (dict): Maps user_id to policy (Agent) for each AI player
        - npc_state_queues (dict): Mapping of NPC user_ids to LIFO queues for the policy to process
        - curr_tick (int): How many times the game server has called this instance's `tick` method
        - ticker_per_ai_action (int): How many frames should pass in between NPC policy forward passes.
            Note that this is a lower bound; if the policy is computationally expensive the actual frames
            per forward pass can be higher
        - action_to_overcooked_action (dict): Maps action names returned by client to action names used by OvercookedGridworld
            Note that this is an instance variable and not a static variable for efficiency reasons
        - human_players (set(str)): Collection of all player IDs that correspond to humans
        - npc_players (set(str)): Collection of all player IDs that correspond to AI
        - randomized (boolean): Whether the order of the layouts should be randomized

    Methods:
        - npc_policy_consumer: Background process that asynchronously computes NPC policy forward passes. One thread
            spawned for each NPC
        - _curr_game_over: Determines whether the game on the current mdp has ended
    """

    def __init__(
        self,
        layouts=["cramped_room"],
        mdp_params={},
        num_players=2,
        gameTime=30,
        playerZero="vicuna",
        playerOne="human",
        showPotential=False,
        randomized=False,
        ticks_per_ai_action=1,
        **kwargs
    ):
        super(OvercookedGame, self).__init__(**kwargs)
        self.show_potential = showPotential
        self.mdp_params = mdp_params
        self.layouts = layouts
        self.max_players = int(num_players)
        self.mdp = None
        self.mp = None
        self.score = 0
        self.phi = 0
        self.max_time = min(int(gameTime), MAX_GAME_TIME)
        self.npc_policies = {}
        self.npc_state_queues = {}
        self.action_to_overcooked_action = {
            "STAY": Action.STAY,
            "UP": Direction.NORTH,
            "DOWN": Direction.SOUTH,
            "LEFT": Direction.WEST,
            "RIGHT": Direction.EAST,
            "SPACE": Action.INTERACT,
        }
        self.ticks_per_ai_action = ticks_per_ai_action
        self.curr_tick = 0
        self.human_players = set()
        self.npc_players = set()
        self.ai_player_state = None 
        self.human_message = None 
        if randomized:
            random.shuffle(self.layouts)

        if playerZero != "human":
            player_zero_id = 'player' + "_0"
            self.add_player(player_zero_id, idx=0, buff_size=1, is_human=False)
            self.npc_policies[player_zero_id] = self.get_policy(
                playerZero, idx=0
            )
            self.npc_state_queues[player_zero_id] = LifoQueue()

        if playerOne != "human":
            player_one_id = 'player' + "_1"
            self.add_player(player_one_id, idx=1, buff_size=1, is_human=False)
            self.npc_policies[player_one_id] = self.get_policy(
                playerOne, idx=1
            )
            self.npc_state_queues[player_one_id] = LifoQueue()
        # Always kill ray after loading agent, otherwise, ray will crash once process exits
        # Only kill ray after loading both agents to avoid having to restart ray during loading
        if ray.is_initialized():
            ray.shutdown()
        self.action_managers = {}

        if kwargs["dataCollection"]:
            self.write_data = True
            self.write_config = kwargs["collection_config"]
        else:
            self.write_data = False
        self.messages = Manager().dict()
        self.messages[0] = ''
        self.messages[1] = ''
        self.trajectory = []

    def _curr_game_over(self):
        return time() - self.start_time >= self.max_time

    def needs_reset(self):
        return self._curr_game_over() and not self.is_finished()
    
    def set_human_message(self, msg):
        self.human_message = msg

    def add_player(self, player_id, idx=None, buff_size=-1, is_human=True):
        super(OvercookedGame, self).add_player(
            player_id, idx=idx, buff_size=buff_size
        )
        if is_human:
            self.human_players.add(player_id)
        else:
            self.npc_players.add(player_id)

    def remove_player(self, player_id):
        removed = super(OvercookedGame, self).remove_player(player_id)
        if removed:
            if player_id in self.human_players:
                self.human_players.remove(player_id)
            elif player_id in self.npc_players:
                self.npc_players.remove(player_id)
            else:
                raise ValueError("Inconsistent state")

    def npc_policy_consumer(self, policy_id):
        queue = self.npc_state_queues[policy_id]
        policy = self.npc_policies[policy_id]
        while self._is_active:
            state = queue.get()
            # npc_action, _= policy.action(state, self.action_managers[policy_id], self.messages)
            # npc_action, _ = policy.action(state)
            # Uncomment when using BC 
            npc_action = policy.action(state)

            # uncomment when using ppo 
            # npc_action = policy.action(state)
            print(f"INFO BC-AGENT took action: {npc_action}")
            # print(f"INFO: MESSAGES set in thread function {self.messages}")
            super(OvercookedGame, self).enqueue_action(policy_id, npc_action)
    
    def npc_policy_consumer_llm(self, policy_id):
        queue = self.npc_state_queues[policy_id]
        policy = self.npc_policies[policy_id]
        while self._is_active:
            state = queue.get()
            npc_action, _= policy.action(state, self.action_managers[policy_id], self.messages)
            # npc_action, _ = policy.action(state)
            # print(f"INFO: MESSAGES set in thread function {self.messages}")
            super(OvercookedGame, self).enqueue_action(policy_id, npc_action)


    def is_full(self):
        return self.num_players >= self.max_players

    def is_finished(self):
        val = not self.layouts and self._curr_game_over()
        return val

    def is_empty(self):
        """
        Game is considered safe to scrap if there are no active players or if there are no humans (spectating or playing)
        """
        return (
            super(OvercookedGame, self).is_empty()
            or not self.spectators
            and not self.human_players
        )

    def is_ready(self):
        """
        Game is ready to be activated if there are a sufficient number of players and at least one human (spectator or player)
        """
        return super(OvercookedGame, self).is_ready() and not self.is_empty()

    def apply_action(self, player_id, action):
        pass

    def apply_actions(self):
        # Default joint action, as NPC policies and clients probably don't enqueue actions fast
        # enough to produce one at every tick
        joint_action = [Action.STAY] * len(self.players)

        # Synchronize individual player actions into a joint-action as required by overcooked logic
        for i in range(len(self.players)):
            # if this is a human, don't block and inject
            if self.players[i] in self.human_players:
                try:
                    # we don't block here in case humans want to Stay
                    joint_action[i] = self.pending_actions[i].get(block=False)
                except Empty:
                    pass
            else:
                # we block on agent actions to ensure that the agent gets to do one action per state
                joint_action[i] = self.pending_actions[i].get(block=True)

        # Apply overcooked game logic to get state transition
        prev_state = self.state
        self.state, sparse_reward, shaped_reward = self.mdp.get_state_transition(
            prev_state, joint_action
        )
        print(f"{bcolors.WARNING} NEW STATE: {self.state.__dict__}")
        info = {
            'sparse_reward_by_agent': sparse_reward, 
            'shaped_reward_by_agent': shaped_reward
        }

        if self.show_potential:
            self.phi = self.mdp.potential_function(
                prev_state, self.mp, gamma=0.99
            )

        # Send next state to all background consumers if needed
        if self.curr_tick % self.ticks_per_ai_action == 0:
            for npc_id in self.npc_policies:
                self.npc_state_queues[npc_id].put(self.state, block=False)


        
        # Update score based on soup deliveries that might have occured
        curr_reward = sparse_reward
        self.score += curr_reward
        print("Current tick is : ", self.curr_tick)
        if self.curr_tick >= 399:
            print(f"{bcolors.BOLD}Completed 400 timesteps, score {self.score}{bcolors.ENDC}")
            with open('final_score_forced_coordination_gpt4_cache_at_400ts.txt', 'w') as f:
                f.write(f"Completed 400 timesteps forced coordination, score {self.score}")

        transition = {
            "state": json.dumps(prev_state.to_dict()),
            "joint_action": json.dumps(joint_action),
            "reward": curr_reward,
            "time_left": max(self.max_time - (time() - self.start_time), 0),
            "score": self.score,
            "time_elapsed": time() - self.start_time,
            "cur_gameloop": self.curr_tick,
            "layout": json.dumps(self.mdp.terrain_mtx),
            "layout_name": self.curr_layout,
            "trial_id": str(self.start_time),
            "player_0_id": self.players[0],
            # "player_1_id": self.players[1],
            "player_0_is_human": self.players[0] in self.human_players,
            # "player_1_is_human": self.players[1] in self.human_players,
        }

        self.trajectory.append(transition)

        # Return about the current transition
        return prev_state, joint_action, info

    def enqueue_action(self, player_id, action):
        overcooked_action = self.action_to_overcooked_action[action]
        super(OvercookedGame, self).enqueue_action(
            player_id, overcooked_action
        )

    def reset(self):
        status = super(OvercookedGame, self).reset()
        if status == self.Status.RESET:
            # Hacky way of making sure game timer doesn't "start" until after reset timeout has passed
            self.start_time += self.reset_timeout / 1000

    def tick(self):
        self.curr_tick += 1
        return super(OvercookedGame, self).tick()

    def activate(self):
        super(OvercookedGame, self).activate()

        # Sanity check at start of each game
        if not self.npc_players.union(self.human_players) == set(self.players):
            raise ValueError("Inconsistent State")

        self.curr_layout = self.layouts.pop()
        self.mdp = OvercookedGridworld.from_layout_name(
            self.curr_layout, **self.mdp_params
        )
        
        if self.show_potential:
            self.mp = MotionPlanner.from_pickle_or_compute(
                self.mdp, counter_goals=NO_COUNTERS_PARAMS
            )
        no_counter_params =  {
            "start_orientations": False,
            "wait_allowed": False,
            "counter_goals": [],
            "counter_drop": [],
            "counter_pickup": [],
            "same_motion_goals": False,
        }
        # self.mp_mm = MediumLevelActionManager.from_pickle_or_compute(self.mdp, mlam_params=no_counter_params, force_compute=True)
    
        # self.mdp.start_player_positions = [(3, 1), (1, 1)]
        
        self.state = self.mdp.get_standard_start_state()
        # self.state.players[0].orientation = Direction.NORTH
        # self.state.players[1].orientation = Direction.EAST 
        # self.state.players[0].set_object(ObjectState('dish', self.state.players[0].position))
        # self.state.players[1].set_object(ObjectState('onion', self.state.players[1].position))
        # # self.state.players[1].set_object(ObjectState('soup', self.state.players[1].position, ('onion', 3, 20)))
        # self.state.add_object(ObjectState('soup', (3, 0), ('onion', 3, 5)))
        # self.state.add_object(ObjectState('soup', (4, 1), ('onion', 2, 0)))
        # self.state.add_object(ObjectState('soup', (4, 1), ('onion', 3, 10)))
        # self.state.add_object(ObjectState('onion', (2, 1)))
        # self.state.add_object(ObjectState('onion', (2, 2)))
        # self.state.add_object(ObjectState('dish', (2, 3)))

        # self.simple_planner = SimplePlanner(self.mdp)

        # self.ai_player_state = 

        # sys.exit(0)
        
        if self.show_potential:
            self.phi = self.mdp.potential_function(
                self.state, self.mp, gamma=0.99
            )
        self.start_time = time()
        self.curr_tick = 0
        self.score = 0
        self.threads = []
        
        for idx, npc_policy in enumerate(self.npc_policies):
            # if idx ==LLM_PLAYER_IDX:
            self.action_managers[npc_policy] = LLMActionManager(self.mdp, npc_policy, self.curr_layout)
            self.npc_policies[npc_policy].reset()
            self.npc_state_queues[npc_policy].put(self.state)
            # if idx == LLM_PLAYER_IDX:
            t = Thread(target=self.npc_policy_consumer_llm, args=(npc_policy,))
            # else:
                # t = Thread(target=self.npc_policy_consumer, args=(npc_policy,))

            self.threads.append(t)
            t.start()

    def deactivate(self):
        super(OvercookedGame, self).deactivate()
        # Ensure the background consumers do not hang
        for npc_policy in self.npc_policies:
            self.npc_state_queues[npc_policy].put(self.state)

        # Wait for all background threads to exit
        for t in self.threads:
            t.join()

        # Clear all action queues
        self.clear_pending_actions()

    def get_state(self):
        state_dict = {}
        state_dict["potential"] = self.phi if self.show_potential else None
        state_dict["state"] = self.state.to_dict()
        state_dict["score"] = self.score
        state_dict["time_left"] = max(
            self.max_time - (time() - self.start_time), 0
        )
        return state_dict

    def to_json(self):
        obj_dict = {}
        obj_dict["terrain"] = self.mdp.terrain_mtx if self._is_active else None
        obj_dict["state"] = self.get_state() if self._is_active else None
        return obj_dict

    def get_policy(self, npc_id, idx):
        return LLMPolicy()



    def get_data(self):
        """
        Returns and then clears the accumulated trajectory
        """
        data = {
            "uid": str(time()),
            "trajectory": self.trajectory,
        }
        self.trajectory = []
        # if we want to store the data and there is data to store
        if self.write_data and len(data["trajectory"]) > 0:
            configs = self.write_config
            # create necessary dirs
            # data_path = create_dirs(configs, self.curr_layout)
            # the 3-layer-directory structure should be able to uniquely define any experiment
            # with open(os.path.join(data_path, "result.pkl"), "wb") as f:
            #     pickle.dump(data, f)
        return data
# -------------------------------------------------------------------------------------------------------------------------------
        


class OvercookedTutorial(OvercookedGame):

    """
    Wrapper on OvercookedGame that includes additional data for tutorial mechanics, most notably the introduction of tutorial "phases"

    Instance Variables:
        - curr_phase (int): Indicates what tutorial phase we are currently on
        - phase_two_score (float): The exact sparse reward the user must obtain to advance past phase 2
    """

    def __init__(
        self,
        layouts=["tutorial_0"],
        mdp_params={},
        playerZero="human",
        playerOne="AI",
        phaseTwoScore=15,
        **kwargs
    ):
        super(OvercookedTutorial, self).__init__(
            layouts=layouts,
            mdp_params=mdp_params,
            playerZero=playerZero,
            playerOne=playerOne,
            showPotential=False,
            **kwargs
        )
        self.phase_two_score = phaseTwoScore
        self.phase_two_finished = False
        self.max_time = 0
        self.max_players = 2
        self.ticks_per_ai_action = 1
        self.curr_phase = 0
        # we don't collect tutorial data
        self.write_data = False

    @property
    def reset_timeout(self):
        return 1

    def needs_reset(self):
        if self.curr_phase == 0:
            return self.score > 0
        elif self.curr_phase == 1:
            return self.score > 0
        elif self.curr_phase == 2:
            return self.phase_two_finished
        return False

    def is_finished(self):
        return not self.layouts and self.score >= float("inf")

    def reset(self):
        super(OvercookedTutorial, self).reset()
        self.curr_phase += 1

    def get_policy(self, *args, **kwargs):
        return TutorialAI()

    def apply_actions(self):
        """
        Apply regular MDP logic with retroactive score adjustment tutorial purposes
        """
        _, _, info = super(OvercookedTutorial, self).apply_actions()

        human_reward, ai_reward = info["sparse_reward_by_agent"]

        # We only want to keep track of the human's score in the tutorial
        self.score -= ai_reward

        # Phase two requires a specific reward to complete
        if self.curr_phase == 2:
            self.score = 0
            if human_reward == self.phase_two_score:
                self.phase_two_finished = True


class DummyOvercookedGame(OvercookedGame):
    """
    Class that hardcodes the AI to be random. Used for debugging
    """

    def __init__(self, layouts=["cramped_room"], **kwargs):
        super(DummyOvercookedGame, self).__init__(layouts, **kwargs)

    def get_policy(self, *args, **kwargs):
        return DummyAI()


class DummyAI:
    """
    Randomly samples actions. Used for debugging
    """

    def action(self, state):
        [action] = random.sample(
            [
                Action.STAY,
                Direction.NORTH,
                Direction.SOUTH,
                Direction.WEST,
                Direction.EAST,
                Action.INTERACT,
            ],
            1,
        )
        return action, None

    def reset(self):
        pass


class DummyComputeAI(DummyAI):
    """
    Performs simulated compute before randomly sampling actions. Used for debugging
    """

    def __init__(self, compute_unit_iters=1e5):
        """
        compute_unit_iters (int): Number of for loop cycles in one "unit" of compute. Number of
                                    units performed each time is randomly sampled
        """
        super(DummyComputeAI, self).__init__()
        self.compute_unit_iters = int(compute_unit_iters)

    def action(self, state):
        # Randomly sample amount of time to busy wait
        iters = random.randint(1, 10) * self.compute_unit_iters

        # Actually compute something (can't sleep) to avoid scheduling optimizations
        val = 0
        for i in range(iters):
            # Avoid branch prediction optimizations
            if i % 2 == 0:
                val += 1
            else:
                val += 2

        # Return randomly sampled action
        return super(DummyComputeAI, self).action(state)

class LLMPolicy:
    """
        takes agent to onions
    """
    def action(self, state, action_manager, game_messages):
        action, message = action_manager.get_next_move(state, game_messages[action_manager.other_player_id])
        # print(action)
        game_messages[action_manager.player_id] = message 
        
        # print(f"{bcolors.OKCYAN}[INFO] Player {action_manager.player_id} Action: {action}")
        
        # print(f"[INFO] Player {action_manager.player_id} Message: {message}{bcolors.ENDC}")
        # print(f"[INFO] Player {action_manager.other_player_id} Previous Message: {game_messages[action_manager.other_player_id]}{bcolors.ENDC}")
        return action, message 
    
    def reset(self):
        pass


class StayAI:
    """
    Always returns "stay" action. Used for debugging
    """

    def action(self, state):
        return Action.STAY, None

    def reset(self):
        pass


class TutorialAI:
    COOK_SOUP_LOOP = [
        # Grab first onion
        Direction.WEST,
        Direction.WEST,
        Direction.WEST,
        Action.INTERACT,
        # Place onion in pot
        Direction.EAST,
        Direction.NORTH,
        Action.INTERACT,
        # Grab second onion
        Direction.WEST,
        Action.INTERACT,
        # Place onion in pot
        Direction.EAST,
        Direction.NORTH,
        Action.INTERACT,
        # Grab third onion
        Direction.WEST,
        Action.INTERACT,
        # Place onion in pot
        Direction.EAST,
        Direction.NORTH,
        Action.INTERACT,
        # Cook soup
        Action.INTERACT,
        # Grab plate
        Direction.EAST,
        Direction.SOUTH,
        Action.INTERACT,
        Direction.WEST,
        Direction.NORTH,
        # Deliver soup
        Action.INTERACT,
        Direction.EAST,
        Direction.EAST,
        Direction.EAST,
        Action.INTERACT,
        Direction.WEST,
    ]

    COOK_SOUP_COOP_LOOP = [
        # Grab first onion
        Direction.WEST,
        Direction.WEST,
        Direction.WEST,
        Action.INTERACT,
        # Place onion in pot
        Direction.EAST,
        Direction.SOUTH,
        Action.INTERACT,
        # Move to start so this loops
        Direction.EAST,
        Direction.EAST,
        # Pause to make cooperation more real time
        Action.STAY,
        Action.STAY,
        Action.STAY,
        Action.STAY,
        Action.STAY,
        Action.STAY,
        Action.STAY,
        Action.STAY,
        Action.STAY,
    ]

    def __init__(self):
        self.curr_phase = -1
        self.curr_tick = -1

    def action(self, state):
        self.curr_tick += 1
        if self.curr_phase == 0:
            return (
                self.COOK_SOUP_LOOP[self.curr_tick % len(self.COOK_SOUP_LOOP)],
                None,
            )
        elif self.curr_phase == 2:
            return (
                self.COOK_SOUP_COOP_LOOP[
                    self.curr_tick % len(self.COOK_SOUP_COOP_LOOP)
                ],
                None,
            )
        return Action.STAY, None

    def reset(self):
        self.curr_tick = -1
        self.curr_phase += 1



# class OvercookedGame(Game):
#     """
#     Class for bridging the gap between Overcooked_Env and the Game interface

#     Instance variable:
#         - max_players (int): Maximum number of players that can be in the game at once
#         - mdp (OvercookedGridworld): Controls the underlying Overcooked game logic
#         - score (int): Current reward acheived by all players
#         - max_time (int): Number of seconds the game should last
#         - npc_policies (dict): Maps user_id to policy (Agent) for each AI player
#         - npc_state_queues (dict): Mapping of NPC user_ids to LIFO queues for the policy to process
#         - curr_tick (int): How many times the game server has called this instance's `tick` method
#         - ticker_per_ai_action (int): How many frames should pass in between NPC policy forward passes.
#             Note that this is a lower bound; if the policy is computationally expensive the actual frames
#             per forward pass can be higher
#         - action_to_overcooked_action (dict): Maps action names returned by client to action names used by OvercookedGridworld
#             Note that this is an instance variable and not a static variable for efficiency reasons
#         - human_players (set(str)): Collection of all player IDs that correspond to humans
#         - npc_players (set(str)): Collection of all player IDs that correspond to AI
#         - randomized (boolean): Whether the order of the layouts should be randomized

#     Methods:
#         - npc_policy_consumer: Background process that asynchronously computes NPC policy forward passes. One thread
#             spawned for each NPC
#         - _curr_game_over: Determines whether the game on the current mdp has ended
#     """

#     def __init__(
#         self,
#         layouts=["cramped_room"],
#         mdp_params={},
#         num_players=2,
#         gameTime=30,
#         playerZero="human",
#         playerOne="human",
#         showPotential=False,
#         randomized=False,
#         ticks_per_ai_action=1,
#         **kwargs
#     ):
#         super(OvercookedGame, self).__init__(**kwargs)
#         self.show_potential = showPotential
#         self.mdp_params = mdp_params
#         self.layouts = layouts
#         self.max_players = int(num_players)
#         self.mdp = None
#         self.mp = None
#         self.score = 0
#         self.phi = 0
#         self.max_time = min(int(gameTime), MAX_GAME_TIME)
#         self.npc_policies = {}
#         self.npc_state_queues = {}
#         self.action_to_overcooked_action = {
#             "STAY": Action.STAY,
#             "UP": Direction.NORTH,
#             "DOWN": Direction.SOUTH,
#             "LEFT": Direction.WEST,
#             "RIGHT": Direction.EAST,
#             "SPACE": Action.INTERACT,
#         }
#         self.ticks_per_ai_action = ticks_per_ai_action
#         self.curr_tick = 0
#         self.human_players = set()
#         self.npc_players = set()

#         if randomized:
#             random.shuffle(self.layouts)

#         if playerZero != "human":
#             player_zero_id = playerZero + "_0"
#             self.add_player(player_zero_id, idx=0, buff_size=1, is_human=False)
#             self.npc_policies[player_zero_id] = self.get_policy(
#                 playerZero, idx=0
#             )
#             self.npc_state_queues[player_zero_id] = LifoQueue()

#         if playerOne != "human":
#             player_one_id = playerOne + "_1"
#             self.add_player(player_one_id, idx=1, buff_size=1, is_human=False)
#             self.npc_policies[player_one_id] = self.get_policy(
#                 playerOne, idx=1
#             )
#             self.npc_state_queues[player_one_id] = LifoQueue()
#         # Always kill ray after loading agent, otherwise, ray will crash once process exits
#         # Only kill ray after loading both agents to avoid having to restart ray during loading
#         if ray.is_initialized():
#             ray.shutdown()

#         if kwargs["dataCollection"]:
#             self.write_data = True
#             self.write_config = kwargs["collection_config"]
#         else:
#             self.write_data = False

#         self.trajectory = []

#     def _curr_game_over(self):
#         return time() - self.start_time >= self.max_time

#     def needs_reset(self):
#         return self._curr_game_over() and not self.is_finished()

#     def add_player(self, player_id, idx=None, buff_size=-1, is_human=True):
#         super(OvercookedGame, self).add_player(
#             player_id, idx=idx, buff_size=buff_size
#         )
#         if is_human:
#             self.human_players.add(player_id)
#         else:
#             self.npc_players.add(player_id)

#     def remove_player(self, player_id):
#         removed = super(OvercookedGame, self).remove_player(player_id)
#         if removed:
#             if player_id in self.human_players:
#                 self.human_players.remove(player_id)
#             elif player_id in self.npc_players:
#                 self.npc_players.remove(player_id)
#             else:
#                 raise ValueError("Inconsistent state")

#     def npc_policy_consumer(self, policy_id):
#         queue = self.npc_state_queues[policy_id]
#         policy = self.npc_policies[policy_id]
#         while self._is_active:
#             state = queue.get()
#             npc_action, _ = policy.action(state)
#             super(OvercookedGame, self).enqueue_action(policy_id, npc_action)

#     def is_full(self):
#         return self.num_players >= self.max_players

#     def is_finished(self):
#         val = not self.layouts and self._curr_game_over()
#         return val

#     def is_empty(self):
#         """
#         Game is considered safe to scrap if there are no active players or if there are no humans (spectating or playing)
#         """
#         return (
#             super(OvercookedGame, self).is_empty()
#             or not self.spectators
#             and not self.human_players
#         )

#     def is_ready(self):
#         """
#         Game is ready to be activated if there are a sufficient number of players and at least one human (spectator or player)
#         """
#         return super(OvercookedGame, self).is_ready() and not self.is_empty()

#     def apply_action(self, player_id, action):
#         pass

#     def apply_actions(self):
#         # Default joint action, as NPC policies and clients probably don't enqueue actions fast
#         # enough to produce one at every tick
#         joint_action = [Action.STAY] * len(self.players)

#         # Synchronize individual player actions into a joint-action as required by overcooked logic
#         for i in range(len(self.players)):
#             # if this is a human, don't block and inject
#             if self.players[i] in self.human_players:
#                 try:
#                     # we don't block here in case humans want to Stay
#                     joint_action[i] = self.pending_actions[i].get(block=False)
#                 except Empty:
#                     pass
#             else:
#                 # we block on agent actions to ensure that the agent gets to do one action per state
#                 joint_action[i] = self.pending_actions[i].get(block=True)

#         # Apply overcooked game logic to get state transition
#         prev_state = self.state
#         self.state, info = self.mdp.get_state_transition(
#             prev_state, joint_action
#         )
#         if self.show_potential:
#             self.phi = self.mdp.potential_function(
#                 prev_state, self.mp, gamma=0.99
#             )

#         # Send next state to all background consumers if needed
#         if self.curr_tick % self.ticks_per_ai_action == 0:
#             for npc_id in self.npc_policies:
#                 self.npc_state_queues[npc_id].put(self.state, block=False)

#         # Update score based on soup deliveries that might have occured
#         curr_reward = sum(info["sparse_reward_by_agent"])
#         self.score += curr_reward

#         transition = {
#             "state": json.dumps(prev_state.to_dict()),
#             "joint_action": json.dumps(joint_action),
#             "reward": curr_reward,
#             "time_left": max(self.max_time - (time() - self.start_time), 0),
#             "score": self.score,
#             "time_elapsed": time() - self.start_time,
#             "cur_gameloop": self.curr_tick,
#             "layout": json.dumps(self.mdp.terrain_mtx),
#             "layout_name": self.curr_layout,
#             "trial_id": str(self.start_time),
#             "player_0_id": self.players[0],
#             "player_1_id": self.players[1],
#             "player_0_is_human": self.players[0] in self.human_players,
#             "player_1_is_human": self.players[1] in self.human_players,
#         }

#         self.trajectory.append(transition)

#         # Return about the current transition
#         return prev_state, joint_action, info

#     def enqueue_action(self, player_id, action):
#         overcooked_action = self.action_to_overcooked_action[action]
#         super(OvercookedGame, self).enqueue_action(
#             player_id, overcooked_action
#         )

#     def reset(self):
#         status = super(OvercookedGame, self).reset()
#         if status == self.Status.RESET:
#             # Hacky way of making sure game timer doesn't "start" until after reset timeout has passed
#             self.start_time += self.reset_timeout / 1000

#     def tick(self):
#         self.curr_tick += 1
#         return super(OvercookedGame, self).tick()

#     def activate(self):
#         super(OvercookedGame, self).activate()

#         # Sanity check at start of each game
#         if not self.npc_players.union(self.human_players) == set(self.players):
#             raise ValueError("Inconsistent State")

#         self.curr_layout = self.layouts.pop()
#         self.mdp = OvercookedGridworld.from_layout_name(
#             self.curr_layout, **self.mdp_params
#         )
#         if self.show_potential:
#             self.mp = MotionPlanner.from_pickle_or_compute(
#                 self.mdp, counter_goals=NO_COUNTERS_PARAMS
#             )
#         self.state = self.mdp.get_standard_start_state()
#         if self.show_potential:
#             self.phi = self.mdp.potential_function(
#                 self.state, self.mp, gamma=0.99
#             )
#         self.start_time = time()
#         self.curr_tick = 0
#         self.score = 0
#         self.threads = []
#         for npc_policy in self.npc_policies:
#             self.npc_policies[npc_policy].reset()
#             self.npc_state_queues[npc_policy].put(self.state)
#             t = Thread(target=self.npc_policy_consumer, args=(npc_policy,))
#             self.threads.append(t)
#             t.start()

#     def deactivate(self):
#         super(OvercookedGame, self).deactivate()
#         # Ensure the background consumers do not hang
#         for npc_policy in self.npc_policies:
#             self.npc_state_queues[npc_policy].put(self.state)

#         # Wait for all background threads to exit
#         for t in self.threads:
#             t.join()

#         # Clear all action queues
#         self.clear_pending_actions()

#     def get_state(self):
#         state_dict = {}
#         state_dict["potential"] = self.phi if self.show_potential else None
#         state_dict["state"] = self.state.to_dict()
#         state_dict["score"] = self.score
#         state_dict["time_left"] = max(
#             self.max_time - (time() - self.start_time), 0
#         )
#         return state_dict

#     def to_json(self):
#         obj_dict = {}
#         obj_dict["terrain"] = self.mdp.terrain_mtx if self._is_active else None
#         obj_dict["state"] = self.get_state() if self._is_active else None
#         return obj_dict

#     def get_policy(self, npc_id, idx=0):
#         if npc_id.lower().startswith("rllib"):
#             try:
#                 # Loading rllib agents requires additional helpers
#                 fpath = os.path.join(AGENT_DIR, npc_id, "agent")
#                 fix_bc_path(fpath)
#                 agent = load_agent(fpath, agent_index=idx)
#                 return agent
#             except Exception as e:
#                 raise IOError(
#                     "Error loading Rllib Agent\n{}".format(e.__repr__())
#                 )
#         else:
#             try:
#                 fpath = os.path.join(AGENT_DIR, npc_id, "agent.pickle")
#                 with open(fpath, "rb") as f:
#                     return pickle.load(f)
#             except Exception as e:
#                 raise IOError("Error loading agent\n{}".format(e.__repr__()))

#     def get_data(self):
#         """
#         Returns and then clears the accumulated trajectory
#         """
#         data = {
#             "uid": str(time()),
#             "trajectory": self.trajectory,
#         }
#         self.trajectory = []
#         # if we want to store the data and there is data to store
#         if self.write_data and len(data["trajectory"]) > 0:
#             configs = self.write_config
#             # create necessary dirs
#             data_path = create_dirs(configs, self.curr_layout)
#             # the 3-layer-directory structure should be able to uniquely define any experiment
#             with open(os.path.join(data_path, "result.pkl"), "wb") as f:
#                 pickle.dump(data, f)
#         return data