import os
import pickle
import glob
import logging
import os.path as osp
import time
from functools import partial


import numpy as np
import torch
import torch_geometric
import torch_geometric.transforms as T
from numpy.random import default_rng

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import load_ogb, set_dataset_attr
from torch_geometric.graphgym.register import register_loader
import torch_geometric.graphgym.register as register

from torch_geometric.graphgym.loader import create_dataset, get_loader
from einops import repeat
from torch_geometric.loader import DataLoader
from custom_modules.loader.node_loader import NodeDataset

from custom_modules.loader.custom_loaders import *
from custom_modules.loader.utils import *
from custom_modules.loader.split_generator import prepare_splits, set_dataset_splits
from custom_modules.transform.posenc_stats import compute_posenc_stats
from custom_modules.transform.task_preprocessing import task_specific_preprocessing
from custom_modules.transform.transforms import (
    pre_transform_in_memory,
)

from torch.utils.data import DataLoader as torch_DataLoader
from torch.utils.data import ConcatDataset


from custom_modules.loader.synthetic_dataset import SyntheticDataset

from sklearn.cluster import KMeans
from multiprocessing import Manager


def log_loaded_dataset(dataset, format, name):
    logging.info(f"[*] Loaded dataset '{name}' from '{format}':")
    logging.info(f"  {dataset.data}")
    logging.info(f"  undirected: {dataset[0].is_undirected()}")
    logging.info(f"  num graphs: {len(dataset)}")

    total_num_nodes = 0
    if hasattr(dataset.data, "num_nodes"):
        total_num_nodes = dataset.data.num_nodes
    elif hasattr(dataset.data, "x"):
        total_num_nodes = dataset.data.x.size(0)
    logging.info(f"  avg num_nodes/graph: " f"{total_num_nodes // len(dataset)}")
    logging.info(f"  num node features: {dataset.num_node_features}")
    logging.info(f"  num edge features: {dataset.num_edge_features}")
    if hasattr(dataset, "num_tasks"):
        logging.info(f"  num tasks: {dataset.num_tasks}")

    if hasattr(dataset.data, "y") and dataset.data.y is not None:
        if isinstance(dataset.data.y, list):
            # A special case for ogbg-code2 dataset.
            logging.info(f"  num classes: n/a")
        elif dataset.data.y.numel() == dataset.data.y.size(
            0
        ) and torch.is_floating_point(dataset.data.y):
            logging.info(f"  num classes: (appears to be a regression task)")
        else:
            logging.info(f"  num classes: {dataset.num_classes}")
    elif hasattr(dataset.data, "train_edge_label") or hasattr(
        dataset.data, "edge_label"
    ):
        # Edge/link prediction task.
        if hasattr(dataset.data, "train_edge_label"):
            labels = dataset.data.train_edge_label  # Transductive link task
        else:
            labels = dataset.data.edge_label  # Inductive link task
        if labels.numel() == labels.size(0) and torch.is_floating_point(labels):
            logging.info(f"  num edge classes: (probably a regression task)")
        else:
            logging.info(f"  num edge classes: {len(torch.unique(labels))}")

    ## Show distribution of graph sizes.
    # graph_sizes = [d.num_nodes if hasattr(d, 'num_nodes') else d.x.shape[0]
    #                for d in dataset]
    # hist, bin_edges = np.histogram(np.array(graph_sizes), bins=10)
    # logging.info(f'   Graph size distribution:')
    # logging.info(f'     mean: {np.mean(graph_sizes)}')
    # for i, (start, end) in enumerate(zip(bin_edges[:-1], bin_edges[1:])):
    #     logging.info(
    #         f'     bin {i}: [{start:.2f}, {end:.2f}]: '
    #         f'{hist[i]} ({hist[i] / hist.sum() * 100:.2f}%)'
    #     )


def load_synthetic_data(dataset_dir, num_samples="all"):
    syn_graph_files = glob.glob(os.path.join(dataset_dir, "*.pkl"))
    if num_samples != -1:
        rng = default_rng(cfg.seed)
        syn_graph_files = rng.choice(syn_graph_files, num_samples, replace=False)
    node_dataset_list = []
    graph_dataset_dict = {}
    syn_graph_config = []
    for graph_file_name in syn_graph_files:
        with open(graph_file_name, "rb") as f:
            dataset = pickle.load(f)
        # dataset_list.append(dataset)
        graph_name = os.path.basename(graph_file_name).split(".")[0]
        graph_dataset_name = graph_file_name.split(".")[0]
        dataset_final = check_and_load_processed_eig(dataset_dir, graph_name)

        if dataset_final is None:

            dataset = SyntheticDataset("placeholder", [dataset])
            dataset.data.x = dataset.data.x.to(torch.float32)
            graph_dataset_dict[graph_dataset_name] = dataset.data
            process_synthetic_data(dataset, dataset_dir, graph_name)
            syn_graph_config.append(
                {
                    "dataset_name": graph_dataset_name,
                    "task": "node",
                    "task_type": "classification",
                    "loss_fun": "cross_entropy",
                    "task_dim": dataset.data.config["num_clusters"],
                    "feat_dim": dataset.data.config["feature_dim"],
                    "node_feat_encoder_name": "LinearNode",
                }
            )
            dataset = NodeDataset(dataset.data, mask=None)
            node_dataset_list.append(dataset)
        else:

            dataset_loaded_clustered = check_and_load_processed_cluster(
                dataset_dir, graph_name
            )
            dataset_loaded_clustered = None
            if dataset_loaded_clustered is not None:
                dataset_final = dataset_loaded_clustered
            else:
                if cfg.model.num_latents > 0:
                    kmeans = KMeans(
                        n_clusters=cfg.model.num_latents,
                        n_init="auto",
                        random_state=cfg.seed,
                    )
                    kmeans.fit(dataset_final.data.eigvecs_sn)
                    pos_latents = kmeans.cluster_centers_
                    pos_latents = torch.tensor(pos_latents, dtype=torch.float32)
                    pos_latents = pos_latents / pos_latents.norm(dim=1, keepdim=True)
                    dataset_final.data.pos_latents = pos_latents

            dataset_final.data.x = dataset_final.data.x.to(torch.float32)
            graph_dataset_dict[graph_dataset_name] = dataset_final.data
            syn_graph_config.append(
                {
                    "dataset_name": graph_dataset_name,
                    "task": "node",
                    "task_type": "classification",
                    "loss_fun": "cross_entropy",
                    "task_dim": dataset_final.data.config["num_clusters"],
                    "feat_dim": dataset_final.data.config["feature_dim"],
                    "node_feat_encoder_name": "LinearNode",
                }
            )
            dataset_final = NodeDataset(dataset_final.data, mask=None)
            node_dataset_list.append(dataset_final)
    return node_dataset_list, graph_dataset_dict, syn_graph_config


