"""
Launcher for experiments with PEARL

"""
import os
import pathlib
import numpy as np
import click
import json
import torch
import torch.nn as nn


from rlkit.envs import ENVS
# from rlkit.envs.wrappers import NormalizedBoxEnv, CameraWrapper
from rlkit.envs.wrappers import NormalizedBoxEnv, FrameStackAndFrameSkip
from rlkit.torch.sac.policies import TanhGaussianPolicy, TanhGaussianCnnPolicy
from rlkit.torch.networks import FlattenMlp, MlpEncoder, RecurrentEncoder
from rlkit.torch.networks import CnnEncoder, CnnPolicyNetwork, CnnQFunction, CnnVf, \
    CnnContextEncoder
from rlkit.torch.sac.sac import PEARLSoftActorCritic
from rlkit.torch.sac.agent import PEARLAgent
from rlkit.launchers.launcher_util import setup_logger
import rlkit.torch.pytorch_util as ptu
from configs.default import default_config


def experiment(variant):

    # create multi-task environment and sample tasks
    env_name = variant["env_name"]
    visual = True
    if visual:
        frame_stack = 3
        frame_skip = 1
        env = FrameStackAndFrameSkip(ENVS[variant["env_name"]](**variant["env_params"]), frame_stack, frame_skip)
    else:
        env = NormalizedBoxEnv(ENVS[variant["env_name"]](**variant["env_params"]))
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    tasks = env.get_all_task_idx()
    train_goals = env.get_train_goals(variant["n_train_tasks"])
    reward_dim = 1

    # instantiate networks
    latent_dim = variant["latent_size"]
    context_encoder_input_dim = (
        2 * obs_dim + action_dim + reward_dim
        if variant["algo_params"]["use_next_obs_in_context"]
        else obs_dim + action_dim + reward_dim
    )
    context_encoder_output_dim = (
        latent_dim * 2
        if variant["algo_params"]["use_information_bottleneck"]
        else latent_dim
    )
    net_size = variant["net_size"]
    recurrent = variant["algo_params"]["recurrent"]
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder
    n_train_tasks = variant["n_train_tasks"]

    goal_dim = len(train_goals[0])
    # goal_dim = 1

    visual = True
    if visual:
        embedding_dim = 10
        context_encoder_input_dim = embedding_dim + action_dim + reward_dim
    if visual:
        obs_shape = (1 * frame_stack, 64, 64)
        feature_dim = 50
        num_layers = 4
        num_filters = 32
        hidden_dim = 1024
        # image to observation
        embedding_dim = 10
        latent_dim = 5
        # image_embedding_encoder is the encoder that turns images into "observation" vectors
        # doesn't have fc layers though
        device_ids = [0, 1]
        image_embedding_encoder = nn.DataParallel(CnnEncoder(obs_shape, num_layers, num_filters,
                hidden_dim), device_ids=device_ids)
        context_encoder = CnnContextEncoder(obs_shape, action_dim, 
                embedding_dim, num_layers, num_filters, hidden_dim, latent_dim,)
        qf1 = nn.DataParallel(CnnQFunction(obs_shape, action_dim, latent_dim, feature_dim, num_layers,
                num_filters, hidden_dim), device_ids=device_ids)
        qf2 = nn.DataParallel(CnnQFunction(obs_shape, action_dim, latent_dim, feature_dim, num_layers,
                num_filters, hidden_dim), device_ids=device_ids)
        rf = nn.DataParallel(CnnQFunction(obs_shape, action_dim, goal_dim, feature_dim, num_layers,
                num_filters, hidden_dim), device_ids=device_ids)
        vf = CnnVf(obs_shape, latent_dim, feature_dim, num_layers, num_filters, hidden_dim)
        policy = TanhGaussianCnnPolicy(
            obs_shape,
            action_dim,
            latent_dim,
            feature_dim,
            num_layers,
            num_filters,
            hidden_dim
        )
        return_log_prob = False
    else:
        context_encoder = encoder_model(
            hidden_sizes=[200, 200, 200],
            input_size=context_encoder_input_dim,
            output_size=context_encoder_output_dim,
        )
        qf1 = FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=obs_dim + action_dim + latent_dim,
            output_size=1,
        )
        qf2 = FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=obs_dim + action_dim + latent_dim,
            output_size=1,
        )
        multitask_qf1 = FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=obs_dim + action_dim + goal_dim,
            output_size=1,
        )
        multitask_qf2 = FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=obs_dim + action_dim + goal_dim,
            output_size=1,
        )
        vf = FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=obs_dim + latent_dim,
            output_size=1,
        )
        multitask_vf = FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=obs_dim + goal_dim,
            output_size=1,
        )
        # return_log_prob = (variant["relabel_method"] == "hipi");
        return_log_prob = False
        policy = TanhGaussianPolicy(
            hidden_sizes=[net_size, net_size, net_size],
            obs_dim=obs_dim + latent_dim,
            latent_dim=latent_dim,
            action_dim=action_dim,
            return_log_prob=return_log_prob,
        )
    agent = PEARLAgent(latent_dim, context_encoder, policy, image_embedding_encoder, **variant["algo_params"])
    algorithm = PEARLSoftActorCritic(
        env=env,
        train_tasks=list(tasks[: variant["n_train_tasks"]]),
        eval_tasks=list(tasks[-variant["n_eval_tasks"] :]),
        train_goals=train_goals,
        nets=[agent, qf1, qf2, vf, rf, image_embedding_encoder],
        latent_dim=latent_dim,
        return_log_prob=return_log_prob,
        **variant["algo_params"]
    )

    # optionally load pre-trained weights
    if variant["path_to_weights"] is not None:
        path = variant["path_to_weights"]
        context_encoder.load_state_dict(
            torch.load(os.path.join(path, "context_encoder.pth"))
        )
        qf1.load_state_dict(torch.load(os.path.join(path, "qf1.pth")))
        qf2.load_state_dict(torch.load(os.path.join(path, "qf2.pth")))
        vf.load_state_dict(torch.load(os.path.join(path, "vf.pth")))
        # TODO hacky, revisit after model refactor
        algorithm.networks[-2].load_state_dict(
            torch.load(os.path.join(path, "target_vf.pth"))
        )
        policy.load_state_dict(torch.load(os.path.join(path, "policy.pth")))

    # optional GPU mode
    ptu.set_gpu_mode(
        variant["util_params"]["use_gpu"], variant["util_params"]["gpu_id"]
    )
    if ptu.gpu_enabled():
        algorithm.to()

    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant["util_params"]["debug"]
    os.environ["DEBUG"] = str(int(DEBUG))

    # create logging directory
    # TODO support Docker
    exp_id = "debug" if DEBUG else None
    experiment_log_dir = setup_logger(
        variant["env_name"],
        variant=variant,
        exp_id=exp_id,
        base_log_dir=variant["util_params"]["base_log_dir"],
    )

    # optionally save eval trajectories as pkl files
    if variant["algo_params"]["dump_eval_paths"]:
        pickle_dir = experiment_log_dir + "/eval_trajectories"
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)
    pickle_dir = experiment_log_dir + "/trajectories"
    pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()


def deep_update_dict(fr, to):
    """ update dict of dicts with new values """
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to


@click.command()
@click.argument("config", default=None)
@click.option("--gpu", default=0)
@click.option("--docker", is_flag=True, default=False)
@click.option("--debug", is_flag=True, default=False)
def main(config, gpu, docker, debug):

    variant = default_config
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    variant["util_params"]["gpu_id"] = gpu

    experiment(variant)


if __name__ == "__main__":
    main()
