import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import torch
import json
import wandb
import shutil
from pathlib import Path
from environment.wrapper import LMPromptEnv
from environment.DDP_wrapper import DDP_LMPromptEnv
from train_test.config import parse_args
from train_test.DDP_trainer import Trainer
from utils.utils import create_folder_overwrite_if_exist, create_folder_if_not_exist
from data.used.make_data import *
from environment.used.Env_bp_v1 import BP_V1, DDP_BP_V1
from environment.used.Env_bp_v2 import BP_V2, DDP_BP_V2
from environment.used.Env_cvrp_v1 import CVRP_V1, DDP_CVRP_V1
from environment.used.Env_cvrp_v2 import CVRP_V2, DDP_CVRP_V2
from environment.used.Env_cvrp_v3 import CVRP_V3, DDP_CVRP_V3
from environment.used.Env_ffsp_v1 import DDP_FFSP_V1
from environment.used.Env_atsp_v1 import ATSP_V1, DDP_ATSP_V1
from environment.used.Env_atsp_v2 import ATSP_V2, DDP_ATSP_V2
from environment.used.Env_tsp_v1 import TSP_V1, DDP_TSP_V1
from environment.used.Env_tsp_v2 import TSP_V2, DDP_TSP_V2
from environment.used.Env_tsp_v3 import TSP_V3, DDP_TSP_V3
from environment.used.Env_tsp_v4 import TSP_V4, DDP_TSP_V4
from environment.used.Env_op_v1 import OP_V1, DDP_OP_V1
from environment.used.Env_op_v2 import OP_V2, DDP_OP_V2
from environment.used.Env_op_v3 import OP_V3, DDP_OP_V3
from environment.used.Env_op_v4 import OP_V4, DDP_OP_V4
from environment.used.Env_pctsp_v1 import PCTSP_V1, DDP_PCTSP_V1
from environment.used.Env_pctsp_v2 import PCTSP_V2, DDP_PCTSP_V2
from environment.used.Env_pctsp_v3 import PCTSP_V3, DDP_PCTSP_V3
from environment.used.Env_ffsp_v1 import DDP_FFSP_V1
from environment.used.Env_spctsp_v2 import SPCTSP_V2, DDP_SPCTSP_V2
from environment.used.Env_spctsp_v3 import SPCTSP_V3, DDP_SPCTSP_V3
from torch.distributed import init_process_group, destroy_process_group
import setproctitle
setproctitle.setproctitle("GATO-pretrain@XXX")
os.environ['TQDM_AUTO_KILL'] = '1'  # 关闭默认的进度条共享行为

# CUDA_VISIBLE_DEVICES=3,6 torchrun --standalone --nproc_per_node=gpu --master_addr=127.0.0.2 --master_port=29100 ./train_test/DDP_PFT2TSP100.py 

def ddp_setup():
    os.environ["MASTER_ADDR"] = "127.0.0.2" # 由于这里是单机实验所以直接写 localhost
    os.environ["MASTER_PORT"] = "29111"     # 任意空闲端口
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

