"""Implement regression data."""

import torch
import torchvision.transforms as transforms

from .consts import *
from .synthetic_regression_data import Cosine, Integration, Differentiation, Blur, Downsampling
from .synthetic_regression_data import BlurSinc, DownsamplingSinc, BlurPiecewise, DownsamplingPiecewise
from .denoising_datasets import BSDS


def construct_dataloaders(dataset, batch_size, noise_level=None, data_path='~/data'):
    """Regression Stuff."""
    wave_length = 50
    epoch_size = (2500 // batch_size) * batch_size
    if dataset == 'Cosine':
        trainset = validset = metaset = Cosine(epoch_size, noise_level=noise_level, wave_length=wave_length)
    elif dataset == 'Integration':
        trainset = validset = metaset = Integration(epoch_size, noise_level=noise_level, wave_length=wave_length)
    elif dataset == 'Differentiation':
        trainset = validset = metaset = Differentiation(epoch_size, noise_level=noise_level, wave_length=wave_length)
    elif dataset == 'Blur':
        trainset = validset = metaset = Blur(epoch_size, noise_level=noise_level, wave_length=wave_length, sigma=20)
    elif dataset == 'Downsampling':
        trainset = validset = metaset = Downsampling(epoch_size, noise_level=noise_level, wave_length=wave_length,
                                                     sigma=20, downsampling=4)
    elif dataset == 'BlurSinc':
        trainset = validset = metaset = BlurSinc(epoch_size, noise_level=noise_level, wave_length=wave_length, sigma=20)
    elif dataset == 'DownsamplingSinc':
        trainset = validset = metaset = DownsamplingSinc(epoch_size, noise_level=noise_level, wave_length=wave_length,
                                                         sigma=20, downsampling=4)
    elif dataset == 'BlurPiecewise':
        trainset = validset = metaset = BlurPiecewise(
            epoch_size, noise_level=noise_level, wave_length=wave_length, sigma=20)
    elif dataset == 'DownsamplingPiecewise':
        trainset = validset = metaset = DownsamplingPiecewise(epoch_size, noise_level=noise_level, wave_length=wave_length,
                                                              sigma=20, downsampling=4)
    elif dataset == 'BSDS-Blur':
        settings = dict(download=True, grayscale=False, noise_level=noise_level, sigma=4, downsampling=1)
        trainset = BSDS(data_path, split='train', transform=transforms.ToTensor(), patch_size=40, **settings)
        metaset = BSDS(data_path, split='train', transform=transforms.ToTensor(), patch_size=40, **settings)
        validset = BSDS(data_path, split='test', transform=transforms.ToTensor(), patch_size=40, **settings)
    elif dataset == 'BSDS-Denoising':  # The "classic" denoising version
        settings = dict(download=True, grayscale=True, noise_level=noise_level, sigma=0, downsampling=1)
        trainset = BSDS(data_path, split='train', transform=transforms.ToTensor(), patch_size=40, **settings)
        metaset = BSDS(data_path, split='train', transform=transforms.ToTensor(), patch_size=40, **settings)
        validset = BSDS(data_path, split='test68', transform=transforms.ToTensor(), patch_size=312, **settings)
    elif dataset == 'BSDS-Downsampling':
        settings = dict(download=True, grayscale=False, noise_level=noise_level, sigma=4, downsampling=4)
        trainset = BSDS(data_path, split='train', transform=transforms.ToTensor(), patch_size=40, **settings)
        metaset = BSDS(data_path, split='train', transform=transforms.ToTensor(), patch_size=40, **settings)
        validset = BSDS(data_path, split='test', transform=transforms.ToTensor(), patch_size=256, **settings)
    else:
        raise ValueError(f'Invalid dataset {dataset}.')

    if MULTITHREAD_DATAPROCESSING:
        num_workers = min(torch.get_num_threads(),
                          MULTITHREAD_DATAPROCESSING) if torch.get_num_threads() > 1 else 0
    else:
        num_workers = 0

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=min(batch_size, len(trainset)),
                                              shuffle=True, drop_last=True,
                                              num_workers=num_workers, pin_memory=PIN_MEMORY)
    metaloader = torch.utils.data.DataLoader(metaset, batch_size=min(batch_size, len(trainset)),
                                             shuffle=True, drop_last=True,
                                             num_workers=num_workers, pin_memory=PIN_MEMORY)
    validloader = torch.utils.data.DataLoader(validset, batch_size=min(batch_size // 4, len(trainset)),
                                              shuffle=False, drop_last=False,
                                              num_workers=num_workers, pin_memory=PIN_MEMORY)

    return trainloader, metaloader, validloader
