import os
import os.path as osp
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, download_url, Data

class HeteroDataset(InMemoryDataset):
    def __init__(self, root, name, transform=None, pre_transform=None):
        self.root = root
        self.name = name.lower()
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self):
        return f'{self.name}.npz'

    @property
    def processed_file_names(self):
        return 'data.pt'

    def process(self):
        data = np.load(self.raw_paths[0])
        x = torch.tensor(data['node_features'])
        y = torch.tensor(data['node_labels'])
        edge_index = torch.tensor(data['edges']).T
        train_mask = torch.tensor(data['train_masks']).to(torch.bool).T
        val_mask = torch.tensor(data['val_masks']).to(torch.bool).T
        test_mask = torch.tensor(data['test_masks']).to(torch.bool).T
        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
                        val_mask=val_mask, test_mask=test_mask)
        if self.pre_transform is not None:
            data = self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])