from torch.utils.data import Dataset
from torch_geometric.utils import to_dense_adj
import torch









class GraphDataset(Dataset):
    def __init__(self, X, edge_index):
        self.num_graphs = X.size(0)
        self.graphs = self.generate_graphs(X, edge_index)

    def __len__(self):
        return self.num_graphs

    def __getitem__(self, idx):
        return self.graphs[idx]

    def generate_graphs(self, X, edge_index):
        graphs = []
        edge_index = to_dense_adj(edge_index)
        for i in range(self.num_graphs):
            x = X[i].T
            graphs.append((x, edge_index))

        return graphs




class MyData:
    def __init__(self, x, edge_index):
        self.x = x
        self.edge_index = edge_index
        self.batch = self.create_batch(x)

    def create_batch(self, x):
        num_graphs = x.size(0)
        batch = torch.arange(num_graphs).repeat_interleave(x.size(1))
        return batch

    def to(self, device):
        self.x = self.x.to(device)
        self.edge_index = self.edge_index.to(device)
        self.batch = self.batch.to(device)
        return self