from experiment_utils.launch_experiment import launch_experiment

from algorithms.algorithms.mbrl import get_algorithm
from algorithms.algorithms.offline_mbrl import get_offline_algorithm
from algorithms.configs.dads.dads_model_prior_config import get_config
from algorithms.configs.mpc.make_mpc_policy import make_get_config


ENV_NAME = 'ContinualAnt'
experiment_kwargs = dict(
    exp_name='rf-cant-1',
    num_seeds=3,
    instance_type='c4.4xlarge',
    use_gpu=True,
    include_date=True,
)


if __name__ == "__main__":
    variant = dict(
        algorithm='LiSP',
        env_name=ENV_NAME,
        env_kwargs=dict(
            terminates=False,
        ),
        teacher_data_files=['sac_cant_full'],

        collector_type='rf',
        do_offline_training=True,
        do_online_training=True,
        replay_buffer_size=int(1e6),

        generated_replay_buffer_size=5000,
        use_as_eval_policy='uniform',

        policy_kwargs=dict(
            layer_size=256,
            latent_dim=4,
        ),
        discriminator_kwargs=dict(
            layer_size=512,
            num_layers=2,
            restrict_input_size=0,
        ),
        rollout_len_schedule=[-1, -1, 1, 1],  # i.e. always one
        trainer_kwargs=dict(
            num_model_samples=400,          # edit me with rollout_len_schedule
            num_prior_samples=16,
            num_discrim_updates=4,          # edit me with rollout_len_schedule
            num_policy_updates=8,
            discrim_learning_rate=3e-4,
            policy_batch_size=256,
            reward_bounds=(-30, 30),
            empowerment_horizon=1,
            reward_scale=5,
            disagreement_threshold=.1,
            relabel_rewards=True,
            train_every=1,                 # edit me with rollout_len_schedule
            prior_batch_size=256,
            prior_train_steps=4,           # edit me with rollout_len_schedule
        ),
        policy_trainer_kwargs=dict(
            discount=0.99,
            policy_lr=3e-4,
            qf_lr=3e-4,
            soft_target_tau=5e-3,
        ),
        prior_trainer_kwargs=dict(
            discount=0.99,
            policy_lr=3e-4,
            qf_lr=3e-4,
            soft_target_tau=5e-3,
            use_automatic_entropy_tuning=True,
        ),
        mppi_kwargs=dict(
            discount=.99,
            horizon=60,
            repeat_length=3,
            plan_every=1,
            temperature=0.01,
            noise_std=1,
            num_rollouts=400,
            num_particles=5,
            planning_iters=3,
            polyak=0,
            sampling_mode='ts',
            filter_coefs=(0.2, 0.6, 0),
        ),
        mbrl_kwargs=dict(
            ensemble_size=4,
            num_elites=4,
            layer_size=256,
            learning_rate=1e-3,
            batch_size=256,
            train_call_freq=10,
        ),
        offline_kwargs=dict(
            num_epochs=1000,
            num_eval_steps_per_epoch=1000,
            num_trains_per_train_loop=100,
            model_train_freq=250,
            model_batch_size=256,
            max_path_length=200,
            batch_size=256,
            save_snapshot_freq=50,
        ),
        algorithm_kwargs=dict(
            num_epochs=10000,
            num_eval_steps_per_epoch=0,
            num_trains_per_train_loop=10,
            num_expl_steps_per_train_loop=10,
            min_num_steps_before_training=0,
            num_model_trains_per_train_loop=1,
            max_path_length=200,
            reset_free=False,  # ignore
            batch_size=256,
            model_batch_size=256,
            save_snapshot_freq=250,
        ),
    )

    sweep_values = {
    }

    launch_experiment(
        get_config=make_get_config(get_config),
        get_algorithm=get_algorithm,
        get_offline_algorithm=get_offline_algorithm,
        variant=variant,
        sweep_values=sweep_values,
        **experiment_kwargs
    )
