import gym
import tqdm
import numpy as np
from hsp.envs.overcooked.overcooked_ai_py.utils import mean_and_std_err
from hsp.envs.overcooked.overcooked_ai_py.mdp.actions import Action, Direction
from hsp.envs.overcooked.overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld, BASE_REW_SHAPING_PARAMS,FANCP_REW_SHAPING_PARAMS
from hsp.envs.overcooked.overcooked_ai_py.planning.planners import MediumLevelPlanner, NO_COUNTERS_PARAMS
from hsp.envs.overcooked.overcooked_ai_py.mdp.overcooked_trajectory import TIMESTEP_TRAJ_KEYS, EPISODE_TRAJ_KEYS, DEFAULT_TRAJ_KEYS
from hsp.envs.overcooked.overcooked_ai_py.visualization.state_visualizer import StateVisualizer
from utils.pic2video import write_gif
import imageio
import os
from hsp.envs.overcooked.script_agent import SCRIPT_AGENTS
import matplotlib.pyplot as plt
import os
import pickle
import argparse
from pathlib import Path

DEFAULT_ENV_PARAMS = {
    "horizon": 400
}

MAX_HORIZON = 1e10

class OvercookedEnv(object):
    """An environment wrapper for the OvercookedGridworld Markov Decision Process.

    The environment keeps track of the current state of the agent, updates
    it as the agent takes actions, and provides rewards to the agent.
    """

    def __init__(self, mdp, start_state_fn=None, horizon=MAX_HORIZON, debug=False):
        """
        mdp (OvercookedGridworld or function): either an instance of the MDP or a function that returns MDP instances 
        start_state_fn (OvercookedState): function that returns start state for the MDP, called at each environment reset
        horizon (float): number of steps before the environment returns done=True
        """
        if isinstance(mdp, OvercookedGridworld):
            self.mdp_generator_fn = lambda: mdp
        elif callable(mdp) and isinstance(mdp(), OvercookedGridworld):
            self.mdp_generator_fn = mdp
        else:
            raise ValueError("Mdp should be either OvercookedGridworld instance or a generating function")
        
        self.horizon = horizon
        self.start_state_fn = start_state_fn
        self.reset()
        if self.horizon >= MAX_HORIZON and self.state.order_list is None and debug:
            print("Environment has (near-)infinite horizon and no terminal states")

    def __repr__(self):
        """Standard way to view the state of an environment programatically
        is just to print the Env object"""
        return self.mdp.state_string(self.state)

    def print_state_transition(self, a_t, r_t, info):
        print("Timestep: {}\nJoint action taken: {} \t Reward: {} + shape * {} \n{}\n".format(
            self.t, tuple(Action.ACTION_TO_CHAR[a] for a in a_t), r_t, info["shaped_r"], self)
        )

    @property
    def env_params(self):
        return {
            "start_state_fn": self.start_state_fn,
            "horizon": self.horizon
        }

    def display_states(self, *states):
        old_state = self.state
        for s in states:
            self.state = s
            print(self)
        self.state = old_state

    @staticmethod
    def print_state(mdp, s):
        e = OvercookedEnv(mdp, s)
        print(e)

    def copy(self):
        return OvercookedEnv(
            mdp=self.mdp.copy(),
            start_state_fn=self.start_state_fn,
            horizon=self.horizon
        )

    def step(self, joint_action):
        """Performs a joint action, updating the environment state
        and providing a reward.
        
        On being done, stats about the episode are added to info:
            ep_sparse_r: the environment sparse reward, given only at soup delivery
            ep_shaped_r: the component of the reward that is due to reward shaped (excluding sparse rewards)
            ep_length: length of rollout
        """
        assert not self.is_done()
        self.t += 1
        next_state, mdp_infos = self.mdp.get_state_transition(self.state, joint_action)

        # Update game_stats 
        self._update_game_stats(mdp_infos)

        # Update state and done
        self.state = next_state
        done = self.is_done()
        env_info = self._prepare_info_dict([{}, {}], mdp_infos)
        
        if done: self._add_episode_info(env_info)

        timestep_sparse_reward = sum(mdp_infos["sparse_reward_by_agent"])
        return next_state, timestep_sparse_reward, done, env_info
    
    def reset(self):
        """Resets the environment. Does NOT reset the agent."""
        self.mdp = self.mdp_generator_fn()
        if self.start_state_fn is None:
            self.state = self.mdp.get_standard_start_state()
        elif type(self.start_state_fn) in [float, int]:
            p = np.random.uniform(0, 1)
            if p <= self.start_state_fn:
                self.state = self.mdp.get_random_start_state()
            else:
                self.state = self.mdp.get_standard_start_state()
        else:
            self.state = self.start_state_fn()
        self.cumulative_sparse_rewards = 0
        self.cumulative_shaped_rewards = 0
        self.t = 0

        rewards_dict = {
            "cumulative_sparse_rewards_by_agent": np.array([0] * self.mdp.num_players),
            "cumulative_shaped_rewards_by_agent": np.array([0] * self.mdp.num_players),
            "cumulative_category_rewards_by_agent": np.zeros((self.mdp.num_players, 12)),
        }
        self.game_stats = {**rewards_dict}

    def is_done(self):
        """Whether the episode is over."""
        return self.t >= self.horizon or self.mdp.is_terminal(self.state)

    def _prepare_info_dict(self, joint_agent_action_info, mdp_infos):
        """
        The normal timestep info dict will contain infos specifc to each agent's action taken,
        and reward shaping information.
        """
        # Get the agent action info, that could contain info about action probs, or other
        # custom user defined information
        env_info = {"agent_infos": [joint_agent_action_info[agent_idx] for agent_idx in range(self.mdp.num_players)]}
        # TODO: This can be further simplified by having all the mdp_infos copied over to the env_infos automatically 
        env_info["sparse_r_by_agent"] = mdp_infos["sparse_reward_by_agent"]
        env_info["shaped_r_by_agent"] = mdp_infos["shaped_reward_by_agent"]
        env_info["shaped_info_by_agent"] = mdp_infos["shaped_info_by_agent"]
        env_info["phi_s"] = mdp_infos["phi_s"] if "phi_s" in mdp_infos else None
        env_info["phi_s_prime"] = mdp_infos["phi_s_prime"] if "phi_s_prime" in mdp_infos else None
        return env_info

    def _add_episode_info(self, env_info):
        env_info["episode"] = {
            "ep_game_stats": self.game_stats,
            "ep_sparse_r": sum(self.game_stats["cumulative_sparse_rewards_by_agent"]),
            "ep_shaped_r": sum(self.game_stats["cumulative_shaped_rewards_by_agent"]),
            "ep_sparse_r_by_agent": self.game_stats["cumulative_sparse_rewards_by_agent"],
            "ep_shaped_r_by_agent": self.game_stats["cumulative_shaped_rewards_by_agent"],
            "ep_category_r_by_agent": self.game_stats["cumulative_category_rewards_by_agent"],
            "ep_length": self.t
        }
        return env_info

    def vectorize_shaped_info(self, shaped_info_by_agent):
        def vectorize(d: dict):
            return np.array([v for k, v in d.items()])
        shaped_info_by_agent = np.stack([vectorize(shaped_info) for shaped_info in shaped_info_by_agent])
        return shaped_info_by_agent

    def _update_game_stats(self, infos):
        """
        Update the game stats dict based on the events of the current step
        NOTE: the timer ticks after events are logged, so there can be events from time 0 to time self.horizon - 1
        """
        self.game_stats["cumulative_sparse_rewards_by_agent"] += np.array(infos["sparse_reward_by_agent"])
        self.game_stats["cumulative_shaped_rewards_by_agent"] += np.array(infos["shaped_reward_by_agent"])
        self.game_stats["cumulative_category_rewards_by_agent"] += self.vectorize_shaped_info(infos["shaped_info_by_agent"])

        '''for event_type, bool_list_by_agent in infos["event_infos"].items():
            # For each event type, store the timestep if it occurred
            event_occurred_by_idx = [int(x) for x in bool_list_by_agent]
            for idx, event_by_agent in enumerate(event_occurred_by_idx):
                if event_by_agent:
                    self.game_stats[event_type][idx].append(self.state.timestep)'''

    def execute_plan(self, start_state, joint_action_plan, display=False):
        """Executes action_plan (a list of joint actions) from a start 
        state in the mdp and returns the resulting state."""
        self.state = start_state
        done = False
        if display: print("Starting state\n{}".format(self))
        for joint_action in joint_action_plan:
            self.step(joint_action)
            done = self.is_done()
            if display: print(self)
            if done: break
        successor_state = self.state
        self.reset()
        return successor_state, done

    def run_agents(self, agent_pair, include_final_state=False, display=False, display_until=np.Inf):
        """
        Trajectory returned will a list of state-action pairs (s_t, joint_a_t, r_t, done_t).
        """
        assert self.cumulative_sparse_rewards == self.cumulative_shaped_rewards == 0, \
            "Did not reset environment before running agents"
        trajectory = []
        done = False

        if display: print(self)
        while not done:
            s_t = self.state
            a_t = agent_pair.joint_action(s_t)

            # Break if either agent is out of actions
            if any([a is None for a in a_t]):
                break

            s_tp1, r_t, done, info = self.step(a_t)
            trajectory.append((s_t, a_t, r_t, done))

            if display and self.t < display_until:
                self.print_state_transition(a_t, r_t, info)

        assert len(trajectory) == self.t, "{} vs {}".format(len(trajectory), self.t)

        # Add final state
        if include_final_state:
            trajectory.append((s_tp1, (None, None), 0, True))

        return np.array(trajectory), self.t, self.cumulative_sparse_rewards, self.cumulative_shaped_rewards

    def get_rollouts(self, agent_pair, num_games, display=False, final_state=False, agent_idx=0, reward_shaping=0.0, display_until=np.Inf, info=True):
        """
        Simulate `num_games` number rollouts with the current agent_pair and returns processed 
        trajectories.

        Only returns the trajectories for one of the agents (the actions _that_ agent took), 
        namely the one indicated by `agent_idx`.

        Returning excessive information to be able to convert trajectories to any required format 
        (baselines, stable_baselines, etc)

        NOTE: standard trajectories format used throughout the codebase
        """
        trajectories = {
            # With shape (n_timesteps, game_len), where game_len might vary across games:
            "ep_observations": [],
            "ep_actions": [],
            "ep_rewards": [], # Individual dense (= sparse + shaped * rew_shaping) reward values
            "ep_dones": [], # Individual done values

            # With shape (n_episodes, ):
            "ep_returns": [], # Sum of dense and sparse rewards across each episode
            "ep_returns_sparse": [], # Sum of sparse rewards across each episode
            "ep_lengths": [], # Lengths of each episode
            "mdp_params": [], # Custom MDP params to for each episode
            "env_params": [] # Custom Env params for each episode
        }

        for _ in tqdm.trange(num_games):
            agent_pair.set_mdp(self.mdp)

            trajectory, time_taken, tot_rews_sparse, tot_rews_shaped = self.run_agents(agent_pair, display=display, include_final_state=final_state, display_until=display_until)
            obs, actions, rews, dones = trajectory.T[0], trajectory.T[1], trajectory.T[2], trajectory.T[3]
            trajectories["ep_observations"].append(obs)
            trajectories["ep_actions"].append(actions)
            trajectories["ep_rewards"].append(rews)
            trajectories["ep_dones"].append(dones)
            trajectories["ep_returns"].append(tot_rews_sparse + tot_rews_shaped * reward_shaping)
            trajectories["ep_returns_sparse"].append(tot_rews_sparse)
            trajectories["ep_lengths"].append(time_taken)
            trajectories["mdp_params"].append(self.mdp.mdp_params)
            trajectories["env_params"].append(self.env_params)

            self.reset()
            agent_pair.reset()

        mu, se = mean_and_std_err(trajectories["ep_returns"])
        if info: print("Avg reward {:.2f} (std: {:.2f}, se: {:.2f}) over {} games of avg length {}".format(
            mu, np.std(trajectories["ep_returns"]), se, num_games, np.mean(trajectories["ep_lengths"]))
        )

        # Converting to numpy arrays
        trajectories = {k: np.array(v) for k, v in trajectories.items()}
        return trajectories


