"""
Launcher for experiments with PEARL

"""
import os
import pathlib
import numpy as np
import click
import json
import torch


from rlkit.envs import ENVS
# from rlkit.envs.wrappers import NormalizedBoxEnv, CameraWrapper
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.torch.sac.policies import TanhGaussianPolicy
from rlkit.torch.networks import FlattenMlp, MlpEncoder, RecurrentEncoder
from rlkit.torch.sac.sac import PEARLSoftActorCritic
from rlkit.torch.sac.agent import PEARLAgent
from rlkit.launchers.launcher_util import setup_logger
import rlkit.torch.pytorch_util as ptu
from configs.default import default_config


def experiment(variant):

    # create multi-task environment and sample tasks
    env_name = variant["env_name"]
    env = NormalizedBoxEnv(ENVS[variant["env_name"]](**variant["env_params"]))
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    tasks = env.get_all_task_idx()
    train_goals = env.get_train_goals(variant["n_train_tasks"])
    reward_dim = 1

    # instantiate networks
    latent_dim = variant["latent_size"]
    context_encoder_input_dim = (
        2 * obs_dim + action_dim + reward_dim
        if variant["algo_params"]["use_next_obs_in_context"]
        else obs_dim + action_dim + reward_dim
    )
    context_encoder_output_dim = (
        latent_dim * 2
        if variant["algo_params"]["use_information_bottleneck"]
        else latent_dim
    )
    net_size = variant["net_size"]
    recurrent = variant["algo_params"]["recurrent"]
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder
    n_train_tasks = variant["n_train_tasks"]

    if "vel" in variant["env_name"]:
        goal_dim = 1
    else:
        goal_dim = len(train_goals[0])

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
    )
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    rf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
    )
    multitask_qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
    )
    multitask_qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + latent_dim,
        output_size=1,
    )
    multitask_vf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + goal_dim,
        output_size=1,
    )
    # return_log_prob = (variant["relabel_method"] == "hipi");
    return_log_prob = False
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
        return_log_prob=return_log_prob,
    )
    agent = PEARLAgent(latent_dim, context_encoder, policy, **variant["algo_params"])
    algorithm = PEARLSoftActorCritic(
        env=env,
        train_tasks=list(tasks[: variant["n_train_tasks"]]),
        eval_tasks=list(tasks[-variant["n_eval_tasks"] :]),
        train_goals=train_goals,
        nets=[agent, qf1, qf2, vf, rf, multitask_qf1, multitask_qf2, multitask_vf],
        # nets=[agent, qf1, qf2, vf, multitask_qf1, multitask_qf2, multitask_vf],
        latent_dim=latent_dim,
        return_log_prob=return_log_prob,
        **variant["algo_params"]
    )

    # optionally load pre-trained weights
    if variant["path_to_weights"] is not None:
        path = variant["path_to_weights"]
        context_encoder.load_state_dict(
            torch.load(os.path.join(path, "context_encoder.pth"))
        )
        qf1.load_state_dict(torch.load(os.path.join(path, "qf1.pth")))
        qf2.load_state_dict(torch.load(os.path.join(path, "qf2.pth")))
        vf.load_state_dict(torch.load(os.path.join(path, "vf.pth")))
        # TODO hacky, revisit after model refactor
        algorithm.networks[-2].load_state_dict(
            torch.load(os.path.join(path, "target_vf.pth"))
        )
        policy.load_state_dict(torch.load(os.path.join(path, "policy.pth")))

    # optional GPU mode
    ptu.set_gpu_mode(
        variant["util_params"]["use_gpu"], variant["util_params"]["gpu_id"]
    )
    if ptu.gpu_enabled():
        algorithm.to()

    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant["util_params"]["debug"]
    os.environ["DEBUG"] = str(int(DEBUG))

    # create logging directory
    # TODO support Docker
    exp_id = "debug" if DEBUG else None
    experiment_log_dir = setup_logger(
        variant["env_name"],
        variant=variant,
        exp_id=exp_id,
        base_log_dir=variant["util_params"]["base_log_dir"],
    )

    # optionally save eval trajectories as pkl files
    if variant["algo_params"]["dump_eval_paths"]:
        pickle_dir = experiment_log_dir + "/eval_trajectories"
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()


def deep_update_dict(fr, to):
    """ update dict of dicts with new values """
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to


@click.command()
@click.argument("config", default=None)
@click.option("--gpu", default=0)
@click.option("--docker", is_flag=True, default=False)
@click.option("--debug", is_flag=True, default=False)
def main(config, gpu, docker, debug):

    variant = default_config
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    variant["util_params"]["gpu_id"] = gpu

    experiment(variant)


if __name__ == "__main__":
    main()
