import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.append(base_path)

import numpy as np
import os
from typing import List, Dict, Union
from environment.used.BaseEnv_COP import DataProblem
from dataloader.code.dataset import RLFullDataset, DatasetAdapter
from dataloader.code.DDP_dataset import DDP_RLFullDataset
from environment.used.Env_bp_v1 import BinBackpack_logger
from environment.used.Env_tsp_v1 import TSP_logger
from environment.used.Env_tsp_v2 import TSP_logger_V2
from environment.used.Env_tsp_v3 import TSP_logger_V3
from environment.used.Env_tsp_v4 import TSP_logger_V4
from environment.used.Env_pctsp_v1 import PCTSP_logger_V1
from environment.used.Env_pctsp_v2 import PCTSP_logger_V2
from environment.used.Env_pctsp_v3 import PCTSP_logger_V3
from environment.used.Env_op_v1 import OP_logger_V1
from environment.used.Env_op_v2 import OP_logger_V2
from environment.used.Env_op_v3 import OP_logger_V3
from environment.used.Env_cvrp_v1 import CVRP_logger_V1
from environment.used.Env_cvrp_v2 import CVRP_logger_V2
from utils.utils import split_dataproblem, load_data

EXAMPLE_RENDER = {
    'Env_01BP_V1': lambda: BinBackpack_logger(
        env_name='Env_01BP_V1', 
        dataset_name='01BP_V1'
    ),
    'Env_TSP_V1': lambda: TSP_logger(
        env_name='Env_TSP_V1', 
        dataset_name='TSP_V1'
    ),
    'Env_TSP_V2': lambda: TSP_logger_V2(
        env_name='Env_TSP_V2', 
        dataset_name='TSP_V2'
    ),
    'Env_TSP_V3': lambda: TSP_logger_V3(
        env_name='Env_TSP_V3', 
        dataset_name='TSP_V3'
    ),
    'Env_TSP_V4': lambda: TSP_logger_V4(
        env_name='Env_TSP_V4', 
        dataset_name='TSP_V4'
    ),
    'Env_PCTSP_V1': lambda: PCTSP_logger_V1(
        env_name='Env_PCTSP_V1', 
        dataset_name='PCTSP_V1'
    ),
    'Env_PCTSP_V2': lambda: PCTSP_logger_V2(
        env_name='Env_PCTSP_V2', 
        dataset_name='PCTSP_V2'
    ),
    'Env_PCTSP_V3': lambda: PCTSP_logger_V3(
        env_name='Env_PCTSP_V3', 
        dataset_name='PCTSP_V3'
    ),
    'Env_OP_V1': lambda: OP_logger_V1(
        env_name='Env_OP_V1', 
        dataset_name='OP_V1'
    ),
    'Env_OP_V2': lambda: OP_logger_V2(
        env_name='Env_OP_V2', 
        dataset_name='OP_V2'
    ),
    'Env_OP_V3': lambda: OP_logger_V3(
        env_name='Env_OP_V3', 
        dataset_name='OP_V3'
    ),
    'Env_CVRP_V1': lambda: CVRP_logger_V1(
        env_name='Env_CVRP_V1', 
        dataset_name='CVRP_V1'
    ),
    'Env_CVRP_V2': lambda: CVRP_logger_V2(
        env_name='Env_CVRP_V2', 
        dataset_name='CVRP_V2'
    ),
    'Env_FFSP_V1': lambda: TSP_logger(  # 仅debug用
        env_name='Env_FFSP_V1', 
        dataset_name='FFSP_V1'
    )
}

class DatasetAdapterGeneral(DatasetAdapter):
    def __init__(self, dataset_name, epi_obs, epi_act, epi_prefix=None, disable_visited_obs=False):
        super().__init__(dataset_name, epi_obs, epi_act, epi_prefix, disable_visited_obs)
        
def to_d4rl_format(episodes:List) -> Dict:
    obss, acts, rwds, terminals = [], [], [], []
    for epi in episodes:
        obss.append(epi['observations'])
        acts.append(epi['actions'])
        rwds.append(epi['rewards'])
        terminals.append(epi['terminals'])
    
    prefixes, prefix_masks = None, None
    if episodes[0]['prefix'] is not None:
        prefixes, prefix_masks = [], []
        for epi in episodes:
            prefixes.append(epi['prefix'])
            prefix_masks.append(epi['prefix_masks'])

    return {
        'prefixes': prefixes, 
        'prefix_masks': prefix_masks, 
        'observations': obss, 
        'actions': acts, 
        'rewards': rwds, 
        'terminals': terminals
    }

