import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import random
import numpy as np
from utils.COP_slover import knapsack_dc
from environment.used.Env_bp_v1 import BinBackpack
from gym.utils.env_checker import check_env
from tqdm import tqdm
import matplotlib.pyplot as plt

def greedy_knapsack(capacity, item_volumes, item_values):    
    selection = []
    get_value = 0
    sorted_idx = np.argsort(-item_values)   # 将物品按价值降序排序
    for idx in sorted_idx:
        value = item_values[idx].item()
        volume = item_volumes[idx].item()

        if volume != 0 and volume < capacity:
            selection.append(idx)
            capacity -= volume
            get_value += value

    return get_value, selection

# 考察随着物品数量增加最优解相比简单启发式解的质量变化
iters = 200
qulities = {num_node: [] for num_node in range(10, 11)}
for num in qulities.keys():
    with tqdm(total=iters, desc=f'Test qulity with {num} node') as pbar:
        env = BinBackpack(max_quantity=num)
        check_env(env.unwrapped) # 检查环境是否符合 gym 规范
        
        for i in range(iters):
            # 随机生成 01BP 问题并用 LKH 方法求解
            observation, info = env.reset()
            volumes = observation['item_volumes']
            values = observation['item_values']
            capacity = observation['capacity_left']

            # 保证解中至少选一个物品
            dc_max_value, dc_selection = knapsack_dc(capacity, volumes, values)
            while dc_selection == []:
                observation, info = env.reset()  
                volumes = observation['item_volumes']
                values = observation['item_values']
                capacity = observation['capacity_left'] 
                dc_max_value, dc_selection = knapsack_dc(capacity, volumes, values)
                max_value, selected_items = dc_max_value, dc_selection
                assert dc_selection == env.real_answer
                assert volumes[selected_items].sum() <= capacity.item()

            # 启发式解: 总是选择当前能放下的最高价值物品
            greedy_max_value, greedy_selection = greedy_knapsack(capacity.item(), volumes, values)

            # 计算启发解和暴力解的质量差距
            assert greedy_max_value < dc_max_value or abs(greedy_max_value - dc_max_value) < 1e-4
            qulity = 1 - abs(dc_max_value-greedy_max_value)/dc_max_value
            qulity = 1 if abs(qulity-1) < 1e-4 else qulity
            qulities[num].append(qulity)

            info = {'qulity': np.mean(qulities[num])}
            pbar.set_postfix(info)
            pbar.update()
            

# 保存柱状图
nodes = list(qulities.keys())
qulities = {num: np.mean(q) for num, q in qulities.items()}

fig = plt.figure(figsize=(6, 4))
a1 = fig.add_subplot(1,1,1, label='a1')

a1.bar(list(qulities.keys()), list(qulities.values()), capsize=5)
a1.set_xlabel('node num')
a1.set_ylabel('qulity')
a1.set_title('greedy vs dc')
for index, q in qulities.items():
    a1.text(index, q, str(round(q, 2)), ha='center', va='bottom')
plt.show()
