from modules.agents import REGISTRY as agent_REGISTRY
from components.action_selectors import REGISTRY as action_REGISTRY
import torch as th
import torch.nn as nn
from modules.agents.rnn_agent import RNNAgent

class SeqMAC:
    def __init__(self, scheme, groups, args):
        self.n_agents = args.n_agents
        self.args = args

        self.agent_output_type = args.agent_output_type

        self.action_selector = action_REGISTRY[args.action_selector](args)

        self.hidden_states = None

        self.agents = th.nn.ModuleList([RNNAgent(self._get_input_shape_agent(agent_id, scheme), args) for agent_id in range(self.n_agents)])

    def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False):
        # Only select actions for the selected batch elements in bs
        chosen_actions = self.forward(ep_batch, t_ep, t_env, None, bs=bs, test_mode=test_mode)
        return chosen_actions

    def _forward_agent(self, agent_input, hidden_state, agent_id):
        a, h = self.agents[agent_id](agent_input, hidden_state)
        return a, h

    def _select_action_agent(self, agent_id, agent_id_output, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False):
        avail_actions = ep_batch["avail_actions"][:, t_ep, agent_id].unsqueeze(1)
        chosen_actions = self.action_selector.select_action(agent_id_output[bs], avail_actions[bs], t_env, test_mode=test_mode)
        return chosen_actions

    def forward(self, ep_batch, t, t_env, actions=None, bs=slice(None), test_mode=False):
        output_list = []
        hidden_list = []
        action_list = []
        pre_action = None
        agent_output = None
        avail_actions = ep_batch["avail_actions"][:, t]

        for i in range(self.n_agents):
            if i > 0 and actions == None:
                pre_action = th.cat(action_list, dim=-1)
            elif i > 0 and actions != None:
                pre_action = actions[:, t, :i]
            agent_input = self._build_inputs_agent(i, pre_action, ep_batch, t)
            agent_output, agent_hidden = self._forward_agent(agent_input, self.hidden_states[:, i], i)
            agent_output = agent_output.unsqueeze(1)
            hidden_list.append(agent_hidden.unsqueeze(1))
            # Softmax the agent outputs if they're policy logits
            if self.agent_output_type == "pi_logits":

                if getattr(self.args, "mask_before_softmax", True):
                    # Make the logits for unavailable actions very negative to minimise their affect on the softmax
                    reshaped_avail_actions = avail_actions[:, i].reshape(ep_batch.batch_size, 1, -1)
                    agent_output[reshaped_avail_actions == 0] = -1e10
                agent_output = th.nn.functional.softmax(agent_output, dim=-1)
            output_list.append(agent_output)
            if actions == None:
                agent_action = self._select_action_agent(i, agent_output, ep_batch, t, t_env, bs, test_mode) #todo()
                action_list.append(agent_action)
        '''
        for i in list(reversed(range(self.n_agents))):
            if i < self.n_agents-1 and actions == None:
                pre_action = th.cat(action_list, dim=-1)
            elif i < self.n_agents-1 and actions != None:
                pre_action = actions[:, t, i+1:]
            agent_input = self._build_inputs_agent(i, pre_action, ep_batch, t)
            agent_output, agent_hidden = self._forward_agent(agent_input, self.hidden_states[:, i], i)
            agent_output = agent_output.unsqueeze(1)
            output_list.append(agent_output)
            hidden_list.append(agent_hidden.unsqueeze(1))
            if actions == None:
                agent_action = self._select_action_agent(i, agent_output, ep_batch, t, t_env, bs, test_mode) #todo()
                action_list.append(agent_action)
        '''

        self.hidden_states = th.cat(hidden_list, dim=1)
        agent_outs = th.cat(output_list, dim=-1).view(-1, agent_output.size(-1))

        if actions != None:
            return agent_outs.view(ep_batch.batch_size, self.n_agents, -1)
        else:
            return th.cat(action_list, dim=-1)

    def init_hidden(self, batch_size):
        hidden = th.cat([a.init_hidden() for a in self.agents])
        self.hidden_states = hidden.unsqueeze(0).expand(batch_size, -1, -1)  # bav

    def parameters(self):
        para_list = []
        for a in self.agents:
            para_list.append({'params':a.parameters()})
        return para_list

    def load_state(self, other_mac):
        for i in range(self.n_agents):
            self.agents[i].load_state_dict(other_mac.agents[i].state_dict())

    def cuda(self, device="cuda:0"):
        for a in self.agents:
            a.cuda(device=device)

    def save_models(self, path):
        for i in range(self.n_agents):
            th.save(self.agents[i].state_dict(), "{}/agent{}.th".format(path, i))

    def load_models(self, path):
        for i in range(self.n_agents):
            self.agents[i].load_state_dict(th.load("{}/agent{}.th".format(path, i), map_location=lambda storage, loc: storage))

    '''
    def _build_agents(self, input_shape):
        self.agent = agent_REGISTRY[self.args.agent](input_shape, self.args)
    '''

    def _build_inputs_agent(self, agent_id, pre_action, batch, t):#todo() previous actions
        # Assumes homogenous agents with flat observations.
        # Other MACs might want to e.g. delegate building inputs to each agent
        bs = batch.batch_size
        inputs = []
        inputs.append(batch["obs"][:, t, agent_id])  # b1av\

        # todo-----ordered
        if agent_id > 0:
            inputs.append(pre_action) # pre_action is like bs*(id) 
        '''

        # todo-------reverse order
        if agent_id < self.n_agents-1:
            inputs.append(pre_action)
        '''

        '''
        if self.args.obs_last_action:
            if t == 0:
                inputs.append(th.zeros_like(batch["actions_onehot"][:, t]))
            else:
                inputs.append(batch["actions_onehot"][:, t-1])
        if self.args.obs_agent_id:
            inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1))
        '''

        inputs = th.cat([x.reshape(bs, -1) for x in inputs], dim=1)
        return inputs

    def _get_input_shape_agent(self, agent_id, scheme): #todo() previous actions
        input_shape = scheme["obs"]["vshape"] #vshape for obs of one single agent
        input_shape += agent_id #add previous actions todo()-----ordered
        # input_shape += self.n_agents - agent_id -1 # todo--------reverse order
        '''
        if self.args.obs_last_action:
            input_shape += scheme["actions_onehot"]["vshape"][0]
                        ##vshape for actions of one single agent ---n_action the same
        if self.args.obs_agent_id:
            input_shape += self.n_agents
        '''
        return input_shape