import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import numpy as np
from utils.COP_slover import PCTSP_ILS
from data.used._raw.pctsp.gen_pctsp import MAX_LENGTHS
import time
from tqdm import tqdm

if __name__ == "__main__":
    # 设置参数
    node_num = 20                       # TSP城市数量
    data_distribution = 'uniform'       # TSP城市分布类型
    data_num = 100

    process_data = []
    for i in range(data_num):
        penalty_factor = 3
        rng = np.random.RandomState(42)
        pos_depot = rng.uniform(0, 1, size=(2, ))
        pos_node = rng.uniform(0, 1, size=(node_num, 2))
        penalty_max = MAX_LENGTHS[node_num] * (penalty_factor) / float(node_num)
        penalty = np.random.uniform(size=(node_num, )) * penalty_max
        prize = np.random.uniform(size=(node_num,)) * 4 / float(node_num)
        while prize.sum() < 1:
            prize = np.random.uniform(size=(node_num,)) * 4 / float(node_num)
        process_data.append((pos_depot.tolist(), pos_node.tolist(), penalty.tolist(), prize.tolist()))

    start = time.time()
    with tqdm(total=data_num, desc=f'PCTSP Speed Testing') as pbar:
        for pos_depot, pos_node, penalty, prize in process_data:
            cost, answer, _ = PCTSP_ILS(pos_depot, pos_node, penalty, prize,)
            pbar.update()
    duration = time.time() - start
    print(duration/data_num)