"""
PEARL Experiment
"""

from pathlib import Path

import click

import rlkit.pythonplusplus as ppp
from rlkit.launchers.launcher_util import load_pyhocon_configs
from rlkit.torch.pearl.awac_launcher import pearl_awac_experiment, process_args

from roboverse.envs.sawyer_affordances_meta_v0 import SawyerAffordancesMetaV0

from rlkit.launchers.arglauncher import run_variants
import rlkit.misc.hyperparameter as hyp
import pickle
import copy
from rlkit.torch.networks import Clamp

def main():
    base_dir = Path(__file__).parent.parent
    configs = [
        str(base_dir / 'configs/default_awac.conf'),
        str(base_dir / 'configs/offline_pretraining.conf'),
        str(base_dir / 'configs/bullet_offline.conf'),
    ]
    variant = ppp.recursive_to_dict(load_pyhocon_configs(configs))

    PATH_TO_MACAW_BUFFERS = "path/to/generated/buffer_directory"  #TODO

    nseeds = 5
    search_space = {
        'seed': list(range(nseeds)),
        'env_class': [SawyerAffordancesMetaV0],
        'env_params': [dict(fixed_tasks="task_data/sawyer_manipulation_tasks.pkl")],
        'flatten': [True],
        'trainer_kwargs.beta': [0.3, ],
        'macaw_format_base_path': [PATH_TO_MACAW_BUFFERS],
        'load_buffer_kwargs.is_macaw_buffer_path': [True],
        'trainer_kwargs.train_context_decoder': [True,],
        'trainer_kwargs.backprop_q_loss_into_encoder': [False,],
        'trainer_kwargs.awr_use_mle_for_vf': [True, ],
        'trainer_kwargs.clip_score': [2, ],
        'trainer_kwargs.reward_scale': [1.0, ],
        'trainer_kwargs.reward_transform_kwargs': [None, ],
        'trainer_kwargs.terminal_transform_kwargs': [dict(m=0, b=0),],
        'train_task_idxs': [list(range(50)), ], # range(20), # # [[0, ],],
        'eval_task_idxs': [list(range(40, 50)), list(range(50, 60)), ],
        'algo_kwargs.num_iterations_with_reward_supervision': [None, 0],
        'algo_kwargs.exploration_resample_latent_period': [1,],
        'algo_kwargs.encoder_buffer_matches_rl_buffer': [True,],
        'algo_kwargs.freeze_encoder_buffer_in_unsupervised_phase': [False,],
        'algo_kwargs.num_tasks_sample': [10,],
        'algo_kwargs.max_path_length': [50,],
        'algo_kwargs.num_initial_steps': [0,],
        'algo_kwargs.num_train_steps_per_itr': [3000],
        'algo_kwargs.num_steps_prior': [100,],
        'algo_kwargs.num_steps_posterior': [100,],
        'algo_kwargs.num_steps_per_eval': [100,],
        'algo_kwargs.num_exp_traj_eval': [1,],
        'algo_kwargs.clear_encoder_buffer_before_every_update': [False,],
        'launcher_config': [dict(unpack_variant=True, region='us-west-1', )],
        'qf_kwargs.output_activation': [Clamp(max=0)],
    }

    sweeper = hyp.DeterministicHyperparameterSweeper(
        search_space, default_parameters=variant,
    )

    variants = []
    for variant in sweeper.iterate_hyperparameters():
        variants.append(variant)

    run_variants(pearl_awac_experiment, variants, run_id=0, process_args_fn=process_args)


if __name__ == "__main__":
    main()
