import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import torch
import json
import setproctitle
import argparse
from environment.wrapper import LMPromptEnv
from environment.DDP_wrapper import DDP_LMPromptEnv
from environment.used.Env_bp_v1 import BP_V1, DDP_BP_V1
from environment.used.Env_bp_v2 import BP_V2, DDP_BP_V2
from environment.used.Env_cvrp_v1 import CVRP_V1, DDP_CVRP_V1
from environment.used.Env_cvrp_v2 import CVRP_V2, DDP_CVRP_V2
from environment.used.Env_cvrp_v3 import CVRP_V3, DDP_CVRP_V3
from environment.used.Env_atsp_v1 import ATSP_V1, DDP_ATSP_V1
from environment.used.Env_tsp_v1 import TSP_V1, DDP_TSP_V1
from environment.used.Env_tsp_v2 import TSP_V2, DDP_TSP_V2
from environment.used.Env_tsp_v3 import TSP_V3, DDP_TSP_V3
from environment.used.Env_tsp_v4 import TSP_V4, DDP_TSP_V4
from environment.used.Env_op_v1 import OP_V1, DDP_OP_V1
from environment.used.Env_op_v2 import OP_V2, DDP_OP_V2
from environment.used.Env_op_v3 import OP_V3, DDP_OP_V3
from environment.used.Env_op_v4 import OP_V4, DDP_OP_V4
from environment.used.Env_pctsp_v1 import PCTSP_V1, DDP_PCTSP_V1
from environment.used.Env_pctsp_v2 import PCTSP_V2, DDP_PCTSP_V2
from environment.used.Env_pctsp_v3 import PCTSP_V3, DDP_PCTSP_V3
from environment.used.Env_spctsp_v2 import SPCTSP_V2, DDP_SPCTSP_V2
from environment.used.Env_spctsp_v3 import SPCTSP_V3, DDP_SPCTSP_V3
from utils.utils import create_folder_overwrite_if_exist, str2bool, load_model
from evaluate_test.DDP_evaluater import Evaluater
from data.used.make_data import *
setproctitle.setproctitle("GATO-eval@XXX")

# 用于 DDP 并行的库
from torch.utils.data.distributed import DistributedSampler         
from torch.nn.parallel import DistributedDataParallel as DDP         
from torch.distributed import init_process_group, destroy_process_group 

def ddp_setup():
    os.environ["MASTER_ADDR"] = "localhost" # 由于这里是单机实验所以直接写 localhost
    os.environ["MASTER_PORT"] = "12358"     # 任意空闲端口
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

