from ActualCausal.Train.train_utils import compute_adaptive_rate

def compute_embed_reg(args, result, params, batch):
    if args.inter.regularization.embedding.adaptive_embed_reg[0] > 0: return params.embed_reg

    base_value = params.embed_reg
    baseline_likelihood = params.converged_active_loss_value if "converged_active_loss_value" in params else 3.5 * result.log_probs.shape[-1]
    adaptive_lasso = params.embed_reg
    bias = args.inter.regularization.embedding.adaptive_embed_reg[0]
    flatten_factor = args.inter.regularization.embedding.adaptive_embed_reg[1]
    return compute_adaptive_rate(args.inter.regularization.embedding.adaptive_embed_reg_type, False, base_value, adaptive_lasso, result, batch, baseline_likelihood, bias, flatten_factor)


def compute_embedding_losses(args, params, model, batch, results, use_masked=False):
    # adds a penalty for the embeddings, assuming that embeddings are in additional
    # and thus in results
    if type(results) == tuple:
        results = results[0] # just use the first one for embedding losses, if multiple
    embed_reg = compute_embed_reg(args, results, params, batch)
    if use_masked: return results.masked_pre_embeddings[1].mean(dim=-1) * float(embed_reg)
    else: return results.pre_embeddings_query[1].mean(dim=-1) * float(embed_reg)
