import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset

class NPYGraphDataset(InMemoryDataset):
    def __init__(self, npy_path, task='graph', transform=None):
        super(NPYGraphDataset, self).__init__(None, transform)
        self.npy_path = npy_path
        self.task = task
        self.data, self.slices = self.load_npy(npy_path)

        
        self._num_features = self.data.x.size(1) if self.data.x is not None else 0

    
    @property
    def num_features(self):
        return self._num_features

    def load_npy(self, npy_path):
        np_data = np.load(npy_path, allow_pickle=True).item()

        node_features = np_data.get('node_features')
        edge_indices = np_data['edge_index']
        labels = np_data['label']  

        data_list = []
        for feats, edges, label in zip(node_features, edge_indices, labels):
            feats = torch.tensor(feats, dtype=torch.float)
            edges = torch.tensor(edges, dtype=torch.long)
            if edges.ndim == 2 and edges.size(0) != 2:
                edges = edges.t()
            edges = edges.contiguous()
            y = torch.tensor([label], dtype=torch.long)  
            graph_data = Data(x=feats, edge_index=edges, y=y)
            data_list.append(graph_data)

        return self.collate(data_list)