def process_synthetic_data(dataset, dataset_dir, graph_name):
    pre_transform_in_memory(dataset, partial(task_specific_preprocessing, cfg=cfg))
    pe_enabled_list = []
    for key, pecfg in cfg.items():
        if key.startswith("posenc_") and pecfg.enable:
            pe_name = key.split("_", 1)[1]
            pe_enabled_list.append(pe_name)
            if hasattr(pecfg, "kernel"):
                # Generate kernel times if functional snippet is set.
                if pecfg.kernel.times_func:
                    pecfg.kernel.times = list(eval(pecfg.kernel.times_func))
                logging.info(
                    f"Parsed {pe_name} PE kernel times / steps: "
                    f"{pecfg.kernel.times}"
                )
    if pe_enabled_list:
        start = time.perf_counter()
        logging.info(
            f"Precomputing Positional Encoding statistics: "
            f"{pe_enabled_list} for all graphs..."
        )
        # Estimate directedness based on 10 graphs to save time.
        is_undirected = all(d.is_undirected() for d in dataset[:10])
        logging.info(f"  ...estimated to be undirected: {is_undirected}")
        pre_transform_in_memory(
            dataset,
            partial(
                compute_posenc_stats,
                pe_types=pe_enabled_list,
                is_undirected=is_undirected,
                cfg=cfg,
            ),
            show_progress=True,
        )
        elapsed = time.perf_counter() - start
        timestr = (
            time.strftime("%H:%M:%S", time.gmtime(elapsed)) + f"{elapsed:.2f}"[-3:]
        )
        logging.info(f"Done! Took {timestr}")
        if cfg.model.num_latents > 0:
            kmeans = KMeans(
                n_clusters=cfg.model.num_latents, n_init="auto", random_state=cfg.seed
            )
            kmeans.fit(dataset.data.eigvecs_sn)
            pos_latents = kmeans.cluster_centers_
            pos_latents = torch.tensor(pos_latents, dtype=torch.float32)
            pos_latents = pos_latents / pos_latents.norm(dim=1, keepdim=True)
            dataset.data.pos_latents = pos_latents
        dataset_graph_name = os.path.join(dataset_dir, f"{graph_name}")
        set_dataset_attr(
            dataset,
            "dataset_name",
            [dataset_graph_name] * len(dataset),
            len(dataset),
        )

        set_dataset_attr(
            dataset,
            "dataset_task_name",
            [f"{dataset_graph_name}_{'node'}_{'classification'}"] * len(dataset),
            len(dataset),
        )

        set_dataset_attr(
            dataset,
            "node_id",
            torch.tensor(list(range(len(dataset.data.y))), dtype=torch.long),
            len(dataset),
        )
        save_processed_eig(dataset, dataset_dir, f"{graph_name}")
        save_processed_cluster(dataset, dataset_dir, f"{graph_name}")
    return dataset


