import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import numpy as np
from utils.COP_slover import knapsack_dp
from utils.utils import create_folder_overwrite_if_exist, create_folder_if_not_exist, load_existing_raw_data, merge_to_dataset
from environment.used.BaseEnv_COP import RawData
import multiprocessing
import pickle
import math
import time
import random
import concurrent.futures
from functools import partial
from tqdm import tqdm

MAX_ITEM_VOLUME = {
    20: 20,
}

MAX_CAPACITY = {
    20: 30,
}

def check_data(show_info=False, check_data_num=5000):
    ''' 检查生成的数据是否存在问题 '''
    # 加载源数据
    dataset = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    for data_type in ['problem', 'train', 'prompt']:
        path = f'{base_path}/data/used/_raw/bp/bp{num_item}_{data_type}.pkl'
        if os.path.exists(path):
            with open(path, 'rb') as f:
                sub_dataset = pickle.load(f)
            dataset.seed_list.extend(sub_dataset.seed_list)
            dataset.problem_list.extend(sub_dataset.problem_list)
            dataset.answer_list.extend(sub_dataset.answer_list)
            dataset.cost_list.extend(sub_dataset.cost_list)

    assert len(dataset.problem_list) == len(dataset.answer_list) == len(dataset.cost_list), '数据长度不匹配'
    assert len(set(dataset.seed_list)) == len(dataset.seed_list), '存在使用相同随机种子的数据生成进程'
    values = np.array([problem['item_values'] for problem in dataset.problem_list])         # (data_num, num_item)
    volumes = np.array([problem['item_volumes'] for problem in dataset.problem_list])       # (data_num, num_item)
    assert values.shape[0] == np.unique(values, axis=0).shape[0] and volumes.shape[0] == np.unique(volumes, axis=0).shape[0], '存在完全相同的问题'

    if show_info:
        # 打印数据集信息
        print('-'*80)
        print(f'There are [{len(dataset.answer_list)}] raw data generated with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')
        print('-'*80)

    # 随机检查 check_data_num 个数据
    if check_data_num > 0:
        dataset_size = len(dataset.answer_list)
        check_idx = random.sample(range(dataset_size), min(dataset_size, check_data_num))
        for idx in tqdm(check_idx, desc=f"Check data validity"):
            volumes = dataset.problem_list[idx]['item_volumes']
            values = dataset.problem_list[idx]['item_values']
            capacity = dataset.problem_list[idx]['capacity_left']
            data_value, data_answer = dataset.cost_list[idx], dataset.answer_list[idx]
            value, real_answer = knapsack_dp(capacity.item(), volumes, values)        
            assert abs(value - data_value) < 1e-4
            assert real_answer == data_answer

def gen_raw_data(data_num_per_process, seed, num_item, tqdm_pos):
    ''' 数据生成进程的目标方法 '''
    # 生成问题，求解并写入 dataset
    dataset = RawData(seed_list=[seed,], problem_list=[], answer_list=[], cost_list=[])
    rng = np.random.RandomState(seed)  # 创建独立的随机数生成器对象
    with tqdm(total=data_num_per_process, desc=f'Gen 01BP data with seed {seed}', position=tqdm_pos) as pbar:
        for _ in range(data_num_per_process):
            values = rng.randint(2, MAX_ITEM_VOLUME[num_item], num_item)              # 随机生成物品价值
            volumes = values + 0.5*(values * rng.choice([-1, 1], size=(num_item,)))   # 物品体积基于价值生成，总体呈正相关   
            volumes = volumes.astype(np.int32)
            assert volumes.min() > 0        
            capacity = np.array(MAX_CAPACITY[num_item], dtype=np.int32)
            value, real_answer = knapsack_dp(capacity.item(), volumes, values)
            dataset.problem_list.append({'item_values': values, 'item_volumes': volumes, 'capacity_left': capacity})
            dataset.answer_list.append(real_answer)
            dataset.cost_list.append(value)
            pbar.update()

    # 保存该进程生成的原始数据
    dataset_path = f'{base_path}/data/used/_raw/bp/temp/num{data_num_per_process}_seed{seed}.pkl'
    with open(dataset_path, 'wb') as f:
        pickle.dump(dataset, f)
    return dataset_path

