from multiagent_mujoco.mujoco_multi import MujocoMulti
import numpy as np
import time
from .. import MultiAgentEnv


class Humanoid_17x1(MultiAgentEnv):
    def __init__(
        self,
        episode_limit=1000,
        env_name='Humanoid_17x1',
        agent_obsk=1,
        seed=0,
        categories=30
    ):
        self.episode_limit = episode_limit
        self.env_name = env_name
        self.agent_obsk = agent_obsk
        self.seed = seed
        self.categories = categories
        
        env_args = {"scenario": "Humanoid-v2",
                  "agent_conf": "17x1",
                  "agent_obsk": self.agent_obsk,
                  "episode_limit": self.episode_limit}
        self.env = MujocoMulti(env_args=env_args)
        self.env.seed(self.seed)
        
        env_info = self.env.get_env_info()

        self.n_actions = self.categories + 1
        self.n_agents = env_info["n_agents"]


    def step(self, actions):
        """Returns reward, terminated, info."""
        actions = actions.to('cpu').numpy().tolist()
        continuous_actions = [[(-1 + action * 2 / self.categories)] for action in actions]
        reward, terminated, _ = self.env.step(continuous_actions)
        return reward / 30, terminated, None
    
    
    def get_obs(self):
        """Returns all agent observations in a list."""
        origin_obs = self.env.get_obs()
        for i in range(len(origin_obs)):
            origin_obs[i][1:3] = origin_obs[i][1:3] / 100
        return origin_obs

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        origin_obs = self.env.get_obs_agent(agent_id)
        origin_obs[1:3] = origin_obs[1:3] / 100
        return origin_obs

    def get_obs_size(self):
        """Returns the size of the observation."""
        return self.env.get_obs_size()

    def get_state(self):
        """Returns the global state."""
        origin_state = self.env.get_state()
        origin_state[0] = origin_state[0] / 20
        origin_state[6] = origin_state[6] / 20
        origin_state[10] = origin_state[10] / 20
        origin_state[11] = origin_state[11] / 20
        origin_state[14:19] = origin_state[14:19] / 20
        origin_state[20:22] = origin_state[20:22] / 20
        origin_state[24:45] = origin_state[24:45] / 20
        origin_state[55:57] = origin_state[55:57] / 20
        origin_state[61] = origin_state[61] / 20
        origin_state[63:65] = origin_state[63:65] / 20
        origin_state[74] = origin_state[74] / 20
        origin_state[81] = origin_state[81] / 20
        origin_state[84] = origin_state[84] / 20
        origin_state[93:95] = origin_state[93:95] / 20
        origin_state[103:107] = origin_state[103:107] / 20
        origin_state[113:115] = origin_state[113:115] / 20
        origin_state[124] = origin_state[124] / 20
        origin_state[134] = origin_state[134] / 20
        origin_state[144] = origin_state[144] / 20
        origin_state[154] = origin_state[154] / 20
        origin_state[164] = origin_state[164] / 20
        origin_state[174] = origin_state[174] / 20
        origin_state[184] = origin_state[184] / 20
        origin_state[191:269] = origin_state[191:269] / 50
        origin_state[280] = origin_state[280] / 10
        origin_state[284] = origin_state[284] / 10
        origin_state[275:292] = origin_state[275:292] / 20
        return origin_state

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.env.get_state_size()

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [[1 for _ in range(self.n_actions)] for agent_id in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return self.get_avail_actions()[agent_id]

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.categories + 1

    def reset(self):
        """Returns initial observations and states."""
        self.env.reset()
        return self.get_obs(), self.get_state()

    def render(self):
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

