import os
import torch

from src.dataloaders.imbalance.minst_biased import download_mnist_biased, MNISTBias
from src.dataloaders.imbalance.mnist_long_tailed import download_mnist_longtailed, IMAGE_PATH, MNISTLongTailDataset



def color_mnist_longtail():
    filename = '/color_mnist_longtailed'
    if not os.path.exists(IMAGE_PATH + filename):
        download_mnist_longtailed(option='color')

    trainset = MNISTLongTailDataset(
        dataset=torch.load(IMAGE_PATH + filename + '/train.pt'),
    )
    testset = MNISTLongTailDataset(
        dataset=torch.load(IMAGE_PATH + filename + '/test.pt'),
    )

    return trainset, testset


def color_rot_mnist_longtail():
    filename = '/color_rot_mnist_longtailed'
    if not os.path.exists(IMAGE_PATH + filename):
        download_mnist_longtailed(option='both')

    trainset = MNISTLongTailDataset(
        dataset=torch.load(IMAGE_PATH + filename + '/train.pt'),
    )
    testset = MNISTLongTailDataset(
        dataset=torch.load(IMAGE_PATH + filename + '/test.pt'),
    )
    return trainset, testset


def color_rot_mnist_bias(color_std=5.0, rot_std=5.0):
    filename = '/color_rot_mnist_biased_{}_{}'.format(color_std, rot_std)
    if not os.path.exists(IMAGE_PATH + filename):
        download_mnist_biased(option='both', color_std=color_std, rot_std=rot_std)

    trainset = MNISTBias(
        dataset=torch.load(IMAGE_PATH + filename + '/train.pt'),
    )
    testset = MNISTBias(
        dataset=torch.load(IMAGE_PATH + filename + '/test.pt'),
    )
    return trainset, testset