import os
import pathlib
import numpy as np
import click
import json
import torch
from torch import nn as nn
from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.torch.sac.policies import TanhGaussianPolicy
from rlkit.torch.networks import FlattenMlp
from rlkit.torch.sac.sac import Flap
from rlkit.launchers.launcher_util import setup_logger
import rlkit.torch.pytorch_util as ptu
from configs.default import default_config
from rlkit.torch.modifiedmlp import ModifiedFlattenMlp
from rlkit.torch.sac.policies import MakeDeterministic
from rlkit.samplers.util import rollout

def evaluate_policy(variant, num_trajs=1, deterministic=False):

    #default gpu
    gpu = variant['util_params']['gpu_id']

    # create multi-task environment and sample tasks
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))

    #simply range over number tasks
    tasks = env.get_all_task_idx()
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    reward_dim = 1
    num_training = len(list(tasks[:variant['n_train_tasks']]))
    net_size = variant['net_size']
    last_layer_size = variant['last_layer_size']
    last_policy_layer_size = variant['last_policy_layer_size']
    adapter_net_size = variant['adapter_net_size']
    eval_tasks=list(tasks[-variant['n_eval_tasks']:])
    print('testing on {} test tasks, {} trajectories each'.format(len(eval_tasks), num_trajs))

    policy = TanhGaussianPolicy(
            hidden_sizes=[net_size, net_size],
            obs_dim=obs_dim,
            action_dim=action_dim,
            last_layer_dim=last_policy_layer_size,
            tasks_num=num_training,
            output_size=action_dim,
            gpu_id=gpu,
        )
    # deterministic eval
    if deterministic:
        policy = MakeDeterministic(policy)
 
    #specify dim of policy for the adapter neural net
    hyper_dim = last_policy_layer_size
    
    adapter = FlattenMlp(
        hidden_sizes=[adapter_net_size, adapter_net_size, adapter_net_size],
        input_size=obs_dim + action_dim + obs_dim + reward_dim,
        output_size=action_dim * hyper_dim + action_dim,
    )
    
    #GPU
    ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id'])
    
    #Load weights for evaluation, set path in default.py configuration
    path = variant['path_to_weights']
    adapter.load_state_dict(torch.load(os.path.join(path, 'extra_weights.pth')))
    policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    
    
    all_rets = []
    for idx in eval_tasks:
        env.reset_task(idx)
        paths = []
        for n in range(num_trajs):
            path = rollout(
            env, policy, 0, max_path_length=variant['algo_params']['max_path_length'], adapter=adapter, testing=True, action_dim=action_dim, hyperparam_dim=hyper_dim, gpu_id=0)
            paths.append(path)
        all_rets.append([sum(p['rewards']) for p in paths])

    

    # compute average returns across tasks
    n = min([len(a) for a in all_rets])
    rets = [a[:n] for a in all_rets]
    rets = np.mean(np.stack(rets), axis=0)
    for i, ret in enumerate(rets):
        print('trajectory {}, avg return: {} \n'.format(i, ret))
    

#This function just updates the configurations
def deep_update_dict(fr, to):
    ''' update dict of dicts with new values '''
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to

@click.command()
@click.argument('config', default=None)
@click.option('--num_trajs', default=10)
@click.option('--deterministic', is_flag=True, default=True)
def main(config, num_trajs, deterministic):
    #grab default configurations first
    #these are modified
    variant = default_config
    
    #modify the default configs
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
        
    #Evaluate
    evaluate_policy(variant, num_trajs, deterministic)

if __name__ == "__main__":
    main()

