import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import pickle
from utils.COP_slover import OP_gurobi
from tqdm import tqdm
from utils.utils import split_rawdata
import numpy as np
from environment.used.BaseEnv_COP import RawData
import multiprocessing as mp
from utils.COP_slover import calc_op_distance, calc_op_total
from utils.utils import create_folder_if_not_exist

MAX_LENGTHS = {
    20: 2.,
    50: 3.,
    100: 4.
}

def load_raw_data(num_nodes:int, check_data_num:dict):
    datasets = {}
    print('-'*80)
    for data_type in check_data_num.keys():
        path = f'{base_path}/data/used/_raw/op/op{num_nodes}_{data_type}.pkl' 
        with open(path, 'rb') as f:
            dataset = pickle.load(f)
            check_num = check_data_num[data_type]
            if check_num != 0:
                dataset = split_rawdata(dataset, 0, check_num)
        datasets[data_type] = dataset
        assert len(dataset.answer_list) == len(dataset.problem_list) == len(dataset.cost_list)
        print(f'\t{data_type} dataset:   \t[{len(dataset.answer_list)}] raw data with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')    
    print('-'*80)
    return datasets

def raw_data_check(info):
    global datasets
    i, data_type = info
    dataset = datasets[data_type]

    data_problem, data_answer, data_cost = dataset.problem_list[i], dataset.answer_list[i], dataset.cost_list[i] 
    pos_depot = data_problem['pos_depot']
    pos_node = data_problem['pos_node']
    prize = data_problem['prize']
    length_left = MAX_LENGTHS[pos_node.shape[0]]

    tour = [0] + [idx+1 for idx in data_answer]
    real_cost = -calc_op_total(np.array(prize), np.array(tour[1:])-1)
    if abs(real_cost - data_cost) > 5e-4:
        print(f'Find Bad Case! {data_type}-{i} Record error: real cost {real_cost} != data cost {data_cost}')
        return i

    tour_len = calc_op_distance(np.vstack((pos_depot[None,:],pos_node)), tour[1:])
    if tour_len > length_left + 5e-4:
        print(f"Find Bad Case! {data_type}-{i}: Tour exceeds max_length! {tour_len} > {length_left}")
        return i

    real_cost, _ = OP_gurobi(pos_depot.tolist(), pos_node.tolist(), prize.tolist(), length_left, )
    if abs(real_cost - data_cost) > 5e-4:
        print(f'Find Bad Case! {data_type}-{i}: real cost {real_cost} != data cost {data_cost}')
        return i
    
    return None
    
    
def raw_data_check_multiprocessing(raw_data:RawData, process_num:int=10, data_num:int=0, data_type:str='train'):
    data_num = len(raw_data.answer_list) if data_num == 0 else min(data_num, len(raw_data.answer_list))
    with mp.Pool(processes=process_num) as pool:
        results = tqdm(
            pool.imap_unordered(raw_data_check, [(i, data_type) for i in range(data_num)]),
            total=data_num,
        )  # 'total' is redundant here but can be useful when the size of the iterable is unobvious
        results = list(results)
    results = [res for res in results if res is not None]
    return results

if __name__ == "__main__":
    # 构造环境
    num_nodes = 20
    worker_num = os.cpu_count()
    check_data_num = {
        'problem': 0,
        'prompt': 0,
        'train': 0,
    }

    # 用于保存数据的文件夹
    create_folder_if_not_exist(f'{base_path}/data/used/_raw/op/fixed')

    # 加载数据
    datasets = load_raw_data(num_nodes, check_data_num)

    # 基本数据检查
    dataset_all = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    for dataset in datasets.values(): 
        dataset_all.seed_list.extend(dataset.seed_list)
        dataset_all.problem_list.extend(dataset.problem_list)
        dataset_all.answer_list.extend(dataset.answer_list)
        dataset_all.cost_list.extend(dataset.cost_list)
    assert len(dataset.problem_list) == len(dataset.answer_list) == len(dataset.cost_list), '数据长度不匹配'
    assert len(set(dataset.seed_list)) == len(dataset.seed_list), '存在使用相同随机种子的数据生成进程'
    positions = np.array([problem['pos_node'] for problem in dataset.problem_list])         # (data_num, node_num, 2)
    assert positions.shape[0] == np.unique(positions, axis=0).shape[0], '存在完全相同的问题'

    # 重新求解并检查所有数据
    for data_type, dataset in datasets.items():
        bad_idx_list = raw_data_check_multiprocessing(dataset, worker_num, data_type=data_type)
        print(f'Removed {len(bad_idx_list)} error sample in {data_type} dataset')
        if len(bad_idx_list) != 0:
            dataset.answer_list = [v for i, v in enumerate(dataset.answer_list) if i not in bad_idx_list]
            dataset.problem_list = [v for i, v in enumerate(dataset.problem_list) if i not in bad_idx_list]
            dataset.cost_list = [v for i, v in enumerate(dataset.cost_list) if i not in bad_idx_list]
            assert len(dataset.answer_list) == len(dataset.problem_list) == len(dataset.cost_list)

            with open(f'{base_path}/data/used/_raw/op/fixed/op{num_nodes}_{data_type}.pkl', 'wb') as f:
                pickle.dump(dataset, f)
    
    # 过滤后数据结果
    print('-'*80)
    for data_type, dataset in datasets.items():
        print(f'\t{data_type}   dataset:\t[{len(dataset.answer_list)}] raw data with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')
    print('-'*80)