def get_data_by_num(data:Union[List, DataProblem], data_type:str, eval_iter:int, data_num:int, is_train=True):
    def _get_data_num():
        world_size = int(os.environ.get("WORLD_SIZE", default='1'))

        # 问题数据，手动调整数据量使其可以被平均分配到多卡
        if data_type == 'problem':
            num = eval_iter                             # 加载未见问题时，直接加载评估所需数量
            num -= (num % world_size)                   # 调整为可以被GPU数量整除
        elif data_type == 'train_problem':
            train_data_num = int(data_num * 0.9)        # 加载训练问题时（debug），先计算训练数据量，再从中加载所需数量 
            num = min(eval_iter, train_data_num)
            num -= (num % world_size)                   # 调整为可以被GPU数量整除
        
        # 轨迹数据，之后由 dataloader 调整数据量使其可以被平均分配到多卡
        elif data_type == 'train':
            num = data_num                          # 训练集轨迹数据
        elif data_type == 'prompt':
            num = min(int(data_num * 0.1), 1500)    # 用作提示序列的轨迹数据
        elif data_type == 'test':                   
            num = eval_iter                         # 未见问题对应的轨迹数据
        else:
            raise NotImplementedError

        return num

    # 确定加载数据量
    if data_num == 0:
        assert not data_type.endswith('problem')    # 加载全部轨迹
        data_num = len(data)
    if eval_iter == 0:
        assert data_type.endswith('problem')        # 加载全部测试问题
        eval_iter = len(data.answer_list)
    num = _get_data_num()

    # 返回指定量数据
    if data_type == 'train_problem':
        return split_dataproblem(data, 0, num)
    elif data_type == 'problem':
        if is_train:
            return split_dataproblem(data, 0, num)      # 训练时使用前部分问题作为验证集
        else:
            return split_dataproblem(data, -num, None)  # 评估时使用后部分问题作为测试集，和验证集不重叠
    else:
        assert num <= len(data), f'{data_type} data num {num} > {len(data)}'
        return data[:num]

def get_RLFullDataset(args, data, data_name, env_name, get_dataset=True, get_ddp_dataset=False, adapter_builder=DatasetAdapterGeneral):
    data = to_d4rl_format(data)
    adapter = adapter_builder(
        dataset_name = data_name, 
        epi_obs = data['observations'][0], 
        epi_act = data['actions'][0], 
        epi_prefix = data['prefixes'][0] if data['prefixes'] is not None else None,
        disable_visited_obs=args.disable_visited_obs
    )
    dataset = RLFullDataset(args, data, adapter, data_name, env_name) if get_dataset else None
    ddp_dataset = DDP_RLFullDataset(args, data, adapter, data_name, env_name) if get_ddp_dataset else None
    return [dataset, ], [ddp_dataset, ]   

