import torch
from Network.network import Network
from Network.network_utils import reduce_function
import torch.nn as nn
from Network.General.Factor.factored import return_values
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Conv.conv import ConvNetwork

from Network.General.Factor.factor_utils import final_mlp_args, final_conv_args
import copy

def init_attn_decode(args):
    final_args = copy.deepcopy(args)

    embed_input_dim = args.embed_dim if args.factor.query_aggregate else args.mask_attn.model_dim
    if args.aggregate_final: # does not work with a post-channel
        final_args.num_inputs = embed_input_dim if args.factor_net.reduce_function != 'cat' else embed_input_dim * args.factor.num_objects
        final_args.hidden_sizes = args.factor_net.final_layers # TODO: hardcoded final hidden sizes for now

        bonus_factor = args.factor.num_objects * args.factor.num_keys if not args.factor.query_aggregate else args.factor.num_keys
        final_args.num_inputs = args.factor.final_embed_dim if args.factor_net.reduce_function != 'cat' else args.factor.final_embed_dim * bonus_factor
        decode = MLPNetwork(final_args)
    else:
        # need a network to go from the embed_dim to the object_dim
        final_args.object_dim = embed_input_dim
        final_args.hidden_sizes = args.factor_net.final_layers # TODO: hardcoded final hidden sizes for now
        decode = ConvNetwork(final_args)
    return decode


class MultiHeadAttentionBase(Network):
    def __init__(self, args, SublayerType):
        super().__init__(args)
        # just handles the final reasoning logic and multiple layers (on a single key, it reprocesses the key input for each layer)
        self.num_layers = args.mask_attn.num_layers
        self.repeat_layers = args.factor_net.repeat_layers
        if self.repeat_layers: 
            self.multi_head_attention = SublayerType(args)
            layers = [self.multi_head_attention]
        else: 
            layers = [SublayerType(args) for i in range(self.num_layers)]
            self.multi_head_attention = nn.ModuleList(layers)
        self.query_aggregate = args.factor.query_aggregate
        self.model = layers # + [self.final_layer]

    def forward(self, keys, queries, m):
        attns, embeds = list(), list()
        for i in range(self.num_layers):
            # the layers change the keys, except if not query_aggregate and the last layer
            query_final = i==self.num_layers-1 and (not self.query_aggregate)
            if self.repeat_layers: keys, attn, embed = self.multi_head_attention(keys, queries, m, query_final=query_final)
            else: keys, attn, embed = self.multi_head_attention[i](keys, queries, m, query_final=query_final)
            attns.append(attn) # num_layers of batch, num_heads, num_keys, num_queries
            embeds.append(embed)
        return keys, torch.stack(attns, dim=1), torch.stack(embeds, dim=1) # batch, num_layers, num_heads, num_keys, num_queries


class BaseMaskedAttentionNetwork(Network):
    def __init__(self, args, SublayerType):
        super().__init__(args)
        self.fp = args.factor
        self.fnp = args.factor_net

        self.multi_head_attention = MultiHeadAttentionBase(args, SublayerType)
        layers = [self.multi_head_attention]

        self.aggregate_final = args.aggregate_final
        # self.softmax = nn.Softmax(-1)
        self.decode = init_attn_decode(args)
        layers.append(self.decode)
        
        self.model = layers
        self.train()
        self.reset_network_parameters()
    
    def compute_attention(self, key, query, mask):
        # abstraction, should return
        # values, attention weights and embeddings
        return None, None, None

    # def append_masks(self, key, query, mask):
    #     if self.fnp.append_mask: 
    #         amask = mask
    #         if amask is None: amask = torch.zeros((*key.shape[:-1], self.fp.num_objects), device =self.device)
    #         key = torch.concatenate([key, mask], dim=-1)
    #     # if self.fnp.append_broadcast_mask:
    #     #     if mask is None: broadcast_mask = torch.zeros((query.shape[0], query.shape[1], self.fnp.append_broadcast_mask), device=self.device) 
    #     #     else: broadcast_mask = torch.broadcast_to(mask[:,0].unsqueeze(-1), (*mask[:,0].shape, self.fnp.append_broadcast_mask)) # the "key" dimension gets removed
    #     #     print(query.shape, broadcast_mask.shape)
    #     #     query = torch.concatenate([query, broadcast_mask], dim=-1)
    #     return key, query

    def forward(self, key, query, mask, ret_settings=[]):
        # x is an input of shape [batch, flattened dim of all target objects + flattened all query objects]
        # m is the batch, key, query mask
        # TODO: raw_attention should just be an aggregation of the masks, but right now masks pass straight through
        x, attns, embed = self.compute_attention(key, query, mask)
        reduction = x
        if self.aggregate_final:
            # combine the conv outputs using the reduce function, and append any post channels
            x = reduce_function(self.fnp.reduce_function, x)
            x = x.view(x.shape[0], -1)
            reduction = x
            # final network goes to batch, num_outputs
            x = self.decode(x)
        else:
            # when dealing with num_query outputs
            if not self.fp.query_aggregate:
                # current shape of x is batch, keys, embed_dim, queries
                query_out = list()
                for i in range(key.shape[1]): # TODO: 2D conv would make this more efficient
                    query_out.append(self.decode(x[:,i].transpose(-1,-2)))
                x = torch.stack(query_out, dim=1)
            else:
                x = self.decode(x.transpose(-1,-2)) # decode should be convnet
            x = x.transpose(-1,-2)
            x = x.reshape(x.shape[0], -1)
        return return_values(ret_settings, x, (key,query), embed, reduction, attn=attns, mask=mask)