import numpy as np
import torch
from sklearn.isotonic import IsotonicRegression

from graphsmodel.training import train_subset
from graphsmodel.utils import get_margin_incorrect_vectorized


def generate_stratified_sampling_mask(y, mask, num_samples_per_class=30):
    unique_classes = torch.unique(y)

    sampled_mask = torch.zeros_like(mask, dtype=torch.bool)

    for cls in unique_classes:
        class_mask = (y == cls) & mask
        nodes = class_mask.nonzero(as_tuple=True)[0]
        if len(nodes) > num_samples_per_class:
            perm = torch.randperm(len(nodes))
            sampled_idx = perm[:num_samples_per_class]
            sampled_mask[nodes[sampled_idx]] = True
        else:
            for node in nodes:
                sampled_mask[node] = True

    return sampled_mask


def process_node(
    sampled_nodes,
    node_idx,
    dm_subset,
    banzhaf_subset,
    loo_subset,
    shap_subset,
    pc_subset,
    perms_pc_subset,
    cfg,
    data,
):
    """
    Process a node in the graph by computing the margin values for different data valuation approaches.

    Args:
        sampled_nodes (torch.Tensor): A PyTorch tensor of the node indices of sampled test nodes.
        node_idx (int): The index of the current test node.
        dm_subset (torch.Tensor): A boolean PyTorch tensor representing the nodes to consider with the top-k removed for dm.
        banzhaf_subset (torch.Tensor): A boolean PyTorch tensor representing the nodes to consider with the top-k removed for banzhaf.
        loo_subset (torch.Tensor): A boolean PyTorch tensor representing the nodes to consider with the top-k removed for loo.
        shap_subset (torch.Tensor): A boolean PyTorch tensor representing the nodes to consider with the top-k removed for shap.
        cfg (dict): The configuration parameters.
        data (object): The Data object representing the graph.

    Returns:
        tuple: A tuple containing the margin values for dm, banzhaf, loo, and shap.

    """
    _, _, _, dm_logits = train_subset(
        subset_idx=None, subset=dm_subset, cfg=cfg, data=data, logits_on_data=True
    )
    dm_margin = (
        get_margin_incorrect_vectorized(dm_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)[sampled_nodes[node_idx]]
    )

    _, _, _, banzhaf_logits = train_subset(
        subset_idx=None, subset=banzhaf_subset, cfg=cfg, data=data, logits_on_data=True
    )
    banzhaf_margin = (
        get_margin_incorrect_vectorized(banzhaf_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)[sampled_nodes[node_idx]]
    )

    _, _, _, loo_logits = train_subset(
        subset_idx=None, subset=loo_subset, cfg=cfg, data=data, logits_on_data=True
    )
    loo_margin = (
        get_margin_incorrect_vectorized(loo_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)[sampled_nodes[node_idx]]
    )

    _, _, _, shap_logits = train_subset(
        subset_idx=None, subset=shap_subset, cfg=cfg, data=data, logits_on_data=True
    )
    shap_margin = (
        get_margin_incorrect_vectorized(shap_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)[sampled_nodes[node_idx]]
    )

    _, _, _, pc_logits = train_subset(
        subset_idx=None, subset=pc_subset, cfg=cfg, data=data, logits_on_data=True
    )
    pc_margin = (
        get_margin_incorrect_vectorized(pc_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)[sampled_nodes[node_idx]]
    )

    _, _, _, perms_pc_logits = train_subset(
        subset_idx=None, subset=perms_pc_subset, cfg=cfg, data=data, logits_on_data=True
    )
    perms_pc_margin = (
        get_margin_incorrect_vectorized(perms_pc_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)[sampled_nodes[node_idx]]
    )

    return (
        dm_margin,
        banzhaf_margin,
        loo_margin,
        shap_margin,
        pc_margin,
        perms_pc_margin,
    )


def compute_iso_reg(margins, node_idx, n_samples):
    X = np.array(list(margins.keys()))
    y = np.stack(list(margins.values()))[:, node_idx]

    iso_reg = IsotonicRegression(increasing=False, out_of_bounds="clip").fit(X, y)
    expected_margin = iso_reg.predict(np.arange(n_samples, dtype=int))

    k_hat = (expected_margin <= 0).nonzero()[0]
    # the smallest k such that the node is misclassified
    return (
        k_hat[0]
        if len(k_hat) > 0
        # if the node is never misclassified, return the number of samples accounting for the conservative estimate of data support
        else n_samples
    )
