"""
    To train your models
"""
# torch...
import torch

# custom libs
from utils.datasets import load_denoised_dataset


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_use_cuda = torch.cuda.is_available()


# ------------------------------------------------------------------------------
#   To load the denoised data; suppose that the data is under
#   : datasets/denoised/cifar10/clean/denoised_w_sigma_1.0.pt
# ------------------------------------------------------------------------------
kwargs = {
    'num_workers': 8,
    'pin_memory' : True
} if _use_cuda else {}
denoised_data = 'datasets/denoised/cifar10/denoised_w_sigma_1.0.pt'

train_loader, test_loader = load_denoised_dataset( \
    'cifar10',      # original dataset
    256,            # batch size
    True,           # augmentations, random crop and horizontal flipping
    True,           # normalization
    denoised_data,  # denoised datafile
    kwargs)

# It returns pytorch dataloaders
#   loader and augmentations are under utils/dataset.py