from seq_dataset import SeqDataset
from src.retraining_task_processing.circles import generate_circles_data
from src.retraining_task_processing.gauss import generate_gauss_data


def load_synthetic_dataset(cfg, seed, dataset_name):
    offline_split = cfg['offline_t']
    size_datasets = cfg['N']
    T = cfg['T']
    test_split = cfg['test_frac']

    if dataset_name == 'gauss':
        X_full, y_full = generate_gauss_data(seed, size_datasets, T)
    if dataset_name == 'circles':
        X_full, y_full = generate_circles_data(seed, size_datasets, T)
    
    seq_dataset = SeqDataset(
        X_full, y_full, offline_split, test_split, val_split=0.1, seed=seed)
    return seq_dataset
