import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import numpy as np
from utils.COP_slover import OP_gurobi
from utils.utils import create_folder_overwrite_if_exist, create_folder_if_not_exist, load_existing_raw_data, merge_to_dataset
from environment.used.BaseEnv_COP import RawData
import pickle
import math
import time
import random
import concurrent.futures
from functools import partial
from tqdm import tqdm

MAX_LENGTHS = {
    10: 1.5,
    20: 2.,
    50: 3.,
    100: 4.
}

def check_data(show_info=False, check_data_num=5000):
    ''' 检查生成的数据是否存在问题 '''
    # 加载源数据
    dataset = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    for data_type in ['problem', 'train', 'prompt']:
        path = f'{base_path}/data/used/_raw/op/op{node_num}_{data_type}.pkl'
        if os.path.exists(path):
            with open(path, 'rb') as f:
                sub_dataset = pickle.load(f)
            dataset.seed_list.extend(sub_dataset.seed_list)
            dataset.problem_list.extend(sub_dataset.problem_list)
            dataset.answer_list.extend(sub_dataset.answer_list)
            dataset.cost_list.extend(sub_dataset.cost_list)

    assert len(dataset.problem_list) == len(dataset.answer_list) == len(dataset.cost_list), '数据长度不匹配'
    assert len(set(dataset.seed_list)) == len(dataset.seed_list), '存在使用相同随机种子的数据生成进程'
    positions = np.array([problem['pos_node'] for problem in dataset.problem_list])         # (data_num, node_num, 2)
    assert positions.shape[0] == np.unique(positions, axis=0).shape[0], '存在完全相同的问题'

    if show_info:
        # 打印数据集信息
        print('-'*80)
        print(f'There are [{len(dataset.answer_list)}] raw data generated with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')
        print('-'*80)

    # 随机检查 check_data_num 个数据
    if check_data_num > 0:
        dataset_size = len(dataset.answer_list)
        check_idx = random.sample(range(dataset_size), min(dataset_size, check_data_num))
        for idx in tqdm(check_idx, desc=f"Check data validity"):
            pos_depot = dataset.problem_list[idx]['pos_depot']
            pos_node = dataset.problem_list[idx]['pos_node']
            prize = dataset.problem_list[idx]['prize']
            length_left = MAX_LENGTHS[pos_node.shape[0]]
            data_answer = dataset.answer_list[idx]
            data_cost = dataset.cost_list[idx]
            cost, answer = OP_gurobi(pos_depot.tolist(), pos_node.tolist(), prize.tolist(), length_left, )
            assert abs(cost - data_cost) < 1e-5
            assert answer == data_answer

def gen_raw_data(data_num_per_process, seed, node_num, tqdm_pos, prize_type='dist'):
    ''' 数据生成进程的目标方法 '''
    # 生成问题，求解并写入 dataset
    dataset = RawData(seed_list=[seed,], problem_list=[], answer_list=[], cost_list=[])
    rng = np.random.RandomState(seed)  # 创建独立的随机数生成器对象
    with tqdm(total=data_num_per_process, desc=f'Gen OP data with seed {seed}', position=tqdm_pos) as pbar:
        for _ in range(data_num_per_process):
            pos_depot = rng.uniform(0, 1, size=(2, ))
            pos_node = rng.uniform(0, 1, size=(node_num, 2))
            length_left = MAX_LENGTHS[node_num]

            distance_node = np.zeros((node_num, node_num), dtype=np.float32)
            distance_depot = np.zeros(node_num, dtype=np.float32)
            for i in range(node_num):
                distance_depot[i] = np.linalg.norm(pos_node[i] - pos_depot)
                for j in range(node_num):
                    distance_node[i,j] = np.linalg.norm(pos_node[i] - pos_node[j])

            # 三种 prize 类型
            if prize_type == 'const':
                prize = np.ones(node_num, dtype=np.float32)
            elif prize_type == 'unif':
                prize = (1 + np.random.randint(0, 100, size=(node_num,))) / 100.
            else:  # Based on distance to depot
                assert prize_type == 'dist'
                distance = distance_depot
                prize = (1 + (distance / distance.max() * 99).astype(int)) / 100.

            # OP 问题的解从仓库出发，经过若干站点后在仓库结束
            # 仓库索引为 -1，站点索引从 0 开始
            # 调用 gurobi 方法求得的解格式中首尾的仓库都不包含
            cost, real_answer = OP_gurobi(pos_depot.tolist(), pos_node.tolist(), prize.tolist(), length_left, )

            dataset.problem_list.append({'pos_depot': pos_depot, 'pos_node': pos_node, 'prize': prize})
            dataset.answer_list.append(real_answer)
            dataset.cost_list.append(cost)
            pbar.update()

    # 保存该进程生成的原始数据
    dataset_path = f'{base_path}/data/used/_raw/op/temp/num{data_num_per_process}_seed{seed}.pkl'
    with open(dataset_path, 'wb') as f:
        pickle.dump(dataset, f)
    return dataset_path

