import torch
import torch_geometric


def get_dataloder(args):
    if args.dataset.name in ['ZINC', 'proteins']:
        from dataloaders.combine_with_edgewiseloader import EdgeWiseDataLoader
        train_dataloader = EdgeWiseDataLoader(args.dataset, mode='train', )
        test_dataloader = EdgeWiseDataLoader(args.dataset, mode='test')
    elif args.dataset.dataset_path.split('/')[-2] in [
        'ising_dataset',
        'junction_tree',
        'path_graph',
        'prufer_tree',
        'prufer_tree_graphbank',
        'd_regular_ising',
    ]:
        from dataloaders.belief_dataloader import InductiveFolderDataLoader
        print(args.dataset.dataset_path.split('/')[-1])
        if args.dataset.dataset_path.split('/')[-1].split('_')[0] in [
            'ns', 'W']:
            file_prefix_2 = 'graph_example'
        else:
            file_prefix_2 = None
        train_dataloader = InductiveFolderDataLoader(
            dataset_folder=args.dataset.dataset_path,
            num_samples=args.dataset.num_samples,
            mode='train',
            file_prefix=args.dataset.file_prefix,
            edge_file_prefix=args.dataset.edge_file_prefix,
        )
        test_dataloader = InductiveFolderDataLoader(
            dataset_folder=args.dataset.dataset_path,
            num_samples=int(0.2 * args.dataset.num_samples),
            mode='test',
            file_prefix=args.dataset.file_prefix,
            edge_file_prefix=args.dataset.edge_file_prefix,
        )
    elif args.dataset.dataset_path.split('/')[-2] in [
        "flipflop",]:
        from dataloaders.inductive_folder_dataloader import FolderDataLoader
        train_dataloader = FolderDataLoader(
            dataset_name=args.dataset.name, 
            num_samples=args.dataset.num_samples,
            folder_path=args.dataset.dataset_prefix_path,
            mode='train',
        )
        test_dataloader = FolderDataLoader(
            dataset_name=args.dataset.name, 
            num_samples=args.dataset.num_samples,
            mode='test',
        )
    elif args.dataset.dataset_path.split('/')[-2] in [
        'star_graph',
        'star_graph_random',
        'star_graph_node_random',
    ]:
        from dataloaders.belief_dataloader import InductiveFolderDataLoaderStarGraph
        print(args.dataset.dataset_path.split('/')[-1])
        train_dataloader = InductiveFolderDataLoaderStarGraph(
            dataset_folder=args.dataset.dataset_path,
            num_samples=args.dataset.num_samples,
            mode='train',
            file_prefix=args.dataset.file_prefix,
            edge_file_prefix=args.dataset.edge_file_prefix,
        )
        test_dataloader = InductiveFolderDataLoaderStarGraph(
            dataset_folder=args.dataset.dataset_path,
            num_samples=int(0.2 * args.dataset.num_samples),
            mode='test',
            file_prefix=args.dataset.file_prefix,
            edge_file_prefix=args.dataset.edge_file_prefix,
        )
    else:
        ValueError("Invalid Dataset Name")
    return train_dataloader, test_dataloader