from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import GeneratorName
from tabicl.data.synthetic_generator_forest import synthetic_dataset_generator_forest
from tabicl.data.synthetic_generator_mix import synthetic_dataset_generator_mix
from tabicl.data.synthetic_generator_neighbor import synthetic_dataset_generator_neighbor
from tabicl.data.synthetic_generator_perlin import synthetic_dataset_generator_perlin
from tabicl.data.synthetic_generator_tabpfn import synthetic_dataset_generator_tabpfn


def select_synthetic_dataset_generator(cfg: ConfigPretrain):

    n_samples = cfg.data.n_samples_query + cfg.data.max_samples_support

    match cfg.data.generator:
        case GeneratorName.TABPFN:
            return synthetic_dataset_generator_tabpfn(
                n_samples=n_samples,
                min_features=cfg.data.min_features,
                max_features=cfg.data.max_features,
                max_classes=cfg.data.max_classes,
                task=cfg.data.task,
            )
        case GeneratorName.FOREST:
            return synthetic_dataset_generator_forest(
                n_samples=n_samples,
                min_features=cfg.data.min_features,
                max_features=cfg.data.max_features,
                max_classes=cfg.data.max_classes,
                task=cfg.data.task,
                base_size=cfg.data.generator_hyperparams['base_size'],
                min_depth=cfg.data.generator_hyperparams['min_depth'],
                max_depth=cfg.data.generator_hyperparams['max_depth'],
                categorical_x=cfg.data.generator_hyperparams['categorical_x'],
            )
        case GeneratorName.NEIGHBOR:
            return synthetic_dataset_generator_neighbor(
                n_samples=n_samples,
                min_features=cfg.data.min_features,
                max_features=cfg.data.max_features,
                max_classes=cfg.data.max_classes,
                min_neighbors=cfg.data.generator_hyperparams['min_neighbors'],
                max_neighbors=cfg.data.generator_hyperparams['max_neighbors'],
            )
        case GeneratorName.MIX:
            return synthetic_dataset_generator_mix(
                n_samples=n_samples,
                min_features=cfg.data.min_features,
                max_features=cfg.data.max_features,
                max_classes=cfg.data.max_classes,
                task=cfg.data.task,
                base_size=cfg.data.generator_hyperparams['base_size'],
                min_depth=cfg.data.generator_hyperparams['min_depth'],
                max_depth=cfg.data.generator_hyperparams['max_depth'],
                categorical_x=cfg.data.generator_hyperparams['categorical_x'],
            )
        case GeneratorName.PERLIN:
            return synthetic_dataset_generator_perlin(
                n_samples=n_samples,
                min_features=cfg.data.min_features,
                max_features=cfg.data.max_features,
                max_classes=cfg.data.max_classes,
                min_complexity=cfg.data.generator_hyperparams['min_complexity'],
                max_complexity=cfg.data.generator_hyperparams['max_complexity'],
                n_octaves=cfg.data.generator_hyperparams['n_octaves'],
            )