from lfrl.models.dynamics_models.probabilistic_ensemble import ProbabilisticEnsemble
from lfrl.trainers.mbrl.mbrl import MBRLTrainer
from lfrl.policies.mpc.mpc import MPCPolicy
from lfrl.trainers.mpc.mpc_trainer import MPPITrainer

from experiments.mbrl_experiment import get_algorithm


get_algorithm = get_algorithm


def get_config(
        variant,
        expl_env,
        eval_env,
        obs_dim,
        action_dim,
        replay_buffer,
):

    """
    Model-based reinforcement learning (MBRL) dynamics models
    """

    M = variant['mbrl_kwargs']['layer_size']

    dynamics_model = ProbabilisticEnsemble(
        ensemble_size=variant['mbrl_kwargs']['ensemble_size'],
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M, M, M],
    )
    model_trainer = MBRLTrainer(
        ensemble=dynamics_model,
        **variant['mbrl_kwargs'],
    )

    """
    Setup of MPPI policies
    """

    policy = MPCPolicy(
        dynamics_model=dynamics_model,
        plan_dim=action_dim,
        **variant['mppi_kwargs'],
    )
    eval_policy = MPCPolicy(
        dynamics_model=dynamics_model,
        plan_dim=action_dim,
        **variant['mppi_kwargs'],
    )
    trainer = MPPITrainer(
        policy=policy,
    )

    """
    Create config dict
    """

    config = dict()
    config.update(dict(
        trainer=trainer,
        model_trainer=model_trainer,
        exploration_policy=policy,
        evaluation_policy=eval_policy,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        replay_buffer=replay_buffer,
    ))
    config['algorithm_kwargs'] = variant['algorithm_kwargs']

    return config
