from .cifar10 import CIFAR10
from .cifar10_noisy import CIFAR10Noisy
from .mnist import MNIST
from .mnist_noisy import MNISTNoisy
from .heloc import HELOC
from .heloc_noisy import HELOCNoisy
from .qnli import QNLI
from .qnli_noisy import QNLINoisy
from .clustering import (
    equal_clustering, kmeans_clustering, grad_kmeans_clustering, repr_kmeans_clustering
)

"""Wrapper functions for dataset, model, and train function loading."""
def LoadDataset(dataset_name: str, train: bool, **kwargs):
    """
    Wrapper function for loading datasets.
    """
    if dataset_name == 'cifar10':
        return CIFAR10(train, **kwargs)
    elif dataset_name == 'cifar10_noisy':
        return CIFAR10Noisy(train, **kwargs)
    elif dataset_name == 'mnist':
        return MNIST(train, **kwargs)
    elif dataset_name == 'mnist_noisy':
        return MNISTNoisy(train, **kwargs)
    elif dataset_name in ['heloc', 'heloc_wd']:
        return HELOC(train, **kwargs)
    elif dataset_name == 'heloc_noisy':
        return HELOCNoisy(train, **kwargs)
    elif dataset_name == 'qnli':
        return QNLI(train, **kwargs)
    elif dataset_name == 'qnli_noisy':
        return QNLINoisy(train, **kwargs)
    else:
        raise ValueError(f"Dataset {dataset_name} not recognized.")
    
def LoadTrainModel(dataset_name: str, model_name: str):
    """
    Wrapper function for loading model trainers.
    """
    # CIFAR-10 models
    if dataset_name == 'cifar10':
        from data.cifar10.train import train_fns, train_params
        return train_fns[model_name], train_params[model_name]
    elif dataset_name == 'cifar10_noisy':
        from data.cifar10_noisy.train import train_fns, train_params
        return train_fns[model_name], train_params[model_name]
    
    # MNIST models
    elif dataset_name == 'mnist':
        from data.mnist.train import train_fns, train_params
        return train_fns[model_name], train_params[model_name]
    elif dataset_name == 'mnist_noisy':
        from data.mnist_noisy.train import train_fns, train_params
        return train_fns[model_name], train_params[model_name]

    # HELOC models
    elif dataset_name == 'heloc':
        from data.heloc.train import train_fns, train_params
        return train_fns[model_name], train_params[model_name]
    elif dataset_name == 'heloc_noisy':
        from data.heloc_noisy.train import train_fns, train_params
        return train_fns[model_name], train_params[model_name]

    # QNLI models
    elif dataset_name == 'qnli':
        from data.qnli.train import train_fns, train_params
        return train_fns[model_name], train_params[model_name]
    elif dataset_name == 'qnli_noisy':
        from data.qnli_noisy.train import train_fns, train_params
        return train_fns[model_name], train_params[model_name]

    # Other datasets
    else:
        raise ValueError(f"Dataset {dataset_name} not recognized.")