from ltsgns_mp.algorithms.posterior_learners.abstract_posterior_learner import AbstractPosteriorLearner
from ltsgns_mp.envs.env import Env
from ltsgns_mp.util.own_types import ConfigDict


def get_posterior_learner(config: ConfigDict, env: Env,
                          device: str) -> AbstractPosteriorLearner:
    name = config.name
    if name == "constant_learner":
        from ltsgns_mp.algorithms.posterior_learners.constant_learner import ConstantLearner
        n_all_tasks = len(env.train_iterator)
        return ConstantLearner(config, n_all_tasks, device=device)
    elif name == "multi_daft_learner":
        from ltsgns_mp.algorithms.posterior_learners.multi_daft_learner import MultiDaftLearner

        n_all_train_tasks = env.train_iterator.num_tasks
        n_all_eval_tasks = {context_size: iterator.num_tasks for context_size, iterator in env.eval_iterators.items()}
        return MultiDaftLearner(config,
                                n_all_train_tasks=n_all_train_tasks,
                                n_all_eval_tasks=n_all_eval_tasks, device=device)
    elif name == "task_prop_learner":
        n_all_tasks = len(env.train_iterator)
        from ltsgns_mp.algorithms.posterior_learners.task_prop_learner import TaskPropLearner
        return TaskPropLearner(config, n_all_tasks=n_all_tasks, device=device)
    else:
        raise ValueError(f"Unknown posterior learner name {name}")
