import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import pickle
import random
from utils.COP_slover import knapsack_dp
from utils.utils import split_rawdata
from tqdm import tqdm
import numpy as np
from environment.used.BaseEnv_COP import RawData
from environment.used.Env_bp_v1 import MAX_CAPACITY
        
def check_data(num_items, load_data_num=None, check_data_num=5000):
    ''' 检查生成的数据是否存在问题 '''
    # 加载所有源数据
    datasets = {}
    dataset = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    for data_type in ['problem', 'train', 'prompt']:
        path = f'{base_path}/data/used/_raw/bp/bp{num_items}_{data_type}.pkl'
        if os.path.exists(path):
            with open(path, 'rb') as f:
                sub_dataset = pickle.load(f)
                if load_data_num is not None:
                    sub_dataset = split_rawdata(sub_dataset, 0, load_data_num)
            datasets[data_type] = sub_dataset
            dataset.seed_list.extend(sub_dataset.seed_list)
            dataset.problem_list.extend(sub_dataset.problem_list)
            dataset.answer_list.extend(sub_dataset.answer_list)
            dataset.cost_list.extend(sub_dataset.cost_list)

    assert len(dataset.problem_list) == len(dataset.answer_list) == len(dataset.cost_list), '数据长度不匹配'
    assert len(set(dataset.seed_list)) == len(dataset.seed_list), '存在使用相同随机种子的数据生成进程'
    volumes = np.array([problem['item_volumes'] for problem in dataset.problem_list])           # (data_num, node_num, 2)
    capacities = np.array([problem['capacity_left'] for problem in dataset.problem_list])       # (data_num, node_num, 2)
    assert np.all(capacities == MAX_CAPACITY[num_items]), '存在总容量有误的问题'
    assert volumes.min() > 0, '存在0体积物品'

    # 打印数据集信息
    dataset_train, dataset_prompt, dataset_problem = datasets['train'], datasets['prompt'], datasets['problem']
    print('-'*80)
    print(f'There are [{len(dataset.answer_list)}] raw data generated with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')
    print(f'\tTrain dataset:    [{len(dataset_train.answer_list)}] raw data with [{len(dataset_train.seed_list)}] different random seeds, ave cost = [{np.mean(dataset_train.cost_list):.3f}]')
    print(f'\tProblem dataset:  [{len(dataset_problem.answer_list)}] raw data with [{len(dataset_problem.seed_list)}] different random seeds, ave cost = [{np.mean(dataset_problem.cost_list):.3f}]')
    print(f'\tPrompt dataset:   [{len(dataset_prompt.answer_list)}] raw data with [{len(dataset_prompt.seed_list)}] different random seeds, ave cost = [{np.mean(dataset_prompt.cost_list):.3f}]')
    print('-'*80)

    # 随机检查 check_data_num 个数据
    if check_data_num > 0:
        dataset_size = len(dataset.answer_list)
        check_idx = random.sample(range(dataset_size), min(dataset_size, check_data_num))
        for idx in tqdm(check_idx, desc=f"Check data validity"):
            capacity = dataset.problem_list[idx]['capacity_left']
            volumes = dataset.problem_list[idx]['item_volumes']
            values = dataset.problem_list[idx]['item_values']
            data_answer = dataset.answer_list[idx]
            data_cost = dataset.cost_list[idx]
            cost, answer = knapsack_dp(capacity.item(), volumes, values) 
            assert abs(cost - data_cost) < 1e-5
            assert answer == data_answer


if __name__ == "__main__":
    num_items = 20                      # BP物品数量
    source_data = check_data(num_items, load_data_num=None)
    


    

    