import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

import argparse
import time
import os
import random

from .tool import load_arguments
from .dataset import TableDataset, LocalTableDataset
import sys
sys.path.append('./')
sys.path.append('../')


def pai_load_data(args, local_args):
    '''
    data loader
    '''
    return train_loader, test_loader, test_loader

def local_load_data(args):
    '''
    data loader
    '''
    return train_iterator, test_iterator, test_iterator

def dnn_data_load_process(args, res, batch_size):
    init_fn = lambda worker_id : random.seed(1)
    dataset = LocalTableDataset(args, res, if_predict = False)
    train_iterator = DataLoader(dataset, batch_size=batch_size, shuffle=False, worker_init_fn=init_fn)

    dataset = LocalTableDataset(args, res, if_predict = True)
    test_iterator = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return train_iterator, test_iterator

