import torch
from torch_geometric.utils.num_nodes import maybe_num_nodes
import re
import glob
import os
import pickle

from torch_geometric.graphgym.config import cfg
import torch.nn.functional as F


# def calc_k_hops(
#     node_idx,
#     num_hops,
#     edge_index,
#     hop_cutoff=100,
#     get_undirected_hops=False,
#     num_nodes=None,
# ):
#     num_nodes = maybe_num_nodes(edge_index, num_nodes)

#     if get_undirected_hops:
#         edge_index_reversed = edge_index.flip([0])
#         edge_index_undirected = torch.cat([edge_index, edge_index_reversed], dim=1)
#         edge_index_undirected = torch.unique(edge_index_undirected, dim=1)
#         edge_index_undirected = edge_index_undirected[
#             :, edge_index_undirected[0] != edge_index_undirected[1]
#         ]
#         edge_index = edge_index_undirected
#     col, row = edge_index
#     node_mask = row.new_empty(num_nodes, dtype=torch.bool)
#     edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
#     k_hop_neigh_idx = []
#     max_size = 0
#     for node in node_idx:
#         # node = torch.tensor(node, device=row.device)
#         node_mask = row.new_empty(num_nodes, dtype=torch.bool)
#         k_hop_neigh = []
#         for _ in range(num_hops):
#             node_mask.fill_(False)
#             node_mask[node] = True
#             torch.index_select(node_mask, 0, row, out=edge_mask)
#             k_hop_neigh.append(col[edge_mask])
#         k_hop_neigh, _ = torch.cat(k_hop_neigh).unique(return_inverse=True)
#         perm = torch.randperm(len(k_hop_neigh))
#         k_hop_neigh = k_hop_neigh[perm]
#         k_hop_neigh = k_hop_neigh[0:hop_cutoff]
#         max_size = max(max_size, len(k_hop_neigh))
#         k_hop_neigh += 1

#         k_hop_neigh_idx.append(k_hop_neigh)

#     return k_hop_neigh_idx, max_size


from torch_geometric.utils import to_undirected, subgraph, k_hop_subgraph


# def calc_k_hops(
#     node_idx,
#     num_hops,
#     edge_index,
#     hop_cutoff=100,
#     get_undirected_hops=False,
#     num_nodes=None,
# ):
#     # Convert edge_index to undirected, if required
#     node_idx = torch.tensor(node_idx, dtype=torch.long)
#     if not torch.is_tensor(edge_index):
#         edge_index = torch.tensor(edge_index, dtype=torch.long)

#     # Convert edge_index to undirected, if required
#     if get_undirected_hops:
#         edge_index = to_undirected(edge_index, num_nodes=num_nodes)

#     # Setup NeighborSampler for 1-hop neighborhood sampling
#     sampler = NeighborSampler(
#         edge_index=edge_index,
#         sizes=[
#             hop_cutoff
#         ],  # [hop_cutoff] controls the maximum number of neighbors to sample
#         node_idx=node_idx,  # Specify nodes for which to sample neighbors
#         batch_size=node_idx.size(0),  # Process all specified nodes at once
#         shuffle=False,
#         num_nodes=num_nodes if num_nodes is not None else edge_index.max().item() + 1,
#     )

#     k_hop_neigh_idx = []
#     max_size = 0

#     # Iterate through the NeighborSampler
#     for batch_size, n_id, adjs in sampler:
#         # n_id contains the node indices in the sampled neighborhoods, including the input nodes
#         for node, adj in zip(node_idx, adjs):
#             _, edge_index_sub, _, _ = adj  # Get the sampled adjacency information
#             subset = edge_index_sub[1]  # Target nodes of the edges are the neighbors

#             # Apply hop cutoff if necessary (already applied via sizes in NeighborSampler)
#             # Random shuffling could be done if required, but it's not due to NeighborSampler's handling

#             # Adjust indices if necessary, based on your logic
#             subset += 1  # Uncomment if you need to adjust indices

#             # Update max_size
#             max_size = max(max_size, subset.size(0))
#             k_hop_neigh_idx.append(subset)

#     return k_hop_neigh_idx, max_size


from torch_geometric.loader import NeighborSampler
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from torch_geometric.loader import NeighborSampler
from torch_geometric.loader import GraphSAINTRandomWalkSampler


class GraphSAINTRandomWalkSampler_custom(GraphSAINTRandomWalkSampler):

    def sample_neighbours(self, node_idx):
        start = node_idx
        node_idx = self.adj.random_walk(start.flatten(), self.walk_length)
        return node_idx.view(-1)


