import numpy as np

from graphsmodel.utils import get_margin_incorrect_vectorized


def train_mode_subsets_counts(subsets, nodes_subsets_in, data):
    per_node_subsets_in = np.repeat(
        1 / nodes_subsets_in[:, None], data.num_nodes, axis=-1
    )
    per_node_subsets_notin = np.repeat(
        np.zeros((data.train_mask.sum(), 1)), data.num_nodes, axis=-1
    )
    # counts_notin_to_nodes_prev = np.repeat(
    #     np.zeros((data.train_mask.sum(), 1)), data.num_nodes, axis=-1
    # )
    for train_idx in range(data.train_mask.sum()):
        subsets_in_mask = (subsets[:, train_idx].toarray() == 1).squeeze()
        assert nodes_subsets_in[train_idx] == subsets[subsets_in_mask].shape[0]

        sub_counts_in = subsets[subsets_in_mask].sum(0).A1
        per_node_subsets_in[train_idx, data.train_mask] = 1 / sub_counts_in
        per_node_subsets_in[np.isinf(per_node_subsets_in)] = 0

        # sub_counts_out_prev = (1 - subsets[subsets_in_mask].toarray()).sum(0)
        sub_counts_out = nodes_subsets_in - sub_counts_in

        # counts_notin_to_nodes_prev[train_idx, data.train_mask] = 1 / sub_counts_out_prev
        per_node_subsets_notin[train_idx, data.train_mask] = 1 / sub_counts_out
        per_node_subsets_notin[np.isinf(per_node_subsets_notin)] = 0

    return per_node_subsets_in, per_node_subsets_notin


def val_mode_subsets_counts(subsets, nodes_subsets_in, data):
    per_node_subsets_in = np.repeat(
        1 / nodes_subsets_in[:, None], data.num_nodes, axis=-1
    )
    per_node_subsets_notin = np.repeat(
        np.zeros((data.val_mask.sum(), 1)), data.num_nodes, axis=-1
    )

    for node_idx in range(data.val_mask.sum()):
        subsets_in_mask = (subsets[:, node_idx].toarray() == 1).squeeze()
        assert nodes_subsets_in[node_idx] == subsets[subsets_in_mask].shape[0]

        sub_counts_in = subsets[subsets_in_mask].sum(0).A1
        per_node_subsets_in[node_idx, data.val_mask] = 1 / sub_counts_in
        per_node_subsets_in[np.isinf(per_node_subsets_in)] = 0

        sub_counts_out = nodes_subsets_in - sub_counts_in

        per_node_subsets_notin[node_idx, data.val_mask] = 1 / sub_counts_out
        per_node_subsets_notin[np.isinf(per_node_subsets_notin)] = 0

    return per_node_subsets_in, per_node_subsets_notin


def test_mode_subsets_counts(subsets, nodes_subsets_in, data):
    per_node_subsets_in = np.repeat(
        1 / nodes_subsets_in[:, None], data.num_nodes, axis=-1
    )
    per_node_subsets_notin = np.repeat(
        np.zeros((data.test_mask.sum(), 1)), data.num_nodes, axis=-1
    )

    for node_idx in range(data.test_mask.sum()):
        subsets_in_mask = (subsets[:, node_idx].toarray() == 1).squeeze()
        assert nodes_subsets_in[node_idx] == subsets[subsets_in_mask].shape[0]

        sub_counts_in = subsets[subsets_in_mask].sum(0).A1
        per_node_subsets_in[node_idx, data.test_mask] = 1 / sub_counts_in
        per_node_subsets_in[np.isinf(per_node_subsets_in)] = 0

        sub_counts_out = nodes_subsets_in - sub_counts_in

        per_node_subsets_notin[node_idx, data.test_mask] = 1 / sub_counts_out
        per_node_subsets_notin[np.isinf(per_node_subsets_notin)] = 0

    return per_node_subsets_in, per_node_subsets_notin


