import gym
import numpy as np
import importlib
from pantheonrl.common.multiagentenv import SimultaneousEnv, DummyEnv


class AssistiveMultiEnv(SimultaneousEnv):
    def __init__(self, env_name, ego_agent_idx=0, baselines=False, masked_events=None):
        """
        base_env: OvercookedEnv
        featurize_fn: what function is used to featurize states returned in the 'both_agent_obs' field
        """
        super(AssistiveMultiEnv, self).__init__()
        
        if baselines: np.random.seed(0)
        module = importlib.import_module('assistive_gym.envs')
        env_class = getattr(module, env_name.split('-')[0] + 'Env')
        self.base_env = env_class()

        self.observation_space = self.base_env.observation_space_robot
        self.action_space = self.base_env.action_space_robot
        
        self.human_observation_space = self.base_env.observation_space_human
        self.human_action_space = self.base_env.action_space_human
        self.human_env = DummyEnv(self.human_observation_space, self.human_action_space)

        self.ego_agent_idx = ego_agent_idx
        self.multi_reset()

        
    def getDummyEnv(self, player_ind: int):
        return self.human_env if player_ind else self

    def multi_step(self, robot_action, human_action):
        """
        action:
            (agent with index self.agent_idx action, other agent action)
            is a tuple with the joint action of the primary and secondary agents in index format
            encoded as an int

        returns:
            observation: formatted to be standard input for self.agent_idx's policy
        """
        joint_action = {'robot': robot_action, 'human': human_action}

        obs, reward, done, info = self.base_env.step(joint_action)
        
        robot_obs, human_obs = obs
        robot_reward, human_reward = reward['robot'],  reward['human'],  # default the same
        assert robot_reward == human_reward
        robot_done, human_done, all_done = done['robot'], done['human'], done['__all__']
        
        return (robot_obs, human_obs), (robot_reward, human_reward), all_done, info
    
    def multi_reset(self):
        return self.base_env.reset()

    def render(self, mode='human', close=False):
        pass
