from lfrl.models.dynamics_models.probabilistic_ensemble import ProbabilisticEnsemble
from lfrl.models.dynamics_models.pt_model import PtModel
from lfrl.policies.mpc.mpc import MPCPolicy
from lfrl.trainers.mbrl.mbrl import MBRLTrainer
from lfrl.trainers.mpc.mpc_trainer import MPPITrainer


def get_config(
        variant,
        expl_env,
        eval_env,
        obs_dim,
        action_dim,
        replay_buffer,
):

    M = variant['mbrl_kwargs']['layer_size']
    dynamics_model = PtModel(
        ensemble_size=variant['mbrl_kwargs']['ensemble_size'],
        obs_dim=obs_dim,
        action_dim=action_dim,
    )
    model_trainer = MBRLTrainer(
        ensemble=dynamics_model,
        **variant['mbrl_kwargs'],
    )

    policy = MPCPolicy(
        env=expl_env,
        dynamics_model=dynamics_model,
        plan_dim=action_dim,
        **variant['mpc_kwargs'],
    )
    trainer = MPPITrainer(
        policy=policy,
    )

    config = dict()
    config.update(dict(
        trainer=trainer,
        model_trainer=model_trainer,
        exploration_policy=policy,
        evaluation_policy=policy,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        replay_buffer=replay_buffer,
        dynamics_model=dynamics_model,
    ))
    config['algorithm_kwargs'] = variant.get('algorithm_kwargs', dict())

    return config
