import argparse
import gym
import numpy as np
import torch
import json
import wandb
import time
import glob
start_time = time.time()

import pickle 
import json 
from collections import OrderedDict

from configs import args_point_robot, args_half_cheetah_vel, args_half_cheetah_dir, args_ant_dir, args_ant_goal, args_walker, args_hopper, args_reach
from meta_dt.model import MetaDecisionTransformer
from meta_dt.dataset import MetaDT_Dataset, append_context_to_data
from meta_dt.evaluation import meta_evaluate_episode_rtg
from decision_transformer.dataset import convert_data_to_trajectories
from context.model import RNNContextEncoder


parser = argparse.ArgumentParser()
parser.add_argument('--env_type', type=str, default='walker')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--context_horizon', type=int, default=4)
args, rest_args = parser.parse_known_args()
env_type = args.env_type
context_horizon = args.context_horizon
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')


if env_type == 'point_robot':
    args = args_point_robot.get_args(rest_args)
    args.context_horizon = context_horizon
elif  env_type == 'cheetah_vel':
    args = args_half_cheetah_vel.get_args(rest_args)
    args.context_horizon = context_horizon
elif  env_type == 'cheetah_dir':
    args = args_half_cheetah_dir.get_args(rest_args)
    args.context_horizon = context_horizon
    args.env_name = 'HalfCheetahVel-v0'
elif  env_type == 'ant_dir':
    args = args_ant_dir.get_args(rest_args)
    args.context_horizon = context_horizon
elif  env_type == 'ant_goal':
    args = args_ant_goal.get_args(rest_args)
    args.context_horizon = context_horizon
elif  env_type == 'walker':
    args = args_walker.get_args(rest_args)
    args.context_horizon = context_horizon
elif  env_type == 'reach':
    args = args_reach.get_args(rest_args)
    args.context_horizon = context_horizon
elif  env_type == 'hopper':
    args = args_hopper.get_args(rest_args)
    args.context_horizon = context_horizon
else:
    raise NotImplementedError

#################################
args.data_quality = 'expert'
args.num_tasks = 30
args.num_train_tasks = 25
args.zero_shot = True
#################################

torch.manual_seed(args.seed)
np.random.seed(args.seed)
np.set_printoptions(precision=3, suppress=True)

from pathlib import Path
import sys
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv

if env_type =='cheetah_vel':
    env_name = 'cheetah-vel'
elif env_type == 'cheetah_dir':
    env_name = 'cheetah-dir'
elif env_type=='ant_dir':
    env_name = 'ant-dir'
elif env_type == 'ant_goal':
    env_name = 'ant-goal'
elif env_type=='walker':
    env_name = 'walker-rand-params'
elif env_type == 'hopper':
    env_name = 'hopper-rand-params'
# if env_name
env = NormalizedBoxEnv(ENVS[env_name]())

if args.debug:
    pass
else:
    if env_name == 'cheetah-dir':
        logger = wandb.init(project = f'Meta Test cheetah-vel -> cheetah-dir',
                            name = f'Meta-DT ({args.seed})',
                            group = 'Meta-DT')
    else:
        logger = wandb.init(project = f'Meta Test {env_name}',
                            name = f'Meta-DT ({args.seed})',
                            group = 'Meta-DT')

if env_name == 'cheetah-vel':
    env.set_velocity(-2) # set velocity (-2)
elif env_name == 'cheetah-dir':
    env.set_direction(-1) # set direction (backward)
elif env_name== 'ant-goal':
    env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
elif env_name == 'ant-dir':
    env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
elif 'params' in env_name:
    env.set_test_task()
env.set_seed(args.seed)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

# load the task information
with open(f'datasets/{args.env_name}/{args.data_quality}/task_info_0.json', 'r') as f:
    task_info = json.load(f)
f.close()
train_task_ids = np.arange(args.num_train_tasks)

target_ret = 0
for i in range(args.num_tasks):
    target_ret += task_info[f'task {i}']['return_scale'][1]
