import os.path as osp
import pickle as pkl

import torch
import random
import numpy as np
from torch_geometric.data import InMemoryDataset, Data

class SPMotif(InMemoryDataset):
    splits = ['train', 'val', 'test']
    def __init__(self, root, mode='train', transform=None, pre_transform=None, pre_filter=None):
        assert mode in self.splits
        self.mode = mode
        super(SPMotif, self).__init__(root, transform, pre_transform, pre_filter)
        idx = self.processed_file_names.index('SPNode_{}.pt'.format(mode))
        self.data, self.slices = torch.load(self.processed_paths[idx])
    @property
    def raw_file_names(self):
        return ['train.npy', 'val.npy', 'test.npy']
    @property
    def processed_file_names(self):
        return ['SPNode_train.pt', 'SPNode_val.pt', 'SPNode_test.pt']
    def download(self):
        if not osp.exists(osp.join(self.raw_dir, 'raw', 'SPMotif_train.npy')):
            print("raw data of `SPMotif` doesn't exist, please redownload from our github.")
            raise FileNotFoundError
    def process(self):
        idx = self.raw_file_names.index('{}.npy'.format(self.mode))
        
        data = np.load(osp.join(self.raw_dir, self.raw_file_names[idx]), allow_pickle=True)
        edge_index_list = data.item().get('edge_index')
        label_list = data.item().get('label')
        ground_truth_list = data.item().get('ground_truth')
        role_id_list = data.item().get('role_id')
        pos = data.item().get('pos')
        data_list = []
        for idx, (edge_index, y, ground_truth, z, p) in enumerate(zip(edge_index_list, label_list, ground_truth_list, role_id_list, pos)):
            
            
            
            
            node_idx = torch.unique(edge_index.flatten())
            num_nodes = node_idx.size(0)
            max_node_idx = torch.max(node_idx).item()
            if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
               num_nodes = max_node_idx + 1
               node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)
            
            assert node_idx.max() == node_idx.size(0) - 1
            x = torch.zeros(node_idx.size(0),5)
            
            index = [i for i in range(node_idx.size(0))]
            
            if len(index)!= len(z):
            
               num_elements_to_remove = len(z) - len(index)
            
               z = z[num_elements_to_remove:]
            x[index, z] = 1
            x = torch.rand((node_idx.size(0),5))
            edge_attr = torch.ones(edge_index.size(1), 1)
            y = torch.tensor(y, dtype=torch.long).unsqueeze(dim=0)
            data = Data(x=x, y=y, z=z,
                        edge_index=edge_index,
                        edge_attr=edge_attr,
                        pos=p,
                        edge_gt_att=torch.LongTensor(ground_truth),
                        name=f'SPNode-{self.mode}-{idx}', idx=idx)
            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)
            data_list.append(data)
        idx = self.processed_file_names.index('SPNode_{}.pt'.format(self.mode))
        print(self.processed_paths[idx])
        print(len(data_list))
        torch.save(self.collate(data_list), self.processed_paths[idx])