import torch
import networkx as nx
import numpy as np
from dgl.data import GINDataset, TUDataset
from dgl.dataloading import GraphDataLoader
from sklearn.model_selection import train_test_split

def load_dataset(dataset_name):
    """
    Load and preprocess the specified dataset.
    
    Args:
        dataset_name (str): Name of the dataset to load.
        
    Returns:
        dataset (dgl.data.DGLDataset): Loaded dataset.
        num_classes (int): Number of classes in the dataset.
    """
    if dataset_name in ['MUTAG', 'PTC', 'NCI1', 'PROTEINS']:
        dataset = TUDataset(dataset_name)
    elif dataset_name in ['IMDB-BINARY', 'IMDB-MULTI', 'COLLAB', 'REDDIT-BINARY', 'REDDIT-MULTI-5K']:
        dataset = GINDataset(dataset_name, self_loop=True, degree_as_nlabel=False)
    else:
        raise ValueError(f'Unknown dataset: {dataset_name}')
    
    num_classes = dataset.num_classes
    return dataset, num_classes

def subgraph_sampling(graph, num_samples=10):
    """
    Sample subgraphs of different patterns from the given graph.
    
    Args:
        graph (networkx.Graph): Input graph.
        num_samples (int): Number of subgraphs to sample for each pattern.
        
    Returns:
        pattern_subgraphs (dict): Dictionary containing sampled subgraphs for each pattern.
    """
    nodes = list(graph.nodes())
    
    # Path sampling
    path_subgraphs = []
    for _ in range(num_samples):
        source, target = np.random.choice(nodes, 2, replace=len(nodes) < 2)
        paths = [path for path in nx.all_simple_paths(graph, source=source, target=target, cutoff=3)]
        if paths:
            path_subgraphs.append(nx.subgraph(graph, paths[0]))
        else:
            path_subgraphs.append(graph)
    
    # Tree sampling
    tree_subgraphs = []
    for _ in range(num_samples):
        tree = nx.bfs_tree(graph, np.random.choice(nodes))
        tree_subgraphs.append(nx.subgraph(graph, list(tree.nodes())))
    
    # Cycle sampling
    cycle_subgraphs = [nx.subgraph(graph, cycle) for cycle in nx.cycle_basis(graph)]
    cycle_subgraphs += [graph] * (num_samples - len(cycle_subgraphs))
    
    # Clique sampling
    clique_subgraphs = [nx.subgraph(graph, clique) for clique in nx.find_cliques(graph) if len(clique) >= 3]
    clique_subgraphs += [graph] * (num_samples - len(clique_subgraphs))
    
    # Graphlet sampling
    graphlet_subgraphs = []
    for _ in range(num_samples):
        sample_size = min(np.random.randint(3, 6), len(nodes))
        graphlet = graph.subgraph(np.random.choice(nodes, sample_size, replace=False))
        graphlet_subgraphs.append(graphlet)
    
    # Wheel sampling
    wheel_subgraphs = [nx.subgraph(graph, wheel) for wheel in nx.wheel_graph(min(np.random.randint(3, 6), len(nodes))).edges()]
    wheel_subgraphs += [graph] * (num_samples - len(wheel_subgraphs))
    
    # Star sampling
    star_subgraphs = [nx.subgraph(graph, star) for star in nx.star_graph(min(np.random.randint(2, 6), len(nodes))).edges()]
    star_subgraphs += [graph] * (num_samples - len(star_subgraphs))
    
    pattern_subgraphs = {
        'path': path_subgraphs,
        'tree': tree_subgraphs,
        'cycle': cycle_subgraphs,
        'clique': clique_subgraphs,
        'graphlet': graphlet_subgraphs,
        'wheel': wheel_subgraphs,
        'star': star_subgraphs
    }
    return pattern_subgraphs

def get_dataloader(dataset, batch_size, split_ratio=0.8, seed=42):
    """
    Create data loaders for training and validation.
    
    Args:
        dataset (dgl.data.DGLDataset): Dataset to create data loaders from.
        batch_size (int): Batch size for the data loaders.
        split_ratio (float): Ratio of samples to use for training.
        seed (int): Random seed for splitting the dataset.
        
    Returns:
        train_loader (dgl.dataloading.GraphDataLoader): Data loader for training.
        val_loader (dgl.dataloading.GraphDataLoader): Data loader for validation.
    """
    train_size = int(len(dataset) * split_ratio)
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(seed))
    
    train_loader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = GraphDataLoader(val_dataset, batch_size=batch_size)
    
    return train_loader, val_loader