from Network.network import Network
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

class FactoredNetwork(Network):    
    # abstract class for factored network, which assume: key query inputs
    #   key: batch, num_dim
    # and it can return either: 
    #       batch, dim outputs
    #       batch, query, dim outputs,
    #       batch, key, query, dim outputs 
    def __init__(self, args):
        super().__init__(args)

    def forward(self, key, query, mask, dist_settings, ret_settings):
        # calls return_values for return
        pass

def return_values(return_settings, x, pre_embeddings_key, pre_embedding_query, embeddings=None, reduction=None, attn=None, mask=None):
    return_vals = [x]
    for rt in return_settings:
        if 'mask' == rt:
            return_vals += [mask]
        elif 'pre_embeddings_key' == rt:
            return_vals += [pre_embeddings_key]
        elif 'pre_embeddings_query' == rt:
            return_vals += [pre_embedding_query]
        elif 'embeddings' == rt:
            return_vals += [embeddings]
        elif 'reduction' == rt:
            return_vals += [reduction]
        elif 'attn' == rt:
            return_vals += [attn]
    return return_vals