import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import numpy as np
from utils.COP_slover import OP_gurobi
from data.used._raw.op.gen_op import MAX_LENGTHS
import time
from tqdm import tqdm

if __name__ == "__main__":
    # 设置参数
    node_num = 20                       # TSP城市数量
    data_distribution = 'uniform'       # TSP城市分布类型
    data_num = 100
    prize_type = 'dist'

    process_data = []
    rng = np.random.RandomState(42)
    for i in range(data_num):
        pos_depot = rng.uniform(0, 1, size=(2, ))
        pos_node = rng.uniform(0, 1, size=(node_num, 2))
        length_left = MAX_LENGTHS[node_num]

        distance_node = np.zeros((node_num, node_num), dtype=np.float32)
        distance_depot = np.zeros(node_num, dtype=np.float32)
        for i in range(node_num):
            distance_depot[i] = np.linalg.norm(pos_node[i] - pos_depot)
            for j in range(node_num):
                distance_node[i,j] = np.linalg.norm(pos_node[i] - pos_node[j])

        # 三种 prize 类型
        if prize_type == 'const':
            prize = np.ones(node_num, dtype=np.float32)
        elif prize_type == 'unif':
            prize = (1 + np.random.randint(0, 100, size=(node_num,))) / 100.
        else:  # Based on distance to depot
            assert prize_type == 'dist'
            distance = distance_depot
            prize = (1 + (distance / distance.max() * 99).astype(int)) / 100.

        process_data.append((pos_depot.tolist(), pos_node.tolist(), prize.tolist(), length_left))        

    start = time.time()
    with tqdm(total=data_num, desc=f'OP Speed Testing') as pbar:
        for pos_depot, pos_node, prize, length_left in process_data:
            cost, real_answer = OP_gurobi(pos_depot, pos_node, prize, length_left)
            pbar.update()
    duration = time.time() - start
    print(duration/data_num)