import copy
from Network.General.Factor.key_query import KeyQueryEncoder

def init_key_args(args):
    key_args = copy.deepcopy(args)
    key_args.object_dim = args.factor.single_obj_dim
    key_args.output_dim = args.embed_dim
    key_args.hidden_sizes = args.factor_net.embed_layers
    key_args.use_layer_norm = False
    key_args.factor_net.aggregate_final = False
    return key_args

def init_query_args(args):
    query_args = copy.deepcopy(args)
    query_args.output_dim = args.embed_dim
    query_args.hidden_sizes = args.factor_net.embed_layers
    query_args.use_layer_norm = False
    query_args.factor_net.aggregate_final = False
    return query_args

def init_key_query(args):
    key_args = init_key_args(args)
    query_args = init_query_args(args)
    key_query_encoder = KeyQueryEncoder(args, key_args, query_args)
    return key_args, query_args, key_query_encoder

def init_inter_args(args, use_cluster=True):
    inter_args = copy.deepcopy(args)
    inter_args.num_outputs = args.cluster.num_clusters if args.cluster.use_cluster and use_cluster else 1 
    inter_args.activation_final = "none" # args.activation_final if args.activation_final != 'softmax' else 'none'
    inter_args.factor.query_aggregate = False
    inter_args.needs_encoding = not args.inter_net.shared_encoding
    inter_args.aggregate_final = args.cluster.use_cluster and use_cluster # if in cluster mode, special behavior
    return inter_args

def init_select_args(args):
    select_args = copy.deepcopy(args)
    select_args.pair.aggregate_final = False
    select_args.include_last = True
    select_args.num_outputs = args.cluster.num_clusters
    select_args.activation_final = "none"
    select_args.mask_attn.needs_encoding = True
    return select_args

def init_forward_args(args):
    forward_args = copy.deepcopy(args)
    forward_args.aggregate_final = False
    forward_args.output_dim = forward_args.factor.single_obj_dim
    forward_args.num_outputs = forward_args.factor.single_obj_dim * forward_args.factor.num_objects
    if args.embed_dim > 0:
        forward_args.factor.needs_encoding = False
        forward_args.factor.single_obj_dim = args.embed_dim
        forward_args.factor.object_dim = args.embed_dim 
        forward_args.factor.first_obj_dim = int(args.embed_dim * (args.factor.first_obj_dim / args.factor.single_obj_dim))
    return forward_args