from RHOSolver import RHOSolver
from CPSolver import CPSolver
from uniform_instance_gen import uni_instance_gen
from network.JobShopSetTransformer import JobShopSetTransformer
import numpy as np
from params import *
from TRHOSolver import TRHOSolver
import torch
import os
import time
import argparse
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

if __name__ == '__main__':
    # 解析命令行参数
    
    num_instances = configs.instances  # 生成指定数量的随机实例
    skip_trho = configs.skip_trho
    
    all_net_results = []  # 纯网络推理的结果
    all_greedy_results = []
    all_random_results = []  # 随机选择后的结果
    all_cp_results = []  # CP 求解结果
    m = configs.gen_machine_num
    model = JobShopSetTransformer(m = m)
    modelname = 'jsst'+str(configs.gen_instance_num) + '_' + str(configs.gen_machine_num) + '_' + str(configs.gen_machine_num) + '_' + str(configs.gen_time_limit) +'_AU'+str(configs.augment_times)
    model.load_state_dict(torch.load('model/'+modelname+'.pth'))
    solver = RHOSolver(numofMachines=m, k=m,net=model)
    
    # 只有在不跳过TRHO测试时才初始化TRHOSolver
    if not skip_trho:
        tsolver = TRHOSolver(numofMachines=m)
    
    # 结果统计列表已在上方初始化
    
    print(f"开始测试 {num_instances} 个实例...")
    if skip_trho:
        print("已启用跳过TRHO测试选项")
    for i in range(num_instances):
        print(f"\n开始测试实例 {i+1}/{num_instances}...")
        BWS = False
        # 生成15个工件5个机器的随机实例,加工时间在1-99之间
        np.random.seed(i)
        N = configs.test_GenData_N
        times, machines = uni_instance_gen(N, m, 1, 100)
        data = (times, machines)

        if True:
            # 测试神经网络模型
            print(f"  测试神经网络模型...")
            solver.reset(data)
            start_time = time.time()
            net_result = solver.solve(data, model='net', initsolver=False, time_limit=12, returnSchedule=False, initChoose='random',bws=BWS)
            net_time = time.time() - start_time
            print(f"  神经网络完成，耗时: {net_time:.4f}秒")

            # 测试TRHOSolver (可选)
            if not skip_trho:
                print(f"  测试TRHOSolver...")
                tsolver.reset(data)
                start_time = time.time()
                t_result = tsolver.solve(data, PTmask=None, time_limit=20*N,bws=BWS).cal_makespan()  # 减少time_limit
                t_time = time.time() - start_time
                print(f"  TRHOSolver完成，耗时: {t_time:.4f}秒")
            else:
                t_result = None  # 跳过TRHO测试时设为None

            # 测试随机
            print(f"  测试随机选择...")
            solver.reset(data)
            solver.randomChoose(k=10)
            start_time = time.time()
            random_result = solver.solve(data, model='random', initsolver=False, time_limit=12, returnSchedule=False, initChoose='random',bws=BWS)
            rand_time = time.time() - start_time
            print(f"  随机选择完成，耗时: {rand_time:.4f}秒")

            # 测试 CP 求解器（时间上限固定为不超过 1 小时）
            print(f"  测试CP求解器(≤1h)...")
            cp_solver = CPSolver()
            start_time = time.time()
            cp_schedule = cp_solver.solve_blocking_job_shop(data, mask=None, random_seed=2025, time_limit=3600, bws=BWS)
            cp_time = time.time() - start_time
            cp_result = cp_schedule.cal_makespan()
            print(f"  CP求解器完成，耗时: {cp_time:.4f}秒")

            all_net_results.append(net_result)
            if not skip_trho:
                all_greedy_results.append(t_result)
            all_random_results.append(random_result)
            all_cp_results.append(cp_result)
            
            print(f"\n实例 {i+1} 的结果:")
            print(f"  神经网络模型结果: {net_result}")
            if not skip_trho:
                print(f"  Tsolver算法结果: {t_result}")
            print(f"  随机选择结果: {random_result}")
            print(f"  CP求解器结果: {cp_result}")
            
        
    print("\n所有实例统计:")
    print("神经网络模型: 平均=", float(np.mean(all_net_results)), 
          " 最小=", int(np.min(all_net_results)), 
          " 最大=", int(np.max(all_net_results)), 
          " 标准差=", float(np.std(all_net_results)))
    if not skip_trho and all_greedy_results:
        print("Tsolver算法: 平均=", float(np.mean(all_greedy_results)), 
              " 最小=", int(np.min(all_greedy_results)), 
              " 最大=", int(np.max(all_greedy_results)), 
              " 标准差=", float(np.std(all_greedy_results)))
    print("随机选择: 平均=", float(np.mean(all_random_results)), 
          " 最小=", int(np.min(all_random_results)), 
          " 最大=", int(np.max(all_random_results)), 
          " 标准差=", float(np.std(all_random_results)))
    if all_cp_results:
        print("CP求解器: 平均=", float(np.mean(all_cp_results)), 
              " 最小=", int(np.min(all_cp_results)), 
              " 最大=", int(np.max(all_cp_results)), 
              " 标准差=", float(np.std(all_cp_results)))

    # 相对改进统计（随机 -> 网络，负数表示更好）
    if len(all_net_results) == len(all_random_results) and len(all_net_results) > 0:
        improvements = np.array(all_net_results) - np.array(all_random_results)
        print("网络相对随机的改进(负数=更优): 平均=", float(np.mean(improvements)),
              " 最小=", float(np.min(improvements)),
              " 最大=", float(np.max(improvements)))
        # 占比统计：网络更好（更小）
        net_arr = np.array(all_net_results)
        rand_arr = np.array(all_random_results)
        pct_net_better_rand = float(np.mean(net_arr < rand_arr) * 100.0)
        pct_net_equal_rand = float(np.mean(net_arr == rand_arr) * 100.0)
        print(f"网络优于随机的占比: {pct_net_better_rand:.2f}%  (相等: {pct_net_equal_rand:.2f}%)")

        # 若存在Tsolver结果，统计网络优于Tsolver的占比，以及网络为最优方案的占比
        if not skip_trho and len(all_greedy_results) == len(all_net_results):
            t_arr = np.array(all_greedy_results)
            pct_net_better_t = float(np.mean(net_arr < t_arr) * 100.0)
            pct_net_equal_t = float(np.mean(net_arr == t_arr) * 100.0)
            print(f"网络优于Tsolver的占比: {pct_net_better_t:.2f}%  (相等: {pct_net_equal_t:.2f}%)")

            # 网络在三者中严格最优的占比（严格小于随机和Tsolver）
            pct_net_best_strict = float(np.mean((net_arr < rand_arr) & (net_arr < t_arr)) * 100.0)
            print(f"网络在(网络/随机/Tsolver)中严格最优的占比: {pct_net_best_strict:.2f}%")

        # 若存在CP结果，统计网络优于CP的占比，以及网络/随机/CP的最优占比
        if len(all_cp_results) == len(all_net_results) and len(all_cp_results) > 0:
            cp_arr = np.array(all_cp_results)
            pct_net_better_cp = float(np.mean(net_arr < cp_arr) * 100.0)
            pct_net_equal_cp = float(np.mean(net_arr == cp_arr) * 100.0)
            print(f"网络优于CP的占比: {pct_net_better_cp:.2f}%  (相等: {pct_net_equal_cp:.2f}%)")
            pct_net_best_vs_cp_rand = float(np.mean((net_arr < rand_arr) & (net_arr < cp_arr)) * 100.0)
            print(f"网络在(网络/随机/CP)中严格最优的占比: {pct_net_best_vs_cp_rand:.2f}%")