if __name__ == "__main__":
    # 设置参数
    num_item = 20                       # 01BP物品数量
    data_num_per_process = 1000         # 每个进程保存的样本数量
    worker_num = 10                     # 同时执行的生成进程数量
    overwrite = False                   # 是否覆盖已有的数据
    data_num_dict = {
        'train': 500000,
        'problem': 10000,
        'prompt': 10000,
    }
    
    # 在临时文件夹保存求解器可能产生的中间文件
    create_folder_overwrite_if_exist(f'{base_path}/temp') 

    # 在临时文件夹保存每个子进程生成的数据文件
    create_folder_if_not_exist(f'{base_path}/data/used/_raw/bp/temp') 

    # 加载已经存在的数据
    ex_datasets = {}
    ex_seeds = []
    for data_type in ['problem', 'train', 'prompt']:
        ex_data_path = f'{base_path}/data/used/_raw/bp/bp{num_item}_{data_type}.pkl'
        ex_dataset = load_existing_raw_data(ex_data_path, overwrite)
        ex_datasets[data_type] = ex_dataset
        ex_seeds.extend(ex_dataset.seed_list)

    # 生成新数据
    for data_type in ['problem', 'train', 'prompt']:
        target_data_num = data_num_dict[data_type]                              # 该类目标样本量   
        ex_dataset = ex_datasets[data_type]                                     # 该类已存在的数据集
        ex_data_num = len(ex_dataset.answer_list)                               # 该类已存在的样本量
        data_to_go = target_data_num - ex_data_num                              # 该类还要生成的样本量
        processes_num = math.ceil(data_to_go/data_num_per_process)              # 所需的进程数量（每个进程生成data_num_per_process个样本）
        real_data_num = ex_data_num + data_num_per_process * processes_num      # 执行完毕后得到的该类真实样本量

        # 打印数据集信息
        if data_to_go > 0:    
            print('-'*100)
            print(f'There are [{ex_data_num}] raw data for [{data_type}] generated, [{data_to_go}] to go')
            print('-'*100)
        else:
            print('-'*100)
            print(f'There are [{ex_data_num}] raw data for [{data_type}] generated, the requirements of [{target_data_num}] have already been met')
            print('-'*100)
            continue

        # 分阶段生成所有数据，每个阶段最多worker_num个进程并行，直到总进程数量达到processes_num
        # 每个阶段最后集合该阶段所有并行进程生成的数据，保存到target_dataset_path
        target_dataset_path = f'{base_path}/data/used/_raw/bp/bp{num_item}_{data_type}.pkl'
        target_dataset_backup_path = f'{base_path}/data/used/_raw/bp/bp{num_item}_{data_type}_backup.pkl'
        gen_data_num = ex_data_num
        stage_num = 0
        while processes_num > 0:
            # 该阶段各个进程使用的随机种子
            max_seed = -1 if ex_seeds == [] else max(ex_seeds)
            stage_worker_num = min(worker_num, processes_num)
            stage_data_num = data_num_per_process * stage_worker_num
            processes_num -= stage_worker_num
            seeds = list(range(max_seed+1, max_seed+1+stage_worker_num))
            ex_seeds.extend(seeds)
            
            # 创建一个 ThreadPoolExecutor 对象，在循环中提交任务
            print('='*50 + f' STAGE {stage_num}: seeds = [{max_seed+1}, {max_seed+1+stage_worker_num}) data {gen_data_num} ~ {gen_data_num + stage_data_num}/{target_data_num}' + '='*50)
            dataset_path_list = []
            with concurrent.futures.ThreadPoolExecutor(max_workers=stage_worker_num) as executor:
                futures = {}
                for i, seed in enumerate(seeds):
                    pos = i % worker_num
                    futures[executor.submit(partial(gen_raw_data, data_num_per_process, seed, num_item, pos))] = seed

                for future in concurrent.futures.as_completed(futures):
                    try:
                        dataset_path = future.result()
                        dataset_path_list.append(dataset_path)
                    except Exception as e:
                        print(f"{e}")

            # 等待所有进程执行完毕
            for future in concurrent.futures.as_completed(futures):
                future.result()
            time.sleep(2)

            # 将该阶段生成的所有子数据集合并
            gen_data_num += stage_data_num
            stage_num += 1
            merge_to_dataset(dataset_path_list, target_dataset_path, target_dataset_backup_path)

            # 数据合法性检查
            check_data(show_info=False, check_data_num=20)
            os.remove(target_dataset_backup_path)   # 通过检查则删除备份文件

    print('='*55 + f' ALL DONE ' + '='*55)
    check_data(show_info=True, check_data_num=200)