import json
import pdb
import os
import sys
from os.path import join
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import logging
import argparse

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

sys.path.insert(0, parent_dir)

from Environment.env import Environment
import trajectory.utils as utils
import trajectory.datasets as datasets
from trajectory.search import (
    beam_plan,
    make_prefix,
    extract_actions,
    update_context,
)

Dataset = 'Mix_EHI_random_train_f2_100'

class Parser(utils.Parser):
    dataset: str = Dataset
    config: str = 'config.offline'
args = Parser().parse_args('plan')

dataset = utils.load_from_config(args.logbase, args.dataset, args.gpt_loadpath,'data_config.pkl')
timer = utils.timer.Timer()

discretizer = dataset.discretizer
discount = dataset.discount
observation_dim = dataset.observation_dim
action_dim = dataset.action_dim

value_fn = lambda x: discretizer.value_fn(x, args.percentile)

T = 100
domain_num = 1000
f_num = 2
seed = 0
view_num = 100
episodes = 100
train_epochs = [40]
test_functions = ["ARa"]
args.beam_width = 8
args.horizon = 1
perturb_x = 0
action_trace_save_path = 'TT_test_data/' + Dataset
log_save_path = 'tb_record_3/' + Dataset + "_test_domain_h{}_b{}".format(args.horizon, args.beam_width)

gpt, gpt_epoch = utils.load_model(
                args.logbase, 
                args.dataset, 
                args.gpt_loadpath,
                epoch=40, 
                device=args.device)

for function_type in test_functions:
    env = Environment(T, domain_num, view_num, f_num, function_type, seed, perturb_x=perturb_x)

    ep_regrets = []
    ep_actions = []
    print("Test on {}".format(function_type))
    for e in range(episodes):
        
        observation = env.reset(seed=3200+e*10)
        
        total_reward = 0

        ## observations for rendering
        rollout = [observation.copy()]

        ## previous (tokenized) transitions for conditioning transformer
        context = []
        actions = []
        regrets = []
        # ep_writer = SummaryWriter(log_save_path + '/{}/{}'.format(function_type,e))
        #logging.basicConfig(level=logging.DEBUG, filename=  log_save_path + '/{}/{}.txt'.format(function_type,e))
        for t in tqdm(range(T)):
            #observation = preprocess_fn(observation)
            if t % args.plan_freq == 0:
                ## concatenate previous transitions and current observations to input to model
                observation = observation.reshape(-1)
                #print(observation.shape)
                prefix = make_prefix(discretizer, context, observation, args.prefix_context)

                ## sample sequence from model beginning with `prefix`
                sequence = beam_plan(
                    gpt, value_fn, prefix,
                    args.horizon, args.beam_width, args.n_expand, observation_dim, action_dim,
                    discount, args.max_context_transitions, verbose=args.verbose,
                    k_obs=args.k_obs, k_act=args.k_act, cdf_obs=args.cdf_obs, cdf_act=args.cdf_act,
                )

            else:
                sequence = sequence[1:]

            ## [ horizon x transition_dim ] convert sampled tokens to continuous trajectory
            sequence_recon = discretizer.reconstruct(sequence)

            ## [ action_dim ] index into sampled trajectory to grab first action
            action = extract_actions(sequence_recon, observation_dim, action_dim, t=0)
            ## execute action in environment
            next_observation, reward, terminal, regret = env.step(action)
            actions.append(int(action))
            #logging.info('t: {} / a: {} / r: {}'.format(t, int(action), -reward))
            total_reward += reward
            
            rollout.append(next_observation.copy())
            context = update_context(context, discretizer, observation, action, reward, args.max_context_transitions)

            regrets.append(regret)
            # ep_writer.add_scalar('Regret', -reward, t)
            #print(f'[ plan ] t: {t} / {T} | r: {reward:.2f} | env: {function_type} | 'f'time: {timer():.2f} | {args.dataset} | {args.exp_name} | {args.suffix}\n')

            if terminal: break

            observation = next_observation
        print(f'E: {e} | R: {regrets[-1]}')
        ep_regrets.append(regrets)
        ep_actions.append(actions)

        save_path = "./Q_value_Transformer/result/test/TT"
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        np.save(save_path + "/TT_{}_{}_{}".format(Dataset, function_type, e), regrets)
        # ep_writer.close()
        
    mean_regrets = np.mean(ep_regrets, axis = 0)
        #writer = SummaryWriter(log_save_path + '/{}/{}'.format(train_epoch, function_type))
        #os.makedirs(action_trace_save_path + "/{}".format(train_epoch,function_type), exist_ok=True)
        #np.save(action_trace_save_path + "/{}/{}".format(train_epoch,function_type), ep_actions)

        # #writer.close()
        
        # # save result as a json file
        # json_path = join(args.savepath, 'rollout.json')
        # json_data = {'step': t, 'return': total_reward, 'term': terminal, 'gpt_epoch': gpt_epoch}
        # json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)
