from omegaconf import DictConfig
from fair_dp_sgd.models.logistic_regression_flax import (
    create_train_state as create_train_state_lr,
)
from fair_dp_sgd.models.resnet import create_train_state as create_train_state_resnet
from functools import partial
import jax


def get_model(cfg: DictConfig, rng: jax.random.PRNGKey):
    if cfg.model.name == "logistic_regression":
        init_fn = partial(create_train_state_lr, cfg, rng)
    elif cfg.model.name in ["resnet16", "resnet50"]:
        init_fn = partial(create_train_state_resnet, cfg, rng)
    else:
        raise ValueError(f"Unknown model: {cfg.model.name}")
    return init_fn()