def get_args_ready(WORLD_SIZE):
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset-distribution", type=str, default='uniform',)
    parser.add_argument("--data-num-tsp", type=int, default=0,)
    parser.add_argument("--data-num-bp", type=int, default=0,)
    parser.add_argument("--data-num-pctsp", type=int, default=0,)
    parser.add_argument("--data-num-op", type=int, default=0,)
    parser.add_argument("--data-num-cvrp", type=int, default=0,)
    parser.add_argument("--traj-type", type=str, default="all")
    parser.add_argument("--ckpt-path", type=str, default=None,)
    parser.add_argument("--snapshot-path", type=str, default=None,)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--eval-envs", nargs='+', type=str, default=['Env_01BP_V1', ])
    parser.add_argument("--eval-problem-set", type=str, default='problem', choices={'problem', 'train_problem'})
    parser.add_argument("--eval-iters-COP", type=int, default=0,)   # 设0则用整个评估问题集
    parser.add_argument("--eval-iters-RL", type=int, default=5,)
    parser.add_argument("--batch-num", type=int, default=0,)
    parser.add_argument("--batch-size", type=int, default=100,)
    parser.add_argument("--eval-batch-num", type=int, default=0,)
    parser.add_argument("--eval-batch-size", type=int, default=100,)
    parser.add_argument("--test-batch-num", type=int, default=0,)
    parser.add_argument("--test-batch-size", type=int, default=100,)
    parser.add_argument("--problem-batch-size", type=int, default=100,)
    parser.add_argument("--problem-batch-num", type=int, default=20,)
    parser.add_argument("--regen-times", type=int, default=5,)
    parser.add_argument("--policy-logger", type=str2bool, default=False, nargs="?", const=True)
    parser.add_argument("--traindata-logger", type=str2bool, default=False, nargs="?", const=True)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--use-ddp-env", type=str2bool, default=False, nargs="?", const=True)
    parser.add_argument("--use-prefix", type=str2bool, default=False)
    parser.add_argument("--use-prompt", type=str2bool, default=True)
    parser.add_argument("--use-mem", type=str2bool, default=True)
    parser.add_argument("--use-default-policy-obj", type=str2bool, default=False)
    parser.add_argument("--check-loss", type=str2bool, default=False)
    parser.add_argument(
        "--dataloader-type",
        type=str,
        default="DDP",
        choices=["sequential", "random", "DDP"],
        help="Fetch data sequentially or out-of-order",
    )
    parser.add_argument("--prompt-strategy", type=str, 
        default="stochastic_subseq;moving_prompt",
        choices={
            "stochastic_timestep;moving_prompt",
            "stochastic_subseq;moving_prompt",
            "stochastic_timestep;fixed_prompt",
            "stochastic_subseq;fixed_prompt",
        },
    )
    eval_args = parser.parse_args()

    #eval_args.ckpt_path = "ScaleLaw20(embv1-4x15w-llama)_800_1024_16|16_12/best/0.97_seed42_epoch220.pt"
    #eval_args.snapshot_path = "ScaleLaw20(embv1-4x15w-llama)_800_1024_16|16_12/snapshot_seed42.pt"
    #eval_args.eval_envs = ['Env_PCTSP_V1', 'Env_TSP_V2', 'Env_OP_V2', 'Env_CVRP_V1']
    
    #eval_args.ckpt_path = 'Cmp-BP20(embv2-15w-ours)_500_768_8|8_10/best/0.97_seed42_epoch72.pt'
    #eval_args.ckpt_path = 'FT-BP20(embv2-15w)_500_768_8|8_10/best/0.98_seed42_epoch64.pt'
    #eval_args.eval_envs = ['Env_BP_V1',]

    #eval_args.ckpt_path = 'Cmp-TSP20(embv2-15w-ours)_500_768_8|8_10/best/0.97_seed42_epoch68.pt'
    #eval_args.ckpt_path = 'Cmp-TSP20(embv2-25w-ours)_500_768_8|8_10/best/0.98_seed42_epoch110.pt'
    #eval_args.ckpt_path = 'FT-TSP20(embv2-15w)_500_768_8|8_10/best/0.97_seed42_epoch60.pt'
    #eval_args.eval_envs = ['Env_TSP_V3',]

    #eval_args.ckpt_path = 'Cmp-PCTSP20(embv2-15w-ours)_500_768_8|8_10/best/0.86_seed42_epoch50.pt'
    #eval_args.ckpt_path = 'Cmp-PCTSP20(embv2-25w-ours)_500_768_8|8_10/best/0.9_seed42_epoch70.pt'
    #eval_args.ckpt_path = 'Cmp-PCTSP20(embv2-50w-ours)_500_768_8|8_10/best/0.92_seed42_epoch110.pt'
    #eval_args.ckpt_path = 'FT-PCTSP20(embv2-15w)_500_768_8|8_10/best/0.85_seed42_epoch40.pt'
    #eval_args.eval_envs = ['Env_PCTSP_V3',]

    #eval_args.ckpt_path = 'Cmp-CVRP20(embv2-15w-ours)_500_768_8|8_10/best/0.87_seed42_epoch48.pt'
    #eval_args.ckpt_path = 'FT-CVRP20(embv2-15w)_500_768_8|8_10/best/0.88_seed42_epoch44.pt'
    #eval_args.eval_envs = ['Env_CVRP_V3',]

    #eval_args.ckpt_path = 'Cmp-OP20(embv2-15w-ours)_500_768_8|8_10/best/0.78_seed42_epoch40.pt'
    #eval_args.ckpt_path = 'FT-OP20(embv2-15w)_500_768_8|8_10/best/0.88_seed42_epoch36.pt'
    #eval_args.eval_envs = ['Env_OP_V4',]

    #eval_args.ckpt_path = 'Cmp-7All20(embv2-7x20w-ours)_500_768_8|8_10/best/0.95_seed42_epoch112.pt'
    #eval_args.ckpt_path = 'FT-7All20(embv2-7x20w)_500_768_8|8_10/best/0.96_seed42_epoch120.pt'
    #eval_args.eval_envs = ["Env_BP_V1","Env_ATSP_V1","Env_SPCTSP_V3",'Env_OP_V4','Env_PCTSP_V3','Env_TSP_V3','Env_CVRP_V3',]
    #eval_args.eval_envs = ['Env_CVRP_V3',]
    #eval_args.dataset_weights = '{"Env_CVRP_V3":1}'    

    eval_args.ckpt_path = 'Cmp-6All20(embv2-6x20w-GATO)_1000_768_8_10/best/0.958_seed42_epoch100.pt'
    eval_args.eval_envs = ["Env_BP_V2","Env_SPCTSP_V2","Env_TSP_V2",'Env_PCTSP_V1','Env_OP_V2','Env_CVRP_V1',]
    #eval_args.eval_envs = ['Env_CVRP_V1',]
    #eval_args.dataset_weights = '{"Env_CVRP_V1":1}'  
    #eval_args.eval_envs = ['Env_TSP_V2',]  
    #eval_args.dataset_weights = '{"Env_TSP_V2":1}'  
    

    common_data_num = 1000
    eval_args.data_num_spctsp = common_data_num
    eval_args.data_num_atsp = common_data_num
    eval_args.data_num_bp = common_data_num
    eval_args.data_num_tsp = common_data_num
    eval_args.data_num_op = common_data_num
    eval_args.data_num_cvrp = common_data_num
    eval_args.data_num_pctsp = common_data_num

    # 评估训练损失、验证损失、测试损失
    eval_args.traj_type = 'all'
    eval_args.eval_problem_set = 'problem'  # train_problem or problem
    eval_args.batch_size = 40
    eval_args.batch_num = 20
    eval_args.eval_batch_size = 200      # 总 batch_size
    eval_args.eval_batch_num = 20
    eval_args.test_batch_size = 40      # 总 batch_size
    eval_args.test_batch_num = 15
    
    # 评估问题
    eval_args.problem_batch_size = 100  # 单 GPU 上的 batch_size
    eval_args.problem_batch_num = 10
    eval_args.eval_iters_COP = WORLD_SIZE * eval_args.problem_batch_size * eval_args.problem_batch_num

    # 控制参数
    eval_args.policy_logger = False
    eval_args.traindata_logger = False
    eval_args.use_mem = False
    eval_args.use_ddp_env = True
    eval_args.use_default_policy_obj = False
    eval_args.check_loss = False
    
    assert (eval_args.ckpt_path is None) ^ (eval_args.snapshot_path is None) 
    if eval_args.ckpt_path is not None:
        exp_profile = eval_args.ckpt_path[:eval_args.ckpt_path.find('/')]
    else:
        exp_profile = eval_args.snapshot_path[:eval_args.snapshot_path.find('/')]
    eval_args.ckpt_performance_path = f'{base_path}/ckpt/{exp_profile}/performance'
    assert eval_args.dataloader_type == 'DDP'
    return eval_args