def mixed_mode_subsets_counts(subsets, nodes_subsets_in, data):
    per_node_subsets_in = np.repeat(
        1 / nodes_subsets_in[:, None], data.num_nodes, axis=-1
    )
    per_node_subsets_notin = np.repeat(
        np.zeros((data.num_nodes, 1)), data.num_nodes, axis=-1
    )

    for node_idx in range(data.num_nodes):
        subsets_in_mask = (subsets[:, node_idx].toarray() == 1).squeeze()
        assert nodes_subsets_in[node_idx] == subsets[subsets_in_mask].shape[0]

        sub_counts_in = subsets[subsets_in_mask].sum(0).A1
        per_node_subsets_in[node_idx] = 1 / sub_counts_in
        per_node_subsets_in[np.isinf(per_node_subsets_in)] = 0

        sub_counts_out = nodes_subsets_in - sub_counts_in

        per_node_subsets_notin[node_idx] = 1 / sub_counts_out
        per_node_subsets_notin[np.isinf(per_node_subsets_notin)] = 0

    return per_node_subsets_in, per_node_subsets_notin


def get_subsets_counts_normalization_factor(
    subsets, nodes_subsets_in, data, subset_mode
):
    if subset_mode == "train":
        return train_mode_subsets_counts(
            subsets=subsets, nodes_subsets_in=nodes_subsets_in, data=data
        )
    elif subset_mode == "val":
        return val_mode_subsets_counts(
            subsets=subsets, nodes_subsets_in=nodes_subsets_in, data=data
        )
    elif subset_mode == "test":
        return test_mode_subsets_counts(
            subsets=subsets, nodes_subsets_in=nodes_subsets_in, data=data
        )
    elif subset_mode == "mixed":
        return mixed_mode_subsets_counts(
            subsets=subsets, nodes_subsets_in=nodes_subsets_in, data=data
        )
    else:
        raise ValueError(f"Invalid subset mode: {subset_mode}")


def compute_banzhaf(
    subsets,
    logits,
    sub_logits,
    sub_y,
    data,
    subset_mode="mixed",
):
    margins = get_margin_incorrect_vectorized(logits, data.y).mean(1).numpy()
    test_accs = (margins > 0)[:, data.test_mask].mean(-1)
    val_accs = (margins > 0)[:, data.val_mask].mean(-1)

    nodes_subsets_in = subsets.sum(0).A1
    nodes_subsets_notin = subsets.shape[0] - nodes_subsets_in

    per_node_subsets_in, per_node_subsets_notin = (
        get_subsets_counts_normalization_factor(
            subsets=subsets,
            nodes_subsets_in=nodes_subsets_in,
            data=data,
            subset_mode=subset_mode,
        )
    )

    margins_banzhaf = np.array(
        (1 / nodes_subsets_in[:, None]) * (subsets.T @ margins)
        - (1 / nodes_subsets_notin[:, None])
        * np.array((np.array([1]) - subsets).T @ margins)
    ).squeeze()

    test_accs_banzhaf = np.array(
        1 / nodes_subsets_in * (subsets.T @ test_accs)
        - 1 / nodes_subsets_notin * np.array((np.array([1]) - subsets).T @ test_accs)
    ).squeeze()
    val_accs_banzhaf = np.array(
        1 / nodes_subsets_in * (subsets.T @ val_accs)
        - 1 / nodes_subsets_notin * np.array((np.array([1]) - subsets).T @ val_accs)
    ).squeeze()

    sub_margins = get_margin_incorrect_vectorized(sub_logits, sub_y).mean(1).numpy()
    sub_margins_ma = np.ma.masked_invalid(sub_margins)

    sub_test_margins = sub_margins_ma[:, data.test_mask]
    sub_val_margins = sub_margins_ma[:, data.val_mask]
    sub_test_accs = (sub_test_margins > 0).mean(-1).data
    sub_val_accs = (sub_val_margins > 0).mean(-1).data

    sub_margins_banzhaf = (
        per_node_subsets_in * np.ma.dot(subsets.toarray().T, sub_margins_ma)
        - per_node_subsets_notin * np.ma.dot((1 - subsets.toarray()).T, sub_margins_ma)
    ).data

    sub_test_accs_banzhaf = np.array(
        1 / nodes_subsets_in * (subsets.T @ sub_test_accs)
        - 1
        / nodes_subsets_notin
        * (np.array((np.array([1]) - subsets).T @ sub_test_accs))
    ).squeeze()
    sub_val_accs_banzhaf = np.array(
        1 / nodes_subsets_in * (subsets.T @ sub_val_accs)
        - 1
        / nodes_subsets_notin
        * (np.array((np.array([1]) - subsets).T @ sub_val_accs))
    ).squeeze()

    return (
        margins_banzhaf.T,
        test_accs_banzhaf,
        val_accs_banzhaf,
        sub_margins_banzhaf.T,
        sub_test_accs_banzhaf,
        sub_val_accs_banzhaf,
    )
