import os
import time

import hydra
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import torch
import wandb
from joblib import Parallel, delayed
from torch_geometric.seed import seed_everything
from tqdm import tqdm

from graphsmodel import PROJECT_ROOT
from graphsmodel.evaluation import compute_banzhaf, compute_loo
from graphsmodel.evaluation.datamodel import fit_ridge
from graphsmodel.experiments.counterfactual import (
    compute_iso_reg,
    generate_stratified_sampling_mask,
    process_node,
)
from graphsmodel.training import get_logits, node_level_subset_training, train_subset
from graphsmodel.utils import (
    convert_seconds_to_dhms,
    get_appropriate_edge_index,
    get_dataset,
    get_margin_incorrect_vectorized,
    get_model,
    get_optimizer,
)
from graphsmodel.utils.visualization import get_figsize


def run_experiment(cfg, wandb_run):
    if cfg.train.deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    dataset = get_dataset(cfg)
    data = dataset[0]

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    if cfg.data.subset_mode == "train":
        nodes_to_select = training_nodes
    elif cfg.data.subset_mode == "val":
        nodes_to_select = val_nodes
    elif cfg.data.subset_mode == "test":
        nodes_to_select = test_nodes
    else:
        nodes_to_select = torch.arange(data.num_nodes)

    # length of the valuation tensor to use
    n_samples = len(nodes_to_select)

    datamodel_alpha = f"alpha_{cfg.data.alpha}"
    results_dir = cfg.core.results_dir / datamodel_alpha / "estimate_supports"
    results_dir.mkdir(parents=True, exist_ok=True)

    ####### Uncomment to double check if the linear regression approach approximates correctly the logistic regression #######
    # adj = pyg.utils.to_scipy_sparse_matrix(data.edge_index)

    # deg_sq = sp.diags(np.power(adj.sum(1).A1, -1/2))
    # lap = deg_sq @ adj @ deg_sq
    # diff_x = lap @ lap @ data.x.numpy()

    # model = linear_model.Ridge(alpha=1, fit_intercept=False)
    # model.fit(diff_x[data.train_mask], np.eye(7)[data.y[data.train_mask].numpy()])
    # print(f"sklearn accuracy: (model.predict(diff_x[data.test_mask]).argmax(1) == data.y[data.test_mask].numpy()).mean()")
    ##########################################################################################################################

    true_logits = []
    for model_seed in range(cfg.train.n_models):
        seed_everything(model_seed)
        model = get_model(cfg, data, num_classes=dataset.num_classes)
        if "_target_" in cfg.train.optimizer:
            optimizer = get_optimizer(cfg.train.optimizer, model)
            _, _, _, _, model_logits = node_level_subset_training(
                data=data,
                model=model,
                optimizer=optimizer,
                patience=cfg.train.patience,
                max_epochs=cfg.train.epochs,
            )
        else:
            edge_index = get_appropriate_edge_index(data)
            model.fit(data.x, edge_index, data.y, data.train_mask)
            model_logits = get_logits(data, model)
        true_logits.append(model_logits)
    true_logits = torch.stack(true_logits)
    true_margins = (
        get_margin_incorrect_vectorized(true_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)
    )
    true_test_acc = (true_margins > 0)[data.test_mask].mean().item()
    true_val_acc = (true_margins > 0)[data.val_mask].mean().item()

    print(f"True test accuracy: {true_test_acc}, True val accuracy: {true_val_acc}")

    correctly_classified_nodes = torch.from_numpy(true_margins > 0)
    sampled_mask = generate_stratified_sampling_mask(
        data.y, correctly_classified_nodes & data.test_mask, num_samples_per_class=30
    )
    sampled_nodes = sampled_mask.nonzero(as_tuple=True)[0]

    correctly_classified_test_nodes = (
        correctly_classified_nodes & data.test_mask
    ).nonzero(as_tuple=True)[0]
    assert torch.isin(sampled_nodes, correctly_classified_test_nodes).all()

    subset_res = torch.load(results_dir.parent / "subsets_res.pt")
    subsets = subset_res["subsets"]
    logits = subset_res["logits"]
    sub_logits = subset_res["sub_logits"]
    sub_y = subset_res["sub_y"]

    # DATAMODEL
    datamodel_res = torch.load(results_dir.parent / "datamodel_res.pt")
    if cfg.data.subset_mode == "train":
        margins_dm = datamodel_res["margins_dm"]
        sub_margins_dm = datamodel_res["sub_margins_dm"]
    elif cfg.data.subset_mode == "mixed":
        margins = get_margin_incorrect_vectorized(logits, data.y).mean(1).numpy()

        margins_dm, _ = fit_ridge(subsets, margins)
        sub_margins_dm = datamodel_res["sub_margins_dm"]

    # Banzhaf
    (
        margins_banzhaf,
        _,
        _,
        sub_margins_banzhaf,
        _,
        _,
    ) = compute_banzhaf(
        subsets=subsets,
        logits=logits,
        sub_logits=sub_logits,
        sub_y=sub_y,
        data=data,
        subset_mode=cfg.data.subset_mode,
    )

    # LOO
    loo_res = torch.load(results_dir.parent.parent / "loo_res.pt")
    loo_logits = loo_res["logits"]
    loo_sub_logits = loo_res["sub_logits"]
    loo_sub_y = loo_res["sub_y"]

    (
        margins_loo,
        _,
        _,
        sub_margins_loo,
        _,
        _,
    ) = compute_loo(
        logits=loo_logits,
        sub_logits=loo_sub_logits,
        sub_y=loo_sub_y,
        data=data,
        true_test_acc=true_test_acc,
        true_val_acc=true_val_acc,
        true_margins=true_margins,
    )

    # SHAPLEY
    shapley_res = torch.load(results_dir.parent.parent / "shapley_res.pt")
    margins_shap = shapley_res["margins_shap"]
    sub_margins_shap = shapley_res["sub_margins_shap"]

    # PC-WINTER
    pc_winter_res = torch.load(
        results_dir.parent.parent.parent / "pc_winter_0_0.5_0.7_res.pt"
    )

    margins_pc = pc_winter_res["margins_pc"][:, nodes_to_select]
    sub_margins_pc = pc_winter_res["sub_margins_pc"][:, nodes_to_select]

    print("PC-Winter results found.")

    # PC-Winter permutations
    pc_winter_perms_res = torch.load(
        results_dir.parent.parent.parent / "pc_winter_0_0.99_0.99_res.pt"
    )
    margins_perms_pc = pc_winter_perms_res["margins_pc"][:, nodes_to_select]
    sub_margins_perms_pc = pc_winter_perms_res["sub_margins_pc"][:, nodes_to_select]

    # COUNTERFACTUAL
    try:
        supports_res = torch.load(results_dir / "estimated_supports.pt")
        sampled_nodes = supports_res["sampled_nodes"]
        sampled_mask = supports_res["sampled_mask"]
        correctly_classified_nodes = supports_res["correctly_classified_nodes"]
        # estimated supports
        dm_estimated_supports = supports_res["dm_estimated_supports"]
        banzhaf_estimated_supports = supports_res["banzhaf_estimated_supports"]
        loo_estimated_supports = supports_res["loo_estimated_supports"]
        shap_estimated_supports = supports_res["shap_estimated_supports"]
        pc_estimated_supports = supports_res["pc_estimated_supports"]
        perms_pc_estimated_supports = supports_res["perms_pc_estimated_supports"]
        overall_dm_estimated_supports = supports_res["overall_dm_estimated_supports"]
        overall_banzhaf_estimated_supports = supports_res[
            "overall_banzhaf_estimated_supports"
        ]
        overall_loo_estimated_supports = supports_res["overall_loo_estimated_supports"]
        overall_shap_estimated_supports = supports_res[
            "overall_shap_estimated_supports"
        ]
        overall_pc_estimated_supports = supports_res["overall_pc_estimated_supports"]
        overall_perms_pc_estimated_supports = supports_res[
            "overall_perms_pc_estimated_supports"
        ]
        # per node margins
        dm_per_node_margin = supports_res["dm_per_node_margin"]
        banzhaf_per_node_margin = supports_res["banzhaf_per_node_margin"]
        loo_per_node_margin = supports_res["loo_per_node_margin"]
        shap_per_node_margin = supports_res["shap_per_node_margin"]
        pc_per_node_margin = supports_res["pc_per_node_margin"]
        perms_pc_per_node_margin = supports_res["perms_pc_per_node_margin"]
        overall_dm_per_node_margin = supports_res["overall_dm_per_node_margin"]
        overall_banzhaf_per_node_margin = supports_res[
            "overall_banzhaf_per_node_margin"
        ]
        overall_loo_per_node_margin = supports_res["overall_loo_per_node_margin"]
        overall_shap_per_node_margin = supports_res["overall_shap_per_node_margin"]
        overall_pc_per_node_margin = supports_res["overall_pc_per_node_margin"]
        overall_perms_pc_per_node_margin = supports_res[
            "overall_perms_pc_per_node_margin"
        ]
        # margins results
        datamodel_margins = supports_res["datamodel_margins"]
        banzhaf_margins = supports_res["banzhaf_margins"]
        loo_margins = supports_res["loo_margins"]
        shap_margins = supports_res["shap_margins"]
        pc_margins = supports_res["pc_margins"]
        perms_pc_margins = supports_res["perms_pc_margins"]
        overall_datamodel_margins = supports_res["overall_datamodel_margins"]
        overall_banzhaf_margins = supports_res["overall_banzhaf_margins"]
        overall_loo_margins = supports_res["overall_loo_margins"]
        overall_shap_margins = supports_res["overall_shap_margins"]
        overall_pc_margins = supports_res["overall_pc_margins"]
        overall_perms_pc_margins = supports_res["overall_perms_pc_margins"]
    except FileNotFoundError:
        print("Estimating supports...")

        #### Prepare the dataset to learn the iso regressor ####
        max_exponent = int(np.log2(n_samples))
        ks = 2 ** np.arange(max_exponent + 1)
        S = np.ones((sampled_mask.sum(), n_samples), dtype=bool)

        datamodel_margins = {}
        banzhaf_margins = {}
        loo_margins = {}
        shap_margins = {}
        pc_margins = {}
        perms_pc_margins = {}

        overall_datamodel_margins = {}
        overall_banzhaf_margins = {}
        overall_loo_margins = {}
        overall_shap_margins = {}
        overall_pc_margins = {}
        overall_perms_pc_margins = {}

        for k in tqdm(ks):
            dm_top_k = torch.topk(torch.from_numpy(margins_dm[sampled_mask]), k, dim=-1)
            banzhaf_top_k = torch.topk(
                torch.from_numpy(margins_banzhaf[sampled_mask]), k, dim=-1
            )
            loo_top_k = torch.topk(
                torch.from_numpy(margins_loo[sampled_mask]), k, dim=-1
            )
            shap_top_k = torch.topk(
                torch.from_numpy(margins_shap[sampled_mask]), k, dim=-1
            )
            pc_top_k = torch.topk(torch.from_numpy(margins_pc[sampled_mask]), k, dim=-1)
            perms_pc_top_k = torch.topk(
                torch.from_numpy(margins_perms_pc[sampled_mask]), k, dim=-1
            )

            overall_dm_top_k = torch.topk(
                torch.from_numpy(sub_margins_dm[sampled_mask]), k, dim=-1
            )
            overall_banzhaf_top_k = torch.topk(
                torch.from_numpy(sub_margins_banzhaf[sampled_mask]), k, dim=-1
            )
            overall_loo_top_k = torch.topk(
                torch.from_numpy(sub_margins_loo[sampled_mask]), k, dim=-1
            )
            overall_shap_top_k = torch.topk(
                torch.from_numpy(sub_margins_shap[sampled_mask]), k, dim=-1
            )
            overall_pc_top_k = torch.topk(
                torch.from_numpy(sub_margins_pc[sampled_mask]), k, dim=-1
            )
            overall_perms_pc_top_k = torch.topk(
                torch.from_numpy(sub_margins_perms_pc[sampled_mask]), k, dim=-1
            )

            # mask out the top k
            rows = np.arange(sampled_mask.sum()).reshape(-1, 1)
            dm_S_minus_G = S.copy()
            dm_S_minus_G[rows, dm_top_k.indices] = False

            banzhaf_S_minus_G = S.copy()
            banzhaf_S_minus_G[rows, banzhaf_top_k.indices] = False

            loo_S_minus_G = S.copy()
            loo_S_minus_G[rows, loo_top_k.indices] = False

            shap_S_minus_G = S.copy()
            shap_S_minus_G[rows, shap_top_k.indices] = False

            pc_S_minus_G = S.copy()
            pc_S_minus_G[rows, pc_top_k.indices] = False

            perms_pc_S_minus_G = S.copy()
            perms_pc_S_minus_G[rows, perms_pc_top_k.indices] = False

            overall_dm_S_minus_G = S.copy()
            overall_dm_S_minus_G[rows, overall_dm_top_k.indices] = False

            overall_banzhaf_S_minus_G = S.copy()
            overall_banzhaf_S_minus_G[rows, overall_banzhaf_top_k.indices] = False

            overall_loo_S_minus_G = S.copy()
            overall_loo_S_minus_G[rows, overall_loo_top_k.indices] = False

            overall_shap_S_minus_G = S.copy()
            overall_shap_S_minus_G[rows, overall_shap_top_k.indices] = False

            overall_pc_S_minus_G = S.copy()
            overall_pc_S_minus_G[rows, overall_pc_top_k.indices] = False

            overall_perms_pc_S_minus_G = S.copy()
            overall_perms_pc_S_minus_G[rows, overall_perms_pc_top_k.indices] = False

            # compute the ground-truth margins for each node when removing the top-k
            results = Parallel(n_jobs=-1)(
                delayed(process_node)(
                    sampled_nodes,
                    node_idx,
                    dm_subset,
                    banzhaf_subset,
                    loo_subset,
                    shap_subset,
                    pc_subset,
                    perms_pc_subset,
                    cfg,
                    data,
                )
                for node_idx, (
                    dm_subset,
                    banzhaf_subset,
                    loo_subset,
                    shap_subset,
                    pc_subset,
                    perms_pc_subset,
                ) in enumerate(
                    zip(
                        dm_S_minus_G,
                        banzhaf_S_minus_G,
                        loo_S_minus_G,
                        shap_S_minus_G,
                        pc_S_minus_G,
                        perms_pc_S_minus_G,
                    )
                )
            )

            overall_results = Parallel(n_jobs=-1)(
                delayed(process_node)(
                    sampled_nodes,
                    node_idx,
                    overall_dm_subset,
                    overall_banzhaf_subset,
                    overall_loo_subset,
                    overall_shap_subset,
                    overall_pc_subset,
                    overall_perms_pc_subset,
                    cfg,
                    data,
                )
                for node_idx, (
                    overall_dm_subset,
                    overall_banzhaf_subset,
                    overall_loo_subset,
                    overall_shap_subset,
                    overall_pc_subset,
                    overall_perms_pc_subset,
                ) in enumerate(
                    zip(
                        overall_dm_S_minus_G,
                        overall_banzhaf_S_minus_G,
                        overall_loo_S_minus_G,
                        overall_shap_S_minus_G,
                        overall_pc_S_minus_G,
                        overall_perms_pc_S_minus_G,
                    )
                )
            )

            (
                dm_per_node_margin,
                banzhaf_per_node_margin,
                loo_per_node_margin,
                shap_per_node_margin,
                pc_per_node_margin,
                perms_pc_per_node_margin,
            ) = zip(*results)

            (
                overall_dm_per_node_margin,
                overall_banzhaf_per_node_margin,
                overall_loo_per_node_margin,
                overall_shap_per_node_margin,
                overall_pc_per_node_margin,
                overall_perms_pc_per_node_margin,
            ) = zip(*overall_results)

            datamodel_margins[k] = np.stack(dm_per_node_margin)
            banzhaf_margins[k] = np.stack(banzhaf_per_node_margin)
            loo_margins[k] = np.stack(loo_per_node_margin)
            shap_margins[k] = np.stack(shap_per_node_margin)
            pc_margins[k] = np.stack(pc_per_node_margin)
            perms_pc_margins[k] = np.stack(perms_pc_per_node_margin)

            overall_datamodel_margins[k] = np.stack(overall_dm_per_node_margin)
            overall_banzhaf_margins[k] = np.stack(overall_banzhaf_per_node_margin)
            overall_loo_margins[k] = np.stack(overall_loo_per_node_margin)
            overall_shap_margins[k] = np.stack(overall_shap_per_node_margin)
            overall_pc_margins[k] = np.stack(overall_pc_per_node_margin)
            overall_perms_pc_margins[k] = np.stack(overall_perms_pc_per_node_margin)

        ##############################################################

        #### Estimate margins with the isotonic regressor ####
        dm_estimated_supports = []
        banzhaf_estimated_supports = []
        loo_estimated_supports = []
        shap_estimated_supports = []
        pc_estimated_supports = []
        perms_pc_estimated_supports = []

        overall_dm_estimated_supports = []
        overall_banzhaf_estimated_supports = []
        overall_loo_estimated_supports = []
        overall_shap_estimated_supports = []
        overall_pc_estimated_supports = []
        overall_perms_pc_estimated_supports = []

        for node_idx in range(len(sampled_nodes)):
            dm_estimated_supports.append(
                compute_iso_reg(datamodel_margins, node_idx, n_samples)
            )
            banzhaf_estimated_supports.append(
                compute_iso_reg(banzhaf_margins, node_idx, n_samples)
            )
            loo_estimated_supports.append(
                compute_iso_reg(loo_margins, node_idx, n_samples)
            )
            shap_estimated_supports.append(
                compute_iso_reg(shap_margins, node_idx, n_samples)
            )
            pc_estimated_supports.append(
                compute_iso_reg(pc_margins, node_idx, n_samples)
            )
            perms_pc_estimated_supports.append(
                compute_iso_reg(perms_pc_margins, node_idx, n_samples)
            )

            overall_dm_estimated_supports.append(
                compute_iso_reg(overall_datamodel_margins, node_idx, n_samples)
            )
            overall_banzhaf_estimated_supports.append(
                compute_iso_reg(overall_banzhaf_margins, node_idx, n_samples)
            )
            overall_loo_estimated_supports.append(
                compute_iso_reg(overall_loo_margins, node_idx, n_samples)
            )
            overall_shap_estimated_supports.append(
                compute_iso_reg(overall_shap_margins, node_idx, n_samples)
            )
            overall_pc_estimated_supports.append(
                compute_iso_reg(overall_pc_margins, node_idx, n_samples)
            )
            overall_perms_pc_estimated_supports.append(
                compute_iso_reg(overall_perms_pc_margins, node_idx, n_samples)
            )

        dm_estimated_supports = torch.tensor(dm_estimated_supports)
        banzhaf_estimated_supports = torch.tensor(banzhaf_estimated_supports)
        loo_estimated_supports = torch.tensor(loo_estimated_supports)
        shap_estimated_supports = torch.tensor(shap_estimated_supports)
        pc_estimated_supports = torch.tensor(pc_estimated_supports)
        perms_pc_estimated_supports = torch.tensor(perms_pc_estimated_supports)

        overall_dm_estimated_supports = torch.tensor(overall_dm_estimated_supports)
        overall_banzhaf_estimated_supports = torch.tensor(
            overall_banzhaf_estimated_supports
        )
        overall_loo_estimated_supports = torch.tensor(overall_loo_estimated_supports)
        overall_shap_estimated_supports = torch.tensor(overall_shap_estimated_supports)
        overall_pc_estimated_supports = torch.tensor(overall_pc_estimated_supports)
        overall_perms_pc_estimated_supports = torch.tensor(
            overall_perms_pc_estimated_supports
        )
        ##############################################################

        #### Compute the margins with the estimated supports from the isotonic regression ####
        S = np.ones(n_samples, dtype=bool)

        dm_per_node_margin = []
        banzhaf_per_node_margin = []
        loo_per_node_margin = []
        shap_per_node_margin = []
        pc_per_node_margin = []
        perms_pc_per_node_margin = []

        for node_idx, (dm_k, banzhaf_k, loo_k, shap_k, pc_k, perms_pc_k) in enumerate(
            tqdm(
                zip(
                    dm_estimated_supports,
                    banzhaf_estimated_supports,
                    loo_estimated_supports,
                    shap_estimated_supports,
                    pc_estimated_supports,
                    perms_pc_estimated_supports,
                ),
                total=len(sampled_nodes),
            )
        ):
            cons_dm_k = int(torch.ceil(dm_k * 1.2).item())
            if cons_dm_k > n_samples:
                cons_dm_k = n_samples
            dm_top_k = torch.topk(
                torch.from_numpy(margins_dm[sampled_nodes[node_idx]]),
                cons_dm_k,
                dim=-1,
            )
            dm_S_minus_G = S.copy()
            dm_S_minus_G[dm_top_k.indices] = False

            _, _, _, dm_logits = train_subset(
                subset_idx=None,
                subset=dm_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            dm_per_node_margin.append(
                get_margin_incorrect_vectorized(dm_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_banzhaf_k = int(torch.ceil(banzhaf_k * 1.2).item())
            if cons_banzhaf_k > n_samples:
                cons_banzhaf_k = n_samples
            banzhaf_top_k = torch.topk(
                torch.from_numpy(margins_banzhaf[sampled_nodes[node_idx]]),
                cons_banzhaf_k,
                dim=-1,
            )
            banzhaf_S_minus_G = S.copy()
            banzhaf_S_minus_G[banzhaf_top_k.indices] = False

            _, _, _, banzhaf_logits = train_subset(
                subset_idx=None,
                subset=banzhaf_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            banzhaf_per_node_margin.append(
                get_margin_incorrect_vectorized(banzhaf_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_loo_k = int(torch.ceil(loo_k * 1.2).item())
            if cons_loo_k > n_samples:
                cons_loo_k = n_samples
            loo_top_k = torch.topk(
                torch.from_numpy(margins_loo[sampled_nodes[node_idx]]),
                cons_loo_k,
                dim=-1,
            )
            loo_S_minus_G = S.copy()
            loo_S_minus_G[loo_top_k.indices] = False

            _, _, _, loo_logits = train_subset(
                subset_idx=None,
                subset=loo_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            loo_per_node_margin.append(
                get_margin_incorrect_vectorized(loo_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_shap_k = int(torch.ceil(shap_k * 1.2).item())
            if cons_shap_k > n_samples:
                cons_shap_k = n_samples
            shap_top_k = torch.topk(
                torch.from_numpy(margins_shap[sampled_nodes[node_idx]]),
                cons_shap_k,
                dim=-1,
            )
            shap_S_minus_G = S.copy()
            shap_S_minus_G[shap_top_k.indices] = False

            _, _, _, shap_logits = train_subset(
                subset_idx=None,
                subset=shap_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            shap_per_node_margin.append(
                get_margin_incorrect_vectorized(shap_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_pc_k = int(torch.ceil(pc_k * 1.2).item())
            if cons_pc_k > n_samples:
                cons_pc_k = n_samples
            pc_top_k = torch.topk(
                torch.from_numpy(margins_pc[sampled_nodes[node_idx]]),
                cons_pc_k,
                dim=-1,
            )
            pc_S_minus_G = S.copy()
            pc_S_minus_G[pc_top_k.indices] = False

            _, _, _, pc_logits = train_subset(
                subset_idx=None,
                subset=pc_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            pc_per_node_margin.append(
                get_margin_incorrect_vectorized(pc_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_perms_pc_k = int(torch.ceil(perms_pc_k * 1.2).item())
            if cons_perms_pc_k > n_samples:
                cons_perms_pc_k = n_samples
            perms_pc_top_k = torch.topk(
                torch.from_numpy(margins_perms_pc[sampled_nodes[node_idx]]),
                cons_perms_pc_k,
                dim=-1,
            )
            perms_pc_S_minus_G = S.copy()
            perms_pc_S_minus_G[perms_pc_top_k.indices] = False

            _, _, _, perms_pc_logits = train_subset(
                subset_idx=None,
                subset=perms_pc_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            perms_pc_per_node_margin.append(
                get_margin_incorrect_vectorized(perms_pc_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

        dm_per_node_margin = np.stack(dm_per_node_margin)
        banzhaf_per_node_margin = np.stack(banzhaf_per_node_margin)
        loo_per_node_margin = np.stack(loo_per_node_margin)
        shap_per_node_margin = np.stack(shap_per_node_margin)
        pc_per_node_margin = np.stack(pc_per_node_margin)
        perms_pc_per_node_margin = np.stack(perms_pc_per_node_margin)
        ##############################################################

        #### Compute the sub margins with the estimated supports from the isotonic regression ####
        overall_dm_per_node_margin = []
        overall_banzhaf_per_node_margin = []
        overall_loo_per_node_margin = []
        overall_shap_per_node_margin = []
        overall_pc_per_node_margin = []
        overall_perms_pc_per_node_margin = []

        for node_idx, (dm_k, banzhaf_k, loo_k, shap_k, pc_k, perms_pc_k) in enumerate(
            tqdm(
                zip(
                    overall_dm_estimated_supports,
                    overall_banzhaf_estimated_supports,
                    overall_loo_estimated_supports,
                    overall_shap_estimated_supports,
                    overall_pc_estimated_supports,
                    overall_perms_pc_estimated_supports,
                ),
                total=len(sampled_nodes),
            )
        ):
            cons_dm_k = int(torch.ceil(dm_k * 1.2).item())
            if cons_dm_k > n_samples:
                cons_dm_k = n_samples
            dm_top_k = torch.topk(
                torch.from_numpy(sub_margins_dm[sampled_nodes[node_idx]]),
                cons_dm_k,
                dim=-1,
            )
            dm_S_minus_G = S.copy()
            dm_S_minus_G[dm_top_k.indices] = False

            _, _, _, dm_logits = train_subset(
                subset_idx=None,
                subset=dm_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            overall_dm_per_node_margin.append(
                get_margin_incorrect_vectorized(dm_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_banzhaf_k = int(torch.ceil(banzhaf_k * 1.2).item())
            if cons_banzhaf_k > n_samples:
                cons_banzhaf_k = n_samples
            banzhaf_top_k = torch.topk(
                torch.from_numpy(sub_margins_banzhaf[sampled_nodes[node_idx]]),
                cons_banzhaf_k,
                dim=-1,
            )
            banzhaf_S_minus_G = S.copy()
            banzhaf_S_minus_G[banzhaf_top_k.indices] = False

            _, _, _, banzhaf_logits = train_subset(
                subset_idx=None,
                subset=banzhaf_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            overall_banzhaf_per_node_margin.append(
                get_margin_incorrect_vectorized(banzhaf_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_loo_k = int(torch.ceil(loo_k * 1.2).item())
            if cons_loo_k > n_samples:
                cons_loo_k = n_samples
            loo_top_k = torch.topk(
                torch.from_numpy(sub_margins_loo[sampled_nodes[node_idx]]),
                cons_loo_k,
                dim=-1,
            )
            loo_S_minus_G = S.copy()
            loo_S_minus_G[loo_top_k.indices] = False

            _, _, _, loo_logits = train_subset(
                subset_idx=None,
                subset=loo_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            overall_loo_per_node_margin.append(
                get_margin_incorrect_vectorized(loo_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_shap_k = int(torch.ceil(shap_k * 1.2).item())
            if cons_shap_k > n_samples:
                cons_shap_k = n_samples
            shap_top_k = torch.topk(
                torch.from_numpy(sub_margins_shap[sampled_nodes[node_idx]]),
                cons_shap_k,
                dim=-1,
            )
            shap_S_minus_G = S.copy()
            shap_S_minus_G[shap_top_k.indices] = False

            _, _, _, shap_logits = train_subset(
                subset_idx=None,
                subset=shap_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            overall_shap_per_node_margin.append(
                get_margin_incorrect_vectorized(shap_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_pc_k = int(torch.ceil(pc_k * 1.2).item())
            if cons_pc_k > n_samples:
                cons_pc_k = n_samples
            pc_top_k = torch.topk(
                torch.from_numpy(sub_margins_pc[sampled_nodes[node_idx]]),
                cons_pc_k,
                dim=-1,
            )
            pc_S_minus_G = S.copy()
            pc_S_minus_G[pc_top_k.indices] = False

            _, _, _, pc_logits = train_subset(
                subset_idx=None,
                subset=pc_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            overall_pc_per_node_margin.append(
                get_margin_incorrect_vectorized(pc_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

            cons_perms_pc_k = int(torch.ceil(perms_pc_k * 1.2).item())
            if cons_perms_pc_k > n_samples:
                cons_perms_pc_k = n_samples
            perms_pc_top_k = torch.topk(
                torch.from_numpy(sub_margins_perms_pc[sampled_nodes[node_idx]]),
                cons_perms_pc_k,
                dim=-1,
            )
            perms_pc_S_minus_G = S.copy()
            perms_pc_S_minus_G[perms_pc_top_k.indices] = False

            _, _, _, perms_pc_logits = train_subset(
                subset_idx=None,
                subset=perms_pc_S_minus_G,
                cfg=cfg,
                data=data,
                logits_on_data=True,
            )
            overall_perms_pc_per_node_margin.append(
                get_margin_incorrect_vectorized(perms_pc_logits.unsqueeze(0), data.y)
                .squeeze(0)
                .numpy()
                .mean(0)[sampled_nodes[node_idx]]
            )

        overall_dm_per_node_margin = np.stack(overall_dm_per_node_margin)
        overall_banzhaf_per_node_margin = np.stack(overall_banzhaf_per_node_margin)
        overall_loo_per_node_margin = np.stack(overall_loo_per_node_margin)
        overall_shap_per_node_margin = np.stack(overall_shap_per_node_margin)
        overall_pc_per_node_margin = np.stack(overall_pc_per_node_margin)
        overall_perms_pc_per_node_margin = np.stack(overall_perms_pc_per_node_margin)
        ##############################################################

        #### Check how many estimations are correct ####
        dm_perc_true_supp = (dm_per_node_margin <= 0).mean()
        print(
            f"Datamodel: {round(dm_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        banzhaf_perc_true_supp = (banzhaf_per_node_margin <= 0).mean()
        print(
            f"Banzhaf: {round(banzhaf_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        loo_perc_true_supp = (loo_per_node_margin <= 0).mean()
        print(f"LOO: {round(loo_perc_true_supp.item(), 2) * 100}% correctly estimated.")

        shap_perc_true_supp = (shap_per_node_margin <= 0).mean()
        print(
            f"SHAP: {round(shap_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        pc_perc_true_supp = (pc_per_node_margin <= 0).mean()
        print(
            f"PC-Winter: {round(pc_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        perms_pc_perc_true_supp = (perms_pc_per_node_margin <= 0).mean()
        print(
            f"PC-Winter permutations: {round(perms_pc_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        overall_dm_perc_true_supp = (overall_dm_per_node_margin <= 0).mean()
        print(
            f"Overall datamodel: {round(overall_dm_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        overall_banzhaf_perc_true_supp = (overall_banzhaf_per_node_margin <= 0).mean()
        print(
            f"Overall Banzhaf: {round(overall_banzhaf_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        overall_loo_perc_true_supp = (overall_loo_per_node_margin <= 0).mean()
        print(
            f"Overall LOO: {round(overall_loo_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        overall_shap_perc_true_supp = (overall_shap_per_node_margin <= 0).mean()
        print(
            f"Overall SHAP: {round(overall_shap_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        overall_pc_perc_true_supp = (overall_pc_per_node_margin <= 0).mean()
        print(
            f"Overall PC-Winter: {round(overall_pc_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )

        overall_perms_pc_perc_true_supp = (overall_perms_pc_per_node_margin <= 0).mean()
        print(
            f"Overall PC-Winter permutations: {round(overall_perms_pc_perc_true_supp.item(), 2) * 100}% correctly estimated."
        )
        ##############################################################

        supports_res = {
            "sampled_nodes": sampled_nodes,
            "sampled_mask": sampled_mask,
            "correctly_classified_nodes": correctly_classified_nodes,
            # estimated supports
            "dm_estimated_supports": dm_estimated_supports,
            "banzhaf_estimated_supports": banzhaf_estimated_supports,
            "loo_estimated_supports": loo_estimated_supports,
            "shap_estimated_supports": shap_estimated_supports,
            "pc_estimated_supports": pc_estimated_supports,
            "perms_pc_estimated_supports": perms_pc_estimated_supports,
            "overall_dm_estimated_supports": overall_dm_estimated_supports,
            "overall_banzhaf_estimated_supports": overall_banzhaf_estimated_supports,
            "overall_loo_estimated_supports": overall_loo_estimated_supports,
            "overall_shap_estimated_supports": overall_shap_estimated_supports,
            "overall_pc_estimated_supports": overall_pc_estimated_supports,
            "overall_perms_pc_estimated_supports": overall_perms_pc_estimated_supports,
            # per node margins
            "dm_per_node_margin": dm_per_node_margin,
            "banzhaf_per_node_margin": banzhaf_per_node_margin,
            "loo_per_node_margin": loo_per_node_margin,
            "shap_per_node_margin": shap_per_node_margin,
            "pc_per_node_margin": pc_per_node_margin,
            "perms_pc_per_node_margin": perms_pc_per_node_margin,
            "overall_dm_per_node_margin": overall_dm_per_node_margin,
            "overall_banzhaf_per_node_margin": overall_banzhaf_per_node_margin,
            "overall_loo_per_node_margin": overall_loo_per_node_margin,
            "overall_shap_per_node_margin": overall_shap_per_node_margin,
            "overall_pc_per_node_margin": overall_pc_per_node_margin,
            "overall_perms_pc_per_node_margin": overall_perms_pc_per_node_margin,
            # margins results
            "datamodel_margins": datamodel_margins,
            "banzhaf_margins": banzhaf_margins,
            "loo_margins": loo_margins,
            "shap_margins": shap_margins,
            "pc_margins": pc_margins,
            "perms_pc_margins": perms_pc_margins,
            "overall_datamodel_margins": overall_datamodel_margins,
            "overall_banzhaf_margins": overall_banzhaf_margins,
            "overall_loo_margins": overall_loo_margins,
            "overall_shap_margins": overall_shap_margins,
            "overall_pc_margins": overall_pc_margins,
            "overall_perms_pc_margins": overall_perms_pc_margins,
        }
        torch.save(supports_res, results_dir / "estimated_supports.pt")

    # PLOTS
    fig, ax = plt.subplots(
        figsize=get_figsize(nrows=1, ncols=1), constrained_layout=True
    )
    k = torch.arange(n_samples, dtype=torch.long)
    ax.plot(
        (torch.ceil(dm_estimated_supports * 1.2) <= k[:, None]).float().mean(1),
        label="datamodel",
    )
    ax.plot(
        (torch.ceil(banzhaf_estimated_supports * 1.2) <= k[:, None]).float().mean(1),
        label="banzhaf",
    )
    ax.plot(
        (torch.ceil(loo_estimated_supports * 1.2) <= k[:, None]).float().mean(1),
        label="loo",
    )
    ax.plot(
        (torch.ceil(shap_estimated_supports * 1.2) <= k[:, None]).float().mean(1),
        label="shap",
    )
    ax.plot(
        (torch.ceil(pc_estimated_supports * 1.2) <= k[:, None]).float().mean(1),
        label="pc-winter",
    )
    ax.plot(
        (torch.ceil(perms_pc_estimated_supports * 1.2) <= k[:, None]).float().mean(1),
        label="pc-winter perms",
    )
    if cfg.data.subset_mode in ["test", "mixed"]:
        ax.set_xlim(0, 50)
    ax.set_xlabel("Estimated support size")
    ax.set_ylabel("CDF")
    ax.legend()
    fig.savefig(results_dir / "estimated_supports.png")


@hydra.main(
    version_base=None, config_path=str(PROJECT_ROOT / "conf"), config_name="experiment"
)
def main(cfg: omegaconf.DictConfig):
    cpu_time_start = time.process_time()
    wall_time_start = time.time()

    config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False)
    if "wandb" in cfg.log:
        run = wandb.init(project=cfg.log.project, config=config)
    else:
        run = None
    run_experiment(cfg=cfg, wandb_run=run)

    if "wandb" in cfg.log:
        run.finish()

    cpu_time = time.process_time() - cpu_time_start
    wall_time = time.time() - wall_time_start

    cpu_days, cpu_hours, cpu_minutes, cpu_seconds = convert_seconds_to_dhms(cpu_time)
    wall_days, wall_hours, wall_minutes, wall_seconds = convert_seconds_to_dhms(
        wall_time
    )
    print(
        f"CPU time: {cpu_days} Days: {cpu_hours} Hours: {cpu_minutes} Min: {cpu_seconds:.2f} Sec"
    )
    print(
        f"Wall time: {wall_days} Days: {wall_hours} Hours: {wall_minutes} Min: {wall_seconds:.2f} Sec"
    )


if __name__ == "__main__":
    main()