@register_loader("custom_master_loader")
def load_dataset_master(format, name, dataset_dir, data_cfg=None):
    """
    Master loader that controls loading of all datasets, overshadowing execution
    of any default GraphGym dataset loader. Default GraphGym dataset loader are
    instead called from this function, the format keywords `PyG` and `OGB` are
    reserved for these default GraphGym loaders.

    Custom transforms and dataset splitting is applied to each loaded dataset.

    Args:
        format: dataset format name that identifies Dataset class
        name: dataset name to select from the class identified by `format`
        dataset_dir: path where to store the processed dataset

    Returns:
        PyG dataset object with applied perturbation transforms and data splits
    """
    # Try to load the processed dataset first

    if format.startswith("PyG-"):
        pyg_dataset_id = format.split("-", 1)[1]
        dataset_dir = osp.join(dataset_dir, pyg_dataset_id)

        if pyg_dataset_id == "Actor":
            if name != "actor":
                raise ValueError(f"Actor class provides only one dataset.")
            dataset = Actor(dataset_dir)

        elif pyg_dataset_id == "GNNBenchmarkDataset":
            dataset = preformat_GNNBenchmarkDataset(dataset_dir, name)

        elif pyg_dataset_id == "MalNetTiny":
            dataset = preformat_MalNetTiny(dataset_dir, feature_set=name)

        elif pyg_dataset_id == "Planetoid":
            dataset = Planetoid(dataset_dir, name)

        elif pyg_dataset_id == "TUDataset":
            dataset = preformat_TUDataset(dataset_dir, name)
            # print(dataset)
            # for iter, batch in enumerate(dataset):
            #     print(batch)

        elif pyg_dataset_id == "WebKB":
            dataset = WebKB(dataset_dir, name)

        elif pyg_dataset_id == "WikipediaNetwork":
            if name == "crocodile":
                raise NotImplementedError(f"crocodile not implemented")
            dataset = WikipediaNetwork(dataset_dir, name, geom_gcn_preprocess=True)

        elif pyg_dataset_id == "ZINC":
            dataset = preformat_ZINC(dataset_dir, name)

        elif pyg_dataset_id == "AQSOL":
            dataset = preformat_AQSOL(dataset_dir, name)

        elif pyg_dataset_id == "VOCSuperpixels":
            dataset = preformat_VOCSuperpixels(
                dataset_dir, name, cfg.dataset.slic_compactness
            )

        elif pyg_dataset_id == "COCOSuperpixels":
            dataset = preformat_COCOSuperpixels(
                dataset_dir, name, cfg.dataset.slic_compactness
            )

        else:
            raise ValueError(f"Unexpected PyG Dataset identifier: {format}")

    # GraphGym default loader for Pytorch Geometric datasets
    elif format == "PyG":
        dataset = load_pyg(name, dataset_dir)

    elif format == "OGB":
        if name.startswith("ogbg"):
            dataset = preformat_OGB_Graph(dataset_dir, name.replace("_", "-"))

        elif name.startswith("ogbn"):
            dataset = preformat_OGB_Node(dataset_dir, name.replace("_", "-"))

        elif name.startswith("PCQM4Mv2-"):
            subset = name.split("-", 1)[1]
            dataset = preformat_OGB_PCQM4Mv2(dataset_dir, subset)

        elif name.startswith("peptides-"):
            dataset = preformat_Peptides(dataset_dir, name)

        ### Link prediction datasets.
        elif name.startswith("ogbl-"):
            # GraphGym default loader.
            dataset = load_ogb(name, dataset_dir)

            # OGB link prediction datasets are binary classification tasks,
            # however the default loader creates float labels => convert to int.
            def convert_to_int(ds, prop):
                tmp = getattr(ds.data, prop).int()
                set_dataset_attr(ds, prop, tmp, len(tmp))

            convert_to_int(dataset, "train_edge_label")
            convert_to_int(dataset, "val_edge_label")
            convert_to_int(dataset, "test_edge_label")

        elif name.startswith("PCQM4Mv2Contact-"):
            dataset = preformat_PCQM4Mv2Contact(dataset_dir, name)

        else:
            raise ValueError(f"Unsupported OGB(-derived) dataset: {name}")
    else:
        raise ValueError(f"Unknown data format: {format}")

    dataset_loaded = check_and_load_processed_eig(
        dataset_dir, name, data_cfg.dir_with_error
    )
    if dataset_loaded is not None:
        return dataset_loaded

        dataset_loaded_clustered = check_and_load_processed_cluster(
            dataset_dir, name, data_cfg.dir_with_error
        )

        if dataset_loaded_clustered is not None:
            return dataset_loaded_clustered
        else:
            # precomputing clusters for model

            # kmeans clustering
            if cfg.model.num_latents > 0:
                kmeans = KMeans(
                    n_clusters=cfg.model.num_latents,
                    n_init="auto",
                    random_state=cfg.seed,
                )
                kmeans.fit(dataset_loaded.data.eigvecs_sn)
                pos_latents = kmeans.cluster_centers_
                if (pos_latents == 0).sum() > 1000:
                    breakpoint()
                pos_latents = torch.tensor(pos_latents, dtype=torch.float32)
                pos_latents = pos_latents / pos_latents.norm(dim=1, keepdim=True)
                dataset_loaded.data.pos_latents = pos_latents

            # using clusterdata from pyg

            # from torch_geometric.data import ClusterData

            # graph_cluster = ClusterData(dataset_loaded.data, num_parts=cfg.model.num_latents, log=True)
            # cluster_list = [torch.mean(cluster.eigvecs_sn,0) for cluster in graph_cluster]
            # pos_latents = torch.stack(cluster_list,0)
            # pos_latents = pos_latents / pos_latents.norm(dim=1, keepdim=True)
            # dataset_loaded.data.pos_latents = pos_latents
            # save_processed_cluster(dataset_loaded, dataset_dir, name)
            return dataset_loaded
    pre_transform_in_memory(dataset, partial(task_specific_preprocessing, cfg=cfg))

    log_loaded_dataset(dataset, format, name)

    # Precompute necessary statistics for positional encodings.

    pe_enabled_list = []
    for key, pecfg in cfg.items():
        if key.startswith("posenc_") and pecfg.enable:
            pe_name = key.split("_", 1)[1]
            pe_enabled_list.append(pe_name)
            if hasattr(pecfg, "kernel"):
                # Generate kernel times if functional snippet is set.
                if pecfg.kernel.times_func:
                    pecfg.kernel.times = list(eval(pecfg.kernel.times_func))
                logging.info(
                    f"Parsed {pe_name} PE kernel times / steps: "
                    f"{pecfg.kernel.times}"
                )
    if pe_enabled_list:
        start = time.perf_counter()
        logging.info(
            f"Precomputing Positional Encoding statistics: "
            f"{pe_enabled_list} for all graphs..."
        )
        # Estimate directedness based on 10 graphs to save time.
        is_undirected = all(d.is_undirected() for d in dataset[:10])
        logging.info(f"  ...estimated to be undirected: {is_undirected}")
        pre_transform_in_memory(
            dataset,
            partial(
                compute_posenc_stats,
                pe_types=pe_enabled_list,
                is_undirected=is_undirected,
                cfg=cfg,
            ),
            show_progress=True,
        )
        elapsed = time.perf_counter() - start
        timestr = (
            time.strftime("%H:%M:%S", time.gmtime(elapsed)) + f"{elapsed:.2f}"[-3:]
        )
        logging.info(f"Done! Took {timestr}")

    # Set standard dataset train/val/test splits
    if hasattr(dataset, "split_idxs"):
        set_dataset_splits(dataset, dataset.split_idxs)
        delattr(dataset, "split_idxs")

    # precomputing clusters for model
    if cfg.model.num_latents > 0:
        kmeans = KMeans(
            n_clusters=cfg.model.num_latents, n_init="auto", random_state=cfg.seed
        )
        kmeans.fit(dataset.data.eigvecs_sn)
        pos_latents = kmeans.cluster_centers_
        pos_latents = torch.tensor(pos_latents, dtype=torch.float32)
        pos_latents = pos_latents / pos_latents.norm(dim=1, keepdim=True)
        dataset.data.pos_latents = pos_latents

    prepare_splits(dataset, data_cfg=data_cfg)
    if data_cfg is not None:

        set_dataset_attr(
            dataset,
            "dataset_name",
            [data_cfg.dataset_name] * len(dataset),
            len(dataset),
        )

        set_dataset_attr(
            dataset,
            "dataset_task_name",
            [f"{data_cfg.dataset_name}_{data_cfg.task}_{data_cfg.task_type}"]
            * len(dataset),
            len(dataset),
        )

        # add if task is node classification
        if data_cfg.task == "node":
            set_dataset_attr(
                dataset,
                "node_id",
                torch.tensor(list(range(len(dataset.data.y))), dtype=torch.long),
                len(dataset),
            )
    # Precompute in-degree histogram if needed for PNAConv.
    # if cfg.gt.layer_type.startswith('PNA') and len(cfg.gt.pna_degrees) == 0:
    #     cfg.gt.pna_degrees = compute_indegree_histogram(
    #         dataset[dataset.data['train_graph_index']])
    # print(f"Indegrees: {cfg.gt.pna_degrees}")
    # print(f"Avg:{np.mean(cfg.gt.pna_degrees)}")

    # Save the processed dataset for future use
    save_processed_eig(dataset, dataset_dir, name)
    save_processed_cluster(dataset, dataset_dir, name)

    return dataset


