"""
PEARL Experiment
"""

import pickle
import click
from pathlib import Path

from rlkit.launchers.launcher_util import load_pyhocon_configs
import rlkit.pythonplusplus as ppp
import rlkit.misc.hyperparameter as hyp
from rlkit.torch.pearl.sac_launcher import pearl_sac_experiment
from rlkit.launchers.launcher_util import run_experiment


@click.command()
@click.option('--debug', is_flag=True, default=False)
@click.option('--dry', is_flag=True, default=False)
@click.option('--suffix', default=None)
@click.option('--nseeds', default=1)
def main(debug, dry, suffix, nseeds):
    base_dir = Path(__file__).parent.parent

    path_parts = __file__.split('/')
    suffix = '' if suffix is None else '--{}'.format(suffix)
    exp_name = path_parts[-1].split('.')[0].replace('_', '-')

    if debug or dry:
        exp_name = 'dev--' + exp_name
        nseeds = 1
    mode = 'here_no_doodad'

    configs = [
        base_dir / 'configs/default_sac.conf',
        base_dir / 'configs/half_cheetah_130_offline.conf',
    ]
    if debug:
        configs.append(base_dir / 'configs/debug.conf')
    default_variant = ppp.recursive_to_dict(load_pyhocon_configs(configs))
    tasks = joblib.load("task_data/half_cheetah_tasks.joblib")['tasks']
    search_space = {
        'seed': list(range(nseeds)),
        'trainer_kwargs.train_context_decoder': [
            True,
        ],
        'trainer_kwargs.backprop_q_loss_into_encoder': [
            True,
        ],
        'train_task_idxs': [
            list(range(100)),
        ],
        'eval_task_idxs': [
            list(range(100, 130))
        ],
        'env_params.presampled_tasks': [
            tasks,
        ],
        'algo_kwargs.num_iterations_with_reward_supervision': [
            None,
        ],
        'algo_kwargs.exploration_resample_latent_period': [
            1,
        ],
        'algo_kwargs.encoder_buffer_matches_rl_buffer': [False],
        'algo_kwargs.freeze_encoder_buffer_in_unsupervised_phase': [False],
        'algo_kwargs.clear_encoder_buffer_before_every_update': [True],
        'tags.encoder_buffer_mode': [
            'keep_latest_exploration_only',
        ],
    }

    sweeper = hyp.DeterministicHyperparameterSweeper(
        search_space, default_parameters=default_variant,
    )
    for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
        variant['exp_id'] = exp_id
        run_experiment(
            pearl_sac_experiment,
            unpack_variant=True,
            exp_name=exp_name,
            mode=mode,
            variant=variant,
            use_gpu=True,
            mount_point=None,
        )


if __name__ == "__main__":
    main()

