from common.imports import *

class AdditiveMixer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, qvalues):
        return th.sum(qvalues, dim=-1, keepdim=True)
    
## From OxWhirl PYMARL repo ##
class MonotonicMixer(nn.Module):
    def __init__(self, n_agents, state_dim, mix_embed_dim=32, hypernet_embed_dim=32, hypernet_layers=2):
        super().__init__()

        self.n_agents = n_agents
        self.state_dim = state_dim
       
        self.embed_dim = mix_embed_dim

        if hypernet_layers == 1:
            self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents)
            self.hyper_w_2 = nn.Linear(self.state_dim, self.embed_dim)
        elif hypernet_layers == 2:
            self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed_dim),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed_dim, self.embed_dim * self.n_agents))
            self.hyper_w_2 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed_dim),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed_dim, self.embed_dim))
        elif hypernet_layers > 2:
            raise Exception("Sorry >2 hypernet layers is not implemented!")
        else:
            raise Exception("Error setting number of hypernet layers.")

        # State dependent bias for hidden layer
        self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)

        # V(s) instead of a bias for the last layers
        self.v_bias = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, 1))

    def forward(self, qvalues, states):     # States are already 2d (batch_size, state_dim)
        qvalues = qvalues.unsqueeze(1)  # Make it 3d for tensor operations (batch_size, 1, n_agents)
 
        # First layer
        w1 = th.abs(self.hyper_w_1(states)).view(-1, self.n_agents, self.embed_dim)
        b1 = self.hyper_b_1(states).view(-1, 1, self.embed_dim)
    
        hidden = F.elu(th.bmm(qvalues, w1) + b1)

        # Second layer
        w_2 = th.abs(self.hyper_w_2(states))
        w_2 = w_2.view(-1, self.embed_dim, 1)

        # State-dependent bias
        v = self.v_bias(states).view(-1, 1, 1)

        # Compute mix_qvalue and get it back (batch_size, 1)
        return (th.bmm(hidden, w_2) + v).squeeze(1)
    
class QPLEXMixer(nn.Module):
    def __init__(self, n_agents, state_dim, act_dim, transf_embed_dim=32, n_heads=4, mix_embed_dim=32, 
                 mix_embed_layers=2, is_minus_one=True, weighted_head=True):
        super().__init__()

        self.n_agents = n_agents
        self.state_dim = state_dim
        self.act_dim = act_dim

        self.mlp_w = nn.Sequential(nn.Linear(self.state_dim, transf_embed_dim),
                                        nn.ReLU(),
                                        nn.Linear(transf_embed_dim, self.n_agents))
        self.v_bias = nn.Sequential(nn.Linear(self.state_dim, transf_embed_dim),
                               nn.ReLU(),
                               nn.Linear(transf_embed_dim, self.n_agents))

        self.si_weight = AttentionQPLEX(n_agents, self.state_dim, self.act_dim,
                                         n_heads, mix_embed_dim, mix_embed_layers)

        self.is_minus_one = is_minus_one
        self.weighted_head = weighted_head
    
    def compute_v_mix(self, qvalues):
        return th.sum(qvalues, dim=-1, keepdim=True)

    def compute_a_mix(self, qvalues, states, actions, max_agent_q):
        a_vals = (qvalues - max_agent_q).detach()
        
        # Weights for the advantage stream
        a_w = self.si_weight(states, actions)
      
        # Compute mixed advantage
        if self.is_minus_one: return th.sum(a_vals * (a_w - 1.), dim=1, keepdim=True)
        return th.sum(a_vals * a_w, dim=1, keepdim=True)

    def calc(self, qvalues, states, actions=None, max_agent_q=None, is_v=False):
        # Compute mixed state value else mixed advantage
        if is_v: return self.compute_v_mix(qvalues)
        return self.compute_a_mix(qvalues, states, actions, max_agent_q)

    # max_agent_q is the advantage of agent's i selected action
    def forward(self, 
                qvalues,     # q_val for agents' picked action
                states, 
                actions=None,   # one-hot encoding of the picked action
                max_agent_q=None,   # max_q_val according to current net
                is_v=False      # whether we are combining V or A
        ):
       
        # These are already in shape (batch_size, n_agents)
        w_2 = th.abs(self.mlp_w(states)) + 1e-10
        v = self.v_bias(states)

        if self.weighted_head: qvalues = w_2 * qvalues + v
        if not is_v:
            max_agent_q = max_agent_q.view(-1, self.n_agents)   # We already flattened this in the loss
            if self.weighted_head: max_agent_q = w_2 * max_agent_q + v

        return self.calc(qvalues, 
                         states, 
                         actions=actions, 
                         max_agent_q=max_agent_q, 
                         is_v=is_v
                        )
    