if __name__ == "__main__":
    # 设置参数
    node_num = 20                       # OP城市数量
    data_num_per_process = 40           # 每个进程保存的样本数量
    worker_num = 10                     # 同时执行的生成进程数量
    overwrite = False                   # 是否覆盖已有的数据
    data_num_dict = {
        'train': 200000,
        'problem': 2000,
        'prompt': 15000,
    }
    
    # 在临时文件夹保存求解器可能产生的中间文件
    create_folder_overwrite_if_exist(f'{base_path}/temp') 

    # 在临时文件夹保存每个子进程生成的数据文件
    create_folder_if_not_exist(f'{base_path}/data/used/_raw/op/temp') 

    # 加载已经存在的数据
    ex_datasets = {}
    ex_seeds = []
    for data_type in ['problem', 'train', 'prompt']:
        ex_data_path = f'{base_path}/data/used/_raw/op/op{node_num}_{data_type}.pkl'
        ex_dataset = load_existing_raw_data(ex_data_path, overwrite)
        ex_datasets[data_type] = ex_dataset
        ex_seeds.extend(ex_dataset.seed_list)

    # 生成新数据
    for data_type in ['problem', 'train', 'prompt']:
        target_data_num = data_num_dict[data_type]                              # 该类目标样本量   
        ex_dataset = ex_datasets[data_type]                                     # 该类已存在的数据集
        ex_data_num = len(ex_dataset.answer_list)                               # 该类已存在的样本量
        data_to_go = target_data_num - ex_data_num                              # 该类还要生成的样本量
        processes_num = math.ceil(data_to_go/data_num_per_process)              # 所需的进程数量（每个进程生成data_num_per_process个样本）
        real_data_num = ex_data_num + data_num_per_process * processes_num      # 执行完毕后得到的该类真实样本量

        # 打印数据集信息
        if data_to_go > 0:    
            print('-'*100)
            print(f'There are [{ex_data_num}] raw data for [{data_type}] generated, [{data_to_go}] to go')
            print('-'*100)
        else:
            print('-'*100)
            print(f'There are [{ex_data_num}] raw data for [{data_type}] generated, the requirements of [{target_data_num}] have already been met')
            print('-'*100)
            continue

        # 分阶段生成所有数据，每个阶段最多worker_num个进程并行，直到总进程数量达到processes_num
        # 每个阶段最后集合该阶段所有并行进程生成的数据，保存到target_dataset_path
        target_dataset_path = f'{base_path}/data/used/_raw/op/op{node_num}_{data_type}.pkl'
        target_dataset_backup_path = f'{base_path}/data/used/_raw/op/op{node_num}_{data_type}_backup.pkl'
        gen_data_num = ex_data_num
        stage_num = 0
        while processes_num > 0:
            # 该阶段各个进程使用的随机种子
            max_seed = -1 if ex_seeds == [] else max(ex_seeds)
            stage_worker_num = min(worker_num, processes_num)
            stage_data_num = data_num_per_process * stage_worker_num
            processes_num -= stage_worker_num
            seeds = list(range(max_seed+1, max_seed+1+stage_worker_num))
            ex_seeds.extend(seeds)
            
            # 创建一个 ThreadPoolExecutor 对象，在循环中提交任务
            print('='*50 + f' STAGE {stage_num}: seeds = [{max_seed+1}, {max_seed+1+stage_worker_num}) data {gen_data_num} ~ {gen_data_num + stage_data_num}/{target_data_num}' + '='*50)
            dataset_path_list = []
            #with concurrent.futures.ThreadPoolExecutor(max_workers=stage_worker_num) as executor:
            with concurrent.futures.ProcessPoolExecutor(max_workers=stage_worker_num) as executor:
                futures = {}
                for i, seed in enumerate(seeds):
                    pos = i % worker_num
                    futures[executor.submit(partial(gen_raw_data, data_num_per_process, seed, node_num, pos, 'dist'))] = seed

                for future in concurrent.futures.as_completed(futures):
                    try:
                        dataset_path = future.result()
                        dataset_path_list.append(dataset_path)
                    except Exception as e:
                        print(f"{e}")

            # 等待所有进程执行完毕
            for future in concurrent.futures.as_completed(futures):
                future.result()
            time.sleep(2)

            # 将该阶段生成的所有子数据集合并
            gen_data_num += stage_data_num
            stage_num += 1
            merge_to_dataset(dataset_path_list, target_dataset_path, target_dataset_backup_path)

            # 数据合法性检查
            check_data(show_info=False, check_data_num=20)
            os.remove(target_dataset_backup_path)   # 通过检查则删除备份文件

    print('='*55 + f' ALL DONE ' + '='*55)
    check_data(show_info=True, check_data_num=200)