target_ret/=args.num_tasks

### load the pretrained context encoder 
if ((env_type=='walker')or(env_type=='hopper')):
    context_encoder = RNNContextEncoder(state_dim, action_dim, args.context_dim, args.context_hidden_dim).to(device)
else:
    context_encoder = RNNContextEncoder(state_dim, action_dim, args.context_dim, args.context_hidden_dim).to(device)
load_path = f'./saves/{args.env_name}/context/{args.data_quality}/*/horizon{args.context_horizon}/context_models_best.pt'
load_path = glob.glob(load_path)
load_path = load_path[0]
context_encoder.load_state_dict(torch.load(load_path)['context_encoder'])

for name, param in context_encoder.named_parameters():
    param.requires_grad = False 

print('Load context encoder from {}'.format(load_path))


########################## load train data ##########################
train_trajectories = []
for task_id in train_task_ids:
    train_data = OrderedDict()
    keys = ['observations', 'actions', 'rewards', 'next_observations', 'terminals','masks']
    if env_type=='cheetah_dir':
        if args.data_quality=='medium':
            keys = ['states', 'actions', 'rewards', 'next_states', 'dones', 'masks']
    for key in keys:
        train_data[key] = []

    with open(f'datasets/{args.env_name}/{args.data_quality}/dataset_task_{task_id}.pkl', "rb") as f:
        data = pickle.load(f)
    
    for key, values in data.items():
        train_data[key].append(values)
    for key, values in train_data.items():
        train_data[key] = np.concatenate(values, axis=0)
    if env_type=='cheetah_dir':
        if args.data_quality=='medium':
            train_data['observations']=train_data['states']
            train_data['next_observations']=train_data['next_states']
            train_data['terminals']=train_data['dones']
    train_data = append_context_to_data(train_data, context_encoder, horizon=args.context_horizon, device=device,args=args)
    train_trajectories_per = convert_data_to_trajectories(train_data,args)
    for trajectory in train_trajectories_per:
        train_trajectories.append(trajectory)

train_dataset = MetaDT_Dataset(
    train_trajectories, 
    args.dt_horizon, 
    args.max_episode_steps, 
    args.dt_return_scale, 
    device,
    prompt_trajectories_list=None,
    args=args,
    world_model = None
)
state_mean, state_std = train_dataset.state_mean, train_dataset.state_std
########################## load train data ##########################

########################## load transformer ##########################
model = MetaDecisionTransformer(
    state_dim=state_dim,
    act_dim=action_dim,
    max_length=args.dt_horizon,
    max_ep_len=args.max_episode_steps,
    context_dim=args.context_dim,
    hidden_size=args.dt_embed_dim,
    n_layer=args.dt_n_layer,
    n_head=args.dt_n_head,
    n_inner=4*args.dt_embed_dim,
    activation_function=args.dt_activation_function,
    n_positions=1024,
    resid_pdrop=args.dt_dropout,
    attn_pdrop=args.dt_dropout,
).to(device)
if env_type == 'cheetah_dir':
    env_type = 'cheetah_vel'
save_path = f'meta_dt_policy/{env_type}/meta_dt_model.pt'
model.load_state_dict(torch.load(save_path))
########################## load transformer ##########################

global_step = 0
scale = args.scale
total_steps = 0


model.eval()

epi_return, epi_length, traj_per_test, length = meta_evaluate_episode_rtg(
    env,
    state_dim,
    action_dim,
    model,
    context_encoder,
    max_episode_steps=args.max_episode_steps,
    scale=args.dt_return_scale,
    state_mean=state_mean,
    state_std=state_std,
    device=device,
    target_return=target_ret/args.dt_return_scale,
    horizon=args.context_horizon,
    num_eval_episodes=100000,
    prompt=None,
    args = args,
    epoch=global_step,
    eval=True,
    wandb = logger
)
total_steps += length
print(epi_return,total_steps)
print(f'\nElapsed time: {(time.time()-start_time)/60.:.2f} minutes')
