import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import random
import numpy as np
from utils.COP_slover import TSP_lkh
from environment.used.Env_tsp_v1 import TSP_V1
from gym.utils.env_checker import check_env
from tqdm import tqdm
import matplotlib.pyplot as plt

def calc_distance(position, answer):
    sorted_answer = position[np.concatenate((answer, [0]))]
    return np.linalg.norm(sorted_answer[1:] - sorted_answer[:-1], axis=-1).sum()

# 考察随着城市数量增加 LKH 解相比随机解的质量变化
iters = 200
qulities = {num_node: [] for num_node in range(10, 21)}
for num in qulities.keys():
    with tqdm(total=iters, desc=f'Test qulity with {num} node') as pbar:
        env = TSP_V1(num_nodes=num)
        check_env(env.unwrapped) # 检查环境是否符合 gym 规范
        
        for i in range(iters):
            # 随机生成 TSP 问题并用 LKH 方法求解
            observation, info = env.reset()
            position = observation['position'].reshape((num, 2))
            distance, real_answer = TSP_lkh(position) 
            assert real_answer == env.real_answer

            # 随机生成解
            model_answer = list(range(1, num))
            random.shuffle(model_answer)
            model_answer = [0] + model_answer

            # 计算随机解和LKH解的质量差距
            real_distance = calc_distance(position, real_answer)
            model_distance = calc_distance(position, model_answer)
        
            assert abs(distance - real_distance) < 1e-4
            qulity = 1 - (model_distance-real_distance)/real_distance
            qulities[num].append(qulity)

            info = {'qulity': np.mean(qulities[num])}
            pbar.set_postfix(info)
            pbar.update()

# 保存柱状图
nodes = list(qulities.keys())
qulities = {num: np.mean(q) for num, q in qulities.items()}

fig = plt.figure(figsize=(6, 4))
a1 = fig.add_subplot(1,1,1, label='a1')

a1.bar(list(qulities.keys()), list(qulities.values()), capsize=5)
a1.set_xlabel('node num')
a1.set_ylabel('qulity')
a1.set_title('random vs lkh')
for index, q in qulities.items():
    a1.text(index, q, str(round(q, 2)), ha='center', va='bottom')
plt.show()
