from .utility import *
from torch.utils.data import TensorDataset, DataLoader
from torch_geometric.loader import DataLoader
import pickle

def load_dataset(args):
    final_path = './data/' + str(args.data) + '_' + str(args.feat) + '_' + str(args.lab) + '.pickle'

    with open(final_path, 'rb') as dataset:
        dataset = pickle.load(dataset)
    
    A = dataset['A']
    X = dataset['X']
    Y = dataset['Y']
    L = dataset['L']
    EIGVAL = dataset['EIGVAL']
    EIGVEC = dataset['EIGVEC']
    
    return A, X, Y, L, EIGVAL, EIGVEC

def build_data_loader(args, idx_pair, A, X, Y, L, EIGVAL, EIGVEC):
    idx_train, idx_test = idx_pair

    data_train = TensorDataset(A[idx_train], X[idx_train], Y[idx_train], EIGVEC[idx_train], EIGVAL[idx_train])
    data_test = TensorDataset(A[idx_test], X[idx_test], Y[idx_test], EIGVEC[idx_test], EIGVAL[idx_test])

    
    data_loader_train = DataLoader(data_train, batch_size=idx_train.shape[0], shuffle=True) # Full-batch
    data_loader_test = DataLoader(data_test, batch_size=idx_test.shape[0], shuffle=False) # Full-batch
    
    return data_loader_train, data_loader_test