import numpy as np
import torch 
from .adult import ContinualAdultData, AdultDataset
from .bank import ContinualBankData, BankDataset
from .bank_addition import ContinualBankAddData, BankAddDataset
from .mushroom import ContinualMushroomData, MushroomDataset
from .covertype import ContinualCoverData, CoverDataset
from .retail import ContinualRetailData, RetailDataset


def dataloader(dataset_name, model_config, env_config):
    train_dataset = None
    val_dataset = None
    test_dataset = None
    
    if dataset_name == 'adult':
        db = ContinualAdultData('./data/tabular_data/adult/',
                             model_config=model_config,
                             env_config=env_config)
        
        return db, AdultDataset
    elif dataset_name == 'bank':
        db = ContinualBankData('./data/tabular_data/bank/',
                             model_config=model_config,
                             env_config=env_config)
        
        return db, BankDataset
    elif dataset_name == 'bank_add':
        db = ContinualBankAddData('./data/tabular_data/bank/',
                             model_config=model_config,
                             env_config=env_config)
        
        return db, BankAddDataset
    elif dataset_name == 'mushroom':
        db = ContinualMushroomData('./data/tabular_data/mushroom/',
                             model_config=model_config,
                             env_config=env_config)
        
        return db, MushroomDataset
    elif dataset_name == 'covertype':
        db = ContinualCoverData('./data/tabular_data/covertype/',
                             model_config=model_config,
                             env_config=env_config)
        
        return db, CoverDataset
    elif dataset_name == 'retail':
        db = ContinualRetailData('./data/time_tabular_data/retail/',
                             model_config=model_config,
                             env_config=env_config)
        
        return db, RetailDataset
    else:
        raise NotImplementedError()

    return train_dataset, val_dataset, test_dataset
