import torch as th
import torch.nn as nn
import numpy as np
import torch.nn.functional as F



class MtQattnMixer(nn.Module):
    def __init__(self, surrogate_decomposer, main_args):
        super(MtQattnMixer, self).__init__()

        self.main_args = main_args
        self.embed_dim = main_args.mixing_embed_dim
        self.attn_embed_dim = main_args.attn_embed_dim
        self.entity_embed_dim = main_args.policy_entity_embed_dim

        self.decomposer = surrogate_decomposer

        # get detailed state shape information
        state_nf_al, state_nf_en, timestep_state_dim = \
            surrogate_decomposer.state_nf_al, surrogate_decomposer.state_nf_en, surrogate_decomposer.timestep_number_state_dim
        self.state_last_action, self.state_timestep_number = surrogate_decomposer.state_last_action, surrogate_decomposer.state_timestep_number

        self.n_actions_no_attack = surrogate_decomposer.n_actions_no_attack

        if self.state_last_action:
            self.ally_encoder = nn.Linear(state_nf_al + self.n_actions_no_attack + 1, self.entity_embed_dim)
            self.enemy_encoder = nn.Linear(state_nf_en, self.entity_embed_dim)
            state_nf_al += self.n_actions_no_attack + 1
        else:
            self.ally_encoder = nn.Linear(state_nf_al, self.entity_embed_dim)
            self.enemy_encoder = nn.Linear(state_nf_en, self.entity_embed_dim)
        
        self.query = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        self.key = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)

        if main_args.use_encoding:
            if self.state_timestep_number:
                mixing_input_dim = self.entity_embed_dim + timestep_state_dim + main_args.encoding_dim
                entity_mixing_input_dim = self.entity_embed_dim + self.entity_embed_dim + timestep_state_dim + main_args.encoding_dim
            else:
                mixing_input_dim = self.entity_embed_dim + main_args.encoding_dim
                entity_mixing_input_dim = self.entity_embed_dim + self.entity_embed_dim + main_args.encoding_dim
        else:
            if self.state_timestep_number:
                mixing_input_dim = self.entity_embed_dim + timestep_state_dim
                entity_mixing_input_dim = self.entity_embed_dim + self.entity_embed_dim + timestep_state_dim
            else:
                mixing_input_dim = self.entity_embed_dim
                entity_mixing_input_dim = self.entity_embed_dim + self.entity_embed_dim
        
        if getattr(main_args, "hypernet_layers", 1) == 1:
            self.hyper_w_1 = nn.Linear(entity_mixing_input_dim, self.embed_dim)
            self.hyper_w_final = nn.Linear(mixing_input_dim, self.embed_dim)
        elif getattr(main_args, "hypernet_layers", 1) == 2:
            hypernet_embed = self.main_args.hypernet_embed
            self.hyper_w_1 = nn.Sequential(
                nn.Linear(entity_mixing_input_dim, hypernet_embed),
                nn.ReLU(),
                nn.Linear(hypernet_embed, self.embed_dim),
            )
            self.hyper_w_final = nn.Sequential(
                nn.Linear(mixing_input_dim, hypernet_embed),
                nn.ReLU(),
                nn.Linear(hypernet_embed, self.embed_dim),
            )
        elif getattr(main_args, "hypernet_layers", 1) > 2:
            raise Exception("Sorry >2 hypernet layers is not implemented!")
        else:
            raise Exception("Error setting number of hypernet layers.")

        # State dependent bias for hidden layer
        self.hyper_b_1 = nn.Linear(mixing_input_dim, self.embed_dim)

        # V(s) instead of a bias for the last layers
        self.V = nn.Sequential(nn.Linear(mixing_input_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, 1))
    
    def forward(self, agent_qs, states, task_encoding, task_decomposer=None):
        # agent_qs.shape (bs, T, n_agents, 1)
         # agent_qs: [batch_size, seq_len, n_agents]
        # states: [batch_size, seq_len, state_dim]
        if task_decomposer is None:
            task_decomposer = self.decomposer

        bs, seq_len, n_agents = agent_qs.size(0), agent_qs.size(1), agent_qs.size(2)
        n_enemies = task_decomposer.n_enemies
        n_entities = n_agents + n_enemies

        # get decomposed state information
        ally_states, enemy_states, last_action_states, timestep_number_state = task_decomposer.decompose_state(states)
        ally_states = th.stack(ally_states, dim=0)  # [n_agents, bs, seq_len, state_nf_al]
        enemy_states = th.stack(enemy_states, dim=0)  # [n_enemies, bs, seq_len, state_nf_en]

        # stack action information
        if self.state_last_action:
            last_action_states = th.stack(last_action_states, dim=0)
            _, _, compact_action_states = task_decomposer.decompose_action_info(last_action_states)
            ally_states = th.cat([ally_states, compact_action_states], dim=-1)

        # do inference and get entity_embed
        ally_embed = self.ally_encoder(ally_states)
        enemy_embed = self.enemy_encoder(enemy_states)

        # we ought to do self-attention
        entity_embed = th.cat([ally_embed, enemy_embed], dim=0)

        # do attention
        proj_query = self.query(entity_embed).permute(1, 2, 0, 3).reshape(bs * seq_len, n_entities, self.attn_embed_dim)
        proj_key = self.key(entity_embed).permute(1, 2, 3, 0).reshape(bs * seq_len, self.attn_embed_dim, n_entities)
        energy = th.bmm(proj_query / (self.attn_embed_dim ** (1 / 2)), proj_key)
        attn_score = F.softmax(energy, dim=1)
        proj_value = entity_embed.permute(1, 2, 3, 0).reshape(bs * seq_len, self.entity_embed_dim, n_entities)
        attn_out = th.bmm(proj_value, attn_score).mean(dim=-1).reshape(bs, seq_len, self.entity_embed_dim)

        # concat timestep information
        if self.state_timestep_number:
            raise Exception(f"Not Implemented")
        else:
            pass

        # embed task_repre: [bs, seq_len, n_agents, task_repre_dim]
        # mean_task_repre = task_repre.mean(dim=-2)

        if len(task_encoding.shape) == 2:
            # task_encoding: [bs, encoding_dim]
            task_encoding = task_encoding.unsqueeze(1).repeat(1, seq_len, 1)
        task_encoding_expand = task_encoding.unsqueeze(2).repeat(1, 1, n_agents, 1)

        # build mixing input
        if self.main_args.use_encoding:
            entity_mixing_input = th.cat([
                attn_out[:, :, None, :].repeat(1, 1, n_agents, 1),
                ally_embed.permute(1, 2, 0, 3),
                task_encoding_expand
                # task_repre
            ], dim=-1)
            mixing_input = th.cat([
                attn_out,
                task_encoding
                # mean_task_repre
            ], dim=-1)
        else:
            entity_mixing_input = th.cat([
                attn_out[:, :, None, :].repeat(1, 1, n_agents, 1),
                ally_embed.permute(1, 2, 0, 3),
                # task_repre
            ], dim=-1)
            mixing_input = th.cat([
                attn_out,
                # mean_task_repre
            ], dim=-1)

        w1 = th.abs(self.hyper_w_1(entity_mixing_input))
        b1 = self.hyper_b_1(mixing_input)
        w1 = w1.view(-1, n_agents, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        # agent_qs = agent_qs.view(-1, 1, n_agents)
        # hidden = F.elu(th.bmm(agent_qs, w1) + b1)

        # Second layer
        w_final = th.abs(self.hyper_w_final(mixing_input)).view(-1, self.embed_dim, 1)
        v = self.V(mixing_input).view(-1, 1, 1)

        k = th.bmm(w1, w_final).view(bs, -1, n_agents) #(bs, seq_len, n_a)
        k = (k / th.sum(k, dim=2, keepdim=True)).reshape(bs, seq_len, n_agents, 1)

        b = (th.bmm(b1, w_final) + v).reshape(bs, seq_len, 1)

        q_tot = (agent_qs * k).sum(dim=2) + b # (bs, T, 1)

        return q_tot

    def w_and_b(self, states, task_encoding, task_decomposer=None):
        if task_decomposer is None:
            task_decomposer = self.decomposer

        bs, seq_len = states.size(0), states.size(1)
        n_agents = task_decomposer.n_agents
        n_enemies = task_decomposer.n_enemies
        n_entities = n_agents + n_enemies

        # get decomposed state information
        ally_states, enemy_states, last_action_states, timestep_number_state = task_decomposer.decompose_state(states)
        ally_states = th.stack(ally_states, dim=0)  # [n_agents, bs, seq_len, state_nf_al]
        enemy_states = th.stack(enemy_states, dim=0)  # [n_enemies, bs, seq_len, state_nf_en]

        # stack action information
        if self.state_last_action:
            last_action_states = th.stack(last_action_states, dim=0)
            _, _, compact_action_states = task_decomposer.decompose_action_info(last_action_states)
            ally_states = th.cat([ally_states, compact_action_states], dim=-1)

        # do inference and get entity_embed
        ally_embed = self.ally_encoder(ally_states)
        enemy_embed = self.enemy_encoder(enemy_states)

        # we ought to do self-attention
        entity_embed = th.cat([ally_embed, enemy_embed], dim=0)

        # do attention
        proj_query = self.query(entity_embed).permute(1, 2, 0, 3).reshape(bs * seq_len, n_entities, self.attn_embed_dim)
        proj_key = self.key(entity_embed).permute(1, 2, 3, 0).reshape(bs * seq_len, self.attn_embed_dim, n_entities)
        energy = th.bmm(proj_query / (self.attn_embed_dim ** (1 / 2)), proj_key)
        attn_score = F.softmax(energy, dim=1)
        proj_value = entity_embed.permute(1, 2, 3, 0).reshape(bs * seq_len, self.entity_embed_dim, n_entities)
        attn_out = th.bmm(proj_value, attn_score).mean(dim=-1).reshape(bs, seq_len, self.entity_embed_dim)

        # concat timestep information
        if self.state_timestep_number:
            raise Exception(f"Not Implemented")
        else:
            pass

        # embed task_repre: [bs, seq_len, n_agents, task_repre_dim]
        # mean_task_repre = task_repre.mean(dim=-2)

        # build mixing input
        if len(task_encoding.shape) == 2:
            # task_encoding: [bs, encoding_dim]
            task_encoding = task_encoding.unsqueeze(1).repeat(1, seq_len, 1)
        task_encoding_expand = task_encoding.unsqueeze(2).repeat(1, 1, n_agents, 1)
        # build mixing input
        if self.main_args.use_encoding:
            entity_mixing_input = th.cat([
                attn_out[:, :, None, :].repeat(1, 1, n_agents, 1),
                ally_embed.permute(1, 2, 0, 3),
                task_encoding_expand
                # task_repre
            ], dim=-1)
            mixing_input = th.cat([
                attn_out,
                task_encoding
                # mean_task_repre
            ], dim=-1)
        else:
            entity_mixing_input = th.cat([
                attn_out[:, :, None, :].repeat(1, 1, n_agents, 1),
                ally_embed.permute(1, 2, 0, 3),
                # task_repre
            ], dim=-1)
            mixing_input = th.cat([
                attn_out,
                # mean_task_repre
            ], dim=-1)

        w1 = th.abs(self.hyper_w_1(entity_mixing_input))
        b1 = self.hyper_b_1(mixing_input)
        w1 = w1.view(-1, n_agents, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        # agent_qs = agent_qs.view(-1, 1, n_agents)
        # hidden = F.elu(th.bmm(agent_qs, w1) + b1)

        # Second layer
        w_final = th.abs(self.hyper_w_final(mixing_input)).view(-1, self.embed_dim, 1)
        v = self.V(mixing_input).view(-1, 1, 1)

        k = th.bmm(w1, w_final).view(bs, -1, n_agents) #(bs, seq_len, n_a)
        k = (k / th.sum(k, dim=2, keepdim=True)).reshape(bs, seq_len, n_agents, 1)

        b = (th.bmm(b1, w_final) + v).reshape(bs, seq_len, 1)
        return k, b



class QattnMixer(nn.Module):
    def __init__(self, surrogate_decomposer, main_args):
        super(QattnMixer, self).__init__()

        self.main_args = main_args
        self.embed_dim = main_args.mixing_embed_dim
        self.attn_embed_dim = main_args.attn_embed_dim
        self.entity_embed_dim = main_args.entity_embed_dim

        self.decomposer = surrogate_decomposer

        # get detailed state shape information
        state_nf_al, state_nf_en, timestep_state_dim = \
            surrogate_decomposer.state_nf_al, surrogate_decomposer.state_nf_en, surrogate_decomposer.timestep_number_state_dim
        self.state_last_action, self.state_timestep_number = surrogate_decomposer.state_last_action, surrogate_decomposer.state_timestep_number

        self.n_actions_no_attack = surrogate_decomposer.n_actions_no_attack

        if self.state_last_action:
            self.ally_encoder = nn.Linear(state_nf_al + self.n_actions_no_attack + 1, self.entity_embed_dim)
            self.enemy_encoder = nn.Linear(state_nf_en, self.entity_embed_dim)
            state_nf_al += self.n_actions_no_attack + 1
        else:
            self.ally_encoder = nn.Linear(state_nf_al, self.entity_embed_dim)
            self.enemy_encoder = nn.Linear(state_nf_en, self.entity_embed_dim)
        
        self.query = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        self.key = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)

        if self.state_timestep_number:
            mixing_input_dim = self.entity_embed_dim + timestep_state_dim #+ main_args.task_repre_dim
            entity_mixing_input_dim = self.entity_embed_dim + self.entity_embed_dim + timestep_state_dim #+ main_args.task_repre_dim
        else:
            mixing_input_dim = self.entity_embed_dim #+ main_args.task_repre_dim
            entity_mixing_input_dim = self.entity_embed_dim + self.entity_embed_dim #+ main_args.task_repre_dim
        
        if getattr(main_args, "hypernet_layers", 1) == 1:
            self.hyper_w_1 = nn.Linear(entity_mixing_input_dim, self.embed_dim)
            self.hyper_w_final = nn.Linear(mixing_input_dim, self.embed_dim)
        elif getattr(main_args, "hypernet_layers", 1) == 2:
            hypernet_embed = self.main_args.hypernet_embed
            self.hyper_w_1 = nn.Sequential(
                nn.Linear(entity_mixing_input_dim, hypernet_embed),
                nn.ReLU(),
                nn.Linear(hypernet_embed, self.embed_dim),
            )
            self.hyper_w_final = nn.Sequential(
                nn.Linear(mixing_input_dim, hypernet_embed),
                nn.ReLU(),
                nn.Linear(hypernet_embed, self.embed_dim),
            )
        elif getattr(main_args, "hypernet_layers", 1) > 2:
            raise Exception("Sorry >2 hypernet layers is not implemented!")
        else:
            raise Exception("Error setting number of hypernet layers.")

        # State dependent bias for hidden layer
        self.hyper_b_1 = nn.Linear(mixing_input_dim, self.embed_dim)

        # V(s) instead of a bias for the last layers
        self.V = nn.Sequential(nn.Linear(mixing_input_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, 1))
    
    def forward(self, agent_qs, states, task_decomposer=None):
        # agent_qs.shape (bs, T, n_agents, 1)
         # agent_qs: [batch_size, seq_len, n_agents]
        # states: [batch_size, seq_len, state_dim]
        if task_decomposer is None:
            task_decomposer = self.decomposer

        bs, seq_len, n_agents = agent_qs.size(0), agent_qs.size(1), agent_qs.size(2)
        n_enemies = task_decomposer.n_enemies
        n_entities = n_agents + n_enemies

        # get decomposed state information
        ally_states, enemy_states, last_action_states, timestep_number_state = task_decomposer.decompose_state(states)
        ally_states = th.stack(ally_states, dim=0)  # [n_agents, bs, seq_len, state_nf_al]
        enemy_states = th.stack(enemy_states, dim=0)  # [n_enemies, bs, seq_len, state_nf_en]

        # stack action information
        if self.state_last_action:
            last_action_states = th.stack(last_action_states, dim=0)
            _, _, compact_action_states = task_decomposer.decompose_action_info(last_action_states)
            ally_states = th.cat([ally_states, compact_action_states], dim=-1)

        # do inference and get entity_embed
        ally_embed = self.ally_encoder(ally_states)
        enemy_embed = self.enemy_encoder(enemy_states)

        # we ought to do self-attention
        entity_embed = th.cat([ally_embed, enemy_embed], dim=0)

        # do attention
        proj_query = self.query(entity_embed).permute(1, 2, 0, 3).reshape(bs * seq_len, n_entities, self.attn_embed_dim)
        proj_key = self.key(entity_embed).permute(1, 2, 3, 0).reshape(bs * seq_len, self.attn_embed_dim, n_entities)
        energy = th.bmm(proj_query / (self.attn_embed_dim ** (1 / 2)), proj_key)
        attn_score = F.softmax(energy, dim=1)
        proj_value = entity_embed.permute(1, 2, 3, 0).reshape(bs * seq_len, self.entity_embed_dim, n_entities)
        attn_out = th.bmm(proj_value, attn_score).mean(dim=-1).reshape(bs, seq_len, self.entity_embed_dim)

        # concat timestep information
        if self.state_timestep_number:
            raise Exception(f"Not Implemented")
        else:
            pass

        # embed task_repre: [bs, seq_len, n_agents, task_repre_dim]
        # mean_task_repre = task_repre.mean(dim=-2)

        # build mixing input
        entity_mixing_input = th.cat([
            attn_out[:, :, None, :].repeat(1, 1, n_agents, 1),
            ally_embed.permute(1, 2, 0, 3),
            # task_repre
        ], dim=-1)
        mixing_input = th.cat([
            attn_out,
            # mean_task_repre
        ], dim=-1)

        w1 = th.abs(self.hyper_w_1(entity_mixing_input))
        b1 = self.hyper_b_1(mixing_input)
        w1 = w1.view(-1, n_agents, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        # agent_qs = agent_qs.view(-1, 1, n_agents)
        # hidden = F.elu(th.bmm(agent_qs, w1) + b1)

        # Second layer
        w_final = th.abs(self.hyper_w_final(mixing_input)).view(-1, self.embed_dim, 1)
        v = self.V(mixing_input).view(-1, 1, 1)

        k = th.bmm(w1, w_final).view(bs, -1, n_agents) #(bs, seq_len, n_a)
        k = (k / th.sum(k, dim=2, keepdim=True)).reshape(bs, seq_len, n_agents, 1)

        b = (th.bmm(b1, w_final) + v).reshape(bs, seq_len, 1)

        q_tot = (agent_qs * k).sum(dim=2) + b # (bs, T, 1)

        return q_tot

    def w_and_b(self, states, task_decomposer=None):
        if task_decomposer is None:
            task_decomposer = self.decomposer

        bs, seq_len = states.size(0), states.size(1)
        n_agents = self.main_args.n_agents
        n_enemies = task_decomposer.n_enemies
        n_entities = n_agents + n_enemies

        # get decomposed state information
        ally_states, enemy_states, last_action_states, timestep_number_state = task_decomposer.decompose_state(states)
        ally_states = th.stack(ally_states, dim=0)  # [n_agents, bs, seq_len, state_nf_al]
        enemy_states = th.stack(enemy_states, dim=0)  # [n_enemies, bs, seq_len, state_nf_en]

        # stack action information
        if self.state_last_action:
            last_action_states = th.stack(last_action_states, dim=0)
            _, _, compact_action_states = task_decomposer.decompose_action_info(last_action_states)
            ally_states = th.cat([ally_states, compact_action_states], dim=-1)

        # do inference and get entity_embed
        ally_embed = self.ally_encoder(ally_states)
        enemy_embed = self.enemy_encoder(enemy_states)

        # we ought to do self-attention
        entity_embed = th.cat([ally_embed, enemy_embed], dim=0)

        # do attention
        proj_query = self.query(entity_embed).permute(1, 2, 0, 3).reshape(bs * seq_len, n_entities, self.attn_embed_dim)
        proj_key = self.key(entity_embed).permute(1, 2, 3, 0).reshape(bs * seq_len, self.attn_embed_dim, n_entities)
        energy = th.bmm(proj_query / (self.attn_embed_dim ** (1 / 2)), proj_key)
        attn_score = F.softmax(energy, dim=1)
        proj_value = entity_embed.permute(1, 2, 3, 0).reshape(bs * seq_len, self.entity_embed_dim, n_entities)
        attn_out = th.bmm(proj_value, attn_score).mean(dim=-1).reshape(bs, seq_len, self.entity_embed_dim)

        # concat timestep information
        if self.state_timestep_number:
            raise Exception(f"Not Implemented")
        else:
            pass

        # embed task_repre: [bs, seq_len, n_agents, task_repre_dim]
        # mean_task_repre = task_repre.mean(dim=-2)

        # build mixing input
        entity_mixing_input = th.cat([
            attn_out[:, :, None, :].repeat(1, 1, n_agents, 1),
            ally_embed.permute(1, 2, 0, 3),
            # task_repre
        ], dim=-1)
        mixing_input = th.cat([
            attn_out,
            # mean_task_repre
        ], dim=-1)

        w1 = th.abs(self.hyper_w_1(entity_mixing_input))
        b1 = self.hyper_b_1(mixing_input)
        w1 = w1.view(-1, n_agents, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        # agent_qs = agent_qs.view(-1, 1, n_agents)
        # hidden = F.elu(th.bmm(agent_qs, w1) + b1)

        # Second layer
        w_final = th.abs(self.hyper_w_final(mixing_input)).view(-1, self.embed_dim, 1)
        v = self.V(mixing_input).view(-1, 1, 1)

        k = th.bmm(w1, w_final).view(bs, -1, n_agents) #(bs, seq_len, n_a)
        k = (k / th.sum(k, dim=2, keepdim=True)).reshape(bs, seq_len, n_agents, 1)

        b = (th.bmm(b1, w_final) + v).reshape(bs, seq_len, 1)
        return k, b
    