import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import numpy as np
from utils.COP_slover import PCTSP_ILS
from utils.utils import create_folder_overwrite_if_exist, create_folder_if_not_exist, load_existing_raw_data, merge_to_dataset
from environment.used.BaseEnv_COP import RawData
import multiprocessing
import pickle
import math
import time
import random
import concurrent.futures
from functools import partial
from tqdm import tqdm

MAX_LENGTHS = {
    10: 1.5,
    20: 2.,
    50: 3.,
    100: 4.
}

def check_data(show_info=False, check_data_num=5000):
    ''' 检查生成的数据是否存在问题 '''
    # 加载源数据
    dataset = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    for data_type in ['problem', 'train', 'prompt']:
        path = f'{data_base_path}/data/used/_raw/pctsp/pctsp{node_num}_{data_type}.pkl'
        if os.path.exists(path):
            with open(path, 'rb') as f:
                sub_dataset = pickle.load(f)
            dataset.seed_list.extend(sub_dataset.seed_list)
            dataset.problem_list.extend(sub_dataset.problem_list)
            dataset.answer_list.extend(sub_dataset.answer_list)
            dataset.cost_list.extend(sub_dataset.cost_list)

    assert len(dataset.problem_list) == len(dataset.answer_list) == len(dataset.cost_list), '数据长度不匹配'
    assert len(set(dataset.seed_list)) == len(dataset.seed_list), '存在使用相同随机种子的数据生成进程'
    positions = np.array([problem['pos_node'] for problem in dataset.problem_list])         # (data_num, node_num, 2)
    assert positions.shape[0] == np.unique(positions, axis=0).shape[0], '存在完全相同的问题'

    if show_info:
        # 打印数据集信息
        print('-'*80)
        print(f'There are [{len(dataset.answer_list)}] raw data generated with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')
        print('-'*80)

    # 随机检查 check_data_num 个数据
    if check_data_num > 0:
        dataset_size = len(dataset.answer_list)
        check_idx = random.sample(range(dataset_size), min(dataset_size, check_data_num))
        for idx in tqdm(check_idx, desc=f"Check data validity"):
            pos_depot = dataset.problem_list[idx]['pos_depot']
            pos_node = dataset.problem_list[idx]['pos_node']
            penalty = dataset.problem_list[idx]['penalty']
            prize = dataset.problem_list[idx]['prize']
            data_answer = dataset.answer_list[idx]
            data_cost = dataset.cost_list[idx]
            cost, answer, _ = PCTSP_ILS(
                pos_depot.tolist(), 
                pos_node.tolist(), 
                penalty.tolist(), 
                prize.tolist(),
            )
            assert abs(cost - data_cost) < 1e-4
            assert answer + [0] == data_answer

def gen_raw_data(data_num_per_process, seed, node_num, tqdm_pos, lock, penalty_factor=3):
    ''' 数据生成进程的目标方法 '''
    # 生成问题，求解并写入 dataset
    dataset = RawData(seed_list=[seed,], problem_list=[], answer_list=[], cost_list=[])
    rng = np.random.RandomState(seed)  # 创建独立的随机数生成器对象
    with tqdm(total=data_num_per_process, desc=f'Gen PCTSP data with seed {seed}', position=tqdm_pos) as pbar:
        for _ in range(data_num_per_process):
            #lock.acquire()  
            real_answer = None
            cost = float('inf')
            # 为了加速生成，runs参数设为2，求解精度较低
            # 这导致生成的解有概率不满足约束或得到次优解（20城市的情况下根据经验设cost>5）这种情况下重新生成
            while real_answer is None or cost > 5:  
                pos_depot = rng.uniform(0, 1, size=(2, ))
                pos_node = rng.uniform(0, 1, size=(node_num, 2))
                penalty_max = MAX_LENGTHS[node_num] * (penalty_factor) / float(node_num)
                penalty = np.random.uniform(size=(node_num, )) * penalty_max
                prize = np.random.uniform(size=(node_num,)) * 4 / float(node_num)
                while prize.sum() < 1:
                    prize = np.random.uniform(size=(node_num,)) * 4 / float(node_num)

                # PCTSP 问题的解从仓库出发，经过若干站点后在仓库结束
                # 仓库索引为 0，站点索引从 1 开始
                # 调用 ILS 方法求得的解格式中首尾的仓库都不包含
                cost, real_answer, _ = PCTSP_ILS(
                    pos_depot.tolist(), 
                    pos_node.tolist(), 
                    penalty.tolist(), 
                    prize.tolist(),
                )
            real_answer += [0]  # answer 格式中包含终止的仓库, 不包含出发的仓库
            #lock.release()

            dataset.problem_list.append({'pos_depot': pos_depot, 'pos_node': pos_node, 'prize': prize, 'penalty': penalty})
            dataset.answer_list.append(real_answer)
            dataset.cost_list.append(cost)
            pbar.update()

    # 保存该进程生成的原始数据
    dataset_path = f'{data_base_path}/data/used/_raw/pctsp/temp/num{data_num_per_process}_seed{seed}.pkl'
    with open(dataset_path, 'wb') as f:
        pickle.dump(dataset, f)
    return dataset_path

