"""
Launcher for experiments
This launcher is based off rlkit and Pearl algorithm on github

https://github.com/vitchyr/rlkit
https://github.com/katerakelly/oyster
"""
import os
import pathlib
import numpy as np
import click
import json
import torch
from torch import nn as nn
from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.torch.sac.policies import TanhGaussianPolicy
from rlkit.torch.networks import FlattenMlp
from rlkit.torch.sac.sac import Flap
from rlkit.launchers.launcher_util import setup_logger
import rlkit.torch.pytorch_util as ptu
from configs.default import default_config
from rlkit.torch.modifiedmlp import ModifiedFlattenMlp

def experiment(variant, gpu):
    # create multi-task environment and sample tasks
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
   
    #simply range over number tasks
    tasks = env.get_all_task_idx()
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    reward_dim = 1
    num_training = len(list(tasks[:variant['n_train_tasks']]))
    net_size = variant['net_size']
    last_layer_size = variant['last_layer_size']
    last_policy_layer_size = variant['last_policy_layer_size']
    adapter_net_size = variant['adapter_net_size']

    adapt_steps = variant['adapt_steps']

    policy = TanhGaussianPolicy(
            hidden_sizes=[net_size, net_size],
            obs_dim=obs_dim,
            action_dim=action_dim,
            last_layer_dim=last_policy_layer_size,
            tasks_num=num_training,
            output_size=action_dim,
            gpu_id=gpu,
        )
    
    #to use only 0th layer for testing
    test_policy = TanhGaussianPolicy(
            hidden_sizes=[net_size, net_size],
            obs_dim=obs_dim,
            action_dim=action_dim,
            last_layer_dim=last_policy_layer_size,
            tasks_num=num_training,
            output_size=action_dim,
            gpu_id=gpu,
        )
        
 
    #specify dim of policy for the adapter neural net
    hyper_dim = last_policy_layer_size
    
    adapter = FlattenMlp(
        hidden_sizes=[adapter_net_size, adapter_net_size, adapter_net_size],
        input_size=obs_dim + action_dim + obs_dim + reward_dim,
        output_size=action_dim * hyper_dim + action_dim,
    )
    
    #phi dimension vector is outputsize
    #This is modified neural net to allow
    #multiple output layers, 1 for each different task
    q1_net = ModifiedFlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim + action_dim,
        last_layer_dim=last_layer_size,
        tasks_num=num_training,
        output_size=1,
    )

    q2_net = ModifiedFlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim + action_dim,
        last_layer_dim=last_layer_size,
        tasks_num=num_training,
        output_size=1,
    )

    target_q1 = ModifiedFlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim + action_dim,
        last_layer_dim=last_layer_size,
        tasks_num=num_training,
        output_size=1,
    )

    target_q2 = ModifiedFlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim + action_dim,
        last_layer_dim=last_layer_size,
        tasks_num=num_training,
        output_size=1,
    )

    algorithm = Flap(
        env=env,
        train_tasks=list(tasks[:variant['n_train_tasks']]),
        eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
        q1_net=q1_net,
        q2_net=q2_net,
        policy=policy,
        target_q1=target_q1,
        target_q2=target_q2,
        tasks_num=len(tasks),
        adapter=adapter,
        action_dim=action_dim,
        hyperparam_dim=hyper_dim,
        test_policy=test_policy,
        gpu_id=gpu,
        adapt_steps=adapt_steps,
        **variant['algo_params']
    )


    # optionally load pre-trained weights
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        q1_net.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        q2_net.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))
        target_q1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        target_q2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))
        adapter.load_state_dict(torch.load(os.path.join(path, 'adapter.pth')))
        policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    # optional GPU mode
    ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id'])
    if variant['util_params']['use_gpu']:
        algorithm.to()
        
    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))

    # create logging directory
    exp_id = 'debug' if DEBUG else None
    experiment_log_dir = setup_logger(variant['env_name'], variant=variant, exp_id=exp_id, base_log_dir=variant['util_params']['base_log_dir'])

    # optionally save eval trajectories as pkl files
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()


#This function just updates the configurations
def deep_update_dict(fr, to):
    ''' update dict of dicts with new values '''
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to

@click.command()
@click.argument('config', default=None)
@click.option('--gpu', default=0)
@click.option('--debug', is_flag=True, default=False)
def main(config, gpu, debug):
    #grab default configurations first
    #these are modified
    variant = default_config
    
    #modify the default configs
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
        
    #gpu id for training
    variant['util_params']['gpu_id'] = gpu

    #run the experiment with the configurations
    experiment(variant, gpu)

if __name__ == "__main__":
    main()

