import numpy as np
import torch
import argparse
import pickle
import random
import sys
import os
from decision_transformer.evaluation.evaluate_episodes import evaluate_episode, evaluate_episode_rtg
from decision_transformer.models.decision_transformer import DecisionTransformer
from decision_transformer.models.mlp_bc import MLPBCModel
from decision_transformer.training.act_trainer import ActTrainer
from decision_transformer.training.seq_trainer import SequenceTrainer
from Environment.env import Environment
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
torch.cuda.set_device(1)   

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def test(variant):
    # load dataset
    env_name = 'BO_2'
    dataset_path = f'Decision_Transformer/decision_transformer/envs/data/Mix.pkl'
    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)

    mode = variant.get('mode', 'normal')
    states, traj_lens, returns = [], [], []
    for path in trajectories:
        if mode == 'delayed':  # delayed: all rewards moved to end of trajectory
            path['rewards'][-1] = path['rewards'].sum()
            path['rewards'][:-1] = 0.
        states.append(path['observations'])
        traj_lens.append(len(path['observations']))
        # returns.append(path['rewards'].sum())
        returns.append(np.sum(path['rewards']))
    traj_lens, returns = np.array(traj_lens), np.array(returns)

    # used for input normalization
    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
    num_timesteps = sum(traj_lens)

    env_name = env_name + '_' + str(variant['f_num'])
    state_dim = variant['domain_size']*2*variant['f_num']+variant['f_num']+1
    act_dim = 1

    K = variant['K']
    max_ep_len = 100
    num_eval_episodes = variant['num_eval_episodes']

    model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant['embed_dim'],
            n_layer=variant['n_layer'],
            n_head=variant['n_head'],
            n_inner=4*variant['embed_dim'],
            activation_function=variant['activation_function'],
            action_tanh=False,
            n_positions=1024,
            resid_pdrop=variant['dropout'],
            attn_pdrop=variant['dropout'],
        )
    
    log_save_path = 'Decision_Transformer/tb_record'
    for train_ep in [9]:
        target_return = 0
        pretrain_path = 'Decision_Transformer/preTrained/BO_2/{}.pth'.format(train_ep)
        model.load_state_dict(torch.load(pretrain_path))
        model.to(device)
        print('Success load model from {}'.format(pretrain_path))

        for test_f in ["ARa", "AR", "BC", "DR", "RBF_0.2", "RBF_0.05", "matern52_0.2", "matern52_0.05"]:

            env = Environment(T=100, domain_num= variant['domain_size'], f_num= variant['f_num'], function_type= test_f, seed=0)
            writer = SummaryWriter(log_save_path + '/DT_{}/{}/mean'.format(train_ep, test_f))
            regrets = []
            print('Test on function {}'.format(test_f))

            for e in tqdm(range(num_eval_episodes)):
                state = env.reset(seed=3100+e*10, new_ls=True)
                # we keep all the histories on the device
                # note that the latest action and reward will be "padding"
                states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
                actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
                rewards = torch.zeros(0, device=device, dtype=torch.float32)
                ep_regret = []

                ep_return = target_return
                target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
                timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
                ep_writer = SummaryWriter(log_save_path + '/DT_{}/{}/{}'.format(train_ep, test_f, e))

                episode_return, episode_length = 0, 0
                for t in range(max_ep_len):

                    # add padding
                    actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
                    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

                    action = model.get_action(
                        states.to(dtype=torch.float32),
                        actions.to(dtype=torch.float32),
                        rewards.to(dtype=torch.float32),
                        target_return.to(dtype=torch.float32),
                        timesteps.to(dtype=torch.long),
                    )
                    actions[-1] = action
                    action = action.detach().cpu().numpy()
                    #print(action)

                    state, reward, done, regret = env.step(int(action))
                    #print(action)
                    ep_regret.append(regret)
                    ep_writer.add_scalar('Regret', regret, t)
                    # state = state[env.index(action)].astype('float32') # for BO dataset
                    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
                    states = torch.cat([states, cur_state], dim=0)
                    rewards[-1] = reward

                    episode_return += reward
                    episode_length += 1

                    if done:
                        break
                ep_writer.close()
                regrets.append(ep_regret)
                save_path = "./Q_value_Transformer/result/test/DT"
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                np.save(save_path + "/DT_random_{}_{}".format(test_f, e), ep_regret)
            mean_regrets = np.mean(regrets, axis = 0)
            print(mean_regrets)
            for i, reg in enumerate(mean_regrets.tolist()):
                writer.add_scalar('Regret', reg, i)
            writer.close()
            



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='BO')
    parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
    parser.add_argument('--mode', type=str, default='normal')  # normal for standard setting, delayed for sparse
    parser.add_argument('--K', type=int, default=100)
    parser.add_argument('--pct_traj', type=float, default=1.)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--model_type', type=str, default='dt')  # dt for decision transformer, bc for behavior cloning
    parser.add_argument('--embed_dim', type=int, default=500)
    parser.add_argument('--n_layer', type=int, default=3)
    parser.add_argument('--n_head', type=int, default=1)
    parser.add_argument('--activation_function', type=str, default='relu')
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=10)
    parser.add_argument('--num_eval_episodes', type=int, default=100)
    parser.add_argument('--max_iters', type=int, default=20)
    parser.add_argument('--num_steps_per_iter', type=int, default=5)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--log_to_wandb', '-w', type=bool, default=False)
    parser.add_argument('--f_num', type=int, default=2)
    parser.add_argument('--domain_size', type=int, default=1000)
    
    args = parser.parse_args()

    test(variant=vars(args))