def get_loader(dataset, sampler, batch_size, shuffle=True):
    pw = cfg.num_workers > 0
    if sampler == "full_batch" or len(dataset) > 1:
        loader_train = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=cfg.num_workers,
            pin_memory=True,
            persistent_workers=pw,
        )

    else:
        raise NotImplementedError(f"'{sampler}' is not implemented")

    return loader_train


def load_dataset_from_pt(data_cfg, partition=None):
    dataset_dir = data_cfg.dir
    if data_cfg.format.startswith("PyG-"):
        pyg_dataset_id = data_cfg.format.split("-", 1)[1]
        dataset_dir = osp.join(dataset_dir, pyg_dataset_id)
    dataset_loaded = check_and_load_processed_eig(
        dataset_dir, data_cfg.dataset_name, data_cfg.dir_with_error, partiton=partition
    )
    return dataset_loaded


def check_and_load_processed_eig_network_repo(
    dataset_dir, dataset_name, dataset_dir_with_err
):
    all_dataset_eig_files = glob.glob(
        os.path.join(dataset_dir, f"{dataset_name}*_eigen_*_processed.pt")
    )

    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{32}_processed.pt",
    )

    print(processed_path)
    if os.path.exists(processed_path):
        # Adjust the loading mechanism based on how you saved the dataset
        with open(processed_path, "rb") as f:
            dataset = pickle.load(f)
        print(f"Loaded processed dataset from {processed_path}")

        dataset.data.eigvecs_sn = dataset.data.eigvecs_sn[:, :32]
        dataset.data.eigvals_sn = dataset.data.eigvals_sn[:, :32, :]
        return dataset
    return None


def load_dataset_from_pt_network_repo(data_cfg, partition=None):
    dataset_dir = "/dev/shm/graph-datasets/eigen_processed/"
    dataset_dir_with_err = "/dev/shm/graph-datasets/eigen_processed_with_err/"
    dataset_loaded = check_and_load_processed_eig_network_repo(
        dataset_dir, data_cfg.dataset_name, dataset_dir_with_err
    )
    return dataset_loaded


def load_dataset_from_pt_syn(processed_path):
    if os.path.exists(processed_path):
        with open(processed_path, "rb") as f:
            dataset_loaded = pickle.load(f)
        return dataset_loaded
    else:
        return None


def create_dataset_multi(cfg):
    r"""
    Create dataset object

    Returns: PyG dataset object

    """
    format = cfg.format
    name = cfg.dataset_name
    dataset_dir = cfg.dir
    # Try to load customized data format
    for func in register.loader_dict.values():
        dataset = func(format, name, dataset_dir, cfg)
        if dataset is not None:
            return dataset
    # Load from Pytorch Geometric dataset
    if format == "PyG":
        dataset = load_pyg(name, dataset_dir)
    # Load from OGB formatted data
    elif format == "OGB":
        dataset = load_ogb(name.replace("_", "-"), dataset_dir)
    else:
        raise ValueError("Unknown data format: {}".format(format))

    return dataset


# def load_dataset(cfg):


import copy
from operator import itemgetter


# def custom_collate_fn(batch, graph_dataset_dict,saint_sampler, hop_cutoff):

#     # Extract unique dataset names from the batch
#     graph_dataset_keys = np.unique([item["node_dataset_name"] for item in batch])
#     # Fetch corresponding graph data for the unique dataset names
#     # extracted_values = itemgetter(*graph_dataset_keys)(graph_dataset_dict)

#     # if not isinstance(extracted_values, tuple):
#     #     extracted_values = (extracted_values,)

#     # extracted_values_dict = dict(zip(graph_dataset_keys, extracted_values))

#     node_ids = torch.tensor([item["node_id"] for item in batch], dtype=torch.long)
#     try:
#         node_labels = torch.tensor(
#             [item["node_label"] for item in batch], dtype=torch.long
#         )
#     except:
#         print([item["node_dataset_name"] for item in batch])
#         print([item["node_label"] for item in batch])
#         raise ValueError("not working")
#     dataset_name_to_index = {name: idx for idx, name in enumerate(graph_dataset_keys)}
#     node_dataset_indices = torch.tensor(
#         [dataset_name_to_index[item["node_dataset_name"]] for item in batch],
#         dtype=torch.long,
#     )

#     from collections import defaultdict

#     # Efficiently creating the dictionary
#     dataset_indices_efficient = defaultdict(list)
#     _ = [
#         dataset_indices_efficient[item["node_dataset_task_name"]].append(index)
#         for index, item in enumerate(batch)
#     ]

#     dict(dataset_indices_efficient)

#     output_values = {}
#     for key in dataset_indices_efficient:
#         output_values[key] = node_labels[dataset_indices_efficient[key]]
#     k_hop_neigh_idx, seq_mask = hops_multi(
#         main_dict=saint_sampler,
#         # main_dict=graph_dataset_dict,
#         subset_keys=graph_dataset_keys,
#         dataset_name_to_index=dataset_name_to_index,
#         node_dataset_id=node_dataset_indices,
#         node_idx=node_ids,
#         num_hops=1,
#         hop_cutoff=hop_cutoff,
#         get_undirected_hops=True,
#     )