def calc_k_hops(
    node_idx,
    num_hops,
    data_obj,
    # edge_index,
    hop_cutoff=100,
    get_undirected_hops=False,
    num_nodes=None,
):
    # Convert edge_index to undirected, if required
    # data = Data(edge_index=edge_index)

    # Convert edge_index to undirected, if required
    # if get_undirected_hops:
    #     data.edge_index = to_undirected(data.edge_index, num_nodes=num_nodes)

    # Define the sizes for NeighborSampler: [hop_cutoff] * num_hops to control the sampling at each layer
    # sizes = [hop_cutoff] * num_hops  # Adjusted for num_hops

    # Initialize NeighborSampler
    # sampler = NeighborSampler(
    #     edge_index=edge_index,
    #     sizes=sizes,
    #     node_idx=node_idx,  # Directly use node_idx here
    #     batch_size=len(node_idx),  # Or another suitable batch size for your dataset
    #     shuffle=False,
    #     num_nodes=num_nodes,
    # )

    # k_hop_neigh_idx = []
    # max_size = 0

    # Iterate through the NeighborSampler
    # for batch_size, n_id, adjs in sampler:
    #     # n_id contains the node indices in the sampled neighborhoods
    #     # No need to filter by node_idx_tensor; NeighborSampler handles this based on initialization

    #     # Adjust indices if necessary (e.g., increment by 1 if your indexing starts from 1)
    #     subset = n_id + 1
    #     print(edge_index.shape)

    #     # Update max_size based on the actual sampled neighborhood size
    #     max_size = max(max_size, len(subset))
    #     k_hop_neigh_idx.append(subset)

    # return k_hop_neigh_idx, max_size

    # sampler = GraphSAINTRandomWalkSampler_custom(
    #     data_obj,
    #     batch_size=hop_cutoff,  # Controls the size of the subgraph
    #     walk_length=hop_cutoff,  # Length of the random walks
    #     num_steps=5,  # Number of steps (subgraphs) to sample
    #     sample_coverage=0,  # Set to 0 for no specific coverage; adjust based on needs
    #     save_dir=None,  # Temporary directory to save the walks
    # )
    k_hop_neigh_idx = []
    max_size = 0

    # Since GraphSAINTSampler samples subgraphs, the usage here is different from NeighborSampler
    for node_id in node_idx:
        # Sub_data contains the sampled subgraph for this iteration
        # Use sub_data.node_idx to get the node indices in the sampled subgraph
        subset = data_obj.sample_neighbours(node_id)

        # Update max_size based on the actual sampled neighborhood size
        max_size = max(max_size, subset.size(0))
        k_hop_neigh_idx.append(subset)  # Example: collecting numpy arrays of indices

    # Note: This function's behavior and return values have changed due to the differences in sampling methods
    return k_hop_neigh_idx, max_size


def pad_hops(k_hop_neigh_idx, max_size):
    # Pad the idx tensors - added 0 idx for padding remember to reserve node 0 for padded features

    # cal max size from the array
    seq_mask = []
    for i in range(len(k_hop_neigh_idx)):
        seq_mask.append(
            k_hop_neigh_idx[i].new_ones(k_hop_neigh_idx[i].shape[0], dtype=torch.bool)
        )
        pad_size = max_size - k_hop_neigh_idx[i].shape[0]
        k_hop_neigh_idx[i] = F.pad(k_hop_neigh_idx[i], (0, pad_size), "constant", 0)
        mask = k_hop_neigh_idx[i].new_zeros(pad_size, dtype=torch.bool)
        seq_mask[-1] = torch.cat([seq_mask[-1], mask], dim=0)

    k_hop_neigh_idx = torch.stack(k_hop_neigh_idx, dim=0)
    seq_mask = torch.stack(seq_mask).unsqueeze(0)

    k_hop_neigh_idx = torch.tensor(k_hop_neigh_idx, dtype=torch.long)

    return k_hop_neigh_idx, seq_mask


def pad_hops_dict(k_hop_neigh_idx_dict, max_size):
    seq_mask_dict = {}
    for key in k_hop_neigh_idx_dict:
        seq_mask = []
        for i in range(len(k_hop_neigh_idx_dict[key])):
            seq_mask.append(
                k_hop_neigh_idx_dict[key][i].new_ones(
                    k_hop_neigh_idx_dict[key][i].shape[0], dtype=torch.bool
                )
            )
            pad_size = max_size - k_hop_neigh_idx_dict[key][i].shape[0]
            k_hop_neigh_idx_dict[key][i] = torch.nn.functional.pad(
                k_hop_neigh_idx_dict[key][i], (0, pad_size), "constant", 0
            )
            mask = k_hop_neigh_idx_dict[key][i].new_zeros(pad_size, dtype=torch.bool)
            seq_mask[-1] = torch.cat([seq_mask[-1], mask], dim=0)
        seq_mask_dict[key] = torch.stack(seq_mask)
        k_hop_neigh_idx_dict[key] = torch.stack(k_hop_neigh_idx_dict[key], dim=0)

    return k_hop_neigh_idx_dict, seq_mask_dict


