"""
PEARL Experiment
"""

import pickle
import click
from pathlib import Path

from rlkit.launchers.launcher_util import load_pyhocon_configs
import rlkit.pythonplusplus as ppp
import rlkit.misc.hyperparameter as hyp
from rlkit.torch.pearl.sac_launcher import pearl_sac_experiment
from rlkit.launchers.launcher_util import run_experiment


@click.command()
@click.option('--debug', is_flag=True, default=False)
@click.option('--dry', is_flag=True, default=False)
@click.option('--suffix', default=None)
@click.option('--nseeds', default=1)
def main(debug, dry, suffix, nseeds):
    base_dir = Path(__file__).parent.parent

    path_parts = __file__.split('/')
    suffix = '' if suffix is None else '--{}'.format(suffix)
    exp_name = 'pearl-awac-{}--{}{}'.format(
        path_parts[-2].replace('_', '-'),
        path_parts[-1].split('.')[0].replace('_', '-'),
        suffix,
    )

    if debug or dry:
        exp_name = 'dev--' + exp_name
        mode = 'local'
        nseeds = 1

    configs = [
        base_dir / 'configs/default_sac.conf',
        base_dir / 'configs/ant_dir_120_online.conf',
    ]
    if debug:
        configs.append(base_dir / 'configs/debug.conf')
    default_variant = ppp.recursive_to_dict(load_pyhocon_configs(configs))
    tasks = joblib.load("task_data/ant_tasks.joblib")['tasks']
    search_space = {
        'seed': list(range(nseeds)),
        'trainer_kwargs.train_context_decoder': [
            True,
        ],
        'trainer_kwargs.backprop_q_loss_into_encoder': [
            True,
        ],
        'train_task_idxs': [
            list(range(100)),
        ],
        'eval_task_idxs': [
            list(range(100, 120))
        ],
        'env_params.fixed_tasks': [
            [t['goal'] for t in tasks],
        ],
        'env_params.direction_in_degrees': [
            True,
        ],
        'algo_kwargs.num_iterations_with_reward_supervision': [
            None,
        ],
        'algo_kwargs.exploration_resample_latent_period': [1],
        'algo_kwargs.encoder_buffer_matches_rl_buffer': [False],
        'algo_kwargs.freeze_encoder_buffer_in_unsupervised_phase': [False],
        'algo_kwargs.clear_encoder_buffer_before_every_update': [True],
        'tags.encoder_buffer_mode': ['keep_latest_exploration_only'],
    }

    sweeper = hyp.DeterministicHyperparameterSweeper(
        search_space, default_parameters=default_variant,
    )
    for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
        variant['exp_id'] = exp_id
        run_experiment(
            pearl_sac_experiment,
            unpack_variant=True,
            exp_name=exp_name,
            mode='here_no_doodad',
            variant=variant,
            use_gpu=True,
        )


if __name__ == "__main__":
    main()

