from spaghettini import quick_register

import torch
from torch.nn.functional import cross_entropy
from torch.nn.functional import mse_loss
from src.dl.losses.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy


# ____ Prover auxiliary losses. ____
@quick_register
def classification_aux_loss(p_aux_out, p_xs, ys_true, other_data, **kwargs):
    alpha = kwargs["alpha"]
    downweight_ce = kwargs["downweight_ce"]
    loss_fn = LabelSmoothedCrossEntropy(alpha=alpha, downweight_ce=downweight_ce)

    return loss_fn(input=p_aux_out, target=ys_true)


@quick_register
def autoencoding_aux_loss(p_aux_out, p_xs, ys_true, other_data):
    bs = p_xs.shape[0]
    return mse_loss(input=p_aux_out.view(bs, -1), target=p_xs.view(bs, -1))


# ____ Verifier losses. ____
@quick_register
def verifier_input_proof_matching_aux_loss(**kwargs):
    # Extract arguments.
    v_aux_dict = kwargs["v_aux_dict"]
    v_aux_dict_nm = kwargs["v_aux_dict_nm"]
    ys_true = kwargs["ys_true"]

    # Label is 0 iff  the proofs and inputs match.
    v_aux_outs_matching = v_aux_dict["proof_input_matching"]
    v_aux_outs_nonmatching = v_aux_dict_nm["proof_input_matching"]
    matching_labels = torch.zeros_like(ys_true)
    non_matching_labels = torch.ones_like(ys_true[1:])

    matching_loss = cross_entropy(input=v_aux_outs_matching, target=matching_labels)
    nonmatching_loss = cross_entropy(input=v_aux_outs_nonmatching, target=non_matching_labels)

    # Return the sum of matching and non-matching losses.
    return matching_loss + nonmatching_loss


# ____ Probe losses. ____
@quick_register
def flattened_mse_probe_loss(preds, targets):
    bs = preds.shape[0]

    return mse_loss(input=preds.view(bs, -1), target=targets.view(bs, -1))
