import numpy as np
import torch
from joblib import Parallel, delayed
from tqdm import tqdm

from graphsmodel.training import train_subset
from graphsmodel.utils import get_margin_incorrect_vectorized


def shapley_permutation(t, rnd_perm, n_samples, p_trunc, cfg, data):

    sub_test_accs_diff = np.zeros(n_samples)
    sub_val_accs_diff = np.zeros(n_samples)
    test_accs_diff = np.zeros(n_samples)
    val_accs_diff = np.zeros(n_samples)
    margins_diff = np.zeros((n_samples, data.num_nodes))
    sub_margins_diff = np.zeros((n_samples, data.num_nodes))

    prev_sub_test_accs = 0
    prev_sub_val_acc = 0
    prev_test_accs = 0
    prev_val_accs = 0
    prev_margins = np.zeros(data.num_nodes)
    prev_sub_margins = np.zeros(data.num_nodes)

    training_nodes = data.train_mask.nonzero(as_tuple=True)[0]
    first_training_node = False
    scanned_samples = 0
    for sample_idx in range(len(rnd_perm)):
        if not first_training_node:
            if rnd_perm[sample_idx] not in training_nodes:
                continue
            else:
                first_training_node = True
                training_node_idx = sample_idx

        if scanned_samples > int(p_trunc * n_samples):
            break

        subset = torch.zeros(n_samples, dtype=torch.bool)
        subset[rnd_perm[training_node_idx : sample_idx + 1]] = True

        subset_idx = t * int(p_trunc * n_samples) + scanned_samples
        (
            _,
            shap_sub_logits,
            shap_sub_y,
            shap_logits,
        ) = train_subset(
            subset_idx=subset_idx,
            subset=subset,
            cfg=cfg,
            data=data,
            logits_on_data=True,
        )

        shap_margins = (
            get_margin_incorrect_vectorized(shap_logits.unsqueeze(0), data.y)[0].numpy()
        ).mean(0)

        shap_test_accs = (shap_margins > 0)[data.test_mask].mean().item()
        shap_val_accs = (shap_margins > 0)[data.val_mask].mean().item()

        test_accs_diff[rnd_perm[sample_idx]] = shap_test_accs - prev_test_accs
        prev_test_accs = shap_test_accs

        val_accs_diff[rnd_perm[sample_idx]] = shap_val_accs - prev_val_accs
        prev_val_accs = shap_val_accs

        margins_diff[rnd_perm[sample_idx]] = shap_margins - prev_margins
        prev_margins = shap_margins

        shap_sub_margins = np.nanmean(
            (
                get_margin_incorrect_vectorized(
                    shap_sub_logits.unsqueeze(0), shap_sub_y
                )[0].numpy()
            ),
            axis=0,
        )
        shap_sub_margins_ma = np.ma.masked_invalid(shap_sub_margins)
        shap_sub_test_margins = shap_sub_margins_ma[data.test_mask]
        shap_sub_val_margins = shap_sub_margins_ma[data.val_mask]

        shap_sub_test_accs = (shap_sub_test_margins > 0).mean()
        sub_test_accs_diff[rnd_perm[sample_idx]] = (
            shap_sub_test_accs - prev_sub_test_accs
        )
        prev_sub_test_accs = shap_sub_test_accs

        shap_sub_val_accs = (shap_sub_val_margins > 0).mean()
        sub_val_accs_diff[rnd_perm[sample_idx]] = shap_sub_val_accs - prev_sub_val_acc
        prev_sub_val_acc = shap_sub_val_accs

        shap_sub_margins = np.where(
            np.isnan(shap_sub_margins),
            prev_sub_margins,
            shap_sub_margins,
        )
        sub_margins_diff[rnd_perm[sample_idx]] = shap_sub_margins - prev_sub_margins
        prev_sub_margins = shap_sub_margins

        scanned_samples += 1

    return (
        t,
        rnd_perm,
        sub_test_accs_diff,
        sub_val_accs_diff,
        test_accs_diff,
        val_accs_diff,
        margins_diff,
        sub_margins_diff,
    )


def compute_data_shapley(perms, n_samples, p_trunc, cfg, data, n_jobs=1):
    pbar = tqdm(enumerate(perms, start=1), total=len(perms))
    results = Parallel(n_jobs=n_jobs)(
        delayed(shapley_permutation)(t, rnd_perm, n_samples, p_trunc, cfg, data)
        for t, rnd_perm in pbar
    )
    # update time in pbar
    pbar.close()

    # Combine the results
    sub_test_accs_shap = np.zeros(n_samples)
    sub_val_accs_shap = np.zeros(n_samples)
    test_accs_shap = np.zeros(n_samples)
    val_accs_shap = np.zeros(n_samples)
    margins_shap = np.zeros((n_samples, data.num_nodes))
    sub_margins_shap = np.zeros((n_samples, data.num_nodes))

    training_nodes = data.train_mask.nonzero(as_tuple=True)[0]
    perms = []
    for (
        t,
        rnd_perm,
        sub_test,
        sub_val,
        test,
        val,
        margins,
        sub_margins,
    ) in results:
        first_training_node = False
        scanned_samples = 0
        for sample_idx in range(len(rnd_perm)):
            if not first_training_node:
                if rnd_perm[sample_idx] not in training_nodes:
                    continue
                else:
                    first_training_node = True

            if scanned_samples > int(p_trunc * n_samples):
                break

            # test and validation shapley for subset accuracy
            sub_test_accs_shap[rnd_perm[sample_idx]] = (t - 1) / t * sub_test_accs_shap[
                rnd_perm[sample_idx]
            ] + 1 / t * sub_test[rnd_perm[sample_idx]]
            sub_val_accs_shap[rnd_perm[sample_idx]] = (t - 1) / t * sub_val_accs_shap[
                rnd_perm[sample_idx]
            ] + 1 / t * sub_val[rnd_perm[sample_idx]]

            # margins for the subset data
            sub_margins_shap[rnd_perm[sample_idx]] = (t - 1) / t * sub_margins_shap[
                rnd_perm[sample_idx]
            ] + 1 / t * sub_margins[rnd_perm[sample_idx]]

            # test and validation shapley for accuracy on the whole data
            test_accs_shap[rnd_perm[sample_idx]] = (t - 1) / t * test_accs_shap[
                rnd_perm[sample_idx]
            ] + 1 / t * test[rnd_perm[sample_idx]]
            val_accs_shap[rnd_perm[sample_idx]] = (t - 1) / t * val_accs_shap[
                rnd_perm[sample_idx]
            ] + 1 / t * val[rnd_perm[sample_idx]]

            # margins of the whole data
            margins_shap[rnd_perm[sample_idx]] = (t - 1) / t * margins_shap[
                rnd_perm[sample_idx]
            ] + 1 / t * margins[rnd_perm[sample_idx]]

            scanned_samples += 1

        perms.append(rnd_perm)

    return (
        perms,
        sub_test_accs_shap,
        sub_val_accs_shap,
        test_accs_shap,
        val_accs_shap,
        margins_shap.T,
        sub_margins_shap.T,
    )
