from datasets import get_mnist, get_kmnist, get_kuzushiji, get_Fashion_mnist, get_kuzushiji_49, get_cifar10, get_SVHN,  make_training_dataset
from RC_loss import RCCC_loss
from model import MLP, MLP_dropout, resnet20
from gene_matrix import generate_dia_dominate_matrix, gene_noise_diff, generate_random_matrix
from gene_matrix import generate_20_random_matrix, generate_off_diagonal_same_matrix

def load_dataset(dataset):
    if dataset == 'mnist':
        (x_train, y_train), (x_test, y_test) = get_mnist()
        return (x_train, y_train), (x_test, y_test)
    if dataset == 'kmnist':
        (x_train, y_train), (x_test, y_test) = get_kmnist()
        return (x_train, y_train), (x_test, y_test)
#    elif dataset == 'kuzushiji':
#        (x_train, y_train), (x_test, y_test) = get_kuzushiji()
#        return (x_train, y_train), (x_test, y_test)
    elif dataset == 'fashion':
        (x_train, y_train), (x_test, y_test) = get_Fashion_mnist()
        return (x_train, y_train), (x_test, y_test)
#    elif dataset == 'kuzushiji_49':
#        (x_train, y_train), (x_test, y_test) = get_kuzushiji_49()
        return (x_train, y_train), (x_test, y_test)
    elif dataset == 'cifar10':
        (x_train, y_train), (x_test, y_test) = get_cifar10()
        return (x_train, y_train), (x_test, y_test)
    elif dataset == 'svhn':
        (x_train, y_train), (x_test, y_test) = get_SVHN()
        return (x_train, y_train), (x_test, y_test)


def get_matrix(matrix, set_nums):
    if matrix == 'dia_dominate_matrix':
        return generate_dia_dominate_matrix()
    elif matrix == 'random':
        return generate_random_matrix()
    elif matrix == 'random_20':
        return generate_20_random_matrix(set_nums)
    elif matrix == 'off_same':
        return generate_off_diagonal_same_matrix(set_nums)

def get_model(model):
    if model == 'mlp':
        return MLP(28 * 28, 300, 300, 300, 300, 10)
    elif model == 'resnet':
        return resnet20()


def get_loss(loss, matrix,device):
    if loss == 'RC':
        return RCCC_loss(matrix, device)     
    elif loss == 'CC':
        return RCCC_loss(matrix, device)                     