from envs.multiagentenv import MultiAgentEnv
from utils.dict2namedtuple import convert
import numpy as np


class TwoState(MultiAgentEnv):
    def __init__(self, batch_size=None, **kwargs):
        # Unpack arguments from sacred
        args = kwargs["env_args"]
        if isinstance(args, dict):
            args = convert(args)
            self.args = kwargs["args"]

        # Define the agents and actions
        self.n_agents = 2
        self.n_actions = 2
        self.episode_limit = 100
        self.t = 0

        self.x = 2

    def reset(self):
        """ Returns initial observations and states"""
        self.x = 2
        self.t = 0
        return self.get_obs(), self.get_state()

    def step(self, actions):
        """ Returns reward, terminated, info """
        reward = 0
        if self.x == 2:
            if actions[0] == 1 and actions[1] == 1:
                reward = 1

            if actions[0] == 0 and actions[0] == 0:
                self.x = 1

        self.t += 1

        info = {}
        terminated = True if self.t == self.episode_limit else False
        info["episode_limit"] = True if terminated else False

        return reward, terminated, info

    def get_obs(self):
        return [self.get_state() for _ in range(self.n_agents)]

    def get_obs_agent(self, agent_id):
        """ Returns observation for agent_id """
        raise NotImplementedError

    def get_obs_size(self):
        """ Returns the shape of the observation """
        return self.get_state_size()

    def get_state(self):
        s = np.zeros(2)
        s[self.x-1] = 1
        return s

    def get_state_size(self):
        """ Returns the shape of the state"""
        return 2

    def get_avail_actions(self):
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_agent = self.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_agent)
        return avail_actions

    def get_avail_agent_actions(self, agent_id):
        """ Returns the available actions for agent_id """
        return np.ones(self.n_actions)

    def get_total_actions(self):
        """ Returns the total number of actions an agent could ever take """
        # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
        return self.n_actions

    def get_stats(self):
        raise NotImplementedError