from RHOSolver import RHOSolver
import numpy as np
import os
import pandas as pd
from datetime import datetime
from network.JobShopSetTransformer import JobShopSetTransformer
import torch
from params import configs
import copy
import schedule
import time
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

def test_instance(solver, instance_path, num_runs=10):
    """测试单个实例"""
    data = np.load(instance_path, allow_pickle=True)
    data = data[0]
    solver.reset(data)
    
    net_results = []
    net_times = []
    
    # 神经网络结果
    start_time = time.time()
    result = solver.solve(data, model='net')
    elapsed_time = time.time() - start_time
    
    net_results.append(result)
    net_times.append(elapsed_time)
    
    return np.mean(net_results), np.mean(net_times)

def main():
    # 初始化求解器
    model = JobShopSetTransformer()
    solver = RHOSolver(numofMachines=5, k=7,net = model)
    modelname = 'jsst'+str(configs.gen_instance_num) + '_' + str(configs.gen_machine_num) + '_' + str(configs.gen_machine_num) + '_' + str(configs.gen_time_limit) +'_AU'+str(configs.augment_times)


    solver.loadNet('model/'+modelname+'.pth')
    
    # 测试实例列表
    instances = [f'la{i:02d}' for i in range(1, 16)]
    
    # 存储结果
    results = []
    
    # 测试每个实例
    for instance in instances:
        instance_path = f'benchmark/la/{instance}.npy'
        try:
            net_result, net_time = test_instance(solver, instance_path)
            
            results.append({
                'Instance': instance,
                'Makespan': net_result,
                'Time': net_time
            })
            
        except Exception as e:
            continue
    
    df = pd.DataFrame(results)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_filename = f'testlog/benchmark_results_{timestamp}.csv'
    df.to_csv(csv_filename, index=False)

if __name__ == '__main__':
    import json
    path='benchmark/JSPLIB/'
    # 读取 JSPLIB JSON 实例元数据
    with open(path + 'instances.json', 'r') as f:
        inst = json.load(f)

    ta_instances = [x for x in inst if x['name'].startswith('ta')]
    for instance in ta_instances:
        with open(path  + instance['path'], 'r') as f:
            content = f.read()
        lines = content.strip().split('\n')
        n, m = map(int, lines[0].split())
        
        machine_mat = np.zeros((n, m), dtype=int)
        time_mat = np.zeros((n, m), dtype=int)
        for i in range(n):
            data = list(map(int, lines[i+1].split()))
            for j in range(m):
                machine_mat[i][j] = data[j*2] +1
                time_mat[i][j] = data[j*2 + 1]
                
        data = [time_mat, machine_mat]
        if n <50:
            continue
        model = JobShopSetTransformer(m=m)
        modelname = 'jsst'+str(configs.gen_instance_num) + '_' + str(m)+ '_' + str(m) + '_' + str(configs.gen_time_limit) +'_AU'+str(configs.augment_times)
        model.load_state_dict(torch.load('model/' + modelname + '.pth'))
        datacopy = copy.deepcopy(data)
        solver = RHOSolver(numofMachines=m, k=1, net=model)
        solver.reset(data)
        
        solve_start_time = time.time()
        schedule = solver.solve(data, model='net', bws=False, initChoose='random',time_limit=1,initsolver=False,detail=True,returnSchedule=True)
        solve_end_time = time.time()
        solve_time = solve_end_time - solve_start_time
        
        schedule.fixRecord(datacopy)
        makespan_net = schedule.cal_makespan()
        
        log_path = 'log/benchmark_results_'+configs.result_path+'.log'
        solver.reset(data)
        if not os.path.exists(log_path):
            open(log_path, 'w').close()
        with open(log_path, 'a') as f:
            f.write(f"{instance['name']} makespan: {makespan_net:.2f}, 运行时间: {solve_time:.2f}秒\n")
        

