import click
from pathlib import Path

from rlkit.launchers.launcher_util import run_experiment, load_pyhocon_configs
import rlkit.pythonplusplus as ppp
import rlkit.misc.hyperparameter as hyp
from rlkit.torch.pearl.awac_launcher import pearl_awac_experiment
from rlkit.launchers.launcher_util import run_experiment


@click.command()
@click.option('--debug', is_flag=True, default=False)
@click.option('--dry', is_flag=True, default=False)
@click.option('--suffix', default=None)
@click.option('--nseeds', default=1)
def main(debug, dry, suffix, nseeds):
    base_dir = Path(__file__).parent.parent

    path_parts = __file__.split('/')
    suffix = '' if suffix is None else '--{}'.format(suffix)
    exp_name = path_parts[-1].split('.')[0].replace('_', '-')

    if debug or dry:
        exp_name = 'dev--' + exp_name
        nseeds = 1

    configs = [
        base_dir / 'configs/smac.conf',
        base_dir / 'configs/half_cheetah_120_v1.conf',
    ]
    if debug:
        configs.append(base_dir / 'configs/debug.conf')
    default_variant = ppp.recursive_to_dict(load_pyhocon_configs(configs))

    search_space = {
        'load_buffer_kwargs.pretrain_buffer_path': [
            "path/to/generated/extra_snapshot_itr50.cpkl",  # TODO
        ],
        'saved_tasks_path': [
            "task_data/half_cheetah_tasks.joblib",
        ],
        'load_buffer_kwargs.end_idx': [
            1200,
        ],
        'algo_kwargs.train_encoder_decoder_in_unsupervised_phase': [
            False,
        ],
        'algo_kwargs.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase': [
            True,
        ],
        'algo_kwargs.use_rl_buffer_for_enc_buffer': [
            False,
        ],
        'trainer_kwargs.awr_use_mle_for_vf': [
            False,
        ],
        'pretrain_offline_algo_kwargs.num_batches': [
            50000,
        ],
        'algo_kwargs.num_iterations': [
            50,
        ],
        'seed': list(range(nseeds)),
    }
    sweeper = hyp.DeterministicHyperparameterSweeper(
        search_space, default_parameters=default_variant,
    )
    for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
        variant['exp_id'] = exp_id
        run_experiment(
            pearl_awac_experiment,
            unpack_variant=True,
            exp_name=exp_name,
            mode='here_no_doodad',
            variant=variant,
            use_gpu=True,
        )


if __name__ == "__main__":
    main()
