import gym
import numpy as np
# from overcooked_ai_py.mdp.actions import Action
# from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
# from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
# from overcooked_ai_py.planning.planners import MediumLevelPlanner, NO_COUNTERS_PARAMS

from overcooked_ai_py.mdp.actions import Action
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
from overcooked_ai_py.planning.planners import MediumLevelActionManager, NO_COUNTERS_PARAMS

from pantheonrl.common.multiagentenv import SimultaneousEnv

from overcookedgym.overcooked_utils import DIVERSE_COORPERATION_STYLE_LIST, DIVERSE_ORDERS_STYLE_LIST, CENTER_POTS_STYLE_LIST, CROSSWAY_STYLE_LIST


class OvercookedMultiEnv(SimultaneousEnv):
    def __init__(self, layout_name, ego_agent_idx=0, baselines=False, masked_events=None, start_all_orders=None, instruction=None):
        """
        base_env: OvercookedEnv
        featurize_fn: what function is used to featurize states returned in the 'both_agent_obs' field
        """
        super(OvercookedMultiEnv, self).__init__()

        DEFAULT_ENV_PARAMS = {
            "horizon": 400
        }
        rew_shaping_params = {
            "PLACEMENT_IN_POT_REW": 3,
            "DISH_PICKUP_REWARD": 3,
            "SOUP_PICKUP_REWARD": 5,
            "DISH_DISP_DISTANCE_REW": 0,
            "POT_DISTANCE_REW": 0,
            "SOUP_DISTANCE_REW": 0,
        }

        self.mdp = OvercookedGridworld.from_layout_name(layout_name=layout_name, rew_shaping_params=rew_shaping_params, masked_events=masked_events, start_all_orders=start_all_orders)
        # mlp = MediumLevelPlanner.from_pickle_or_compute(self.mdp, NO_COUNTERS_PARAMS, force_compute=False)
        mlam = MediumLevelActionManager.from_pickle_or_compute(self.mdp, NO_COUNTERS_PARAMS, force_compute=False)

        self.base_env = OvercookedEnv.from_mdp(self.mdp, **DEFAULT_ENV_PARAMS)
        # self.featurize_fn = lambda x: self.mdp.featurize_state(x, mlp)
        self.featurize_fn = lambda x: self.mdp.featurize_state(x, mlam)

        if baselines: np.random.seed(0)

        self.observation_space = self._setup_observation_space()
        self.lA = len(Action.ALL_ACTIONS)
        self.action_space  = gym.spaces.Discrete( self.lA )
        self.ego_agent_idx = ego_agent_idx
        self.multi_reset()
        
        self.old_state = self.base_env.state
        self.masked_events = masked_events
        self.instruction = instruction
        
        if not self.masked_events: # if specified, use the specified style
            if layout_name == "diverse_coordination" or layout_name == "big_room":
                print("Using diverse_coordination style list")
                self.style_list = DIVERSE_COORPERATION_STYLE_LIST
            elif layout_name == "diverse_orders":
                print("Using diverse_orders style list")
                self.style_list = DIVERSE_ORDERS_STYLE_LIST
                # self.order_list = [[{"ingredients": ["onion"] * 3}], [{"ingredients": ["tomato"] * 3}]]
            elif layout_name == "center_pots":
                print("Using center_pots style list")
                self.style_list = CENTER_POTS_STYLE_LIST
            elif layout_name == "crossway":
                print("Using crossway style list")
                self.style_list = CROSSWAY_STYLE_LIST
            else:
                self.style_list = None
        else:
            self.style_list = None
            
    def set_instruction_list(self, instruction_list):
        self.instruction_list = instruction_list
        
    def _setup_observation_space(self):
        dummy_state = self.mdp.get_standard_start_state()
        obs_shape = self.featurize_fn(dummy_state)[0].shape
        high = np.ones(obs_shape, dtype=np.float32) * np.inf  # max(self.mdp.soup_cooking_time, self.mdp.num_items_for_soup, 5)

        return gym.spaces.Box(-high, high, dtype=np.float64)

    def _setup_observation_space_with_action(self):
        dummy_state = self.mdp.get_standard_start_state()
        obs_shape = self.featurize_fn(dummy_state)[0].shape
        action_shape = len(Action.ALL_ACTIONS)
        obs_shape = (obs_shape[0] + action_shape * 2,)
        high = np.ones(obs_shape, dtype=np.float32) * np.inf  # max(self.mdp.soup_cooking_time, self.mdp.num_items_for_soup, 5)

        return gym.spaces.Box(-high, high, dtype=np.float64)


    def multi_step(self, ego_action, alt_action):
        """
        action:
            (agent with index self.agent_idx action, other agent action)
            is a tuple with the joint action of the primary and secondary agents in index format
            encoded as an int

        returns:
            observation: formatted to be standard input for self.agent_idx's policy
        """
        ego_action, alt_action = Action.INDEX_TO_ACTION[ego_action], Action.INDEX_TO_ACTION[alt_action]
        if self.ego_agent_idx == 0:
            joint_action = (ego_action, alt_action)
        else:
            joint_action = (alt_action, ego_action)

        # if joint_action[1-self.ego_agent_idx] == Action.INTERACT and self.masked_events:
        #     joint_action = self.mask_partner_action(joint_action) # my code
        next_state, reward, done, info = self.base_env.step(joint_action)
        self.old_state = next_state # my code

        # reward shaping
        # rew_shape = info['shaped_r']
        # reward = reward + info['shaped_r']
        
        # modified to be compatible with new version of overcooked_ai
        # rew_shape = np.sum(info['shaped_r_by_agent'])
        # reward = reward + rew_shape
        ego_reward = reward + info['shaped_r_by_agent'][0]
        alt_reward = reward + info['shaped_r_by_agent'][1]
        
        # In the evaluator, the shaped reward is not contained

        #print(self.base_env.mdp.state_string(next_state))
        ob_p0, ob_p1 = self.featurize_fn(next_state)
        if self.ego_agent_idx == 0:
            ego_obs, alt_obs = ob_p0, ob_p1
        else:
            ego_obs, alt_obs = ob_p1, ob_p0

        if False: # TODO: set right condition
            ego_action_idx = Action.ACTION_TO_INDEX[ego_action]
            alt_action_idx = Action.ACTION_TO_INDEX[alt_action]
            ego_obs = np.concatenate([ego_obs, np.eye(self.lA)[ego_action_idx], np.eye(self.lA)[alt_action_idx]])
            alt_obs = np.concatenate([alt_obs, np.eye(self.lA)[alt_action_idx], np.eye(self.lA)[ego_action_idx]])
        
        return (ego_obs, alt_obs), (ego_reward, alt_reward), done, {}#info

    def multi_reset(self):
        """
        When training on individual maps, we want to randomize which agent is assigned to which
        starting location, in order to make sure that the agents are trained to be able to
        complete the task starting at either of the hardcoded positions.

        NOTE: a nicer way to do this would be to just randomize starting positions, and not
        have to deal with randomizing indices.
        """
        self.base_env.reset()
        ob_p0, ob_p1 = self.featurize_fn(self.base_env.state)
        if self.ego_agent_idx == 0:
            ego_obs, alt_obs = ob_p0, ob_p1
        else:
            ego_obs, alt_obs = ob_p1, ob_p0

        return (ego_obs, alt_obs)

    def multi_reset_with_action(self):
        ego_obs, alt_obs = self.multi_reset()
        ego_action = alt_action = np.zeros(self.lA)
        ego_obs = np.concatenate([ego_obs, ego_action, alt_action])
        alt_obs = np.concatenate([alt_obs, alt_action, ego_action])
        return (ego_obs, alt_obs)

    def render(self, mode='human', close=False):
        pass
