from lfrl.policies.mpc.mpc_policy import MPCPolicyController


def make_get_config(base_get_config):

    def get_config(
            variant,
            expl_env,
            eval_env,
            obs_dim,
            action_dim,
            replay_buffer,
    ):
        config = base_get_config(
            variant,
            expl_env,
            eval_env,
            obs_dim,
            action_dim,
            replay_buffer,
        )

        policy = MPCPolicyController(
            env=expl_env,
            dynamics_model=config['dynamics_model'],
            policy=config['control_policy'],
            latent_dim=config['latent_dim'],
            **variant['mppi_kwargs'],
        )

        config['exploration_policy'] = policy

        if variant['use_as_eval_policy'] == 'mppi':
            config['evaluation_policy'] = policy

        return config

    return get_config