def hops_multi(
    main_dict,
    subset_keys,
    dataset_name_to_index,
    node_dataset_id,
    node_idx,
    num_hops,
    hop_cutoff=100,
    get_undirected_hops=False,
):
    k_hop_neigh_idx_dict = {}
    max_size = 0
    for key in subset_keys:
        k_hop_neigh_idx_dataset = calc_k_hops(
            node_idx[node_dataset_id == dataset_name_to_index[key]],
            num_hops,
            main_dict[key],
            # main_dict[key].edge_index,
            hop_cutoff,
            get_undirected_hops,
            num_nodes=None,
            # num_nodes=main_dict[key].num_nodes,
        )
        k_hop_neigh_idx_dict[key], max_size_node = k_hop_neigh_idx_dataset
        max_size = max(max_size, max_size_node)
    k_hop_neigh_idx_pad_dict, seq_mask_dict = pad_hops_dict(
        k_hop_neigh_idx_dict, max_size
    )

    k_hop_neigh_idx = torch.zeros((len(node_idx), max_size), dtype=torch.long)
    seq_mask = torch.zeros((len(node_idx), max_size), dtype=torch.bool)
    for key in subset_keys:
        k_hop_neigh_idx[node_dataset_id == dataset_name_to_index[key]] = (
            k_hop_neigh_idx_pad_dict[key]
        )
        seq_mask[node_dataset_id == dataset_name_to_index[key]] = seq_mask_dict[key]

    return k_hop_neigh_idx, seq_mask


def extract_eig_freq(filename):
    match = re.search("eigen_([^_]+)_", filename)
    if match:
        return int(match.group(1))
    return float("inf")


# def check_and_load_processed_eig(dataset_dir, dataset_name, dataset_dir_with_err):
#     pecfg = find_pos_cfg()
#     all_dataset_eig_files = glob.glob(
#         os.path.join(dataset_dir, f"{dataset_name}_eigen_*_processed.pt")
#     )
#     all_dataset_eig_files = [
#         file for file in all_dataset_eig_files if "num_cluster" not in file
#     ]
#     if len(all_dataset_eig_files) == 0:
#         # check in folder with error allowed
#         all_dataset_eig_files = glob.glob(
#             os.path.join(dataset_dir_with_err, f"{dataset_name}_eigen_*_processed.pt")
#         )
#         all_dataset_eig_files = [
#             file for file in all_dataset_eig_files if "num_cluster" not in file
#         ]
#         dataset_dir = dataset_dir_with_err

#     if len(all_dataset_eig_files) != 0:
#         max_eig_file = sorted(
#             all_dataset_eig_files, key=extract_eig_freq, reverse=True
#         )[0]
#         max_eig_ava = int(extract_eig_freq(max_eig_file))
#     else:
#         max_eig_ava = pecfg.eigen.max_freqs

#     if max_eig_ava >= pecfg.eigen.max_freqs:

#         processed_path = os.path.join(
#             dataset_dir,
#             f"{dataset_name}_eigen_{max_eig_ava}_processed.pt",
#         )
#         print(processed_path)
#         if os.path.exists(processed_path):
#             # Adjust the loading mechanism based on how you saved the dataset
#             with open(processed_path, "rb") as f:
#                 dataset = pickle.load(f)
#             print(f"Loaded processed dataset from {processed_path}")

#             dataset.data.eigvecs_sn = dataset.data.eigvecs_sn[
#                 :, : pecfg.eigen.max_freqs
#             ]
#             dataset.data.eigvals_sn = dataset.data.eigvals_sn[
#                 :, : pecfg.eigen.max_freqs, :
#             ]

#             return dataset
#     return None