# ------------------------------ 01BP --------------------------------------
def get_bp_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_BP_V1'
    data_name = 'BP_V1'
    data_file_name=f'bp{args.bp_item_num}_{data_type}'
    
    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'bp{args.bp_item_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass
    
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_bp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_bp_data_v2(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_BP_V2'
    data_name = 'BP_V2'
    data_file_name=f'bp{args.bp_item_num}_{data_type}'
    
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_bp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

# ------------------------------ FFSP --------------------------------------
def get_ffsp_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_FFSP_V1'
    data_name = 'FFSP_V1'
    data_file_name = f'ffsp{args.ffsp_job_num}_{data_type}'

    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'ffsp{args.ffsp_job_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass
    
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_ffsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

# ------------------------------ ATSP --------------------------------------
def get_atsp_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_ATSP_V1'
    data_name = 'ATSP_V1'
    data_file_name = f'atsp{args.atsp_city_num}_{data_type}'

    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'atsp{args.atsp_city_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass
    
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_atsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_atsp_data_v2(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_ATSP_V2'
    data_name = 'ATSP_V2'
    data_file_name = f'atsp{args.atsp_city_num}_{data_type}'

    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'atsp{args.atsp_city_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass
    
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_atsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

# ------------------------------ TSP --------------------------------------
def get_tsp_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_TSP_V1'
    data_name = 'TSP_V1'
    data_file_name=f'tsp{args.tsp_city_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_tsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_tsp_data_v2(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_TSP_V2'
    data_name = 'TSP_V2'
    data_file_name=f'tsp{args.tsp_city_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_tsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_tsp_data_v3(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_TSP_V3'
    data_name = 'TSP_V3'
    data_file_name = f'tsp{args.tsp_city_num}_{data_type}'

    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'tsp{args.tsp_city_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass
    
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_tsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_tsp_data_v4(args, data_type='train', distribution=None, get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_TSP_V4'
    data_name = 'TSP_V4'
    data_file_name = f'tsp{args.tsp_city_num}_{data_type}' if distribution is None else f'tsp{args.tsp_city_num}({distribution})_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_tsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)
# ------------------------------ PCTSP --------------------------------------
def get_pctsp_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_PCTSP_V1'
    data_name = 'PCTSP_V1'
    data_file_name = f'pctsp{args.pctsp_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_pctsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)
    
def get_pctsp_data_v2(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_PCTSP_V2'
    data_name = 'PCTSP_V2'
    data_file_name = f'pctsp{args.pctsp_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_pctsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)
    
def get_pctsp_data_v3(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_PCTSP_V3'
    data_name = 'PCTSP_V3'
    data_file_name = f'pctsp{args.pctsp_node_num}_{data_type}'

    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'pctsp{args.pctsp_node_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass

    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_pctsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

# ------------------------------ SPCTSP ---------------------------------
def get_spctsp_data_v2(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_SPCTSP_V2'
    data_name = 'SPCTSP_V2'
    data_file_name = f'spctsp{args.spctsp_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_spctsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_spctsp_data_v3(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_SPCTSP_V3'
    data_name = 'SPCTSP_V3'
    data_file_name = f'spctsp{args.spctsp_node_num}_{data_type}'

    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'spctsp{args.spctsp_node_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass

    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_spctsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

# ------------------------------ OP --------------------------------------
def get_op_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_OP_V1'
    data_name = 'OP_V1'
    data_file_name = f'op{args.op_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_op, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_op_data_v2(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_OP_V2'
    data_name = 'OP_V2'
    data_file_name = f'op{args.op_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_op, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_op_data_v3(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_OP_V3'
    data_name = 'OP_V3'
    data_file_name = f'op{args.op_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_op, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_op_data_v4(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_OP_V4'
    data_name = 'OP_V4'
    data_file_name = f'op{args.op_node_num}_{data_type}'

    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'op{args.op_node_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass

    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_op, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

# ------------------------------ CVRP --------------------------------------
def get_cvrp_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_CVRP_V1'
    data_name = 'CVRP_V1'
    data_file_name = f'cvrp{args.cvrp_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_cvrp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_cvrp_data_v2(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_CVRP_V2'
    data_name = 'CVRP_V2'
    data_file_name = f'cvrp{args.cvrp_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_cvrp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_cvrp_data_v3(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_CVRP_V3'
    data_name = 'CVRP_V3'
    data_file_name = f'cvrp{args.cvrp_node_num}_{data_type}'

    # 加载数据
    data = load_data(data_name=data_name, data_file_name=data_file_name)
    
    # 对于 Prefix 模型，用未使用的 prompt 数据扩展评估问题 problem 数据数据集
    if data_type == 'problem':
        try:
            prompt_data = load_data(data_name=data_name, data_file_name=f'cvrp{args.cvrp_node_num}_prompt_problem')
            data.problem_list.extend(prompt_data.problem_list)
            data.answer_list.extend(prompt_data.answer_list)
            if data.prefix_list is not None:
                data.prefix_list.extend(prompt_data.prefix_list)
        except FileNotFoundError:
            pass
    
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_cvrp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

# ------------------------------ FFSP --------------------------------------
def get_ffsp_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_FFSP_V1'
    data_name = 'FFSP_V1'
    data_file_name = f'ffsp{args.ffsp_job_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_ffsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_ffsp_data_v2(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_FFSP_V2'
    data_name = 'FFSP_V2'
    data_file_name = f'ffsp{args.ffsp_job_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_ffsp, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)

def get_mis_data_v1(args, data_type='train', get_dataset=True, get_ddp_dataset=False, is_train=True):
    assert data_type in ['train', 'prompt', 'test', 'problem', 'train_problem']
    env_name='Env_MIS_V1'
    data_name = 'MIS_V1'
    data_file_name = f'mis{args.mis_node_num}_{data_type}'

    data = load_data(data_name=data_name, data_file_name=data_file_name)
    data = get_data_by_num(data, data_type, args.eval_iters_COP, args.data_num_mis, is_train)
    if data_type.endswith('problem'):
        return data 
    return get_RLFullDataset(args, data, data_name, env_name, get_dataset, get_ddp_dataset)