from modules.agents import REGISTRY as agent_REGISTRY
from components.action_selectors import REGISTRY as action_REGISTRY
from modules.opponent.context_builder import ContextBuilder
from modules.opponent.contrastive_encoder import ContextEncoder
import torch as th
import pdb


# This multi-agent controller shares parameters between agents
class BasicMAC:
    def __init__(self, scheme, groups, args):
        self.n_agents = args.n_agents
        self.scheme = scheme
        self.args = args
        input_shape = self._get_input_shape(scheme)
        self._build_agents(input_shape)
        self.agent_output_type = args.agent_output_type

        self.action_selector = action_REGISTRY[args.action_selector](args)
        self.save_probs = getattr(self.args, 'save_probs', False)

        self.hidden_states = None

    def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False):
        # Only select actions for the selected batch elements in bs
        z_t = ep_batch['z_t'][:, t_ep] if 'z_t' in ep_batch.data.transition_data else None
        agent_outs, avail_actions = self.forward(ep_batch, t=t_ep, test_mode=test_mode, z_t=z_t)
        '''
        masked = agent_outs.clone()
        masked[avail_actions == 0] = -1e9
        chosen = masked.argmax(dim=-1, keepdim=True)
        '''
        chosen = self.action_selector.select_action(agent_outs[bs], avail_actions[bs], t_env, test_mode=test_mode)
        return chosen


    def forward(self, ep_batch, t, test_mode=False, z_t=None):
        agent_inputs = self._build_inputs(ep_batch, t)
        avail_actions = ep_batch["avail_actions"][:, t]

        B, n_agents, _ = agent_inputs.size()
        flat_inputs = agent_inputs.reshape(B * n_agents, -1)
        flat_hidden = self.hidden_states.reshape(B * n_agents, -1)
        if z_t is not None:
            z_rep = z_t.reshape(B * n_agents, -1)
        else:
            z_rep = None

        if test_mode:
            self.agent.eval()
        
        flat_outs, flat_h = self.agent(flat_inputs, flat_hidden, z_t=z_rep)

        agent_outs = flat_outs.view(B, n_agents, -1)
        self.hidden_states = flat_h.view(B, n_agents, -1)
        return agent_outs, avail_actions
        
            

    def init_hidden(self, batch_size):
        self.hidden_states = self.agent.init_hidden()
        if self.hidden_states is not None:
            self.hidden_states = self.hidden_states.unsqueeze(0).expand(batch_size, self.n_agents, -1).clone()  # bav

    def parameters(self):
        return self.agent.parameters()
    
    def encoder_params(self):
        return self.ctx_encoder.parameters()

    def load_state(self, other_mac):
        self.agent.load_state_dict(other_mac.agent.state_dict())

    def cuda(self):
        self.agent.cuda()

    def save_models(self, path):
        th.save(self.agent.state_dict(), "{}/agent.th".format(path))

    def load_models(self, path):
        self.agent.load_state_dict(th.load("{}/agent.th".format(path), map_location=lambda storage, loc: storage))

    def _build_agents(self, input_shape):
        self.agent = agent_REGISTRY[self.args.agent](input_shape, self.args)
        

        self.ctx_window = getattr(self.args, 'ctx_window', 8)
        self.z_dim = getattr(self.args, 'z_dim', 16)
        self.ctx_hid = getattr(self.args, 'ctx_hid', 64)

        self.ctx_builder = ContextBuilder(self.args, self.scheme)
        in_dim = 0
        if getattr(self.args, "ctx_mean_obs", True):
            in_dim += self.scheme['obs']['vshape']
        if getattr(self.args, "ctx_last_actions", True):

            in_dim += self.args.n_actions
        if getattr(self.args, "ctx_reward", True):
            in_dim += 1
        self.ctx_encoder = ContextEncoder(in_dim=in_dim, hid_dim=self.ctx_hid, z_dim=self.z_dim)
        self.ctx_encoder.to('cuda')


    def _build_inputs(self, ep_batch, t):

        obs = ep_batch['obs'][:, t]  # [B, n_agents, obs_dim]
        B, n_agents, obs_dim = obs.size()
        inputs = [obs]
        if getattr(self.args, 'obs_last_action', False):
            if t == 0:
                last_a = th.zeros(B, n_agents, ep_batch['actions_onehot'].shape[-1], device=obs.device)
            else:
                last_a = ep_batch['actions_onehot'][:, t-1]
            inputs.append(last_a.reshape(B, n_agents, -1))

        if getattr(self.args, 'obs_agent_id', False):
            aid = th.eye(n_agents, device=obs.device).unsqueeze(0).expand(B, -1, -1)
            inputs.append(aid)
        inp = th.cat(inputs, dim=-1)
        return inp

    def _get_input_shape(self, scheme):
        input_shape = scheme["obs"]["vshape"]
        if self.args.obs_last_action:
            input_shape += scheme["actions_onehot"]["vshape"][0]
        if self.args.obs_agent_id:
            input_shape += self.n_agents

        return input_shape
