import os

import numpy as np
import torch
from joblib import Parallel, delayed
from torch_geometric.utils import k_hop_subgraph
from tqdm import tqdm

from graphsmodel.training import train_subset
from graphsmodel.utils import get_margin_incorrect_vectorized


def generate_maps(train_idx_list, edge_index, num_nodes):
    r"""
    Code from https://github.com/frankhlchi/graph-data-valuation.

    Generate the computation tree for the PC-Winter algorithm.
    """
    labeled_to_player_map = {}
    sample_value_dict = {}
    sample_counter_dict = {}
    for labeled in train_idx_list:
        hop_1_nodes, _, _, _ = k_hop_subgraph(
            int(labeled),
            num_hops=1,
            edge_index=edge_index,
            relabel_nodes=False,
            num_nodes=num_nodes,
        )
        hop_1_nodes_list = list(hop_1_nodes.cpu().numpy())
        hop_1_nodes_list.remove(labeled)
        labeled_to_player_map[labeled] = {}
        sample_value_dict[labeled] = {}
        sample_counter_dict[labeled] = {}
        labeled_to_player_map[labeled][labeled] = {}
        sample_value_dict[labeled][labeled] = {}
        sample_counter_dict[labeled][labeled] = {}

        for hop_1_node in hop_1_nodes_list:
            sub_nodes_2, _, _, _ = k_hop_subgraph(
                int(hop_1_node),
                num_hops=1,
                edge_index=edge_index,
                relabel_nodes=False,
                num_nodes=num_nodes,
            )
            sub_nodes_2_list = list(sub_nodes_2.cpu().numpy())
            sub_nodes_2_list.remove(hop_1_node)
            labeled_to_player_map[labeled][hop_1_node] = {}
            sample_value_dict[labeled][hop_1_node] = {}
            sample_counter_dict[labeled][hop_1_node] = {}

            for hop_2_node in sub_nodes_2_list:
                labeled_to_player_map[labeled][hop_1_node][hop_2_node] = [hop_2_node]
                sample_value_dict[labeled][hop_1_node][hop_2_node] = 0
                sample_counter_dict[labeled][hop_1_node][hop_2_node] = 0
            labeled_to_player_map[labeled][hop_1_node][hop_1_node] = [hop_1_node]
            sample_value_dict[labeled][hop_1_node][hop_1_node] = 0
            sample_counter_dict[labeled][hop_1_node][hop_1_node] = 0

        labeled_to_player_map[labeled][labeled][labeled] = [labeled]
        sample_value_dict[labeled][labeled][labeled] = 0
        sample_counter_dict[labeled][labeled][labeled] = 0

    return labeled_to_player_map, sample_value_dict, sample_counter_dict


def generate_permutations(
    n_subsets,
    labeled_node_list,
    label_trunc_ratio,
    group_trunc_ratio_hop_1,
    group_trunc_ratio_hop_2,
    labeled_to_player_map,
):
    r"""
    Code adapted from https://github.com/frankhlchi/graph-data-valuation

    Generate permutations for the PC-Winter algorithm aligning with the number of subsets.
    """
    counter = 0
    perms = []
    while counter < n_subsets:
        np.random.shuffle(labeled_node_list)
        perm = {}

        trunc_label_len = int(np.ceil(len(labeled_node_list) * (1 - label_trunc_ratio)))
        for labeled_node in labeled_node_list[:trunc_label_len]:
            hop_1_list = list(labeled_to_player_map[labeled_node].keys())
            np.random.shuffle(hop_1_list)
            hop_1_list.remove(labeled_node)
            hop_1_list.insert(0, labeled_node)

            truncate_length = (
                int(np.ceil((len(hop_1_list) - 1) * (1 - group_trunc_ratio_hop_1))) + 1
            )
            truncate_length = min(truncate_length, len(hop_1_list))
            hop_1_list = hop_1_list[:truncate_length]

            for player_hop_1 in hop_1_list:
                hop_2_list = list(
                    labeled_to_player_map[labeled_node][player_hop_1].keys()
                )
                np.random.shuffle(hop_2_list)
                hop_2_list.remove(player_hop_1)
                hop_2_list.insert(0, player_hop_1)

                truncate_length = (
                    int(np.ceil((len(hop_2_list) - 1) * (1 - group_trunc_ratio_hop_2)))
                    + 1
                )
                truncate_length = min(truncate_length, len(hop_2_list))
                hop_2_list = hop_2_list[:truncate_length]

                if counter + len(hop_2_list) > 50000:
                    break

                perm.setdefault(labeled_node, {})[player_hop_1] = hop_2_list
                counter += len(hop_2_list)

        perms.append(perm)

    return perms


