import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import pprint
import json
import torch
import argparse
import math
from environment.wrapper import LMPromptEnv
from environment.DDP_wrapper import DDP_LMPromptEnv
from environment.used.Env_bp_v1 import BP_V1, DDP_BP_V1
from environment.used.Env_cvrp_v1 import CVRP_V1, DDP_CVRP_V1
from environment.used.Env_cvrp_v2 import CVRP_V2, DDP_CVRP_V2
from environment.used.Env_cvrp_v3 import CVRP_V3, DDP_CVRP_V3
from environment.used.Env_tsp_v1 import TSP_V1, DDP_TSP_V1
from environment.used.Env_tsp_v2 import TSP_V2, DDP_TSP_V2
from environment.used.Env_tsp_v3 import TSP_V3, DDP_TSP_V3
from environment.used.Env_tsp_v4 import TSP_V4, DDP_TSP_V4
from environment.used.Env_op_v1 import OP_V1, DDP_OP_V1
from environment.used.Env_op_v2 import OP_V2, DDP_OP_V2
from environment.used.Env_op_v3 import OP_V3, DDP_OP_V3
from environment.used.Env_op_v4 import OP_V4, DDP_OP_V4
from environment.used.Env_pctsp_v1 import PCTSP_V1, DDP_PCTSP_V1
from environment.used.Env_pctsp_v2 import PCTSP_V2, DDP_PCTSP_V2
from environment.used.Env_pctsp_v3 import PCTSP_V3, DDP_PCTSP_V3
from utils.utils import set_seed, create_folder_overwrite_if_exist, str2bool, load_model
from data.used.make_data import *
import numpy as np
from argparse import Namespace
from typing import Optional, Union, List
from tqdm import tqdm
from dataloader.code.dataset import RLFullDataset, BlendableDataset
from evaluate_test.evaluate_utils import eval_policy
from dataloader.code.data_samplers import build_training_data_loader
from environment.wrapper import LMPromptEnv
import setproctitle
setproctitle.setproctitle("GATO-eval@XXX")

def build_dataloader(args:Namespace, eval_args:Namespace, datasets:List[RLFullDataset], dataset_weights:List[float], envs_problems:Dict):
    # split training set and evaluation set        
    datasets_train, datasets_val = [], []
    for dataset in datasets:
        dataset_train, dataset_val = dataset.split_dataset(args.split)
        datasets_val.append(dataset_val)
        datasets_train.append(dataset_train)

    # build BlendableDataset
    dataset_train = BlendableDataset(
        datasets_train, 
        dataset_weights,
        batch_size=eval_args.batch_size,
        check_visited=False,
        log_data=eval_args.traindata_logger
    )
    dataset_val = BlendableDataset(
        datasets_val, 
        dataset_weights,
        batch_size=eval_args.eval_batch_size,
        check_visited=False,
        log_data=eval_args.traindata_logger
    )

    # build dataloader
    if eval_args.batch_num == 0:
        eval_args.batch_num = math.ceil(len(dataset_train) / eval_args.batch_size)
        sample_num_per_training_epoch = None
    else:
        sample_num_per_training_epoch = eval_args.batch_num * eval_args.batch_size
    dataloader_train = build_training_data_loader(
        eval_args, 
        dataset_train, 
        epoch_total_samples=sample_num_per_training_epoch, 
        is_eval=False
    )

    if eval_args.eval_batch_num == 0:
        eval_args.eval_batch_num = math.ceil(len(dataset_val) / eval_args.eval_batch_size)
        sample_num_per_evaluation_epoch = None
    else:
        sample_num_per_evaluation_epoch = eval_args.eval_batch_num * eval_args.eval_batch_size
    dataloader_val = build_training_data_loader(
        eval_args, 
        dataset_val, 
        epoch_total_samples=sample_num_per_evaluation_epoch,
        is_eval=True
    )

    return dataloader_val, dataloader_train, dataset_val, dataset_train

def get_eval_loss(gato, valid_data_iterator, device):
    eval_losses = []
    with tqdm(total=args.eval_batch_num, desc=f'Calculating eval loss') as pbar:
        for batch in valid_data_iterator:
            rl_task_input = batch[0]
            rl_task_input.to(device=device)

            _, loss, _, _ = gato(rl_task_input)
            loss = loss.item()
            eval_losses.append(loss)

            pbar.set_postfix({
                'loss':'{:.2f}'.format(loss), 
                'ave loss': '{:.2f}'.format(np.array(eval_losses).mean())
            })
            pbar.update(1)
    return np.array(eval_losses).mean()