def get_args_ready(WORLD_SIZE, RANK):
    # NOTE(XXX): Only a part of paras in args structure are used now
    args = parse_args()
    args.world_size = WORLD_SIZE

    # core paras
    args.model = 'llama'   # 'llama' or 'transformer_xl'
    args.n_embed = 768              # embedding dimension
    args.n_q_head = 8               # attention query head num (for llama GQA)
    args.n_kv_head = 8              # attention key/value head num (for llama GQA)
    args.n_head = args.n_q_head     # attention head num (for TransformerXL)
    args.n_position = 500           # model input sequence length (max context length)
    args.n_layer = 10               # transformer block num
    args.auto_batch_len = True      # Automatically clip the sample length to the maximum length in the batch
    args.rms_norm_eps = 1e-6        # RMS Norm epsilon (for llama)
    args.num_workers = 0
    if args.model == 'llama':
        assert args.n_q_head % args.n_kv_head == 0
    
    # trainig paras
    args.is_cop_pretrain = False     
    args.is_cop_cmp = True 

    args.pretrained_ckpt = 'FT-7All20(embv2-7x20w)_500_768_8|8_10/best/0.96_seed42_epoch120.pt'
    #args.pretrained_ckpt = None
    args.skip_first_eval = True
    args.dataset_weights = '{"Env_TSP_V3":1}'
    args.lr_max = 1.0e-4
    args.lr_begin = 1.0e-4
    args.eval_interval = 1                 # epoch interval for eval loss calculating
    args.eval_policy_interval = 1           # epoch interval for policy evaluation
    args.problem_batch_size = 200           # eval problem batch_size (per GPU)
    args.problem_batch_num = 3
    common_data_num = 150000

    args.is_obs_pretrain = False
    args.train_iters = 160                      # training epoch num
    args.batch_size_vaild = 100*WORLD_SIZE      # training batch size (all GPU)
    args.batch_num = 75*3                       # training batch num per epoch (all GPU)
    args.start_grad_accum = 3                   # The equivalent batch_size is [batch_size_vaild * grad_accum_step]
    args.end_grad_accum = 3                     # increase to maximum
    args.grad_accum_step_incr_style = "linear"   # "constant" or "linear" or "power"
    args.batch_size = int(args.batch_size_vaild / WORLD_SIZE)
    if args.is_obs_pretrain:
        assert args.pretrained_ckpt is None

    # loss eval paras
    args.eval_batch_size_vaild = 200*WORLD_SIZE  # eval batch size (all GPU)
    args.eval_batch_num = 5                    # eval batch num (all GPU)
    assert args.train_iters % args.eval_interval == 0
    args.use_early_stopping = True
    args.early_stopping_patience = 20
    args.early_stopping_delta = 0.0
    args.eval_batch_size = int(args.eval_batch_size / WORLD_SIZE)

    # policy eval paras
    assert args.train_iters % args.eval_policy_interval == 0
    args.eval_iters_COP = WORLD_SIZE * args.problem_batch_size * args.problem_batch_num
    args.eval_max_step_size = 1000      # max rollout timestep for policy evaluation
    args.use_default_policy_obj = False  # Whether to use the default random policy obj value to calculate epi quality in evaluation
    args.use_ddp_env = True
    args.use_mem = False
    if args.is_obs_pretrain:
        args.eval_policy_interval = args.train_iters

    # DDP paras
    args.snapshot_save_interval = 1
    args.use_amp = False
    args.dataloader_type = "DDP"
    assert args.train_iters % args.snapshot_save_interval == 0

    # prefix & prompt paras
    args.prompt_strategy = "stochastic_subseq;moving_prompt"
    args.prompt_prob = 0.25               
    args.prompt_ratio = 0.5                      
    args.prompt_at_final_transition_prob = 0.5
    args.use_prefix = True
    args.use_dynamic_prefix = True
    args.use_prompt = False
    assert args.use_prompt ^ args.use_prefix
    if args.use_prefix:
        assert args.use_ddp_env, 'prefix is not supported by basic env currently'
    if not args.use_prefix:
        assert not args.use_dynamic_prefix
        
    # optimizer paras
    args.lr_warmup_ratio = 0.05
    args.lr_decay_ratio = 0.75
    args.lr_decay_factor = 10
    args.lr_decay_style = "cosine"
    args.start_weight_decay = 0.0001
    args.end_weight_decay = args.start_weight_decay
    args.weight_decay_incr_style = "constant"
    args.use_checkpoint_opt_param_scheduler = False
    args.override_opt_param_scheduler = not args.use_checkpoint_opt_param_scheduler

    # embedding paras
    args.tokenizer_ver = 'v2'
    args.discretize_mu = 15
    args.discretize_M = 4
    args.num_continuous_bin = 1800
    args.num_discrete_values = 200
    '''
    args.mlp_emb_items = {               # each item here corresponding to a linear layer for embedding, and the 'item_name' items are the obs item name of MDP episode data
        'position': {'dim': 2, 'item_name': ['position', 'pos_depot', 'pos_node']},
    }
    '''
    args.mlp_emb_items = {}
    if args.auto_batch_len:
        assert args.mlp_emb_items == {}

    # env paras
    args.atsp_city_num = 20
    args.tsp_city_num = 100
    args.op_node_num = 20
    args.pctsp_node_num = 20
    args.spctsp_node_num = 20
    args.cvrp_node_num = 20
    args.bp_item_num = 20
    args.ffsp_job_num = 20

    args.data_num_bp = common_data_num
    args.data_num_atsp = common_data_num
    args.data_num_tsp = common_data_num
    args.data_num_op = common_data_num
    args.data_num_cvrp = common_data_num
    args.data_num_pctsp = common_data_num
    args.data_num_spctsp = common_data_num
    
    args.special_tokens = {
        "<|>": args.num_discrete_values + args.num_continuous_bin,
        "<X>": args.num_discrete_values + args.num_continuous_bin + 1
    }

    # ckpt paras
    #args.exp_profile = 'TEST'
    args.exp_profile = f"FT-to-TSP100-(emb{args.tokenizer_ver})"
    #args.exp_profile = f"Cmp2FT-TSP100-(emb{args.tokenizer_ver})"
    
    exp_name = f'{args.exp_profile}_{args.n_position}_{args.n_embed}_{args.n_head}_{args.n_layer}' if args.model == 'transformer_xl' else \
                f'{args.exp_profile}_{args.n_position}_{args.n_embed}_{args.n_q_head}|{args.n_kv_head}_{args.n_layer}'
    args.save_dir = f'{base_path}/ckpt/{exp_name}' if not args.is_obs_pretrain else f'{base_path}/ckpt/pretrain/{exp_name}'
    args.save_strategy = 'best'
    args.save_interval = args.eval_policy_interval
    assert args.save_interval % args.eval_policy_interval == 0

    # other paras
    args.traj_type = 'all'
    args.dataset_distribution = 'uniform'
    args.eval_problem_set = 'problem'       # 'train_problem' or 'problem'
    args.disable_visited_obs = True         # whether to remove visited obs when constructing MDP episode like DB1
    args.seeds = [42, ]                     # random seeds
    args.wandb = True                      # log the exp curve to wandb or not
    args.policy_logger = False              # render the episodes generated during training or not
    args.traindata_logger = False           # whether to log the sample idx during training
    args.save_ckpt = True                  # save model paras ckpt during training or not 
    args.save_snapshot = True              # save snapshot during DDP training or not 
    
    # get ready for wandb logging
    if RANK == 0:
        create_folder_if_not_exist(f'{base_path}/Wandb')
    if not args.wandb:       
        os.environ['WANDB_MODE'] = 'offline'

    # create floder to save ckpts and hyperparas if we need
    if (args.save_ckpt or args.save_snapshot) and RANK == 0:
        #create_folder_overwrite_if_exist(f'{args.save_dir}/{args.save_strategy}')
        create_folder_if_not_exist(f'{args.save_dir}/{args.save_strategy}')
        with open(f'{args.save_dir}/config.json', 'w') as f:
            f.write(json.dumps(vars(args), indent=4))
        shutil.copy2(
            src=f'{base_path}/train_test/DDP_train.py',
            dst=f'{args.save_dir}/train.py',
        )

    return exp_name, args