def generate_computation_tree(
    adj,
    labeled_node_list,
    label_trunc_ratio,
    group_trunc_ratio_hop_1,
    group_trunc_ratio_hop_2,
    target_subsets,
):
    perm = {}
    np.random.shuffle(labeled_node_list)

    trunc_label_len = int(np.ceil(len(labeled_node_list) * (1 - label_trunc_ratio)))
    subsets_count = 0
    for labeled_node in labeled_node_list[:trunc_label_len]:
        if target_subsets <= subsets_count:
            break
        # get the one hop neighbors
        one_hop_neighbors = adj[labeled_node].nonzero()[1].tolist()

        # shuffle the one hop neighbors
        np.random.shuffle(one_hop_neighbors)

        # set the labeled node to the first position
        if labeled_node in one_hop_neighbors:
            one_hop_neighbors.remove(labeled_node)
        one_hop_neighbors.insert(0, labeled_node)

        # truncate the one hop neighbors
        truncate_length = (
            int(np.ceil((len(one_hop_neighbors) - 1) * (1 - group_trunc_ratio_hop_1)))
            + 1
        )
        truncate_length = min(truncate_length, len(one_hop_neighbors))
        one_hop_neighbors = one_hop_neighbors[:truncate_length]

        for player_hop_1 in one_hop_neighbors:
            if target_subsets <= subsets_count:
                break

            # get the two hop neighbors
            sub_neighbors = adj[player_hop_1].nonzero()[1].tolist()

            # shuffle
            np.random.shuffle(sub_neighbors)

            # set the player_hop_1 to the first position
            if player_hop_1 in sub_neighbors:
                sub_neighbors.remove(player_hop_1)
            sub_neighbors.insert(0, player_hop_1)

            # truncate
            truncate_length = (
                int(np.ceil((len(sub_neighbors) - 1) * (1 - group_trunc_ratio_hop_2)))
                + 1
            )

            truncate_length = min(
                truncate_length,
                len(sub_neighbors),
            )
            if target_subsets - subsets_count > 0:
                truncate_length = min(truncate_length, target_subsets - subsets_count)
            sub_neighbors = sub_neighbors[:truncate_length]

            perm.setdefault(labeled_node, {})[player_hop_1] = sub_neighbors
            subsets_count += len(sub_neighbors)

    return perm, subsets_count