def get_train_components(eval_args, WORLD_SIZE, RANK):
    # load model
    exp_name = eval_args.ckpt_path[:eval_args.ckpt_path.find('/')] if eval_args.ckpt_path is not None \
                else eval_args.snapshot_path[:eval_args.snapshot_path.find('/')]
    args, gato, _ = load_model(
        config_path=f'{base_path}/ckpt/{exp_name}/config.json', 
        ckpt_path=None if eval_args.ckpt_path is None else f'{base_path}/ckpt/{eval_args.ckpt_path}',
        snapshot_path=None if eval_args.snapshot_path is None else f'{base_path}/ckpt/{eval_args.snapshot_path}'
    )
    args.seed = eval_args.seed
    args.world_size = WORLD_SIZE
    #args.use_prefix = eval_args.use_prefix
    #args.use_prompt = eval_args.use_prompt
    args.use_mem = eval_args.use_mem
    args.use_ddp_env = eval_args.use_ddp_env
    args.traj_type = eval_args.traj_type

    args.data_num_atsp = eval_args.data_num_atsp
    args.data_num_spctsp = eval_args.data_num_spctsp
    args.data_num_tsp = eval_args.data_num_tsp
    args.data_num_bp = eval_args.data_num_bp
    args.data_num_op = eval_args.data_num_op
    args.data_num_pctsp = eval_args.data_num_pctsp
    args.data_num_cvrp = eval_args.data_num_cvrp
    #args.dataset_weights = eval_args.dataset_weights

    args.eval_iters_COP = eval_args.eval_iters_COP
    args.problem_batch_size = eval_args.problem_batch_size
    args.problem_batch_num = eval_args.problem_batch_num
    args.use_default_policy_obj = eval_args.use_default_policy_obj
    
    # load prompt dataset and eval problem
    basic_env_builders, ddp_env_builders = [], []
    datasets_train, datasets_prompt, datasets_prompt_ddp, datasets_test, envs_problems = [], [], [], [], {}
    for env_name in eval_args.eval_envs:
        if env_name == 'Env_BP_V1':
            envs_problems[env_name] = get_bp_data_v1(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_bp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_BP_V1(item_num=args.bp_item_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: BP_V1(item_num=args.bp_item_num))
            if eval_args.check_loss:
                datasets_train += get_bp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_BP_V2':
            envs_problems[env_name] = get_bp_data_v2(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_bp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_BP_V2(item_num=args.bp_item_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: BP_V2(item_num=args.bp_item_num))
            if eval_args.check_loss:
                datasets_train += get_bp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_ATSP_V1':
            envs_problems[env_name] = get_atsp_data_v1(args, data_type=args.eval_problem_set)
            dataset, ddp_datset = get_atsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_ATSP_V1(num_nodes=args.atsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: ATSP_V1(num_nodes=args.atsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_atsp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_TSP_V1':
            envs_problems[env_name] = get_tsp_data_v1(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_tsp_data_v1(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V1(num_nodes=args.tsp_city_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V1(num_nodes=args.tsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_tsp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_TSP_V2':
            envs_problems[env_name] = get_tsp_data_v2(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_tsp_data_v2(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V2(num_nodes=args.tsp_city_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V2(num_nodes=args.tsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_tsp_data_v2(args, data_type='train')[0]
        elif env_name == 'Env_TSP_V3':
            envs_problems[env_name] = get_tsp_data_v3(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_tsp_data_v3(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V3(num_nodes=args.tsp_city_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V3(num_nodes=args.tsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_tsp_data_v3(args, data_type='train')[0]
        elif env_name == 'Env_TSP_V4':
            envs_problems[env_name] = get_tsp_data_v4(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_tsp_data_v4(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V4(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V4(num_nodes=args.tsp_city_num))
            if eval_args.check_loss:
                datasets_train += get_tsp_data_v4(args, data_type='train')[0]
        elif env_name == 'Env_PCTSP_V1':
            envs_problems[env_name] = get_pctsp_data_v1(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_pctsp_data_v1(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V1(node_num=args.pctsp_node_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V1(node_num=args.pctsp_node_num))
            if eval_args.check_loss:
                datasets_train += get_pctsp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_PCTSP_V2':
            envs_problems[env_name] = get_pctsp_data_v2(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_pctsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V2(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V2(node_num=args.pctsp_node_num))
            if eval_args.check_loss:
                datasets_train += get_pctsp_data_v2(args, data_type='train')[0]
        elif env_name == 'Env_PCTSP_V3':
            envs_problems[env_name] = get_pctsp_data_v3(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_pctsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V3(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V3(node_num=args.pctsp_node_num))
            if eval_args.check_loss:
                datasets_train += get_pctsp_data_v3(args, data_type='train')[0]
        elif env_name == 'Env_SPCTSP_V2':
            envs_problems[env_name] = get_spctsp_data_v2(args, data_type=args.eval_problem_set)
            dataset, ddp_datset = get_spctsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_SPCTSP_V2(node_num=args.spctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: SPCTSP_V2(node_num=args.spctsp_node_num))
            if eval_args.check_loss:
                datasets_train += get_spctsp_data_v2(args, data_type='train')[0]
        elif env_name == 'Env_SPCTSP_V3':
            envs_problems[env_name] = get_spctsp_data_v3(args, data_type=args.eval_problem_set)
            dataset, ddp_datset = get_spctsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_SPCTSP_V3(node_num=args.spctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: SPCTSP_V3(node_num=args.spctsp_node_num))
            if eval_args.check_loss:
                datasets_train += get_spctsp_data_v3(args, data_type='train')[0]
        elif env_name == 'Env_OP_V1':
            envs_problems[env_name] = get_op_data_v1(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_op_data_v1(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V1(node_num=args.op_node_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V1(node_num=args.op_node_num))
            if eval_args.check_loss:
                datasets_train += get_op_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_OP_V2':
            envs_problems[env_name] = get_op_data_v2(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_op_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V2(node_num=args.op_node_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V2(node_num=args.op_node_num))
            if eval_args.check_loss:
                datasets_train += get_op_data_v2(args, data_type='train')[0]
        elif env_name == 'Env_OP_V3':
            envs_problems[env_name] = get_op_data_v3(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_op_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V3(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V3(node_num=args.op_node_num))
            if eval_args.check_loss:
                datasets_train += get_op_data_v3(args, data_type='train')[0]
        elif env_name == 'Env_OP_V4':
            envs_problems[env_name] = get_op_data_v4(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_op_data_v4(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V4(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V4(node_num=args.op_node_num))
            if eval_args.check_loss:
                datasets_train += get_op_data_v4(args, data_type='train')[0]
        elif env_name == 'Env_CVRP_V1':
            envs_problems[env_name] = get_cvrp_data_v1(args, data_type=eval_args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_cvrp_data_v1(args, data_type='prompt', get_dataset=not eval_args.use_ddp_env, get_ddp_dataset=eval_args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V1(node_num=args.cvrp_node_num, batch_size=eval_args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V1(node_num=args.cvrp_node_num))
            if eval_args.check_loss:
                datasets_train += get_cvrp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_CVRP_V2':
            envs_problems[env_name] = get_cvrp_data_v2(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_cvrp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V2(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V2(node_num=args.cvrp_node_num))
            if eval_args.check_loss:
                datasets_train += get_cvrp_data_v1(args, data_type='train')[0]
        elif env_name == 'Env_CVRP_V3':
            envs_problems[env_name] = get_cvrp_data_v3(args, data_type=args.eval_problem_set, is_train=False)
            dataset, ddp_datset = get_cvrp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V3(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V3(node_num=args.cvrp_node_num))
            if eval_args.check_loss:
                datasets_train += get_cvrp_data_v3(args, data_type='train')[0]
        else:
            raise NotImplementedError

    weight = json.loads(args.dataset_weights)
    dataset_weights = list(weight.values())
    args.eval_env_names = list(envs_problems.keys())
    args.eval_dataset_names = [name[4:] for name in args.eval_env_names]    
    

    '''
    # make sure all data are tokenized properly 
    if int(os.environ["RANK"]) == 0:
        for dataset in datasets_train:
            for i in tqdm(range(len(dataset)), total=len(dataset), desc=f'Checking tokenize format of {dataset.dataset_name}_train'):
                dataset.check_token_list_format(dataset.get(i))
        for dataset in datasets_prompt:
            for i in tqdm(range(len(dataset)), total=len(dataset), desc=f'Checking tokenize format of {dataset.dataset_name}_prompt'):
                dataset.check_token_list_format(dataset.get(i))
    '''

    # build envs for evaluation
    eval_prompt_strat = eval_args.prompt_strategy.split(";")[-1]     # moving_prompt
    if eval_args.use_ddp_env:
        envs = [DDP_LMPromptEnv(env_builer(), args, prompt_dataset, eval_prompt_strat) for env_builer, prompt_dataset in zip(ddp_env_builders, datasets_prompt_ddp)]
    else:
        envs = [LMPromptEnv(env_builer(), args, prompt_dataset, eval_prompt_strat) for env_builer, prompt_dataset in zip(basic_env_builders, datasets_prompt)]

    # build episode render if we need to check generated episodes during eval
    logger = None
    if eval_args.policy_logger:
        logger = {env_name: EXAMPLE_RENDER[env_name]() for env_name in args.eval_env_names} 
        if RANK == 0:
            for env_name, dataset_name in zip(args.eval_env_names, args.eval_dataset_names):
                create_folder_overwrite_if_exist(f'{base_path}/visualize/eval/log/{env_name}/{dataset_name}')

    if RANK == 0:
        create_folder_overwrite_if_exist(eval_args.ckpt_performance_path)
        
    return args, datasets_train, datasets_test, dataset_weights, gato, envs, envs_problems, logger

def main():
    # init DDP process group
    ddp_setup()
    WORLD_SIZE = int(os.environ.get("WORLD_SIZE", default='1'))
    RANK = int(os.environ.get("RANK", default='0'))

    # get hyper paras ready
    eval_args = get_args_ready(WORLD_SIZE)

    # load training component
    args, datasets_train, datasets_test, dataset_weights, gato, envs, envs_problems, logger = get_train_components(eval_args, WORLD_SIZE, RANK)

    # build evaluater
    evaluater = Evaluater(
        args, eval_args,
        gato, envs, logger,
        datasets_train, 
        datasets_test,
        dataset_weights, 
        envs_problems,
    )

    # evaluate
    evaluater.evaluate()

    # destroy DDP process group
    destroy_process_group()

if __name__ == "__main__":
    main()
    
    

    

        
    