import os, shutil
import os.path as osp
import pickle
import json
import numpy as np
import click
import torch
import wandb

from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv, CameraWrapper
from rlkit.torch.sac.policies import TanhGaussianPolicy
from rlkit.torch.networks import FlattenMlp, MlpEncoder
from rlkit.torch.sac.agent import PEARLAgent
from configs.default import default_config
from launch_pearl import deep_update_dict
from rlkit.torch.sac.policies import MakeDeterministic
from rlkit.samplers.util import rollout
from tqdm import tqdm
import re



def sim_policy(variant, path_to_exp,test_env, num_trajs=1, debug=False, seed=0):
    '''
    simulate a trained policy adapting to a new task
    optionally save videos of the trajectories - requires ffmpeg

    :variant: experiment configuration dict
    :path_to_exp: path to exp folder
    :num_trajs: number of trajectories to simulate per task (default 1)
    :deterministic: if the policy is deterministic (default stochastic)
    :save_video: whether to generate and save a video (default False)
    '''
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    # self.env.set_test_task(variant['n_eval_tasks'])
    if test_env == 'cheetah-vel':
        env.set_velocity(-2.0) # set velocity (-2)
    elif test_env == 'cheetah-dir':
        env.set_direction(-1) # set direction (forward)
    elif test_env == 'ant-goal':
        env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
    elif test_env == 'ant-dir':
        env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
    elif 'params' in test_env:
        env.set_test_task()
    env.set_seed(seed)
    tasks = env.get_all_task_idx()
    eval_tasks = list(tasks)
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    print('testing on {} test tasks, {} trajectories each'.format(len(eval_tasks), num_trajs))

    # instantiate networks
    latent_dim = variant['latent_size']
    context_encoder = latent_dim * 2 if variant['algo_params']['use_information_bottleneck'] else latent_dim
    reward_dim = 1
    net_size = variant['net_size']
    encoder_model = MlpEncoder
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant['algo_params']['use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(
        latent_dim,
        context_encoder,
        policy,
        **variant['algo_params']
    )
    if debug:
        pass
    else:
        if test_env == 'cheetah-dir':
            wandb.init(
                    project = f'Meta Test cheetah-vel -> cheetah-dir',
                    name = f'PEARL({seed})',
                    group = "PEARL"
                    )
        else:
            wandb.init(
                    project = f'Meta Test {test_env}',
                    name = f'PEARL({seed})',
                    group = "PEARL"
                    )
    
    agent = MakeDeterministic(agent)

    # load trained weights (otherwise simulate random policy)
    context_encoder.load_state_dict(torch.load(os.path.join(path_to_exp, f'context_encoder.pt')))
    policy.load_state_dict(torch.load(os.path.join(path_to_exp, f'policy.pt')))
    # loop through tasks collecting rollouts
    all_rets = []
    
    for idx in tqdm(eval_tasks):
        env.reset_task(idx)
        agent.clear_z()
        paths = []
        total_steps = 0
        for n in tqdm(range(num_trajs)):
            path = rollout(env, agent, 
                           max_path_length=variant['algo_params']['max_path_length'], 
                           accum_context=True, 
                           save_frames=False)
            steps = len(path['agent_infos'])
            total_steps += steps
            paths.append(path)
            if n >= variant['algo_params']['num_exp_traj_eval']:
                agent.infer_posterior(agent.context)
            if total_steps >= 200000:
                break
        all_rets.append([sum(p['rewards']) for p in paths])

    num_trajs = len(all_rets[0])
    
    # compute average returns across tasks
    n = min([len(a) for a in all_rets])
    rets = [a[:n] for a in all_rets]
    rets = np.mean(np.stack(rets), axis=0)

    for i in range(num_trajs):  # i: trajectory index
        for idx, task_idx in enumerate(eval_tasks):
            ret = all_rets[idx][i]
            if debug:
                pass
            else:
                wandb.log({f"Task {task_idx} returns":ret},step=(i+1)*variant['algo_params']['max_path_length'])
        if debug:
            pass
        else:
            wandb.log({"Task mean return":rets[i]},step=(i+1)*variant['algo_params']['max_path_length'])
        print('trajectory {}, avg return: {} \n'.format(i, rets[i]))

        
@click.command()
@click.option('--test_env',default=None)
@click.option('--num_trajs', default=1000)
@click.option('--debug', is_flag=True, default=False)
@click.option('--seed', default=0)

def main(test_env, num_trajs, debug, seed):
    variant = default_config
    if test_env == 'cheetah-dir':
        config = './configs/cheetah-dir.json'
        path = './pearl_saved_model/cheetah-vel'
    if test_env == 'cheetah-vel':
        config = './configs/cheetah-vel.json'
        path = './pearl_saved_model/cheetah-vel'
    elif test_env == 'ant-goal':
        config = './configs/ant-goal.json'
        path = './pearl_saved_model/ant-goal'
    elif test_env == 'ant-dir':
        config = './configs/ant-dir.json'
        path = './pearl_saved_model/ant-dir'
    elif test_env == 'humanoid-dir':
        config = './configs/humanoid-dir.json'
        path = './pearl_saved_model/humanoid-dir'
    elif test_env == 'hopper-rand-params':
        config = './configs/hopper_rand_params.json'
        path = './pearl_saved_model/hopper-rand-params'
    elif test_env == 'walker-rand-params':
        config = './configs/walker_rand_params.json'
        path = './pearl_saved_model/walker-rand-params'
    if config:
        with open(osp.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    sim_policy(variant, path,test_env, num_trajs, debug, seed)


if __name__ == "__main__":
    main()
