from multiagent_mujoco.mujoco_multi import MujocoMulti
import numpy as np
import time
from .. import MultiAgentEnv


class HalfCheetah_6x1(MultiAgentEnv):
    def __init__(
        self,
        episode_limit=1000,
        env_name='HalfCheetah_6x1',
        agent_obsk=1,
        seed=0,
        categories=30
    ):
        self.episode_limit = episode_limit
        self.env_name = env_name
        self.agent_obsk = agent_obsk
        self.seed = seed
        self.categories = categories
        
        env_args = {"scenario": "HalfCheetah-v2",
                  "agent_conf": "6x1",
                  "agent_obsk": self.agent_obsk,
                  "episode_limit": self.episode_limit}
        self.env = MujocoMulti(env_args=env_args)
        self.env.seed(self.seed)
        
        env_info = self.env.get_env_info()

        self.n_actions = self.categories + 1
        self.n_agents = env_info["n_agents"]


    def step(self, actions):
        """Returns reward, terminated, info."""
        actions = actions.to('cpu').numpy().tolist()
        continuous_actions = [[(-1 + action * 2 / self.categories)] for action in actions]
        reward, terminated, _ = self.env.step(continuous_actions)
        return reward / 5, terminated, None
    
    
    def get_obs(self):
        """Returns all agent observations in a list."""
        origin_obs = self.env.get_obs()
        for i in range(len(origin_obs)):
            origin_obs[i][1] = origin_obs[i][1] / 10
        return origin_obs

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        origin_obs = self.env.get_obs_agent(agent_id)
        origin_obs[1] = origin_obs[1] / 10
        return origin_obs

    def get_obs_size(self):
        """Returns the size of the observation."""
        return self.env.get_obs_size()

    def get_state(self):
        """Returns the global state."""
        origin_state = self.env.get_state()
        origin_state[-7:] = origin_state[-7:] / 10
        origin_state[1] = origin_state[1] / 5
        return origin_state

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.env.get_state_size()

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [[1 for _ in range(self.n_actions)] for agent_id in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return self.get_avail_actions()[agent_id]

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.categories + 1

    def reset(self):
        """Returns initial observations and states."""
        self.env.reset()
        return self.get_obs(), self.get_state()

    def render(self):
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