def compute_single_perm(
    perm,
    n_samples,
    cfg,
    data,
):
    sub_test_accs_pc = np.zeros(n_samples)
    sub_val_accs_pc = np.zeros(n_samples)
    test_accs_pc = np.zeros(n_samples)
    val_accs_pc = np.zeros(n_samples)
    margins_pc = np.zeros((n_samples, data.num_nodes))
    sub_margins_pc = np.zeros((n_samples, data.num_nodes))

    cur_labeled_node_list = []

    prev_sub_test_accs = 0
    prev_sub_val_acc = 0
    prev_test_accs = 0
    prev_val_accs = 0
    prev_margins = np.zeros(data.num_nodes)
    prev_sub_margins = np.zeros(data.num_nodes)

    for labeled_node in perm.keys():
        cur_hop_1_list = []

        for player_hop_1 in perm[labeled_node].keys():
            cur_hop_2_list = []
            cur_hop_1_list += [player_hop_1]

            assert len(perm[labeled_node][player_hop_1]) > 0
            for player_hop_2 in perm[labeled_node][player_hop_1]:
                cur_hop_2_list += [player_hop_2]

                subset = torch.zeros(n_samples, dtype=torch.bool)
                subset[labeled_node] = True
                subset[cur_labeled_node_list] = True
                subset[cur_hop_1_list] = True
                subset[cur_hop_2_list] = True

                (_, sub_logits, sub_y, logits) = train_subset(
                    subset_idx=None,
                    subset=subset,
                    cfg=cfg,
                    data=data,
                    logits_on_data=True,
                )

                margins = (
                    get_margin_incorrect_vectorized(logits.unsqueeze(0), data.y)[0]
                    .numpy()
                    .mean(0)
                )

                test_accs = (margins > 0)[data.test_mask].mean().item()
                val_accs = (margins > 0)[data.val_mask].mean().item()

                test_accs_pc[player_hop_1] = test_accs - prev_test_accs
                if player_hop_1 != player_hop_2 and player_hop_2 != labeled_node:
                    test_accs_pc[player_hop_2] += test_accs - prev_test_accs
                prev_test_accs = test_accs

                val_accs_pc[player_hop_1] = val_accs - prev_val_accs
                if player_hop_1 != player_hop_2 and player_hop_2 != labeled_node:
                    val_accs_pc[player_hop_2] += val_accs - prev_val_accs
                prev_val_accs = val_accs

                margins_pc[player_hop_1] = margins - prev_margins
                if player_hop_1 != player_hop_2 and player_hop_2 != labeled_node:
                    margins_pc[player_hop_2] += margins - prev_margins
                prev_margins = margins

                sub_margins = np.nanmean(
                    get_margin_incorrect_vectorized(sub_logits.unsqueeze(0), sub_y)[
                        0
                    ].numpy(),
                    axis=0,
                )
                sub_margins_ma = np.ma.masked_invalid(sub_margins)
                sub_test_margins = sub_margins_ma[data.test_mask]
                sub_val_margins = sub_margins_ma[data.val_mask]

                sub_test_accs = (sub_test_margins > 0).mean()
                sub_val_accs = (sub_val_margins > 0).mean()

                sub_test_accs_pc[player_hop_1] = sub_test_accs - prev_sub_test_accs
                if player_hop_1 != player_hop_2 and player_hop_2 != labeled_node:
                    sub_test_accs_pc[player_hop_2] += sub_test_accs - prev_sub_test_accs
                prev_sub_test_accs = sub_test_accs

                sub_val_accs_pc[player_hop_1] = sub_val_accs - prev_sub_val_acc
                if player_hop_1 != player_hop_2 and player_hop_2 != labeled_node:
                    sub_val_accs_pc[player_hop_2] += sub_val_accs - prev_sub_val_acc
                prev_sub_val_acc = sub_val_accs

                sub_margins = np.where(
                    np.isnan(sub_margins),
                    prev_sub_margins,
                    sub_margins,
                )
                sub_margins_pc[player_hop_1] = sub_margins - prev_sub_margins
                if player_hop_1 != player_hop_2 and player_hop_2 != labeled_node:
                    sub_margins_pc[player_hop_2] += sub_margins - prev_sub_margins
                prev_sub_margins = sub_margins

        cur_labeled_node_list += [labeled_node]
        subset = torch.zeros(n_samples, dtype=torch.bool)
        subset[cur_labeled_node_list] = True
        (_, sub_logits, sub_y, logits) = train_subset(
            subset_idx=None,
            subset=subset,
            cfg=cfg,
            data=data,
            logits_on_data=True,
        )

        margins = (
            get_margin_incorrect_vectorized(logits.unsqueeze(0), data.y)[0]
            .numpy()
            .mean(0)
        )

        test_accs = (margins > 0)[data.test_mask].mean().item()
        val_accs = (margins > 0)[data.val_mask].mean().item()

        sub_margins = np.nanmean(
            get_margin_incorrect_vectorized(sub_logits.unsqueeze(0), sub_y)[0].numpy(),
            axis=0,
        )
        sub_margins_ma = np.ma.masked_invalid(sub_margins)
        sub_test_margins = sub_margins_ma[data.test_mask]
        sub_val_margins = sub_margins_ma[data.val_mask]

        sub_test_accs = (sub_test_margins > 0).mean()
        sub_val_accs = (sub_val_margins > 0).mean()

        sub_margins = np.where(
            np.isnan(sub_margins),
            prev_sub_margins,
            sub_margins,
        )

        prev_margins = margins
        prev_sub_margins = sub_margins
        prev_test_accs = test_accs
        prev_val_accs = val_accs
        prev_sub_test_accs = sub_test_accs
        prev_sub_val_acc = sub_val_accs

    return (
        sub_test_accs_pc,
        sub_val_accs_pc,
        test_accs_pc,
        val_accs_pc,
        margins_pc,
        sub_margins_pc,
    )