def get_train_objs(args, RANK):
    # load datasets & datasets weight
    weight = json.loads(args.dataset_weights)

    basic_env_builders, ddp_env_builders = [], []
    datasets_train, datasets_prompt, datasets_prompt_ddp, envs_problems = [], [], [], {}
    for env_name in weight.keys():
        if env_name == 'Env_BP_V1':
            envs_problems[env_name] = get_bp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_bp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_bp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_BP_V1(item_num=args.bp_item_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: BP_V1(item_num=args.bp_item_num))
        elif env_name == 'Env_BP_V2':
            envs_problems[env_name] = get_bp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_bp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_bp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_BP_V2(item_num=args.bp_item_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: BP_V2(item_num=args.bp_item_num))
        elif env_name == 'Env_FFSP_V1':
            envs_problems[env_name] = get_ffsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_ffsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_ffsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_FFSP_V1(job_cnt=args.ffsp_job_num, batch_size=args.problem_batch_size))
            #basic_env_builders.append(lambda: FFSP_V1(num_nodes=args.ffsp_job_num))
        elif env_name == 'Env_ATSP_V1':
            envs_problems[env_name] = get_atsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_atsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_atsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_ATSP_V1(num_nodes=args.atsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: ATSP_V1(num_nodes=args.atsp_city_num))
        elif env_name == 'Env_ATSP_V2':
            envs_problems[env_name] = get_atsp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_atsp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_atsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_ATSP_V2(num_nodes=args.atsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: ATSP_V2(num_nodes=args.atsp_city_num))
        elif env_name == 'Env_TSP_V1':
            envs_problems[env_name] = get_tsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_tsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_tsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V1(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V1(num_nodes=args.tsp_city_num))
        elif env_name == 'Env_TSP_V2':
            envs_problems[env_name] = get_tsp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_tsp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_tsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V2(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V2(num_nodes=args.tsp_city_num))
        elif env_name == 'Env_TSP_V3':
            envs_problems[env_name] = get_tsp_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_tsp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_tsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V3(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V3(num_nodes=args.tsp_city_num))
        elif env_name == 'Env_TSP_V4':
            envs_problems[env_name] = get_tsp_data_v4(args, data_type=args.eval_problem_set)
            datasets_train += get_tsp_data_v4(args, data_type='train')[0]
            dataset, ddp_datset = get_tsp_data_v4(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_TSP_V4(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: TSP_V4(num_nodes=args.tsp_city_num))
        elif env_name == 'Env_PCTSP_V1':
            envs_problems[env_name] = get_pctsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_pctsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_pctsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V1(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V1(node_num=args.pctsp_node_num))
        elif env_name == 'Env_PCTSP_V2':
            envs_problems[env_name] = get_pctsp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_pctsp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_pctsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V2(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V2(node_num=args.pctsp_node_num))
        elif env_name == 'Env_PCTSP_V3':
            envs_problems[env_name] = get_pctsp_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_pctsp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_pctsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_PCTSP_V3(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: PCTSP_V3(node_num=args.pctsp_node_num))
        elif env_name == 'Env_SPCTSP_V2':
            envs_problems[env_name] = get_spctsp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_spctsp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_spctsp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_SPCTSP_V2(node_num=args.spctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: SPCTSP_V2(node_num=args.spctsp_node_num))
        elif env_name == 'Env_SPCTSP_V3':
            envs_problems[env_name] = get_spctsp_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_spctsp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_spctsp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_SPCTSP_V3(node_num=args.spctsp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: SPCTSP_V3(node_num=args.spctsp_node_num))
        elif env_name == 'Env_OP_V1':
            envs_problems[env_name] = get_op_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_op_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_op_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V1(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V1(node_num=args.op_node_num))
        elif env_name == 'Env_OP_V2':
            envs_problems[env_name] = get_op_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_op_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_op_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V2(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V2(node_num=args.op_node_num))
        elif env_name == 'Env_OP_V3':
            envs_problems[env_name] = get_op_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_op_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_op_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V3(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V3(node_num=args.op_node_num))
        elif env_name == 'Env_OP_V4':
            envs_problems[env_name] = get_op_data_v4(args, data_type=args.eval_problem_set)
            datasets_train += get_op_data_v4(args, data_type='train')[0]
            dataset, ddp_datset = get_op_data_v4(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset                   
            ddp_env_builders.append(lambda: DDP_OP_V4(node_num=args.op_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: OP_V4(node_num=args.op_node_num))
        elif env_name == 'Env_CVRP_V1':
            envs_problems[env_name] = get_cvrp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_cvrp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_cvrp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V1(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V1(node_num=args.cvrp_node_num))
        elif env_name == 'Env_CVRP_V2':
            envs_problems[env_name] = get_cvrp_data_v2(args, data_type=args.eval_problem_set)
            datasets_train += get_cvrp_data_v2(args, data_type='train')[0]
            dataset, ddp_datset = get_cvrp_data_v2(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V2(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V2(node_num=args.cvrp_node_num))
        elif env_name == 'Env_CVRP_V3':
            envs_problems[env_name] = get_cvrp_data_v3(args, data_type=args.eval_problem_set)
            datasets_train += get_cvrp_data_v3(args, data_type='train')[0]
            dataset, ddp_datset = get_cvrp_data_v3(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset            
            ddp_env_builders.append(lambda: DDP_CVRP_V3(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(lambda: CVRP_V3(node_num=args.cvrp_node_num))
        elif env_name == 'Env_FFSP_V1':
            envs_problems[env_name] = get_ffsp_data_v1(args, data_type=args.eval_problem_set)
            datasets_train += get_ffsp_data_v1(args, data_type='train')[0]
            dataset, ddp_datset = get_ffsp_data_v1(args, data_type='prompt', get_dataset=not args.use_ddp_env, get_ddp_dataset=args.use_ddp_env)
            datasets_prompt += dataset
            datasets_prompt_ddp += ddp_datset
            ddp_env_builders.append(lambda: DDP_FFSP_V1(job_cnt=args.ffsp_job_num, batch_size=args.problem_batch_size))
            basic_env_builders.append(None)
        else:
            raise NotImplementedError
    
    dataset_weights = list(weight.values())
    args.eval_env_names = list(weight.keys())
    args.eval_dataset_names = [dataset.dataset_name for dataset in datasets_train]    
    
    '''
    # make sure all data are tokenized properly (only for use_prompt without mlp_emb)
    if int(os.environ["RANK"]) == 0:
        for dataset in datasets_train:
            for i in tqdm(range(len(dataset)), total=len(dataset), desc=f'Checking tokenize format of {dataset.dataset_name}_train'):
                dataset.check_token_list_format(dataset.get(i))
        if not args.use_ddp_env:
            for dataset in datasets_prompt:
                for i in tqdm(range(len(dataset)), total=len(dataset), desc=f'Checking tokenize format of {dataset.dataset_name}_prompt'):
                    dataset.check_token_list_format(dataset.get(i))
    '''
                    
    # build envs for evaluation
    eval_prompt_strat = args.prompt_strategy.split(";")[-1] # moving_prompt
    if args.use_ddp_env:
        envs = [DDP_LMPromptEnv(env_builer(), args, prompt_dataset, eval_prompt_strat) for env_builer, prompt_dataset in zip(ddp_env_builders, datasets_prompt_ddp)]
    else:
        envs = [LMPromptEnv(env_builer(), args, prompt_dataset, eval_prompt_strat) for env_builer, prompt_dataset in zip(basic_env_builders, datasets_prompt)]

    # build episode render if we need to check generated episodes during training
    logger = None
    if args.policy_logger or args.traindata_logger:
        logger = {env_name: EXAMPLE_RENDER[env_name]() for env_name in args.eval_env_names}    

        if RANK == 0:
            prompts = datasets_prompt if not args.use_ddp_env else datasets_prompt_ddp
            for dataset in prompts:
                create_folder_overwrite_if_exist(f'{base_path}/visualize/train/{dataset.env_name}/{dataset.dataset_name}')
        
    return datasets_train, dataset_weights, envs, envs_problems, logger

if __name__ == "__main__":
    # init DDP process group
    print(os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"])
    ddp_setup()
    WORLD_SIZE = int(os.environ.get("WORLD_SIZE", default='1'))
    RANK = int(os.environ.get("RANK", default='0'))
    print(os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"])

    # get hyper paras ready
    exp_name, args = get_args_ready(WORLD_SIZE, RANK)

    # load training objs
    datasets_train, dataset_weights, envs, envs_problems, logger = get_train_objs(args, RANK)

    # train
    for seed in args.seeds:
        if args.save_dir and args.save_strategy == 'interval' and RANK == 0:
            create_folder_overwrite_if_exist(f'{args.save_dir}/interval/{seed}')
        args.seed = seed

        # This unique id is necessary for log resuming
        wandb_id = wandb.util.generate_id() 
        
        # build trainer
        trianer = Trainer(
            args, seed, wandb_id, 
            envs, logger,
            datasets_train, 
            dataset_weights, 
            envs_problems
        )

        # wandb log only on rank0
        if RANK == 0:
            with wandb.init(
                # set the wandb project where this run will be logged
                project="gato-compare",
                #project="gato-ddp-test",
                dir = Path(f'{base_path}/Wandb'),
                group = exp_name,
                name = f"seed_{seed}",
                id = trianer.wandb_id,
                resume = 'allow',
                config=args
            ):
                raw_model = trianer.gato.module if hasattr(trianer.gato, "module") else trianer.gato
                wandb.watch(raw_model, log='all', log_freq=100)
                trianer.train()
        else:
            trianer.train()

        assert wandb.run is None

    # destroy DDP process group
    destroy_process_group()