from experiment_utils.launch_experiment import launch_experiment

from algorithms.algorithms.mbrl import get_algorithm
from algorithms.configs.mpc.mpc import get_config


ENV_NAME = 'ContinualHopper'
experiment_kwargs = dict(
    exp_name='pets-chopper-4',
    num_seeds=3,
    instance_type='c4.4xlarge',
    use_gpu=True,
    include_date=True,
)


if __name__ == "__main__":
    variant = dict(
        algorithm='PETS',
        env_name=ENV_NAME,
        env_kwargs=dict(),
        collector_type='step',
        do_online_training=True,
        do_offline_training=False,
        replay_buffer_size=int(1e6),
        teacher_data_files=['chopper_sac_full'],
        trainer_kwargs=dict(),
        mpc_kwargs=dict(
            discount=.99,
            horizon=160,
            repeat_length=1,
            plan_every=1,
            temperature=0.01,
            noise_std=1,
            num_rollouts=400,
            num_particles=5,
            planning_iters=10,
            polyak=0.2,
            sampling_mode='ts',
            sampling_kwargs=dict(
                reward_penalty=-20,
                disagreement_threshold=0.1,
            ),
            risk_mode='cvar',  # careful! did you mean to edit sampling_mode?
            risk_kwargs=dict(
                alpha=1,
            ),
            filter_coefs=(0.05, 0.8, 0),
        ),
        mbrl_kwargs=dict(
            ensemble_size=4,
            num_elites=4,
            layer_size=256,
            num_layers=4,
            learning_rate=1e-3,
            batch_size=256,
            noise_clip=1,
        ),
        algorithm_kwargs=dict(
            num_epochs=50,
            num_eval_steps_per_epoch=0,
            num_trains_per_train_loop=0,
            num_expl_steps_per_train_loop=1000,
            min_num_steps_before_training=0,
            num_model_trains_per_train_loop=1,
            max_path_length=1000,
            reset_free=False,  # fix this, it has to be set to False currently
            batch_size=256,
            model_batch_size=256,
            save_snapshot_freq=250,
        ),
    )

    sweep_values = {
    }

    launch_experiment(
        get_config=get_config,
        get_algorithm=get_algorithm,
        variant=variant,
        sweep_values=sweep_values,
        **experiment_kwargs
    )
