import dgl
import torch
from dgl.data import DGLDataset
from torch.utils.data import DataLoader
from dgl.dataloading import GraphDataLoader
import torch.nn as nn
import torch.nn.functional as F


# Step 1: Define a Multi-Graph Dataset
class MultiGraphDataset(DGLDataset):
    def __init__(self, num_graphs=20):
        self.num_graphs = num_graphs
        super().__init__(name='multi_graph')

    def process(self):
        self.graphs = []  # List to store graphs
        self.labels = []   # Labels for graph classification

        for _ in range(self.num_graphs):
            num_nodes = torch.randint(5, 10, (1,)).item()  # Random num of nodes
            num_edges = num_nodes * 2  # Each node has ~2 edges

            # Generate edges
            src = torch.randint(0, num_nodes, (num_edges,))
            dst = torch.randint(0, num_nodes, (num_edges,))
            mask = src != dst  # Remove self-loops
            src, dst = src[mask], dst[mask]

            # Create graph
            g = dgl.graph((src, dst), num_nodes=num_nodes)

            # Assign random node features (5D feature vector per node)
            g.ndata['feat'] = torch.randn(num_nodes, 5)

            # Assign random edge weights
            g.edata['weight'] = torch.rand(g.num_edges())

            # Assign graph-level label (classification task)
            label = torch.randint(0, 3, (1,))  # 3-class classification
            self.graphs.append(g)
            self.labels.append(label)

        # Convert labels to tensor
        self.labels = torch.tensor(self.labels)

    def __getitem__(self, idx):
        return self.graphs[idx], self.labels[idx]

    def __len__(self):
        return len(self.graphs)

# Step 2: Create Dataset and Dataloader
dataset = MultiGraphDataset(num_graphs=50)  # Create dataset with 50 graphs
batch_size = 4

# Create Dataloader
dataloader = GraphDataLoader(dataset, batch_size=batch_size, shuffle=True)