from torch.utils import data
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms

from .FF import FFDataset
from .utils import collate_fn
from .CustomFF import AugmentedDataset
from .CustomFF import CLASSES

__all__ = ['get_loaders', 'CLASSES']


def get_loaders(cnf):
    if cnf.dataset.dataset_name == 'FF++':
        return _get_FF_loaders(cnf)
    elif cnf.dataset.dataset_name == 'augFF':
        return _get_custom_loaders(cnf)


def _get_custom_loaders(cnf):
    test_transform = None
    test_dataset = AugmentedDataset(
        img_dir=cnf.dataset.root,
        transform=test_transform,
        target_transform=None
    )
    if cnf.debug:
        test_dataset.dataset = test_dataset.dataset[:10]
    if cnf.DDP:
        train_sampler = DistributedSampler(test_dataset)
    else:
        train_sampler = None
    train_loader = data.DataLoader(
        test_dataset,
        batch_size=cnf.training.batch_size,
        collate_fn=collate_fn,
        num_workers=2,
        sampler=train_sampler
    )
    return train_loader, train_loader


def _get_FF_loaders(cnf):
    train_transform = None
    train_dataset = FFDataset(
        root=cnf.dataset.root,
        split='train',
        transform=train_transform,
        detailed_lbl=cnf.dataset.binary
    )
    test_transform = None
    test_dataset = FFDataset(
        root=cnf.dataset.root,
        split='test',
        transform=test_transform,
        detailed_lbl=cnf.dataset.binary
    )

    if cnf.debug:
        train_dataset.dataset = train_dataset.dataset[:10]
        test_dataset.dataset = test_dataset.dataset[:10]
    if cnf.DDP:
        train_sampler = DistributedSampler(train_dataset)
        test_sampler = DistributedSampler(test_dataset)
    else:
        train_sampler = None
        test_sampler = None
    train_loader = data.DataLoader(
        train_dataset,
        batch_size=cnf.training.batch_size,
        collate_fn=collate_fn,
        num_workers=2,
        sampler=train_sampler
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=cnf.training.batch_size,
        collate_fn=collate_fn,
        num_workers=2,
        sampler=test_sampler
    )
    return train_loader, test_loader