def check_and_load_processed_eig(
    dataset_dir, dataset_name, dataset_dir_with_err, partiton=None
):
    all_dataset_eig_files = glob.glob(
        os.path.join(dataset_dir, f"{dataset_name}*_eigen_*_processed.pt")
    )
    all_dataset_eig_files = [
        file for file in all_dataset_eig_files if "num_cluster" not in file
    ]
    if partiton == None:
        all_dataset_eig_files = [
            file for file in all_dataset_eig_files if "num_part" not in file
        ]
    else:
        all_dataset_eig_files = [
            file for file in all_dataset_eig_files if f"num_part_{partiton}" in file
        ]
    if len(all_dataset_eig_files) == 0:
        # check in folder with error allowed
        all_dataset_eig_files = glob.glob(
            os.path.join(dataset_dir_with_err, f"{dataset_name}*_eigen_*_processed.pt")
        )
        all_dataset_eig_files = [
            file for file in all_dataset_eig_files if "num_cluster" not in file
        ]

        if partiton == None:
            all_dataset_eig_files = [
                file for file in all_dataset_eig_files if "num_part" not in file
            ]
        else:
            all_dataset_eig_files = [
                file for file in all_dataset_eig_files if f"num_part_{partiton}" in file
            ]

        dataset_dir = dataset_dir_with_err

    if len(all_dataset_eig_files) != 0:
        max_eig_file = sorted(
            all_dataset_eig_files, key=extract_eig_freq, reverse=True
        )[0]
        max_eig_ava = int(extract_eig_freq(max_eig_file))
    else:
        max_eig_ava = 32

    if max_eig_ava >= 32:
        if partiton == None:
            processed_path = os.path.join(
                dataset_dir,
                f"{dataset_name}_eigen_{max_eig_ava}_processed.pt",
            )

            # print(processed_path)
            if os.path.exists(processed_path):
                # Adjust the loading mechanism based on how you saved the dataset
                with open(processed_path, "rb") as f:
                    dataset = pickle.load(f)
                print(f"Loaded processed dataset from {processed_path}")

                dataset.data.eigvecs_sn = dataset.data.eigvecs_sn[:, :32]
                dataset.data.eigvals_sn = dataset.data.eigvals_sn[:, :32, :]

                return dataset
        else:
            processed_path = os.path.join(
                dataset_dir,
                f"{dataset_name}_num_part_{partiton}_eigen_{max_eig_ava}_processed.pt",
            )

            # print(processed_path)
            if os.path.exists(processed_path):
                # Adjust the loading mechanism based on how you saved the dataset
                with open(processed_path, "rb") as f:
                    dataset = pickle.load(f)
                print(f"Loaded processed dataset from {processed_path}")

                dataset.eigvecs_sn = dataset.eigvecs_sn[:, :32]
                dataset.eigvals_sn = dataset.eigvals_sn[:, :32, :]

                return dataset

    return None


def check_and_load_processed_cluster(dataset_dir, dataset_name, dataset_dir_with_err):
    pecfg = find_pos_cfg()
    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{pecfg.eigen.max_freqs}_num_cluster_{cfg.model.num_latents}_processed.pt",
    )
    processed_path_with_err = os.path.join(
        dataset_dir_with_err,
        f"{dataset_name}_eigen_{pecfg.eigen.max_freqs}_num_cluster_{cfg.model.num_latents}_processed.pt",
    )
    if os.path.exists(processed_path):
        # Adjust the loading mechanism based on how you saved the dataset
        with open(processed_path, "rb") as f:
            dataset = pickle.load(f)
        print(f"Loaded processed dataset from {processed_path}")
        return dataset
    elif os.path.exists(processed_path_with_err):
        with open(processed_path_with_err, "rb") as f:
            dataset = pickle.load(f)
        print(f"Loaded processed dataset from {processed_path_with_err}")
        return dataset

    return None


def save_processed_eig(dataset, dataset_dir, dataset_name):
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)

    pecfg = find_pos_cfg()
    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{pecfg.eigen.max_freqs}_processed.pt",
    )
    # Adjust the saving mechanism based on your dataset's compatibility
    with open(processed_path, "wb") as f:
        pickle.dump(dataset, f)
    print(f"Saved processed dataset to {processed_path}")


def save_processed_cluster(dataset, dataset_dir, dataset_name):
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)
    pecfg = find_pos_cfg()
    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{pecfg.eigen.max_freqs}_num_cluster_{cfg.model.num_latents}_processed.pt",
    )
    # Adjust the saving mechanism based on your dataset's compatibility
    with open(processed_path, "wb") as f:
        pickle.dump(dataset, f)
    print(f"Saved processed dataset to {processed_path}")


def find_pos_cfg():
    for name in [
        "posenc_LapPE",
        "posenc_SignNet",
        "posenc_RotPE",
    ]:
        pecfg = getattr(cfg, name)
        if pecfg.enable == True:
            return pecfg

    raise ValueError("No valid cfg enabled")
