import numpy as np
from graphsmodel.utils import get_margin_incorrect_vectorized


def compute_loo(
    logits,
    sub_logits,
    sub_y,
    data,
    true_test_acc,
    true_val_acc,
    true_margins,
):
    margins = get_margin_incorrect_vectorized(logits, data.y).mean(1).numpy()
    test_accs = (margins > 0)[:, data.test_mask].mean(-1)
    val_accs = (margins > 0)[:, data.val_mask].mean(-1)

    # Logits and margins for the subgraphs
    sub_margins = get_margin_incorrect_vectorized(sub_logits, sub_y).mean(1).numpy()
    sub_margins_ma = np.ma.masked_invalid(sub_margins)
    sub_test_margins = sub_margins_ma[:, data.test_mask]
    sub_val_margins = sub_margins_ma[:, data.val_mask]
    sub_test_accs = (sub_test_margins > 0).mean(-1)
    sub_val_accs = (sub_val_margins > 0).mean(-1)

    # LOO utilities
    sub_test_accs_loo = true_test_acc - sub_test_accs
    sub_val_accs_loo = true_val_acc - sub_val_accs

    test_accs_loo = true_test_acc - test_accs
    val_accs_loo = true_val_acc - val_accs

    margins_loo = true_margins - margins
    sub_margins_loo = true_margins - sub_margins

    return (
        margins_loo.T,
        test_accs_loo,
        val_accs_loo,
        sub_margins_loo.T,
        sub_test_accs_loo,
        sub_val_accs_loo,
    )
