import torch
from torch.utils.data import DataLoader

from loaders.datasets.param_datasets import PairedParamDataset
from loaders.datasets.param_datasets import UnpairedParamDataset


def load_paired_params(data_dir, param_names=None, src_tgt_dict=None, batch_size=None):
    
    data_transform = None

    data_set = PairedParamDataset(param_root_dir=data_dir,
                                  param_names=param_names,
                                  src_tgt_dict=src_tgt_dict,
                                  transform=data_transform)
    
    print('-' * 50)
    print('Paired Param Path:', data_dir)
    print('Param Pair Number', len(data_set))
    print('-' * 50)

    data_loader = DataLoader(dataset=data_set,
                             batch_size=batch_size,
                             num_workers=4,
                             shuffle=True)
    return data_loader


def load_unpaired_params(data_dir, param_names=None, batch_size=None, types=None):
    
    data_transform = None

    data_set = UnpairedParamDataset(param_root_dir=data_dir,
                                    param_names=param_names,
                                    transform=data_transform,
                                    types=types)

    print('-' * 50)
    print('Unpaired Param Path:', data_dir)
    print('Param Types:', types)
    print('Param Names:', param_names)
    print('Param Number', len(data_set))
    print('-' * 50)
    
    data_loader = DataLoader(dataset=data_set,
                             batch_size=batch_size,
                             num_workers=4,
                             shuffle=True)
    return data_loader
