import copy
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Conv.conv import ConvNetwork

def final_mlp_args(args):
    final_args = copy.deepcopy(args)
    final_args.num_inputs = args.embed_dim if args.factor_net.reduce_function != 'cat' else args.embed_dim * args.factor.num_objects
    final_args.hidden_sizes = args.factor_net.final_layers # TODO: hardcoded final hidden sizes for now
    return final_args

def final_conv_args(args):
    final_args = copy.deepcopy(args)
    final_args.object_dim = args.embed_dim
    final_args.hidden_sizes = args.factor_net.final_layers # TODO: hardcoded final hidden sizes for now
    return final_args

def init_decode(args):
    if args.aggregate_final: # does not work with a post-channel
        final_args = final_mlp_args(args)
        bonus_factor = args.factor.num_objects * args.factor.num_keys if not args.factor.query_aggregate else args.factor.num_keys
        final_args.num_inputs = args.factor.final_embed_dim if args.factor_net.reduce_function != 'cat' else args.factor.final_embed_dim * bonus_factor
        decode = MLPNetwork(final_args)
    else:
        # need a network to go from the embed_dim to the object_dim
        final_args = final_conv_args(args)
        decode = ConvNetwork(final_args)
    return decode