class AttentionQPLEX(nn.Module):
    def __init__(self, n_agents, state_dim, action_dim, n_heads, mix_embed_dim, mix_embed_layers):
        super().__init__()

        self.n_agents = n_agents
        state_action_dim = state_dim + action_dim

        self.n_heads = n_heads
        self.key_extractors, self.agents_extractors, self.action_extractors = [nn.ModuleList() for _ in range(3)]
    
        for _ in range(self.n_heads):  # multi-head attention
            if mix_embed_layers == 1:
                self.key_extractors.append(nn.Linear(state_dim, 1))  # key
                self.agents_extractors.append(nn.Linear(state_dim, n_agents))  # agent
                self.action_extractors.append(nn.Linear(state_action_dim, n_agents))  # action
            elif mix_embed_layers == 2:
                self.key_extractors.append(nn.Sequential(nn.Linear(state_dim, mix_embed_dim),
                                                         nn.ReLU(),
                                                         nn.Linear(mix_embed_dim, 1)))  # key
                self.agents_extractors.append(nn.Sequential(nn.Linear(state_dim, mix_embed_dim),
                                                            nn.ReLU(),
                                                            nn.Linear(mix_embed_dim, n_agents)))  # agent
                self.action_extractors.append(nn.Sequential(nn.Linear(state_action_dim, mix_embed_dim),
                                                            nn.ReLU(),
                                                            nn.Linear(mix_embed_dim, n_agents)))  # action
            elif mix_embed_layers == 3:
                self.key_extractors.append(nn.Sequential(nn.Linear(state_dim,mix_embed_dim),
                                                         nn.ReLU(),
                                                         nn.Linear(mix_embed_dim, mix_embed_dim),
                                                         nn.ReLU(),
                                                         nn.Linear(mix_embed_dim, 1)))  # key
                self.agents_extractors.append(nn.Sequential(nn.Linear(state_dim, mix_embed_dim),
                                                            nn.ReLU(),
                                                            nn.Linear(mix_embed_dim, mix_embed_dim),
                                                            nn.ReLU(),
                                                            nn.Linear(mix_embed_dim, n_agents)))  # agent
                self.action_extractors.append(nn.Sequential(nn.Linear(state_action_dim, mix_embed_dim),
                                                            nn.ReLU(),
                                                            nn.Linear(mix_embed_dim, mix_embed_dim),
                                                            nn.ReLU(),
                                                            nn.Linear(mix_embed_dim, n_agents)))  # action
            else:
                raise Exception("Error setting number of adv hypernet layers.")

    def forward(self, states, actions):
        # From (batch_size, n_agents, act_dim) to (batch_size, n_agents * act_dim)  
        actions = actions.reshape(actions.shape[0], -1)  
        data = th.cat([states, actions], dim=1)

        heads_key = [k_ext(states) for k_ext in self.key_extractors]
        heads_agents = [k_ext(states) for k_ext in self.agents_extractors]
        heads_actions = [sel_ext(data) for sel_ext in self.action_extractors]

        attention_w = []
        for h_key, h_agents, h_actions in zip(heads_key, heads_agents, heads_actions):
            w_key = th.abs(h_key).repeat(1, self.n_agents) + 1e-10
            w_agents = F.sigmoid(h_agents)
            w_actions = F.sigmoid(h_actions)
            weights = w_key * w_agents * w_actions
            attention_w.append(weights)
 
        attention_w = th.stack(attention_w, dim=1)
        attention_w = attention_w.view(-1, self.n_heads, self.n_agents)
        attention_w = th.sum(attention_w, dim=1)

        return attention_w
    