import datetime
import dgl
import errno
import numpy as np
import os
import pickle
import random
import torch
from scipy.io import loadmat
from dgl.data.utils import download, get_download_dir, _get_dgl_url
from pprint import pprint
from scipy import sparse
from scipy import io as sio
from torch_geometric.data import InMemoryDataset, Data, download_url
def get_binary_mask(total_size, indices):
    mask = torch.zeros(total_size)
    mask[indices] = 1
    return mask.byte()

def load_mat(dataset):
    if dataset == 'ACM':
        data = loadmat("data/ACM3025.mat")
        #ptp,plp,pap,feature,label,train_idx,val_idx,test_idx
        labels, features = torch.from_numpy(data['label']).long(), \
                       torch.from_numpy(data['feature']).float()
        num_classes = labels.shape[1]
        num_nodes = data['label'].shape[0]
        labels = labels.nonzero()[:, 1]
        data['PAP'] = sparse.csr_matrix(data['PAP'] - np.eye(num_nodes)+data['PLP'] - np.eye(num_nodes) )
        author_g = dgl.from_scipy(data['PAP'])
        gs = author_g
    
    elif dataset == 'DBLP':
        data = loadmat("data/DBLP4057_GAT_with_idx.mat")
        #net_APTPA  label      train_idx  features   net_APCPA net_APA test_idx val_idx
        labels, features = torch.from_numpy(data['label']).long(), \
                       torch.from_numpy(data['features']).float()
        num_classes = labels.shape[1]
        num_nodes = data['label'].shape[0]
        labels = labels.nonzero()[:, 1]
        data['net_APTPA'] = sparse.csr_matrix(data['net_APTPA'] - np.eye(num_nodes) + data['net_APCPA'] - np.eye(num_nodes) )
        author_g = dgl.from_scipy(data['net_APTPA'])
        gs = author_g
    elif dataset == 'IMDB':
        data = loadmat("data/imdb5k.mat")
        #MAM MDM MYM

        labels, features = torch.from_numpy(data['label']).long(), \
                       torch.from_numpy(data['feature']).float()
        # print('labels--',labels.shape)
        num_classes = labels.shape[1]
        num_nodes = data['label'].shape[0]
        
        labels = labels.nonzero()[:, 1]
        
        data['MAM'] = sparse.csr_matrix(data['MAM'] - np.eye(num_nodes)+data['MDM'] - np.eye(num_nodes) + data['MYM'] - np.eye(num_nodes))
        author_g = dgl.from_scipy(data['MAM'])
        gs = author_g
    # print(data)
    # print('lable---\n', labels)
    train_idx = torch.from_numpy(data['train_idx']).long().squeeze(0)
    val_idx = torch.from_numpy(data['val_idx']).long().squeeze(0)
    test_idx = torch.from_numpy(data['test_idx']).long().squeeze(0)
    num_nodes = author_g.number_of_nodes()
    train_mask = get_binary_mask(num_nodes, train_idx)
    val_mask = get_binary_mask(num_nodes, val_idx)
    test_mask = get_binary_mask(num_nodes, test_idx)
    
    return gs, features, labels, num_classes, train_idx, val_idx, test_idx, \
           train_mask, val_mask, test_mask

