import os
import numpy as np

import torch

from torch.utils.data import Dataset
from torch.utils.data import DataLoader


def load_data(data_dir, batch_size):
    data = np.loadtxt(os.path.join(data_dir, 'data.txt'))
    data = torch.FloatTensor(data)
    if len(data.shape) == 1:
        data = data.unsqueeze(-1)

    class _Dataset(Dataset):
        
        def __init__(self, vecs):
            super().__init__()
            
            self.vecs = vecs
            
        def __getitem__(self, index):
            return self.vecs[index]
        
        def __len__(self):
            return len(self.vecs)
    
    ds = _Dataset(data)
    return DataLoader(ds, batch_size, True)
