from .nn import build_npe_model, get_embedding_net, MeanEmbedding, IdentityEmbedding
from .npe import (
    NPE_TrainConfig,
    NPETrainer,
    StandardNPETrainer,
    train_NPE_estimator,
    load_npe_model,
    sample_npe_posterior,
)
from .npe_noisy import (
    NoisyNPE_TrainConfig,
    NoisyNPETrainer,
    train_NPE_estimator_noisy,
)
from .npe_rs import (
    NPE_MMD_TrainConfig,
    MMDNPETrainer,
    train_NPE_estimator_mmd,
)

def get_method(cfg):
    method_type = getattr(cfg, "method_type", "npe")
    if method_type == "npe":
        return StandardNPETrainer
    elif method_type == "npe_noisy":
        return NoisyNPETrainer
    elif method_type == "npe_rs" or method_type == "npe_mmd":
        return MMDNPETrainer
    else:
        raise ValueError(f"Unknown method type: {method_type}")