#     # pos_latents_list = []
#     # for graph_key in extracted_values_dict.keys():
#     #     pos_latents = extracted_values_dict[graph_key].pos_latents
#     #     pos_latents_list.append(pos_latents)

#     # # Stack all pos_latents tensors into a single tensor
#     # pos_latents = torch.stack(pos_latents_list)

#     # token_id_self = torch.full(x.unsqueeze(1).shape[:-1], 0, device=x.device)
#     # token_id_latents = torch.full(output.shape[:-1], 1, device=x.device)
#     # token_id_hops = torch.full(k_hops_feat.shape[:-1], 2, device=x.device)

#     # Return a combined batch.

#     return {
#         "node_id": node_ids,
#         "node_label": node_labels,
#         "node_dataset_indices": node_dataset_indices,
#         # "batch_graph": extracted_values_dict,
#         "main_graph_dcit": graph_dataset_dict,
#         "subset_graph_keys": graph_dataset_keys,
#         # "pos_latents": pos_latents,
#         "output_task_indices": dataset_indices_efficient,
#         "output_values": output_values,
#         "k_hop_neigh_idx": k_hop_neigh_idx,
#         "k_hop_seq_mask": seq_mask,
#     }


from custom_modules.loader.utils import GraphSAINTRandomWalkSampler_custom


# class Collate:
#     def __init__(self, graph_dataset_dict, hop_cutoff):
#         self.graph_dataset_dict = graph_dataset_dict
#         manager = Manager()
#         self.graph_dataset_dict = manager.dict(
#             {
#                 key: load_dataset_from_pt(value).data
#                 for key, value in graph_dataset_dict.items()
#             }
#         )
#         self.saint_samplers = manager.dict(
#             {
#                 key: GraphSAINTRandomWalkSampler_custom(
#                     self.graph_dataset_dict[key],
#                     batch_size=hop_cutoff,  # Controls the size of the subgraph
#                     walk_length=hop_cutoff,  # Length of the random walks
#                     num_steps=5,  # Number of steps (subgraphs) to sample
#                     sample_coverage=0,  # Set to 0 for no specific coverage; adjust based on needs
#                     save_dir=None,  # Temporary directory to save the walks
#                 )
#                 for key in self.graph_dataset_dict
#             }
#         )
#         self.hop_cutoff = hop_cutoff

#     def __call__(self, batch):
#         return custom_collate_fn(
#             batch, self.graph_dataset_dict, self.saint_samplers, self.hop_cutoff
#         )


class Collate:
    def __init__(
        self, graph_dataset_dict, saint_samplers, hop_cutoff, model_num_latents
    ):
        self.graph_dataset_dict = graph_dataset_dict
        self.saint_samplers = saint_samplers

        self.hop_cutoff = hop_cutoff
        self.model_num_latents = model_num_latents

    def __call__(self, batch):
        # Extract unique dataset names from the batch
        graph_dataset_keys, node_dataset_indices = np.unique(
            [item["node_dataset_main_graph_dict_key"] for item in batch],
            return_inverse=True,
        )

        # graph_dataset_keys, node_dataset_indices = np.unique(
        #     [item["node_dataset_name"] for item in batch], return_inverse=True
        # )

        dataset_name_to_index = {
            name: idx for idx, name in enumerate(graph_dataset_keys)
        }
        node_ids = torch.tensor([item["node_id"] for item in batch], dtype=torch.long)

        # try:

        # except:
        #     print([item["node_dataset_name"] for item in batch])
        #     print([item["node_label"] for item in batch])
        #     raise ValueError("not working")

        from collections import defaultdict

        # Efficiently creating the dictionary
        output_values = defaultdict(list)
        dataset_indices_efficient = defaultdict(list)
        for index, item in enumerate(batch):
            dataset_indices_efficient[item["node_dataset_task_name"]].append(index)
            output_values[item["node_dataset_task_name"]].append(item["node_label"])

        dataset_indices_efficient = dict(dataset_indices_efficient)
        output_values = dict(output_values)
        for key in dataset_indices_efficient:
            output_values[key] = torch.stack(output_values[key])

        # node_labels = torch.tensor(
        #     [item["node_label"] for item in batch], dtype=torch.long
        # )

        # node_labels = [item["node_label"] for item in batch]

        # output_values = {}
        # for key in dataset_indices_efficient:
        #     output_values[key] = node_labels[dataset_indices_efficient[key]]
        # print(dataset_indices_efficient)
        # print(output_values)

        graph_names = graph_dataset_keys

        graph_seq_len = [
            self.graph_dataset_dict[graph_name].num_nodes for graph_name in graph_names
        ]
        graph_cumsum = np.cumsum([0] + graph_seq_len)

        # returns (B, 19)
        k_hop_neigh_idx, seq_mask = hops_multi(
            main_dict=self.saint_samplers,
            # main_dict=graph_dataset_dict,
            subset_keys=graph_dataset_keys,
            dataset_name_to_index=dataset_name_to_index,
            node_dataset_id=node_dataset_indices,
            node_idx=node_ids,
            num_hops=1,
            hop_cutoff=self.hop_cutoff,
            get_undirected_hops=True,
        )

        batch_wise_neigh_index = k_hop_neigh_idx + torch.LongTensor(
            graph_cumsum[node_dataset_indices]
        ).unsqueeze(1)

        # token type
        token_type_id = torch.LongTensor(
            [0]
            + [1] * (k_hop_neigh_idx.shape[1] - 1)
            + [2] * self.model_num_latents  # (replaced with cfg.model.num_latents)
        )  # replace 32 with num_latents
        token_type_id = repeat(token_type_id, "n -> b n", b=len(node_ids))

        # # extract x_dict
        # x_dict = {}
        # graph_pos = []
        # batch_index = []

        # for graph_id, graph_name in enumerate(graph_names):

        #     # if not hasattr(graph, "x") or graph.x is None:
        #     #     x = torch.zeros((graph.num_nodes, 1), dtype=torch.float32)
        #     # elif torch_geometric.utils.is_sparse(graph.x):
        #     #     x = graph.x.to_dense().to(dtype=torch.float32)
        #     # else:
        #     #     x = graph.x.to(dtype=torch.float32)

        #     x_dict[graph_name] = graph_dataset_dict[graph_name].x
        #     graph_pos.append(graph_dataset_dict[graph_name].eigvecs_sn)
        #     batch_index.append(torch.ones(graph_dataset_dict[graph_name].num_nodes, dtype=torch.long) * graph_id)

        # graph_pos = torch.cat(graph_pos, dim=0)
        # batch_index = torch.cat(batch_index, dim=0)

        # extracted_values = itemgetter(*graph_dataset_keys)(graph_dataset_dict)

        # if not isinstance(extracted_values, tuple):
        #     extracted_values = (extracted_values,)

        # graph_dict = dict(zip(graph_dataset_keys, extracted_values))

        return {
            # "x_dict": x_dict,
            # "pos": graph_pos,
            "graph_names": graph_names,
            # "batch_index": batch_index,
            "graph_seq_len": graph_seq_len,
            "node_dataset_indices": torch.LongTensor(node_dataset_indices),
            "batch_wise_neigh_index": batch_wise_neigh_index,
            "token_type_id": token_type_id,
            # "node_id": node_ids,
            # "node_label": node_labels,
            # "batch_graph": extracted_values_dict,
            "main_graph_dict": self.graph_dataset_dict,
            # "subset_graph_keys": graph_dataset_keys,
            # "pos_latents": pos_latents,
            "output_task_indices": dataset_indices_efficient,
            "output_values": output_values,
            # "k_hop_neigh_idx": k_hop_neigh_idx,
            # "k_hop_seq_mask": seq_mask,
        }


