import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import pickle
import random
from utils.COP_slover import TSP_lkh
from utils.utils import split_rawdata
from tqdm import tqdm
import numpy as np
from environment.used.BaseEnv_COP import RawData

        
def check_data(num_nodes, data_distribution, load_data_num=None, check_data_num=5000):
    ''' 检查生成的数据是否存在问题 '''
    # 加载所有源数据
    datasets = {}
    dataset = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    for data_type in ['problem', 'train', 'prompt']:
        path = f'{base_path}/data/used/_raw/tsp/tsp{num_nodes}({data_distribution})_{data_type}.pkl'
        if os.path.exists(path):
            with open(path, 'rb') as f:
                sub_dataset = pickle.load(f)
                if load_data_num is not None:
                    sub_dataset = split_rawdata(sub_dataset, 0, load_data_num)
            datasets[data_type] = sub_dataset
            dataset.seed_list.extend(sub_dataset.seed_list)
            dataset.problem_list.extend(sub_dataset.problem_list)
            dataset.answer_list.extend(sub_dataset.answer_list)
            dataset.cost_list.extend(sub_dataset.cost_list)

    assert len(dataset.problem_list) == len(dataset.answer_list) == len(dataset.cost_list), '数据长度不匹配'
    assert len(set(dataset.seed_list)) == len(dataset.seed_list), '存在使用相同随机种子的数据生成进程'
    positions = np.array([problem['position'] for problem in dataset.problem_list])         # (data_num, node_num, 2)
    assert positions.shape[0] == np.unique(positions, axis=0).shape[0], '存在完全相同的问题'

    # 打印数据集信息
    dataset_train, dataset_prompt, dataset_problem = datasets['train'], datasets['prompt'], datasets['problem']
    print('-'*80)
    print(f'There are [{len(dataset.answer_list)}] raw data generated with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')
    print(f'\tTrain dataset:    [{len(dataset_train.answer_list)}] raw data with [{len(dataset_train.seed_list)}] different random seeds, ave cost = [{np.mean(dataset_train.cost_list):.3f}]')
    print(f'\tProblem dataset:  [{len(dataset_problem.answer_list)}] raw data with [{len(dataset_problem.seed_list)}] different random seeds, ave cost = [{np.mean(dataset_problem.cost_list):.3f}]')
    print(f'\tPrompt dataset:   [{len(dataset_prompt.answer_list)}] raw data with [{len(dataset_prompt.seed_list)}] different random seeds, ave cost = [{np.mean(dataset_prompt.cost_list):.3f}]')
    print('-'*80)

    # 随机检查 check_data_num 个数据
    if check_data_num > 0:
        dataset_size = len(dataset.answer_list)
        check_idx = random.sample(range(dataset_size), min(dataset_size, check_data_num))
        for idx in tqdm(check_idx, desc=f"Check data validity"):
            data_pos = dataset.problem_list[idx]['position']
            data_answer = dataset.answer_list[idx]
            data_cost = dataset.cost_list[idx]
            cost, answer = TSP_lkh(np.array(data_pos, np.float32)) 
            assert abs(cost - data_cost) < 1e-5
            assert answer == data_answer


if __name__ == "__main__":
    num_nodes = 20                      # TSP城市数量
    data_distribution = 'uniform'       # TSP城市分布类型
    source_data = check_data(num_nodes, data_distribution, load_data_num=None)
    


    

    