def get_loss(gato, dataloader, dataset, eval_args, logger=None, is_train=False):
    epoch_losses = []
    desc = f'Calculating train loss' if is_train else 'Calculating eval loss'
    total = eval_args.batch_num if is_train else eval_args.eval_batch_num    
    with tqdm(total=total, desc=desc) as pbar:
        for batch in dataloader:
            rl_task_input = batch[0]
            rl_task_input.to(device=device)

            with torch.set_grad_enabled(False):
                _, loss, _, _ = gato(rl_task_input)
            epoch_losses.append(loss.item())

            pbar.set_postfix({
                'loss':'{:.2f}'.format(loss), 
                'ave loss (latest 20)': '{:.2f}'.format(np.array(epoch_losses[-20:]).mean())
            })
            pbar.update()

            # log train data if necessary
            if logger is not None and eval_args.traindata_logger:
                logged_data = dataset.get_logged_data()
                for env_name in eval_args.eval_envs:
                    logger[env_name].log_data(logged_data, seed=42, is_train=False)

    return np.array(epoch_losses).mean()

if __name__ == "__main__":
    # eval paras
    parser = argparse.ArgumentParser()
    parser.add_argument("--bp-max-item-num", type=int, default=10)
    parser.add_argument("--tsp-city-num", type=int, default=20,)
    parser.add_argument("--pctsp-node-num", type=int, default=20,)
    parser.add_argument("--op-node-num", type=int, default=20,)
    parser.add_argument("--cvrp-node-num", type=int, default=20,)
    parser.add_argument("--data-num-tsp", type=int, default=0,)
    parser.add_argument("--data-num-01bp", type=int, default=0,)
    parser.add_argument("--data-num-pctsp", type=int, default=0,)
    parser.add_argument("--data-num-op", type=int, default=0,)
    parser.add_argument("--data-num-cvrp", type=int, default=0,)
    parser.add_argument("--traj-type", type=str, default="all")
    parser.add_argument("--ckpt-path", type=str, default=None,)
    parser.add_argument("--snapshot-path", type=str, default=None,)
    parser.add_argument("--seeds", nargs='+', type=int, default=[42, 43])
    parser.add_argument("--eval-envs", nargs='+', type=str, default=['Env_01BP_V1', ])
    parser.add_argument("--eval-problem-set", type=str, default='problem', choices={'problem', 'train_problem'})
    parser.add_argument("--eval-iters-COP", type=int, default=0,)   # 设0则用整个评估问题集
    parser.add_argument("--eval-iters-RL", type=int, default=5,)
    parser.add_argument("--batch-num", type=int, default=0,)
    parser.add_argument("--batch-size", type=int, default=100,)
    parser.add_argument("--eval-batch-num", type=int, default=0,)
    parser.add_argument("--eval-batch-size", type=int, default=100,)
    parser.add_argument("--regen-times", type=int, default=5,)
    parser.add_argument("--policy-logger", type=str2bool, default=False, nargs="?", const=True)
    parser.add_argument("--traindata-logger", type=str2bool, default=False, nargs="?", const=True)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--use-prompt", type=str2bool, default=False)
    parser.add_argument("--use-mem", type=str2bool, default=True)
    parser.add_argument(
        "--dataloader-type",
        type=str,
        default="random",
        choices=["sequential", "random"],
        help="Fetch data sequentially or out-of-order",
    )
    parser.add_argument("--prompt-strategy", type=str, 
        default="stochastic_subseq;moving_prompt",
        choices={
            "stochastic_timestep;moving_prompt",
            "stochastic_subseq;moving_prompt",
            "stochastic_timestep;fixed_prompt",
            "stochastic_subseq;fixed_prompt",
        },
    )
    eval_args = parser.parse_args()
    

    
    eval_args.ckpt_path = 'Cmp-TSP20(embv2-15w-ours)_500_768_8|8_10/best/0.97_seed42_epoch68.pt'
    eval_args.eval_envs = ['Env_TSP_V3',]
    

    eval_args.data_num_bp = 1000
    eval_args.data_num_tsp = 1000
    eval_args.data_num_op = 1000
    eval_args.data_num_cvrp = 1000
    eval_args.data_num_pctsp = 1000

    # 评估训练损失、验证损失、测试损失
    eval_args.traj_type = 'all'
    eval_args.eval_problem_set = 'problem'  # train_problem or problem
    eval_args.batch_size = 40
    eval_args.batch_num = 20
    eval_args.eval_batch_size = 50      # 总 batch_size
    eval_args.eval_batch_num = 20
    eval_args.test_batch_size = 40      # 总 batch_size
    eval_args.test_batch_num = 15
    
    # 评估问题
    eval_args.problem_batch_size = 500  # 单 GPU 上的 batch_size
    eval_args.problem_batch_num = 20
    eval_args.eval_iters_COP = eval_args.problem_batch_size * eval_args.problem_batch_num

    # 控制参数
    eval_args.policy_logger = False
    eval_args.traindata_logger = False
    eval_args.use_prefix = True
    eval_args.use_prompt = False
    eval_args.use_mem = False
    eval_args.use_ddp_env = True
    eval_args.use_default_policy_obj = False
    eval_args.check_loss = False
    
    assert (eval_args.ckpt_path is None) ^ (eval_args.snapshot_path is None) 
    if eval_args.ckpt_path is not None:
        exp_profile = eval_args.ckpt_path[:eval_args.ckpt_path.find('/')]
    else:
        exp_profile = eval_args.snapshot_path[:eval_args.snapshot_path.find('/')]
    eval_args.ckpt_performance_path = f'{base_path}/ckpt/{exp_profile}/performance'
    assert eval_args.dataloader_type == "random"    # 本脚本仅支持单卡

    # load ckpt
    exp_name = eval_args.ckpt_path[:eval_args.ckpt_path.find('/')] if eval_args.ckpt_path is not None else \
                eval_args.snapshot_path[:eval_args.snapshot_path.find('/')]
    args, gato, current_epoch = load_model(
        config_path=f'{base_path}/ckpt/{exp_name}/config.json', 
        ckpt_path=None if eval_args.ckpt_path is None else f'{base_path}/ckpt/{eval_args.ckpt_path}',
        snapshot_path=None if eval_args.snapshot_path is None else f'{base_path}/ckpt/{eval_args.snapshot_path}'
    )
    args.seed = eval_args.seed
    args.use_prefix = eval_args.use_prefix
    args.use_prompt = eval_args.use_prompt
    args.use_mem = eval_args.use_mem
    args.use_ddp_env = eval_args.use_ddp_env
    args.traj_type = eval_args.traj_type

    args.data_num_tsp = eval_args.data_num_tsp
    args.data_num_bp = eval_args.data_num_bp
    args.data_num_op = eval_args.data_num_op
    args.data_num_pctsp = eval_args.data_num_pctsp
    args.data_num_cvrp = eval_args.data_num_cvrp
    
    args.eval_iters_COP = eval_args.eval_iters_COP
    args.problem_batch_size = eval_args.problem_batch_size
    args.problem_batch_num = eval_args.problem_batch_num
    args.use_default_policy_obj = eval_args.use_default_policy_obj

    # load model
    exp_name = eval_args.ckpt_path[:eval_args.ckpt_path.find('/')] if eval_args.ckpt_path is not None \
                else eval_args.snapshot_path[:eval_args.snapshot_path.find('/')]
    args, gato, _ = load_model(
        config_path=f'{base_path}/ckpt/{exp_name}/config.json', 
        ckpt_path=None if eval_args.ckpt_path is None else f'{base_path}/ckpt/{eval_args.ckpt_path}',
        snapshot_path=None if eval_args.snapshot_path is None else f'{base_path}/ckpt/{eval_args.snapshot_path}'
    )
    args.seed = eval_args.seed
    args.use_prefix = eval_args.use_prefix
    args.use_prompt = eval_args.use_prompt
    args.use_mem = eval_args.use_mem
    args.use_ddp_env = eval_args.use_ddp_env
    args.traj_type = eval_args.traj_type

    args.data_num_tsp = eval_args.data_num_tsp
    args.data_num_bp = eval_args.data_num_bp
    args.data_num_op = eval_args.data_num_op
    args.data_num_pctsp = eval_args.data_num_pctsp
    args.data_num_cvrp = eval_args.data_num_cvrp
    
    args.eval_iters_COP = eval_args.eval_iters_COP
    args.problem_batch_size = eval_args.problem_batch_size
    args.problem_batch_num = eval_args.problem_batch_num
    args.use_default_policy_obj = eval_args.use_default_policy_obj
    
    # load prompt dataset and eval problem
    basic_env_builders, ddp_env_builders = [], []
    datasets_train, datasets_prompt, datasets_prompt_ddp, datasets_test, envs_problems = [], [], [], [], {}
    for env_name in eval_args.eval_envs:
        if env_name == 'Env_BP_V1':
            envs_problems[env_name] = get_bp_data_v1(args, data_type=args.eval_problem_set, is_train=False)
            datasets_train += get_bp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_bp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_BP_V1(item_num=args.bp_item_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: BP_V1(item_num=args.bp_item_num))
        elif env_name == 'Env_TSP_V1':
            envs_problems[env_name] = get_tsp_data_v1(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_tsp_data_v1(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V1(num_nodes=args.tsp_city_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V1(num_nodes=args.tsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_tsp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_TSP_V2':
            envs_problems[env_name] = get_tsp_data_v2(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_tsp_data_v2(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V2(num_nodes=args.tsp_city_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V2(num_nodes=args.tsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_tsp_data_v2(args, data_type='train')[0]
        elif env_name == 'Env_TSP_V3':
            envs_problems[env_name] = get_tsp_data_v3(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_tsp_data_v3(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V3(num_nodes=args.tsp_city_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V3(num_nodes=args.tsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_tsp_data_v3(args, data_type='train')[0]
        elif env_name == 'Env_TSP_V4':
            envs_problems[env_name] = get_tsp_data_v4(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_tsp_data_v4(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V4(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V4(num_nodes=args.tsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_tsp_data_v4(args, data_type='train')[0]
        elif env_name == 'Env_PCTSP_V1':
            envs_problems[env_name] = get_pctsp_data_v1(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_pctsp_data_v1(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V1(node_num=args.pctsp_node_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V1(node_num=args.pctsp_node_num))
            if eval_args.check_loss:
                datasets_train += get_pctsp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_PCTSP_V2':
            envs_problems[env_name] = get_pctsp_data_v2(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_pctsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V2(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V2(node_num=args.pctsp_node_num))
            if eval_args.check_loss:
                datasets_train += get_pctsp_data_v2(args, data_type='train')[0]
        elif env_name == 'Env_PCTSP_V3':
            envs_problems[env_name] = get_pctsp_data_v3(args, data_type=args.eval_problem_set, is_train=False)
            datasets_train += get_pctsp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_pctsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V3(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V3(node_num=args.pctsp_node_num))
        elif env_name == 'Env_OP_V1':
            envs_problems[env_name] = get_op_data_v1(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_op_data_v1(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V1(node_num=args.op_node_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V1(node_num=args.op_node_num))
            if eval_args.check_loss:
                datasets_train += get_op_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_OP_V2':
            envs_problems[env_name] = get_op_data_v2(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_op_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V2(node_num=args.op_node_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V2(node_num=args.op_node_num))
            if eval_args.check_loss:
                datasets_train += get_op_data_v2(args, data_type='train')[0]
        elif env_name == 'Env_OP_V3':
            envs_problems[env_name] = get_op_data_v3(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_op_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V3(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V3(node_num=args.op_node_num))
            if eval_args.check_loss:
                datasets_train += get_op_data_v3(args, data_type='train')[0]
        elif env_name == 'Env_OP_V4':
            envs_problems[env_name] = get_op_data_v4(args, data_type=args.eval_problem_set, is_train=False)
            datasets_train += get_op_data_v4(args, data_type='train')[0]
            dataset, ddp_datset = get_op_data_v4(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V4(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V4(node_num=args.op_node_num))
        elif env_name == 'Env_CVRP_V1':
            envs_problems[env_name] = get_cvrp_data_v1(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_cvrp_data_v1(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V1(node_num=args.cvrp_node_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V1(node_num=args.cvrp_node_num))
            if eval_args.check_loss:
                datasets_train += get_cvrp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_CVRP_V2':
            envs_problems[env_name] = get_cvrp_data_v2(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_cvrp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V2(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V2(node_num=args.cvrp_node_num))
            if eval_args.check_loss:
                datasets_train += get_cvrp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_CVRP_V3':
            envs_problems[env_name] = get_cvrp_data_v3(args, data_type=args.eval_problem_set, is_train=False)
            datasets_train += get_cvrp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_cvrp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V3(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V3(node_num=args.cvrp_node_num))
        else:
            raise NotImplementedError

    weight = json.loads(args.dataset_weights)
    dataset_weights = list(weight.values())
    args.eval_env_names = list(weight.keys())
    args.eval_dataset_names = [dataset.dataset_name for dataset in datasets_prompt_ddp]    

    # build envs for evaluation
    eval_prompt_strat = eval_args.prompt_strategy.split(";")[-1]     # moving_prompt
    if eval_args.use_ddp_env:
        envs = [DDP_LMPromptEnv(env_builer(), args, prompt_dataset, eval_prompt_strat) for env_builer, prompt_dataset in zip(ddp_env_builders, datasets_prompt_ddp)]
    else:
        envs = [LMPromptEnv(env_builer(), args, prompt_dataset, eval_prompt_strat) for env_builer, prompt_dataset in zip(basic_env_builders, datasets_prompt)]

    # build episode render if we need to check generated episodes during eval
    logger = None
    if eval_args.policy_logger:
        logger = {env_name: EXAMPLE_RENDER[env_name]() for env_name in args.eval_env_names} 
        for env_name, dataset_name in zip(args.eval_env_names, args.eval_dataset_names):
            create_folder_overwrite_if_exist(f'{base_path}/visualize/eval/log/{env_name}/{dataset_name}')
    create_folder_overwrite_if_exist(eval_args.ckpt_performance_path)


    
    # build dataloader
    (
        dataloader_val, 
        dataloader_train, 
        dataset_val, 
        dataset_train 
    ) = build_dataloader(args, eval_args, datasets_train, dataset_weights)

    # check eval loss
    set_seed(42)
    gato.eval()
    gato.transformer.same_length = False        # use normal context length when loss calculating (TransformerXL back bone)
    with torch.inference_mode():
        train_loss = get_loss(gato, dataloader_train, dataset_train, eval_args, logger, is_train=True)
        eval_loss = get_loss(gato, dataloader_val, dataset_val, eval_args, is_train=False)

    # build envs for evaluation
    eval_prompt_strat = eval_args.prompt_strategy.split(";")[-1]     # moving_prompt
    envs = [
        LMPromptEnv(
            dataset.env_name, 
            args, 
            dataset, 
            eval_prompt_strat,
            **envs_paras[dataset.env_name]
        ) for dataset in datasets_prompt
    ]    
    
    # figure out eval rollout times of each env 
    eval_iters_dict = {
        'COPTask': eval_args.eval_iters_COP, 
        'RLTask': eval_args.eval_iters_RL
    }
    eval_iters = {
        env.env_name: 
            len(envs_problems[env.env_name].answer_list) if eval_iters_dict[env.task_type] == 0 else 
            min(len(envs_problems[env.env_name].answer_list), eval_iters_dict[env.task_type]) 
        for env in envs
    }
    
    # evaluate policy
    eval_setting = {
        'greedy_cst': lambda: eval_policy(args, envs, gato, eval_iters, logger, seed, envs_problems, device,
                sample_action=False, hard_action_constraint=True),
        #'sample_cst': lambda: eval_policy(args, envs, gato, eval_iters, logger, seed, envs_problems, device,
        #        sample_action=True, hard_action_constraint=True),
        #f'vote{eval_args.regen_times}_cst': lambda: eval_policy(args, envs, gato, eval_iters, logger, seed, envs_problems, device,
        #        sample_action=True, hard_action_constraint=True, regen_times=eval_args.regen_times),
        #'greedy_free': lambda: eval_policy(args, envs, gato, eval_iters, logger, seed, envs_problems, device,
        #        sample_action=False, hard_action_constraint=False),
        #'sample_free': lambda: eval_policy(args, envs, gato, eval_iters, logger, seed, envs_problems, device,
        #        sample_action=True, hard_action_constraint=False),
        #f'vote{eval_args.regen_times}_free': lambda: eval_policy(args, envs, gato, eval_iters, logger, seed, envs_problems, device,
        #        sample_action=True, hard_action_constraint=False, regen_times=eval_args.regen_times),
    }
    result = {
        setting: {
            'return': {data_name:[] for data_name in args.eval_dataset_names},
            'safe': {data_name:[] for data_name in args.eval_dataset_names},
            'time': {data_name:[] for data_name in args.eval_dataset_names},
        } for setting in eval_setting.keys()
    }

    gato.eval()
    gato.transformer.same_length = args.use_mem         # use fixed context length when rollout with mem (TransformerXL back bone)
    #with torch.no_grad():
    with torch.inference_mode():
        for seed in eval_args.seeds:
            set_seed(seed)
            for setting, eval_func in eval_setting.items():
                epi_return, epi_safe, epi_time = eval_func()
                for data_name in args.eval_dataset_names:
                    result[setting]['return'][data_name].extend(epi_return[data_name])
                    result[setting]['safe'][data_name].extend(epi_safe[data_name])
                    result[setting]['time'][data_name].extend(epi_time[data_name])
            print()

    # save the eval result
    create_folder_overwrite_if_exist(f'{base_path}/visualize/eval/result')
    with open(f'{base_path}/visualize/eval/result/config.json', 'w') as f:
        f.write(json.dumps(vars(eval_args), indent=4))
    with open(f'{base_path}/visualize/eval/result/result.txt', 'w') as file:
        pass

    for data_name in args.eval_dataset_names:
        res_return = {setting: np.mean(ret['return'][data_name]) for setting, ret in result.items()}
        res_safe = {setting: np.mean(ret['safe'][data_name]) for setting, ret in result.items()}
        res_time = {setting: np.mean(ret['time'][data_name]) for setting, ret in result.items()}

        # 打印结果
        print('='*20+f' {data_name} '+'='*20)
        pp = pprint.PrettyPrinter(indent=4)
        print('return:\t', end='')
        pp.pprint({k: round(v,3) for k,v in res_return.items()})
        print('safe:\t', end='')
        pp.pprint({k: round(v,3) for k,v in res_safe.items()})
        print('time:\t', end='')
        pp.pprint({k: round(v,3) for k,v in res_time.items()})
        print()

        # 结果保存为 txt 文件
        with open(f'{base_path}/visualize/eval/result/result.txt', 'a') as f:
            sys.stdout = f  # 将标准输出重定向到文件
            print('='*20+f' {data_name} '+'='*20)
            pp = pprint.PrettyPrinter(indent=4)
            print('return:\t', end='')
            pp.pprint({k: round(v,3) for k,v in res_return.items()})
            print('safe:\t', end='')
            pp.pprint({k: round(v,3) for k,v in res_safe.items()})
            print('time:\t', end='')
            pp.pprint({k: round(v,3) for k,v in res_time.items()})
            print() 
            sys.stdout = sys.__stdout__  # 恢复标准输出

        '''
        # 保存柱状图
        settings = list(res_return.keys())
        returns = list(res_return.values())
        returns = [np.mean(v) for v in returns]
        safes = list(res_safe.values())
        safes = [np.mean(v) for v in safes]
        times = list(res_time.values())
        times = [np.mean(v) for v in times]

        fig = plt.figure(figsize=(12, 10))
        a1 = fig.add_subplot(3,1,1, label='a1')
        a2 = fig.add_subplot(3,1,2, label='a2')
        a3 = fig.add_subplot(3,1,3, label='a3')

        a1.bar(settings, returns, capsize=5)
        a1.set_xlabel('setting')
        a1.set_ylabel('return')
        for index, value in enumerate(returns):
            a1.text(index, value, str(round(value, 2)), ha='center', va='bottom')
        
        a2.bar(settings, safes, capsize=5)
        a2.set_xlabel('setting')
        a2.set_ylabel('safe ratio')
        for index, value in enumerate(safes):
            a2.text(index, value, str(round(value, 2)), ha='center', va='bottom')
        
        a3.bar(settings, times, capsize=5)
        a3.set_xlabel('setting')
        a3.set_ylabel('time')
        for index, value in enumerate(times):
            a3.text(index, value, str(round(value, 2)), ha='center', va='bottom')
        
        plt.savefig(f'{base_path}/visualize/eval/result/{data_name}.png')
        '''