class print_pid:
    def __init__(self, worker_id):
        pid = os.getpid()
        print(f" PID {pid}")
        with open(f"worker_{worker_id}_pid.txt", "w") as f:
            f.write(f"{pid}\n")


# def create_loader_distributed(rank, world_size):
def create_loader_distributed(
    rank,
    world_size,
    shared_graph_dataset_dict,
    saint_samplers,
    node_dataset_loaded_list_train,
):
    """Create data loader object.

    Returns: List of PyTorch data loaders

    """
    from torch.utils.data.distributed import DistributedSampler
    from custom_modules.loader.graph_distributed_sampler import (
        UGTBatchSampler,
        UGTSnakeSampler,
    )

    # create_gpu_data_buckets()
    if len(cfg.dataset_multi.name_list) > 0:
        cfg.dataset_config_list = []
        # node_dataset_loaded_list_train = []
        node_dataset_loaded_list_val = []
        node_dataset_loaded_list_test = []
        graph_dataset_dict = {}
        # for dataset_name in cfg.dataset_multi.name_list:
        #     dataset_cfg = getattr(cfg, dataset_name)
        #     dataset_cfg.enable = True
        #     cfg.dataset_config_list.append(dataset_cfg)
        #     # full_dataset = create_dataset_multi(dataset_cfg)
        #     # graph_dataset_dict[dataset_cfg.dataset_name] = full_dataset.data
        #     graph_dataset_dict[dataset_cfg.dataset_name] = dataset_cfg
        #     full_dataset = load_dataset_from_pt(dataset_cfg)
        #     # print('dataset',full_dataset)

        #     dataset_mask = full_dataset.data["train_mask"]
        #     dict_for_node_dataset = {
        #         "node_id": full_dataset.data.node_id[dataset_mask],
        #         "y": full_dataset.data.y[dataset_mask],
        #         "node_dataset_name": full_dataset.data.dataset_name[0],
        #         "dataset_task_name": full_dataset.data.dataset_task_name[0],
        #     }
        #     node_dataset_loaded_list_train.append(
        #         NodeDataset(dict_for_node_dataset, mask="train_mask")
        #     )

        #     del full_dataset

        # print(shared_graph_dataset_dict)

        # for key in shared_graph_dataset_dict.keys():
        #     dataset_mask = shared_graph_dataset_dict[key]["train_mask"]
        #     dict_for_node_dataset = {
        #         "node_id": shared_graph_dataset_dict[key].node_id[dataset_mask],
        #         "y": shared_graph_dataset_dict[key].y[dataset_mask],
        #         "node_dataset_name": shared_graph_dataset_dict[key].dataset_name[0],
        #         "dataset_task_name": shared_graph_dataset_dict[key].dataset_task_name[0],
        #     }
        #     node_dataset_loaded_list_train.append(
        #         NodeDataset(dict_for_node_dataset, mask="train_mask")
        #     )

        # del full_dataset

        # if cfg.dataset_multi.use_synthetic:
        #     dataset_list = []
        #     syn_graph_config_list = []
        #     for syn_dataset_dir in cfg.dataset_multi.synthetic_data_dir:
        #         dataset_list_dir, graph_dataset_dict_syn, syn_graph_config_list_dir = (
        #             load_synthetic_data(
        #                 syn_dataset_dir,
        #                 cfg.dataset_multi.num_synthetic_samples,
        #             )
        #         )
        #         dataset_list.extend(dataset_list_dir)
        #         syn_graph_config_list.extend(syn_graph_config_list_dir)
        #         graph_dataset_dict.update(graph_dataset_dict_syn)
        #     node_dataset_loaded_list_train = (
        #         node_dataset_loaded_list_train + dataset_list
        #     )
        #     cfg.dataset_multi.syn_graph_config = syn_graph_config_list

        # for key, value in graph_dataset_dict.items():
        #     try:
        #         copy.deepcopy(value)
        #         # if not hasattr(value, "x"):
        #         #     value.x = torch.zeros((value.num_nodes, 1), dtype=torch.float32)
        #         try:
        #             value.x = value.x.to(dtype=torch.float32)
        #         except:
        #             value.x = torch.zeros((value.num_nodes, 1), dtype=torch.float32)

        #     except NotImplementedError as e:
        #         graph_dataset_dict[key].x = (
        #             graph_dataset_dict[key].x.to_dense().to(dtype=torch.float32)
        #         )
        #     print(key)
        #     print(value.num_nodes)
        # if not hasattr(value, "x"):
        #     value.x = torch.zeros((value.num_nodes, 1), dtype=torch.float32)
        #     value.x = value.x.to(dtype=torch.float32)

        name_to_indices = {}
        graph_ids = []
        num_nodes = []
        ratio_in_epoch = []
        seq_len = []
        count = 0
        for node_dataset in node_dataset_loaded_list_train:
            if len(node_dataset) == 0:
                print(
                    "No labels present in: ",
                    node_dataset.node_level_data_dict["main_graph_dict_key"],
                )
            if (
                node_dataset.node_level_data_dict["main_graph_dict_key"]
                not in name_to_indices
            ):
                name_to_indices[
                    node_dataset.node_level_data_dict["main_graph_dict_key"]
                ] = count
                count += 1
            graph_ids.extend(
                [
                    name_to_indices[
                        node_dataset.node_level_data_dict["main_graph_dict_key"]
                    ]
                ]
                * len(node_dataset)
            )
            num_nodes.extend([len(node_dataset)] * len(node_dataset))
            ratio_in_epoch.extend(
                [node_dataset.node_level_data_dict["graph_ratio_in_epoch"]]
            )
            seq_len.extend([len(node_dataset)])

        graph_ids = torch.Tensor(graph_ids)
        num_nodes = torch.Tensor(num_nodes)
        ratio_in_epoch = torch.Tensor(ratio_in_epoch)
        seq_len = torch.Tensor(seq_len)

        # print("List of num nodes for individual graphs: ", np.unique(num_nodes))
        print("Number of nodes to decode:", len(graph_ids))

        node_dataset_loaded_train = ConcatDataset(node_dataset_loaded_list_train)
        # sampler_train = DistributedSampler(
        #     node_dataset_loaded_train,
        #     num_replicas=world_size,
        #     rank=rank,
        #     shuffle=True,
        #     drop_last=True,
        #     seed=cfg.seed,
        # )
        batch_sampler_train = UGTSnakeSampler(
            node_dataset_loaded_train,
            graph_ids,
            num_nodes,
            ratio_in_epoch,
            seq_len,
            cfg.train.sampler_graph_limit,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            batch_size=cfg.train.batch_size,
            # drop_last=True,
            seed=cfg.seed,
            accum_gradient_steps=cfg.train.accum_gradient_steps,
        )

        # collate_fn = Collate(graph_dataset_dict, cfg.model.hop_cutoff)
        collate_fn = Collate(
            shared_graph_dataset_dict,
            saint_samplers,
            cfg.model.hop_cutoff,
            cfg.model.num_latents,
        )

        train_loader = torch_DataLoader(
            node_dataset_loaded_train,
            # batch_size=cfg.train.batch_size,
            num_workers=4,
            pin_memory=False,
            # sampler=sampler_train,
            batch_sampler=batch_sampler_train,
            collate_fn=collate_fn,
            prefetch_factor=2,
            persistent_workers=True,
            worker_init_fn=print_pid,
        )

        return train_loader


