import os
import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset

class PaperDataset(InMemoryDataset):
    
    def __init__(self, npy_path, transform=None):
        
        self.npy_path = os.path.abspath(npy_path)
        root = os.path.dirname(self.npy_path)
        super(PaperDataset, self).__init__(root=root, transform=transform)
        
        self.data, self.slices = self.load_npy(self.npy_path)
        self._num_features = self.data.x.size(1) if self.data.x is not None else 0

    @property
    def raw_file_names(self):
        return [os.path.basename(self.npy_path)]

    @property
    def processed_file_names(self):
        return [os.path.basename(self.npy_path).replace('.npy', '.pt')]

    def download(self):
        
        pass

    def load_npy(self, npy_path):
        
        np_data = np.load(npy_path, allow_pickle=True).item()
        edge_indices   = np_data.get('edge_index', [])
        features_list  = np_data.get('features')
        roles_list     = np_data.get('role_id')

        if features_list is None or roles_list is None:
            raise KeyError("NPY file must contain 'features' and 'role_id' keys")

        data_list = []
        for edges, feats, roles in zip(edge_indices, features_list, roles_list):
            
            ei = torch.tensor(edges, dtype=torch.long) if not isinstance(edges, torch.Tensor) else edges.long()
            if ei.ndim == 2 and ei.size(0) != 2:
                ei = ei.t()
            ei = ei.contiguous()

            
            num_nodes = int(len(feats))

            
            x = torch.tensor(feats, dtype=torch.float)
            
            if x.size(0) < num_nodes:
                pad = torch.zeros((num_nodes - x.size(0), x.size(1)), dtype=torch.float)
                x = torch.cat([x, pad], dim=0)
            elif x.size(0) > num_nodes:
                x = x[:num_nodes]

            
            y = torch.tensor(roles, dtype=torch.long)
            if y.numel() < num_nodes:
                pad = torch.zeros(num_nodes - y.numel(), dtype=torch.long)
                y = torch.cat([y, pad], dim=0)
            elif y.numel() > num_nodes:
                y = y[:num_nodes]

            
            num_edges = ei.size(1)
            edge_attr = torch.ones(num_edges, dtype=torch.float)

            data_list.append(Data(x=x, edge_index=ei, edge_attr=edge_attr, y=y))

        return self.collate(data_list)

    def get(self, idx):
        
        return super(PaperDataset, self).get(idx)

    def len(self):
        
        return self.data.num_graphs if hasattr(self.data, 'num_graphs') else len(self.slices['y']) - 1

    @property
    def num_node_features(self):
        return self._num_features
