from torch_geometric.graphgym.config import assert_cfg, cfg
from torch_geometric.loader import (
    ClusterLoader,
    DataLoader,
    GraphSAINTEdgeSampler,
    GraphSAINTNodeSampler,
    GraphSAINTRandomWalkSampler,
    NeighborLoader,
    RandomNodeLoader,
)

from torch_geometric.graphgym.loader import create_dataset

import os 
from yacs.config import CfgNode as CN

from torch_geometric.graphgym.config import set_cfg

def get_loader(dataset, sampler, batch_size, node_split_name, shuffle=True):
    if sampler == "full_batch" or len(dataset) > 1:
        loader_train = DataLoader(dataset, batch_size=batch_size,
                                  shuffle=shuffle, num_workers=cfg.num_workers,
                                  pin_memory=True,
                                  persistent_workers=cfg.num_workers > 0)
    elif sampler == "neighbor":
        assert node_split_name, "NeighborLoader is only valid for node tasks"
        loader_train = NeighborLoader(
            dataset[0],
            num_neighbors=cfg.train.neighbor_sizes[:cfg.gnn.layers_mp],
            batch_size=batch_size, shuffle=shuffle,
            num_workers=cfg.num_workers, pin_memory=True,
            input_nodes=getattr(dataset[0], f"{node_split_name}_mask"))
    elif sampler == "random_node":
        loader_train = RandomNodeLoader(dataset[0],
                                        num_parts=cfg.train.train_parts,
                                        shuffle=shuffle,
                                        num_workers=cfg.num_workers,
                                        pin_memory=True)
    elif sampler == "saint_rw":
        loader_train = \
            GraphSAINTRandomWalkSampler(dataset[0],
                                        batch_size=batch_size,
                                        walk_length=cfg.train.walk_length,
                                        num_steps=cfg.train.iter_per_epoch,
                                        sample_coverage=0,
                                        shuffle=shuffle,
                                        num_workers=cfg.num_workers,
                                        pin_memory=True)
    elif sampler == "saint_node":
        loader_train = \
            GraphSAINTNodeSampler(dataset[0], batch_size=batch_size,
                                  num_steps=cfg.train.iter_per_epoch,
                                  sample_coverage=0, shuffle=shuffle,
                                  num_workers=cfg.num_workers,
                                  pin_memory=True)
    elif sampler == "saint_edge":
        loader_train = \
            GraphSAINTEdgeSampler(dataset[0], batch_size=batch_size,
                                  num_steps=cfg.train.iter_per_epoch,
                                  sample_coverage=0, shuffle=shuffle,
                                  num_workers=cfg.num_workers,
                                  pin_memory=True)
    elif sampler == "cluster":
        loader_train = \
            ClusterLoader(dataset[0],
                          num_parts=cfg.train.train_parts,
                          save_dir="{}/{}".format(cfg.dataset.dir,
                                                  cfg.dataset.name.replace(
                                                      "-", "_")),
                          batch_size=batch_size, shuffle=shuffle,
                          num_workers=cfg.num_workers,
                          pin_memory=True)

    else:
        raise NotImplementedError("%s sampler is not implemented!" % sampler)
    return loader_train


def create_loader():
    """
    Create data loader object

    Returns: List of PyTorch data loaders

    """
    dataset = create_dataset()
    # train loader
    if cfg.dataset.task == 'graph':
        id = dataset.data['train_graph_index']
        loaders = [
            get_loader(dataset[id], cfg.train.sampler, cfg.train.batch_size,
                       node_split_name=None, shuffle=True)
        ]
        delattr(dataset.data, 'train_graph_index')
    else:
        loaders = [
            get_loader(dataset, cfg.train.sampler, cfg.train.batch_size,
                       node_split_name="train", shuffle=True)
        ]

    # val and test loaders
    for i in range(cfg.share.num_splits - 1):
        if cfg.dataset.task == 'graph':
            split_names = ['val_graph_index', 'test_graph_index']
            id = dataset.data[split_names[i]]
            loaders.append(
                get_loader(dataset[id], cfg.val.sampler, cfg.train.batch_size,
                           node_split_name=None, shuffle=False))
            delattr(dataset.data, split_names[i])
        else:
            split_names = ['val', 'test']
            loaders.append(
                get_loader(dataset, cfg.val.sampler, cfg.train.batch_size,
                           node_split_name=split_names[i], shuffle=False))

    return loaders


def load_cfg(cfg, args):
    r"""
    Load configurations from file system and command line.

    This patch added the 'parser_drop_eq' option to enable compatibility with
    the wandb sweep, which specify cli args with equal sign. E.g., specifying
    'parser_drop_eq' turns 'param1=value1' to 'param1 value1', which can be
    readily parsed into the given graphgym cli parser.

    Note:
        The 'parser_drop_eq' MUST be specified as the first argument AFTER the
        predefined cli arguments, such as '--cfg'. Example:
        ``python main.py --cfg config.yaml parser_drop_eq param1=value1``

    Args:
        cfg (CfgNode): Configuration node
        args (ArgumentParser): Command argument parser

    """

    cfg.merge_from_file(args.cfg_file)
    if "parser_drop_eq" in args.opts:  # "param1=value1" -> "param1 value1"
        opts = []
        for opt in args.opts:
            if opt != "parser_drop_eq":
                opts += opt.split("=", 1)
    else:
        opts = args.opts
    cfg.merge_from_list(opts)
    assert_cfg(cfg)


# MODIFIED FOR EIGENVECTOR-LEARNING 
# Handling gpse architecture configs which depend on the eigenvector-learning params 
def load_gpse_cfg(cfg, gpse_path):
    r"""
    Load configurations for gpse model when using pretrained version for structural encodings
    Input: cfg, gpse_path (path to RESULTS folder of gpse model to use)

    """
    gpse_cfg_path = os.path.join(gpse_path, "config.yaml")

    gpse_cfg = CN(new_allowed=True)
    # gpse_cfg = CN()
    # set_cfg(gpse_cfg)
    # if "posenc_GPSE" in gpse_cfg and "tag" in gpse_cfg.posenc_GPSE:
    #     gpse_cfg.posenc_GPSE.tag = str(gpse_cfg.posenc_GPSE.tag)
    gpse_cfg.merge_from_file(gpse_cfg_path)
    
    # base GPSE architecture configs
    cfg.posenc_GPSE.gnn_cfg = gpse_cfg.gnn 

    # checkpoint path
    ckpt_folder = os.path.join(gpse_path, "0", "ckpt")
    ckpt = os.listdir(ckpt_folder)[0]
    cfg.posenc_GPSE.model_dir = os.path.join(ckpt_folder, ckpt)
    
    # to distinguish the downstream result save path
    gpse_name = os.path.split(gpse_path)[-1]
    cfg.out_dir = os.path.join(cfg.out_dir, gpse_name)
    # cfg.name_tag = os.path.split(gpse_path)[-1]

    # specific architecture of the eigval learner 
    cfg.posenc_LapPE.MLP_style = gpse_cfg.posenc_LapPE.MLP_style
    cfg.posenc_LapPE.eigen.max_freqs = gpse_cfg.posenc_LapPE.eigen.max_freqs
    cfg.posenc_LapPE.concat_max_nodes = gpse_cfg.posenc_LapPE.concat_max_nodes
    

    