def create_loader_multi_testing(
    rank,
    world_size,
    shared_graph_dataset_dict,
    saint_samplers,
    node_dataset_loaded_list_train,
):
    """Create data loader object.

    Returns: List of PyTorch data loaders

    """
    from torch.utils.data.distributed import DistributedSampler
    from custom_modules.loader.graph_distributed_sampler import (
        UGTSnakeSampler,
    )

    if len(cfg.dataset_multi.name_list) > 0:
        cfg.dataset_config_list = []
        name_to_indices = {}
        graph_ids = []
        num_nodes = []
        ratio_in_epoch = []
        seq_len = []
        count = 0
        for node_dataset in node_dataset_loaded_list_train:
            if (
                node_dataset.node_level_data_dict["main_graph_dict_key"]
                not in name_to_indices
            ):
                name_to_indices[
                    node_dataset.node_level_data_dict["main_graph_dict_key"]
                ] = count
                count += 1
            graph_ids.extend(
                [
                    name_to_indices[
                        node_dataset.node_level_data_dict["main_graph_dict_key"]
                    ]
                ]
                * len(node_dataset)
            )
            num_nodes.extend([len(node_dataset)] * len(node_dataset))
            ratio_in_epoch.extend(
                [node_dataset.node_level_data_dict["graph_ratio_in_epoch"]]
            )
            seq_len.extend([len(node_dataset)])

        graph_ids = torch.Tensor(graph_ids)
        num_nodes = torch.Tensor(num_nodes)
        ratio_in_epoch = torch.Tensor(ratio_in_epoch)
        seq_len = torch.Tensor(seq_len)

        # print("List of num nodes for individual graphs: ", np.unique(num_nodes))
        print("Number of nodes to decode:", len(graph_ids))

        node_dataset_loaded_train = ConcatDataset(node_dataset_loaded_list_train)

        # batch_sampler_train = UGTSnakeSampler(
        #     node_dataset_loaded_train,
        #     graph_ids,
        #     num_nodes,
        #     ratio_in_epoch,
        #     seq_len,
        #     cfg.train.sampler_graph_limit,
        #     num_replicas=1,
        #     rank=0,
        #     shuffle=True,
        #     batch_size=cfg.train.batch_size,
        #     # drop_last=True,
        #     seed=cfg.seed,
        #     accum_gradient_steps=1,
        # )

        # collate_fn = Collate(graph_dataset_dict, cfg.model.hop_cutoff)
        collate_fn = Collate(
            shared_graph_dataset_dict,
            saint_samplers,
            cfg.model.hop_cutoff,
            cfg.model.num_latents,
        )

        train_loader = torch_DataLoader(
            node_dataset_loaded_train,
            batch_size=cfg.train.batch_size,
            num_workers=4,
            pin_memory=False,
            # sampler=sampler_train,
            # batch_sampler=batch_sampler_train,
            collate_fn=collate_fn,
            prefetch_factor=2,
            persistent_workers=True,
            worker_init_fn=print_pid,
        )

        return train_loader