def compute_pc_winter(
    perms,
    n_samples,
    cfg,
    data,
    n_jobs=1,
):
    cfg.data.subset_mode = "mixed"  # pc-winter computes values independently of the subset mode, the only difference is that at the end only train/val/test/mixed nodes are returned

    sub_test_accs_pc = np.zeros(n_samples)
    sub_val_accs_pc = np.zeros(n_samples)
    test_accs_pc = np.zeros(n_samples)
    val_accs_pc = np.zeros(n_samples)
    margins_pc = np.zeros((n_samples, data.num_nodes))
    sub_margins_pc = np.zeros((n_samples, data.num_nodes))

    n_cpus = os.cpu_count()
    n_jobs = n_cpus + 1 + n_jobs if n_jobs < 0 else n_jobs
    perms_batches = [perms[i : i + n_jobs] for i in range(0, len(perms), n_jobs)]
    pbar = tqdm(perms_batches, total=len(perms))

    for batch in pbar:
        # process the nodes in batches to save memory
        results = Parallel(n_jobs=n_jobs)(
            delayed(compute_single_perm)(
                perm,
                n_samples,
                cfg,
                data,
            )
            for perm in batch
        )

        # after each batch, update the weights and intercepts and release the memory from `param: results`
        for result in results:
            sub_test_accs_pc += result[0] / len(perms)
            sub_val_accs_pc += result[1] / len(perms)
            test_accs_pc += result[2] / len(perms)
            val_accs_pc += result[3] / len(perms)
            margins_pc += result[4] / len(perms)
            sub_margins_pc += result[5] / len(perms)

        pbar.update(len(batch))
    pbar.close()

    return (
        sub_test_accs_pc,
        sub_val_accs_pc,
        test_accs_pc,
        val_accs_pc,
        margins_pc.T,
        sub_margins_pc.T,
    )


# def compute_pc_winter(
#     perms,
#     n_samples,
#     cfg,
#     data,
#     n_jobs=1,
# ):
#     cfg.data.subset_mode = "mixed"

#     sub_test_accs_pc = np.zeros(n_samples)
#     sub_val_accs_pc = np.zeros(n_samples)
#     test_accs_pc = np.zeros(n_samples)
#     val_accs_pc = np.zeros(n_samples)
#     margins_pc = np.zeros((n_samples, data.num_nodes))
#     sub_margins_pc = np.zeros((n_samples, data.num_nodes))

#     # Process each permutation individually
#     for perm in tqdm(perms):
#         result = compute_single_perm(perm, n_samples, cfg, data)
#         sub_test_accs_pc += result[0] / len(perms)
#         sub_val_accs_pc += result[1] / len(perms)
#         test_accs_pc += result[2] / len(perms)
#         val_accs_pc += result[3] / len(perms)
#         margins_pc += result[4] / len(perms)
#         sub_margins_pc += result[5] / len(perms)

#     return (
#         sub_test_accs_pc,
#         sub_val_accs_pc,
#         test_accs_pc,
#         val_accs_pc,
#         margins_pc.T,
#         sub_margins_pc.T,
#     )
