import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from Network.network import Network
from Network.network_utils import reduce_function
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Conv.conv import ConvNetwork
from Network.General.Factor.Pair.pair import PairNetwork, merge_key_queries
from Network.General.Factor.factored import return_values


class MultiMLPNetwork(Network):
    def __init__(self, args):
        '''
        Operates only with a fixed number of objects with a fixed order
        applies a separate MLP to each object in order
        '''
        super().__init__(args)
        # assumes the input is flattened list of input space sized values
        # needs an object dim
        self.fp = args.factor
        self.embed_dim = args.output_dim if args.embed_dim == 0 else args.embed_dim
        self.aggregate_final = args.aggregate_final
        self.append_keys = args.factor_net.append_keys
        self.append_mask = args.factor_net.append_mask
        self.append_broadcast_mask = args.factor_net.append_broadcast_mask
        
        mlp_args = copy.deepcopy(args)
        mlp_args.num_inputs = self.fp.object_dim + int(self.append_keys) * self.fp.single_obj_dim + self.append_broadcast_mask if args.embed_dim <= 0 else args.embed_dim + int(self.append_keys) * args.embed_dim
        mlp_args.num_outputs = args.output_dim
        if args.aggregate_final: mlp_args.activation_final = 'none'
        self.mlps = list()
        for i in range(self.fp.num_objects):
            self.mlps.append(MLPNetwork(mlp_args))
        self.mlps = nn.ModuleList(self.mlps)
        if args.aggregate_final:
            final_args = copy.deepcopy(args)
            final_args.num_inputs = self.embed_dim if self.fp.reduce_fn != 'cat' else self.embed_dim * self.fp.num_objects
            final_args.hidden_sizes = args.factor.final_layers 
            self.MLP = MLPNetwork(**final_args)
            self.model = nn.Sequential(self.mlps, self.MLP)
        else:
            self.model = nn.ModuleList([self.mlps])
        self.train()
        self.reset_network_parameters()
    
    def forward(self, key, query, mask, ret_settings):
        embeddings = list()
        queries = merge_key_queries(key, query, mask, append_keys=self.append_keys, append_mask=self.append_mask, append_broadcast_mask=self.append_broadcast_mask)
        for i in range(query.shape[1]):
            embeddings.append(self.mlps[i](queries[...,i]))
        embeddings = torch.stack(embeddings, dim = 1)
        x, reduction = embeddings, None
        if mask is not None: x = x * mask[:,0].unsqueeze(-1)
        if self.aggregate_final:
            x = reduce_function(self.op.reduce_fn, embeddings)[0]
            reduction = x.view(-1, self.output_dim)
            x = self.MLP(reduction)
        return return_values(ret_settings, x, key, query, embeddings, reduction, mask=mask)