def create_loader():
    """Create data loader object.

    Returns: List of PyTorch data loaders

    """
    if len(cfg.dataset_multi.name_list) > 0:
        if cfg.dataset.task == "graph":
            cfg.dataset_config_list = []
            dataset_loaded_list = []
            for dataset_name in cfg.dataset_multi.name_list:
                dataset_cfg = getattr(cfg, dataset_name)
                dataset_cfg.enable = True
                cfg.dataset_config_list.append(dataset_cfg)
                dataset_loaded_list.append(create_dataset_multi(dataset_cfg))

            dataset_train_list = []
            dataset_val_list = []
            dataset_test_list = []

            for dataset in dataset_loaded_list:
                id = dataset.data["train_graph_index"]
                dataset_add = dataset[id]
                dataset_train_list.append(dataset[id])

                delattr(dataset.data, "train_graph_index")

                id = dataset.data["val_graph_index"]
                dataset_add = dataset[id]
                dataset_val_list.append(dataset[id])

                delattr(dataset.data, "val_graph_index")

                id = dataset.data["test_graph_index"]
                dataset_add = dataset[id]
                dataset_test_list.append(dataset[id])

                delattr(dataset.data, "test_graph_index")

            data_train_all = [
                data for dataset in dataset_train_list for data in dataset
            ]
            data_val_all = [data for dataset in dataset_val_list for data in dataset]
            data_test_all = [data for dataset in dataset_test_list for data in dataset]

            loaders = [
                get_loader(
                    data_train_all,
                    cfg.train.sampler,
                    cfg.train.batch_size,
                    shuffle=True,
                )
            ]

            loaders.append(
                get_loader(
                    data_val_all, cfg.val.sampler, cfg.train.batch_size, shuffle=False
                )
            )

            loaders.append(
                get_loader(
                    data_test_all, cfg.val.sampler, cfg.train.batch_size, shuffle=False
                )
            )

            return loaders
        else:
            cfg.dataset_config_list = []
            node_dataset_loaded_list_train = []
            node_dataset_loaded_list_val = []
            node_dataset_loaded_list_test = []
            graph_dataset_dict = {}
            for dataset_name in cfg.dataset_multi.name_list:
                dataset_cfg = getattr(cfg, dataset_name)
                dataset_cfg.enable = True
                cfg.dataset_config_list.append(dataset_cfg)
                full_dataset = create_dataset_multi(dataset_cfg)
                graph_dataset_dict[dataset_cfg.dataset_name] = full_dataset.data
                node_dataset_loaded_list_train.append(
                    NodeDataset(full_dataset.data, mask="train_mask")
                )
                node_dataset_loaded_list_val.append(
                    NodeDataset(full_dataset.data, mask="val_mask")
                )
                node_dataset_loaded_list_test.append(
                    NodeDataset(full_dataset.data, mask="test_mask")
                )

            if cfg.dataset_multi.use_synthetic:
                dataset_list = []
                syn_graph_config_list = []
                for syn_dataset_dir in cfg.dataset_multi.synthetic_data_dir:
                    (
                        dataset_list_dir,
                        graph_dataset_dict_syn,
                        syn_graph_config_list_dir,
                    ) = load_synthetic_data(
                        syn_dataset_dir,
                        cfg.dataset_multi.num_synthetic_samples,
                    )
                    dataset_list.extend(dataset_list_dir)
                    syn_graph_config_list.extend(syn_graph_config_list_dir)
                    graph_dataset_dict.update(graph_dataset_dict_syn)

                node_dataset_loaded_list_train = (
                    node_dataset_loaded_list_train + dataset_list
                )
                cfg.dataset_multi.syn_graph_config = syn_graph_config_list

            node_dataset_loaded_train = ConcatDataset(node_dataset_loaded_list_train)
            node_dataset_loaded_val = ConcatDataset(node_dataset_loaded_list_val)
            node_dataset_loaded_test = ConcatDataset(node_dataset_loaded_list_test)
            collate_fn = Collate(graph_dataset_dict, cfg.model.hop_cutoff)
            # Create a DataLoader for the combined dataset
            loaders = [
                torch_DataLoader(
                    node_dataset_loaded_train,
                    batch_size=cfg.train.batch_size,
                    shuffle=True,
                    num_workers=4,
                    # pin_memory=True,
                    collate_fn=collate_fn,
                )
            ]
            loaders.append(
                torch_DataLoader(
                    node_dataset_loaded_val,
                    batch_size=cfg.train.batch_size,
                    shuffle=False,
                    num_workers=4,
                    # pin_memory=True,
                    collate_fn=collate_fn,
                )
            )
            loaders.append(
                torch_DataLoader(
                    node_dataset_loaded_test,
                    batch_size=cfg.train.batch_size,
                    shuffle=False,
                    num_workers=4,
                    # pin_memory=True,
                    collate_fn=collate_fn,
                )
            )

            return loaders, graph_dataset_dict

    else:
        dataset = create_dataset()

        if cfg.dataset.task == "graph":
            id = dataset.data["train_graph_index"]
            loaders = [
                get_loader(
                    dataset[id], cfg.train.sampler, cfg.train.batch_size, shuffle=True
                )
            ]

            delattr(dataset.data, "train_graph_index")
        else:
            loaders = [
                get_loader(
                    dataset, cfg.train.sampler, cfg.train.batch_size, shuffle=True
                )
            ]
        split_names = ["val_graph_index", "test_graph_index"]

        # val and test loaders
        for i in range(cfg.share.num_splits - 1):
            if cfg.dataset.task == "graph":
                split_names = ["val_graph_index", "test_graph_index"]
                id = dataset.data[split_names[i]]
                loaders.append(
                    get_loader(
                        dataset[id],
                        cfg.val.sampler,
                        cfg.train.batch_size,
                        shuffle=False,
                    )
                )
                delattr(dataset.data, split_names[i])
            else:
                loaders.append(
                    get_loader(
                        dataset, cfg.val.sampler, cfg.train.batch_size, shuffle=False
                    )
                )

        return loaders
