import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import pickle
import numpy as np
from environment.used.BaseEnv_COP import RawData
from tqdm import tqdm
from data.used.CVRP_V1.load_raw_data import load_existing_raw_data, raw_data_to_problem
from utils.utils import create_folder_if_not_exist
from utils.COP_slover import CVRP_lkh, calc_vrp_distance
from utils.utils import split_rawdata
import multiprocessing as mp

CAPACITIES = {
    10: 20.,
    20: 30.,
    50: 40.,
    100: 50.
}

def load_raw_data(num_nodes:int, check_data_num:dict):
    datasets = {}
    print('-'*80)
    for data_type in check_data_num.keys():
        path = f'{base_path}/data/used/_raw/cvrp/cvrp{num_nodes}_{data_type}.pkl' 
        with open(path, 'rb') as f:
            dataset = pickle.load(f)
            check_num = check_data_num[data_type]
            if check_num != 0:
                dataset = split_rawdata(dataset, 0, check_num)
        datasets[data_type] = dataset
        assert len(dataset.answer_list) == len(dataset.problem_list) == len(dataset.cost_list)
        print(f'\t{data_type} dataset:   \t[{len(dataset.answer_list)}] raw data with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')    
    print('-'*80)
    return datasets

def raw_data_check(info):
    global datasets
    i, data_type = info
    dataset = datasets[data_type]

    data_problem, data_answer, data_cost = dataset.problem_list[i], dataset.answer_list[i], dataset.cost_list[i] 
    position = np.vstack((data_problem['pos_depot'][None,:], data_problem['pos_node']))
    real_cost = calc_vrp_distance(position, data_answer)
    if abs(real_cost - data_cost) > 5e-4:
        print(f'Find Bad Case! {data_type}-{i} Record error: real cost {real_cost} != data cost {data_cost}')
        return i
    
    pos_depot = data_problem['pos_depot']
    pos_node = data_problem['pos_node']
    demand = data_problem['demand']
    capacity_left = CAPACITIES[pos_node.shape[0]]
    for _ in range(5):
        real_cost, real_answer, _ = CVRP_lkh(pos_depot.tolist(), pos_node.tolist(), demand.tolist(), capacity_left)
        if real_answer is not None:
            break
    if abs(real_cost - data_cost) > 5e-4:
        print(f'Find Bad Case! {data_type}-{i}: real cost {real_cost} != data cost {data_cost}')
        return i
    
    return None

def raw_data_check_multiprocessing(raw_data:RawData, process_num:int=10, data_num:int=0, data_type:str='train'):
    data_num = len(raw_data.answer_list) if data_num == 0 else min(data_num, len(raw_data.answer_list))
    with mp.Pool(processes=process_num) as pool:
        results = tqdm(
            pool.imap_unordered(raw_data_check, [(i, data_type) for i in range(data_num)]),
            total=data_num,
        )  # 'total' is redundant here but can be useful when the size of the iterable is unobvious
        results = list(results)
    results = [res for res in results if res is not None]
    return results

if __name__ == "__main__":
    # 构造环境
    num_nodes = 20
    worker_num = os.cpu_count()
    check_data_num = {
        'problem': 0,
        'prompt': 0,
        'train': 0,
    }

    # 用于保存数据的文件夹
    create_folder_if_not_exist(f'{base_path}/data/used/_raw/cvrp/fixed')

    # 加载数据
    datasets = load_raw_data(num_nodes, check_data_num)

    # 基本数据检查
    dataset_all = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    for dataset in datasets.values(): 
        dataset_all.seed_list.extend(dataset.seed_list)
        dataset_all.problem_list.extend(dataset.problem_list)
        dataset_all.answer_list.extend(dataset.answer_list)
        dataset_all.cost_list.extend(dataset.cost_list)
    assert len(dataset.problem_list) == len(dataset.answer_list) == len(dataset.cost_list), '数据长度不匹配'
    assert len(set(dataset.seed_list)) == len(dataset.seed_list), '存在使用相同随机种子的数据生成进程'
    positions = np.array([problem['pos_node'] for problem in dataset.problem_list])         # (data_num, node_num, 2)
    assert positions.shape[0] == np.unique(positions, axis=0).shape[0], '存在完全相同的问题'

    # 重新求解并检查所有数据
    for data_type, dataset in datasets.items():
        bad_idx_list = raw_data_check_multiprocessing(dataset, worker_num, data_type=data_type)
        print(f'Removed {len(bad_idx_list)} error sample in {data_type} dataset')
        if len(bad_idx_list) != 0:
            dataset.answer_list = [v for i, v in enumerate(dataset.answer_list) if i not in bad_idx_list]
            dataset.problem_list = [v for i, v in enumerate(dataset.problem_list) if i not in bad_idx_list]
            dataset.cost_list = [v for i, v in enumerate(dataset.cost_list) if i not in bad_idx_list]
            assert len(dataset.answer_list) == len(dataset.problem_list) == len(dataset.cost_list)

            with open(f'{base_path}/data/used/_raw/cvrp/fixed/cvrp{num_nodes}_{data_type}.pkl', 'wb') as f:
                pickle.dump(dataset, f)
    
    # 过滤后数据结果
    print('-'*80)
    for data_type, dataset in datasets.items():
        print(f'\t{data_type}   dataset:\t[{len(dataset.answer_list)}] raw data with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')
    print('-'*80)