if __name__ == "__main__":
    # 设置参数
    node_num = 20                       # PCTSP城市数量
    data_num_per_process = 10         # 每个进程保存的样本数量
    worker_num = os.cpu_count()        # 同时执行的生成进程数量
    overwrite = False                   # 是否覆盖已有的数据
    data_num_dict = {
        'train': 250000,
        'problem': 15000,
        'prompt': 5000,
    }

    data_base_path = base_path
    #data_base_path = '/data3/XXX/toy-gato'    

    # 在临时文件夹保存求解器可能产生的中间文件
    create_folder_overwrite_if_exist(f'{data_base_path}/temp') 

    # 在临时文件夹保存每个子进程生成的数据文件
    create_folder_if_not_exist(f'{data_base_path}/data/used/_raw/pctsp/temp') 

    # 加载已经存在的数据
    ex_datasets = {}
    ex_seeds = []
    for data_type in ['problem', 'train', 'prompt']:
        ex_data_path = f'{data_base_path}/data/used/_raw/pctsp/pctsp{node_num}_{data_type}.pkl'
        ex_dataset = load_existing_raw_data(ex_data_path, overwrite)
        ex_datasets[data_type] = ex_dataset
        ex_seeds.extend(ex_dataset.seed_list)

    # 生成新数据
    for data_type in ['problem', 'train', 'prompt']:
        target_data_num = data_num_dict[data_type]                              # 该类目标样本量   
        ex_dataset = ex_datasets[data_type]                                     # 该类已存在的数据集
        ex_data_num = len(ex_dataset.answer_list)                               # 该类已存在的样本量
        data_to_go = target_data_num - ex_data_num                              # 该类还要生成的样本量
        processes_num = math.ceil(data_to_go/data_num_per_process)              # 所需的进程数量（每个进程生成data_num_per_process个样本）
        real_data_num = ex_data_num + data_num_per_process * processes_num      # 执行完毕后得到的该类真实样本量

        # 打印数据集信息
        if data_to_go > 0:    
            print('-'*100)
            print(f'There are [{ex_data_num}] raw data for [{data_type}] generated, [{data_to_go}] to go')
            print('-'*100)
        else:
            print('-'*100)
            print(f'There are [{ex_data_num}] raw data for [{data_type}] generated, the requirements of [{target_data_num}] have already been met')
            print('-'*100)
            continue

        # 分阶段生成所有数据，每个阶段最多worker_num个进程并行，直到总进程数量达到processes_num
        # 每个阶段最后集合该阶段所有并行进程生成的数据，保存到target_dataset_path
        target_dataset_path = f'{data_base_path}/data/used/_raw/pctsp/pctsp{node_num}_{data_type}.pkl'
        target_dataset_backup_path = f'{data_base_path}/data/used/_raw/pctsp/pctsp{node_num}_{data_type}_backup.pkl'
        gen_data_num = ex_data_num
        stage_num = 0
        while processes_num > 0:
            # 该阶段各个进程使用的随机种子
            max_seed = -1 if ex_seeds == [] else max(ex_seeds)
            stage_worker_num = min(worker_num, processes_num)
            stage_data_num = data_num_per_process * stage_worker_num
            processes_num -= stage_worker_num
            seeds = list(range(max_seed+1, max_seed+1+stage_worker_num))
            ex_seeds.extend(seeds)
            
            # 创建一个 ThreadPoolExecutor 对象，在循环中提交任务
            print('='*50 + f' STAGE {stage_num}: seeds = [{max_seed+1}, {max_seed+1+stage_worker_num}) data {gen_data_num} ~ {gen_data_num + stage_data_num}/{target_data_num}' + '='*50)
            lock = multiprocessing.Lock()
            dataset_path_list = []
            with concurrent.futures.ThreadPoolExecutor(max_workers=stage_worker_num) as executor:
                futures = {}
                for i, seed in enumerate(seeds):
                    pos = i % worker_num
                    futures[executor.submit(partial(gen_raw_data, data_num_per_process, seed, node_num, pos, lock, 3))] = seed

                for future in concurrent.futures.as_completed(futures):
                    try:
                        dataset_path = future.result()
                        dataset_path_list.append(dataset_path)
                    except Exception as e:
                        print(f"{e}")

            # 等待所有进程执行完毕
            for future in concurrent.futures.as_completed(futures):
                future.result()
            time.sleep(2)

            # 将该阶段生成的所有子数据集合并
            gen_data_num += stage_data_num
            stage_num += 1
            merge_to_dataset(dataset_path_list, target_dataset_path, target_dataset_backup_path)

            # 数据合法性检查
            check_data(show_info=False, check_data_num=0)
            os.remove(target_dataset_backup_path)   # 通过检查则删除备份文件

    print('='*55 + f' ALL DONE ' + '='*55)
    check_data(show_info=True, check_data_num=200)