class Overcooked_Old(gym.Env):
    """
    Wrapper for the Env class above that is SOMEWHAT compatible with the standard gym API.

    NOTE: Observations returned are in a dictionary format with various information that is
    necessary to be able to handle the multi-agent nature of the environment. There are probably
    better ways to handle this, but we found this to work with minor modifications to OpenAI Baselines.
    
    NOTE: The index of the main agent in the mdp is randomized at each reset of the environment, and 
    is kept track of by the self.agent_idx attribute. This means that it is necessary to pass on this 
    information in the output to know for which agent index featurizations should be made for other agents.
    
    For example, say one is training A0 paired with A1, and A1 takes a custom state featurization.
    Then in the runner.py loop in OpenAI Baselines, we will get the lossless encodings of the state,
    and the true Overcooked state. When we encode the true state to feed to A1, we also need to know
    what agent index it has in the environment (as encodings will be index dependent).
    """
    env_name = "Overcooked-v0"

    def __init__(self, map_name, seed, baselines_reproducible=True, featurize_type=("bc", "bc"), stuck_time=4, rank=None, use_render=False):
        """
        base_env: OvercookedEnv
        featurize_fn(mdp, state): fn used to featurize states returned in the 'both_agent_obs' field
        """
        if baselines_reproducible:
            # NOTE:
            # This will cause all agent indices to be chosen in sync across simulation 
            # envs (for each update, all envs will have index 0 or index 1).
            # This is to prevent the randomness of choosing agent indexes
            # from leaking when using subprocess-vec-env in baselines (which
            # seeding does not reach) i.e. having different results for different
            # runs with the same seed.
            # The effect of this should be negligible, as all other randomness is 
            # controlled by the actual run seeds
            np.random.seed(seed)
        def parse_args():
            parser = argparse.ArgumentParser(description='hsp', formatter_class=argparse.RawDescriptionHelpFormatter)
            parser.add_argument("--layout_name", type=str, default='cramped_room', help="Name of Submap, 40+ in choice. See /src/data/layouts/.")
            parser.add_argument('--num_agents', type=int, default=2, help="number of players")
            parser.add_argument("--initial_reward_shaping_factor", type=float, default=1.0, help="Shaping factor of potential dense reward.")
            parser.add_argument("--reward_shaping_factor", type=float, default=1.0, help="Shaping factor of potential dense reward.")
            parser.add_argument("--reward_shaping_horizon", type=int, default=100000000, help="Shaping factor of potential dense reward.")
            parser.add_argument("--use_phi", default=False, action='store_true', help="While existing other agent like planning or human model, use an index to fix the main RL-policy agent.")
            parser.add_argument("--use_hsp", default=False, action='store_true') 
            parser.add_argument("--random_index", default=False, action='store_true') 
            parser.add_argument("--w0", type=str, default="1,1,1,1", help="Weight vector of dense reward 0 in overcooked env.")
            parser.add_argument("--w1", type=str, default="1,1,1,1", help="Weight vector of dense reward 1 in overcooked env.") 
            parser.add_argument("--predict_other_shaped_info", default=False, action='store_true', help="Predict other agent's shaped info within a short horizon, default False")
            parser.add_argument("--predict_shaped_info_horizon", default=50, type=int, help="Horizon for shaped info target, default 50")
            parser.add_argument("--predict_shaped_info_event_count", default=10, type=int, help="Event count for shaped info target, default 10")
            parser.add_argument("--use_task_v_out", default=False, action='store_true')
            parser.add_argument("--random_start_prob", default=0., type=float, help="Probability to use a random start state, default 0.")
            parser.add_argument("--use_detailed_rew_shaping", default=False, action="store_true")
            parser.add_argument("--overcooked_version", default="old", type=str, choices=["new", "old"])
            # replay buffer parameters
            parser.add_argument("--episode_length", type=int,
                                default=400, help="Max length for any episode")
            parser.add_argument("--use_render", action='store_true', default=False, help="by default, do not render the env during training. If set, start render. Note: something, the environment has internal render process which is not controlled by this hyperparam.")

            all_args = parser.parse_known_args()[0]
            return all_args

        all_args = parse_args()
        all_args.layout_name = map_name
        self.all_args = all_args
        self.all_args.layout_name = map_name
        
        # run dir
        cur_dir = os.path.dirname(os.path.abspath(__file__))
        run_dir = Path(os.path.abspath(os.path.join(
            cur_dir, '..', '..', '..', "save_replay", 
            'overcooked', 
            self.all_args.layout_name)))
        
        if not run_dir.exists() and use_render:
            os.makedirs(str(run_dir))
        
        self.agent_idx = 0 
        self._initial_reward_shaping_factor = all_args.initial_reward_shaping_factor
        self.reward_shaping_factor = all_args.reward_shaping_factor
        self.reward_shaping_horizon = all_args.reward_shaping_horizon
        self.use_phi = all_args.use_phi
        self.use_hsp = all_args.use_hsp
        self.use_detailed_rew_shaping = all_args.use_detailed_rew_shaping
        self.store_traj = getattr(all_args, "store_traj", False)
        self.rank = rank
        self.random_index = all_args.random_index
        if self.use_hsp:
            self.w0 = self.string2array(all_args.w0)
            self.w1 = self.string2array(all_args.w1)
            print(self.w0, self.w1)
        self.use_render = use_render
        self.num_agents = all_args.num_agents
        self.layout_name = all_args.layout_name
        self.episode_length = all_args.episode_length
        self.episode_limit = all_args.episode_length
        self.random_start_prob = getattr(all_args, "random_start_prob", 0.)
        self.stuck_time = stuck_time
        self.history_sa = []
        self.traj_num = 0
        self.step_count = 0
        self.run_dir = run_dir
        mdp_params = {'layout_name': all_args.layout_name, 'start_order_list': None}
        mdp_params.update({
            "rew_shaping_params": BASE_REW_SHAPING_PARAMS
        })
        env_params = {'horizon': all_args.episode_length}
        self.mdp_fn = lambda: OvercookedGridworld.from_layout_name(**mdp_params)
        self.base_mdp = self.mdp_fn()
        self.base_env = OvercookedEnv(self.mdp_fn, start_state_fn=self.random_start_prob, **env_params)
        self.mlp = MediumLevelPlanner.from_pickle_or_compute(
            mdp=self.base_mdp,
            mlp_params=NO_COUNTERS_PARAMS,
            force_compute=False
        )
        self.use_agent_policy_id = dict(all_args._get_kwargs()).get("use_agent_policy_id", False) # Add policy id for loaded policy
        self.agent_policy_id = [-1. for _ in range(self.num_agents)]
        self.featurize_fn_ppo = lambda state: self.base_mdp.lossless_state_encoding(state) # Encoding obs for PPO
        self.featurize_fn_bc = lambda state: self.base_mdp.featurize_state(state, self.mlp) # Encoding obs for BC
        self.featurize_fn_mapping = {
            "ppo": self.featurize_fn_ppo,
            "bc": self.featurize_fn_bc
        }
        self.reset_featurize_type(featurize_type=featurize_type) # default agents are both ppo


        self.script_agent = [None, None]
        
        self.action_buffer = []
        self.action_old = [0, 0]
        self.same_action_count = [0, 0]
        self.pos1_old = [0, 0]
        self.pos2_old = [0, 0]
        self.keep_pos_count = [0, 0]
        self.congestion_count = 0

    def reset_featurize_type(self, featurize_type=("bc", "bc")):
        assert len(featurize_type) == 2
        self.featurize_type = featurize_type
        self.featurize_fn = lambda state: [self.featurize_fn_mapping[f](state)[i] * (255 if f == "ppo" else 1)for i, f in enumerate(self.featurize_type)]

        # reset observation_space, share_observation_space and action_space
        self.observation_space = []
        self.share_observation_space = []
        self.action_space = []
        self._setup_observation_space()
        for i in range(2):
            self.observation_space.append(self._observation_space(featurize_type[i]))
            self.action_space.append(gym.spaces.Discrete(len(Action.ALL_ACTIONS)))
            self.share_observation_space.append(self._setup_share_observation_space())

    def _anneal(self, start_v, curr_t, end_t, end_v=0, start_t=0):
        if end_t == 0:
            # No annealing if horizon is zero
            return start_v
        else:
            off_t = curr_t - start_t
            # Calculate the new value based on linear annealing formula
            fraction = max(1 - float(off_t) / (end_t - start_t), 0)
            return fraction * start_v + (1 - fraction) * end_v

    def onehot2idx(self, onehot):
        idx = []
        for a in onehot:
            idx.append(np.argmax(a))
        return idx
    
    def string2array(self, weight):
        w = []
        for s in weight.split(','):
            w.append(float(s))
        return np.array(w)

    def _action_convertor(self, action):
        return [int(a) for a in action]

    def _observation_space(self, featurize_type):
        return {
            "ppo": self.ppo_observation_space,
            "bc": self.bc_observation_space
        }[featurize_type]

    def _setup_observation_space(self):
        dummy_state = self.base_env.mdp.get_standard_start_state()
        
        #ppo observation
        featurize_fn_ppo = lambda state: self.base_mdp.lossless_state_encoding(state)
        obs_shape = featurize_fn_ppo(dummy_state)[0].shape
        high = np.ones(obs_shape) * float("inf")
        low = np.ones(obs_shape) * 0
        self.ppo_observation_space = gym.spaces.Box(np.float32(low), np.float32(high), dtype=np.float32)

        # bc observation
        featurize_fn_bc = lambda state: self.base_mdp.featurize_state(state, self.mlp)
        obs_shape = featurize_fn_bc(dummy_state)[0].shape
        high = np.ones(obs_shape, dtype=np.float32) * max(self.base_env.mdp.soup_cooking_time, self.base_env.mdp.num_items_for_soup, 5)
        self.bc_observation_space = gym.spaces.Box(high * 0, high, dtype=np.float32)                                                                                                                                       

    def _setup_share_observation_space(self):
        dummy_state = self.base_env.mdp.get_standard_start_state()
        share_obs_shape = self.featurize_fn_ppo(dummy_state)[0].shape
        if self.use_agent_policy_id:
            share_obs_shape = [share_obs_shape[0], share_obs_shape[1], share_obs_shape[2] + 1]
        share_obs_shape = [share_obs_shape[0], share_obs_shape[1], share_obs_shape[2]*self.num_agents]
        high = np.ones(share_obs_shape) * float("inf")
        low = np.ones(share_obs_shape) * 0
        
        return gym.spaces.Box(np.float32(low), np.float32(high), dtype=np.float32)
    
    def _set_agent_policy_id(self, agent_policy_id):
        self.agent_policy_id = agent_policy_id

    def _gen_share_observation(self, state):
        share_obs = list(self.featurize_fn_ppo(state))
        if self.agent_idx == 1:
            share_obs = [share_obs[1], share_obs[0]]
        if self.use_agent_policy_id:
            for a in range(self.num_agents):
                share_obs[a] = np.concatenate([share_obs[a], np.ones((*share_obs[a].shape[:2], 1), dtype=np.float32) * self.agent_policy_id[a]], axis=-1)
        share_obs0 = np.concatenate([share_obs[0], share_obs[1]], axis=-1) * 255
        share_obs1 = np.concatenate([share_obs[1], share_obs[0]], axis=-1) * 255
        return np.stack([share_obs0, share_obs1], axis=0)

    def step(self, action, test_mode = False):
        """
        action: 
            (agent with index self.agent_idx action, other agent action)
            is a tuple with the joint action of the primary and secondary agents in index format
        
        returns:
            observation: formatted to be standard input for self.agent_idx's policy

        main_agent_index:
            While existing other agent like planning or human model, use an index to fix the main RL-policy agent.
            Default False for multi-agent training.
        """
        self.step_count += 1
        action = self._action_convertor(action)

        assert all(self.action_space[0].contains(a) for a in action), "%r (%s) invalid"%(action, type(action))

        agent_action, other_agent_action = [Action.INDEX_TO_ACTION[a] for a in action]

        joint_action = [agent_action, other_agent_action]
        
        for a in range(self.num_agents):
            if self.script_agent[a] is not None:
                joint_action[a] = self.script_agent[a].step(self.base_env.mdp, self.base_env.state, a)
        joint_action = tuple(joint_action)

        if self.agent_idx == 1:
            joint_action = (other_agent_action, agent_action)

        if self.stuck_time > 0:
            self.history_sa[-1][1] = joint_action
        
        if self.store_traj:
            self.traj_to_store.append(joint_action)
        
        if self.use_phi:
            raise NotImplementedError
            next_state, sparse_reward, done, info = self.base_env.step(joint_action, display_phi=True)
            potential = info['phi_s_prime'] - info['phi_s']
            dense_reward = (potential, potential)
            shaped_reward_p0 = sparse_reward + self.reward_shaping_factor * dense_reward[0]
            shaped_reward_p1 = sparse_reward + self.reward_shaping_factor * dense_reward[1]
        else:
            next_state, sparse_reward, done, info = self.base_env.step(joint_action)
            
            if self.use_hsp:
                shaped_info = info["shaped_info_by_agent"]
                # ["put_onion_on_X", "put_dish_on_X", "put_soup_on_X", "pickup_onion_from_X", "pickup_onion_from_O", "pickup_dish_from_X", "pickup_dish_from_D", "pickup_soup_from_X", "USEFUL_DISH_PICKUP"
                # "SOUP_PICKUP", "PLACEMENT_IN_POT", "delivery", "stay"]
                vec_shaped_info = np.array([[v for k, v in agent_info.items()] for agent_info in shaped_info]).astype(np.float32)
                # add stay
                vec_shaped_info = np.concatenate([vec_shaped_info, np.array([[joint_action[0] == Action.STAY], [joint_action[1] == Action.STAY]]).astype(np.float32)], axis=1)
                assert len(self.w0) == len(self.w1) == 13 + 1
                dense_reward = info["shaped_r_by_agent"]
                if self.agent_idx == 0:
                    hidden_reward = (np.dot(self.w0[:-1], vec_shaped_info[0]) + sparse_reward * self.w0[-1], np.dot(self.w1[:-1], vec_shaped_info[1]) + sparse_reward * self.w1[-1])
                else:
                    hidden_reward = (np.dot(self.w1[:-1], vec_shaped_info[0]) + sparse_reward * self.w1[-1], np.dot(self.w0[:-1], vec_shaped_info[1]) + sparse_reward * self.w0[-1])
                shaped_reward_p0 = hidden_reward[0] + self.reward_shaping_factor * dense_reward[0]
                shaped_reward_p1 = hidden_reward[1] + self.reward_shaping_factor * dense_reward[1]
            else:
                dense_reward = info["shaped_r_by_agent"]
                shaped_reward_p0 = sparse_reward + self.reward_shaping_factor * dense_reward[0]
                shaped_reward_p1 = sparse_reward + self.reward_shaping_factor * dense_reward[1]
        
        stop_punishment = True
        punishment = 0
        if stop_punishment:
            if self.action_old[0] == action[0]:
                self.same_action_count[0] += 1
            else:
                self.same_action_count[0] = 0
            
            if self.action_old[1] == action[1]:
                self.same_action_count[1] += 1
            else:
                self.same_action_count[1] = 0
            
            if self.same_action_count[0] == 6:
                self.same_action_count[0] = 0
                punishment += 3

            if self.same_action_count[1] == 6:
                self.same_action_count[1] = 0
                punishment += 3
            
            pos1, pos2 = next_state.player_positions
            if list(pos1) == list(self.pos1_old):
                self.keep_pos_count[0] += 1
            else:
                self.keep_pos_count[0] = 0
                
            if list(pos2) == list(self.pos2_old):
                self.keep_pos_count[1] += 1
            else:
                self.keep_pos_count[1] = 0

            if self.keep_pos_count[0] == 10:
                self.keep_pos_count[0] = 0
                punishment += 3

            if self.keep_pos_count[1] == 10:
                self.keep_pos_count[1] = 0
                punishment += 3
            
            dist_pos1_pos2 = np.linalg.norm(np.array(pos1) - np.array(pos2)).reshape(-1)
            if dist_pos1_pos2 == 1 and min(self.keep_pos_count[0], self.keep_pos_count[1]) > 0:
                self.congestion_count += 1
            else:
                self.congestion_count = 0
                
            if self.congestion_count == 3:
                self.congestion_count = 0
                punishment += 5
            
            # print('keep:', self.keep_pos_count)
            # print('congestion_count:', self.congestion_count)
            # print('same_action_count', self.same_action_count)
            # print(punishment)
            self.action_old = action
            self.pos1_old, self.pos2_old = pos1, pos2
        
        if self.store_traj:
            self.traj_to_store.append(info["shaped_info_by_agnet"])
            self.traj_to_store.append(self.base_env.state.to_dict())

        reward = [[shaped_reward_p0], [shaped_reward_p1]]
        if self.agent_idx == 1:
            reward = [[shaped_reward_p1], [shaped_reward_p0]]
        
        self.history_sa = self.history_sa[1:] + [[next_state, None],]

        vec_shaped_info_by_agent = self.base_env.vectorize_shaped_info(info["shaped_info_by_agent"])
        vec_shaped_info_by_agent = np.concatenate([vec_shaped_info_by_agent, (np.array(info["sparse_r_by_agent"])>0).astype(np.int32).reshape(2, 1)], axis=-1)
        if self.agent_idx == 1:
            vec_shaped_info_by_agent = np.stack([vec_shaped_info_by_agent[1], vec_shaped_info_by_agent[0]])
        info["vec_shaped_info_by_agent"] = vec_shaped_info_by_agent

        # stuck
        stuck_info = []
        for agent_id in range(2):
            stuck, history_a = self.is_stuck(agent_id)
            if stuck:
                assert any([a not in history_a for a in Direction.ALL_DIRECTIONS]), history_a
                history_a_idxes = [Action.ACTION_TO_INDEX[a] for a in history_a]
                stuck_info.append([True, history_a_idxes])
            else:
                stuck_info.append([False, []])

        info["stuck"] = stuck_info

        # can_begin_cook_soup
        # can_begin_cook_info = [self.mdp.can_begin_cook_soup(next_state, agent_id) for agent_id in range(self.num_agents)]
        # info["can_begin_cook"] = can_begin_cook_info

        if self.use_render:
            state = self.base_env.state
            self.traj["ep_states"][0].append(state)
            self.traj["ep_actions"][0].append(joint_action)
            self.traj["ep_rewards"][0].append(sparse_reward)
            self.traj["ep_dones"][0].append(done)
            self.traj["ep_infos"][0].append(info)
            if done:
                self.traj['ep_returns'].append(info['episode']['ep_sparse_r'])
                self.traj["mdp_params"].append(self.base_mdp.mdp_params)
                self.traj["env_params"].append(self.base_env.env_params)
                self.render()
        
        if done and self.store_traj:
            self._store_trajectory()

        ob_p0, ob_p1 = self.featurize_fn_bc(self.base_env.state)
        array_obs = self.featurize_fn_ppo(self.base_env.state)
        if self.agent_idx == 1:
            array_obs = [array_obs[1], array_obs[0]]
        array_obs = [array_obs[0]*255,array_obs[1]*255]
        
        both_agents_ob = (ob_p0, ob_p1)
        if self.agent_idx == 1:
            both_agents_ob = (ob_p1, ob_p0)

        # share_obs = self._gen_share_observation(self.base_env.state)
        done = [done, done]
        available_actions = np.ones((2, len(Action.ALL_ACTIONS)), dtype=np.uint8)
        
        #return both_agents_ob, share_obs, reward, done, info, available_actions
        
        reward = np.sum(reward) - punishment
        terminated = np.all(done)
        if terminated and self.use_render:
            write_gif(
                gif_name=os.path.join(self.run_dir, 'save.gif'), 
                pic_path=os.path.join(self.run_dir, 'pic/'),
                duration = 0.3,
            )
        self.marl_obs = self._get_flatten_obs(both_agents_ob, array_obs)
        # self.marl_state = self._get_flatten_state(share_obs)
        if test_mode == True:
            # if sparse_reward >0:
            #    print('sparse_reward:',sparse_reward)
            return sparse_reward, terminated, {}
        else:
            return reward, terminated, {}

    def _get_flatten_obs(self, obs, mat_obs):
        obs_array = np.array(obs)
        obs_mat = np.array(mat_obs)
        self._obs = []
        #obs_dim = self.get_obs_size()
        for key in range(self.num_agents):
            #if obs_array[key].ndim==1 :
            #    #self._obs.append(obs_array[key])
            #    self._obs.append(np.pad(array=obs_array[key], pad_width=(0, self.diff_size), mode='constant', constant_values=(0, 0)))
            #else:
            all_obs = np.concatenate((obs_array[key], mat_obs[key].flatten()))
            #print('all_obs:',all_obs.shape)
            self._obs.append(all_obs)
        return self._obs

    def _get_flatten_state(self, state):
        state_array = np.array(state)
        self._state = []
        self._state.append(state_array[0].flatten())
        return self._state
        
    def get_obs(self):
        """ Returns all agent observations in a list """
        return self.marl_obs

    def get_obs_agent(self, agent_id):
        """ Returns observation for agent_id """
        return self.marl_obs[int(agent_id)]
    
    def get_vector_size(self):
        dummy_state = self.base_env.mdp.get_standard_start_state()
        obs_vector_shape = self.featurize_fn_bc(dummy_state)[0].flatten().shape[0]
        return obs_vector_shape

    def get_array_size(self):
        dummy_state = self.base_env.mdp.get_standard_start_state()
        obs_array_shape = self.featurize_fn_ppo(dummy_state)[0].flatten().shape[0]
        return obs_array_shape                                                                  

    def get_obs_size(self):
        """ Returns the shape of the observation """
        
        return self.get_vector_size()+self.get_array_size()

    def get_state(self):
        #return self.marl_state
        # TODO: another definition for state
        obs = self.featurize_fn_bc(self.base_env.state)
        return [*obs[0], *obs[1]]
    
    def get_state_size(self):
        """ Returns the shape of the state"""
        dummy_state = self.base_env.mdp.get_standard_start_state()
        obs_shape1 = self.featurize_fn(dummy_state)[0].flatten().shape[0]
        obs_shape2 = self.featurize_fn(dummy_state)[1].flatten().shape[0]
        return obs_shape1+obs_shape2
        
    def get_avail_actions(self):
        available_actions = np.ones((2, len(Action.ALL_ACTIONS)), dtype=np.uint8)
        return available_actions
        
    def get_avail_agent_actions(self, agent_id):
        """ Returns the available actions for agent_id """
        return self.get_avail_actions()[0]
        
    def get_total_actions(self):
        """ Returns the total number of actions an agent could ever take """
        # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
        return len(Action.ALL_ACTIONS)
        
    # def seed(self):
        # return np.random.seed(seed)
        
    def close(self):
        pass
        # self.base_env.close()
        
    def save_replay(self):
        pass
        

    def anneal_reward_shaping_factor(self, timesteps):
        """
        Set the current reward shaping factor such that we anneal linearly until self.reward_shaping_horizon
        timesteps, given that we are currently at timestep "timesteps"
        """
        new_factor = self._anneal(self._initial_reward_shaping_factor, timesteps, self.reward_shaping_horizon)
        #print('new_factor:',new_factor)
        self.set_reward_shaping_factor(new_factor)

    def set_reward_shaping_factor(self, factor):
        self.reward_shaping_factor = factor
        # if self.reward_shaping_factor < 0.1:
        #     self.use_hsp = True

    def reset(self, reset_choose = True):
        """
        When training on individual maps, we want to randomize which agent is assigned to which
        starting location, in order to make sure that the agents are trained to be able to 
        complete the task starting at either of the hardcoded positions.

        NOTE: a nicer way to do this would be to just randomize starting positions, and not
        have to deal with randomizing indices.
        """

        if reset_choose:
            #self.base_env.seed()
            self.traj_num += 1
            self.step_count = 0
            self.base_env.reset()

        if self.random_index: # False
            self.agent_idx = np.random.choice([0,1])
        for a in range(self.num_agents):
            if self.script_agent[a] is not None:
                self.script_agent[a].reset(self.base_env.mdp, self.base_env.state, a)

        self.mdp = self.base_env.mdp
        ob_p0, ob_p1 = self.featurize_fn(self.base_env.state)
        array_obs = self.featurize_fn_ppo(self.base_env.state)
        if self.agent_idx == 1:
            array_obs = [array_obs[1], array_obs[0]]
        array_obs = [array_obs[0]*255,array_obs[1]*255]
        if self.stuck_time > 0:
            self.history_sa = [None for _ in range(self.stuck_time - 1)] + [[self.base_env.state, None]]

        both_agents_ob = (ob_p0, ob_p1) 
        if self.agent_idx == 1:
            both_agents_ob = (ob_p1, ob_p0)

        if self.use_render:
            self.init_traj()
            # self.fake_render()
        
        if self.store_traj:
            self.traj_to_store = []
            self.traj_to_store.append(self.base_env.state.to_dict())

        available_actions = np.ones((2, len(Action.ALL_ACTIONS)), dtype=np.uint8)

        self.marl_obs = self._get_flatten_obs(both_agents_ob, array_obs)
        self.marl_state = self.get_state()
        return self.marl_obs, self.marl_state

    def is_stuck(self, agent_id):
        if self.stuck_time == 0 or None in self.history_sa:
            return False, []
        history_s = [sa[0] for sa in self.history_sa]
        history_a = [sa[1][agent_id] for sa in self.history_sa[:-1]] # last action is None
        player_s = [s.players[agent_id] for s in history_s]
        pos_and_ors = [p.pos_and_or for p in player_s]
        cur_po = pos_and_ors[-1]
        if all([po[0] == cur_po[0] and po[1] == cur_po[1] for po in pos_and_ors]):
            return True, history_a
        return False, []

    def init_traj(self):
        #return
        self.traj = {k: [] for k in DEFAULT_TRAJ_KEYS}
        for key in TIMESTEP_TRAJ_KEYS:
            self.traj[key].append([])  

    def render(self):
        #raise NotImplementedError
        #try:
        save_dir = os.path.join(self.run_dir, 'pic')
        StateVisualizer().display_rendered_trajectory(self.traj,
                                                      img_directory_path=save_dir,
                                                      ipython_display=False)


    def fake_render(self):
        state = self.base_env.state
        mdp = self.base_mdp
        """String representation of the current state"""
        players_dict = {player.position: player for player in state.players}

        plt.cla()
        plt.clf()
        plt.axis([0, len(mdp.terrain_mtx[0]), 0, len(mdp.terrain_mtx)])
        grid_string = ""
        for y, terrain_row in enumerate(mdp.terrain_mtx):
            for x, element in enumerate(terrain_row):
                plt_x = x + 0.5
                plt_y = len(mdp.terrain_mtx) - y - 0.5
                plt_str = ""
                if (x, y) in players_dict.keys():
                    player = players_dict[(x, y)]
                    orientation = player.orientation
                    assert orientation in Direction.ALL_DIRECTIONS

                    grid_string += Action.ACTION_TO_CHAR[orientation]
                    plt_str = Action.ACTION_TO_CHAR[orientation]
                    player_object = player.held_object
                    if player_object:
                        grid_string += player_object.name[:1]
                        plt_str += player_object.name[:1]
                    else:
                        player_idx_lst = [i for i, p in enumerate(state.players) if p.position == player.position]
                        assert len(player_idx_lst) == 1
                        grid_string += str(player_idx_lst[0])
                        plt_str += str(player_idx_lst[0])
                else:
                    if element == "X" and state.has_object((x, y)):
                        state_obj = state.get_object((x, y))
                        grid_string = grid_string + element + state_obj.name[:1]
                        plt_str += element + state_obj.name[:1]

                    elif element == "P" and state.has_object((x, y)):
                        soup_obj = state.get_object((x, y))
                        soup_type, num_items, cook_time = soup_obj.state
                        if soup_type == "onion":
                            grid_string += "ø"
                            plt_str += "ø"
                        elif soup_type == "tomato":
                            grid_string += "†"
                            plt_str += "†"
                        else:
                            raise ValueError()

                        if num_items == mdp.num_items_for_soup:
                            grid_string += str(cook_time)
                            plt_str += str(cook_time)
                        
                        # NOTE: do not currently have terminal graphics 
                        # support for cooking times greater than 3.
                        elif num_items == 2:
                            grid_string += "="
                            plt_str += "="
                        else:
                            grid_string += "-"
                            plt_str += "-"
                    else:
                        grid_string += element + " "
                        plt_str = element
                plt.text(plt_x, plt_y, plt_str, ha='center', fontsize=30, alpha=0.4)
            grid_string += "\n"
        
        if state.order_list is not None:
            grid_string += "Current orders: {}/{} are any's\n".format(
                len(state.order_list), len([order == "any" for order in state.order_list])
            )
        save_dir = f'{self.run_dir}/gifs/{self.layout_name}/traj_num_{self.traj_num}'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        plt.xticks([])
        plt.yticks([])
        plt.savefig(f"{save_dir}/step_{self.step_count}.png")
        if self.step_count == self.episode_length:
            imgs = []
            for s in range(0, self.episode_length + 1):
                img = imageio.imread(f"{save_dir}/step_{s}.png")
                imgs.append(img)
            imageio.mimsave(f"{save_dir}/traj.gif", imgs, duration=0.1)
            imgs_dir = os.listdir(save_dir)
            for img_path in imgs_dir:
                img_path = save_dir + '/' + img_path
                if 'png' in img_path:
                    os.remove(img_path)
            
        return grid_string
    
    def _store_trajectory(self):
        if not os.path.exists(f'{self.run_dir}/trajs/{self.layout_name}/'):
            os.makedirs(f'{self.run_dir}/trajs/{self.layout_name}/')
        save_dir = f'{self.run_dir}/trajs/{self.layout_name}/traj_{self.rank}_{self.traj_num}.pkl'
        pickle.dump(self.traj_to_store, open(save_dir, 'wb'))

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.num_agents,
                    "episode_limit": self.episode_length,
                    "vector_obsSize": self.get_vector_size(),
                    "array_obsSize": self.get_array_size(),}
        return env_info

    def get_stats(self):
        return {}

if __name__ == '__main__':
    env = Overcooked(r'random3',seed=3)
    #print(env.get_env_info())
    both_agents_ob, share_obs = env.reset()
    print('both_agents_ob, share_obs, available_actions:',np.array(both_agents_ob).shape, np.array(share_obs).shape,)
    # (2, 9, 5, 20) (2, 9, 5, 40) [[1 1 1 1 1 1] [1 1 1 1 1 1]]
    action = [1,3]
    reward, done, info = env.step(action)
    print(env.get_env_info())
    print(env.get_avail_actions())
    np.random.seed(np.random.randint(10000))
    #env.init_traj()
    # for i in range(500):
        # actions = []
        # for _ in range(2):
            # action = np.random.randint(6)
            # actions.append(action)
        
        # #actions = torch.from_numpy(np.array(actions))
        # reward, done, info = env.step(actions)
        # #env.fake_render()
        # print(i,reward,done,actions)
        # env.fake_render()
        # if done == True:
            # env.reset()