import os
import time
import numpy as np
import pandas as pd
from datetime import datetime
import torch

from RHOSolver import RHOSolver
from TRHOSolver import TRHOSolver
from network.JobShopSetTransformer import JobShopSetTransformer
from params import configs
import copy


def load_la_instance(instance_path):
    data = np.load(instance_path, allow_pickle=True)
    return data[0]


def build_rho_solver_for_instance(machines: int, jobs: int) -> RHOSolver:
    # 根据配置推断模型名称（与现有脚本保持一致）
    modelname = (
        'jsst'
        + str(configs.gen_instance_num)
        + '_'
        + str(machines)
        + '_'
        + str(machines)
        + '_'
        + str(configs.gen_time_limit)
        + '_AU'
        + str(configs.augment_times)
    )
    model_path = os.path.join('model', modelname + '.pth')

    model = JobShopSetTransformer(m=machines)
    if os.path.exists(model_path):
        state = torch.load(model_path, map_location='cpu')
        model.load_state_dict(state)
    else:
        print(f"[警告] 未找到模型权重 {model_path}，将使用随机初始化模型进行RHO测试。")

    k = machines
    rho = RHOSolver(numofMachines=machines, k=k, net=model)
    return rho


def eval_instance(instance_name: str, time_limit: int):
    instance_path = os.path.join('benchmark', 'la', f'{instance_name}.npy')
    if not os.path.exists(instance_path):
        return None

    data = load_la_instance(instance_path)
    time_mat, machine_mat = data
    n_jobs, n_machines = time_mat.shape
    data0 = copy.deepcopy(data)
    # RHO (net)
    rho_solver = build_rho_solver_for_instance(n_machines, n_jobs)
    rho_solver.reset(data0)
    t0 = time.time()
    rho_schedule = rho_solver.solve(
        data0,
        model='net',
        bws=False,
        initChoose='GRGC',
        time_limit=time_limit/n_jobs,
        initsolver=False,
        detail=False,
        returnSchedule = True,
    )
    rho_schedule.fixRecord(data)
    rho_time = time.time() - t0
    rho_makespan = rho_schedule.cal_makespan()
    # TRHO
    trho_makespan = None
    trho_time = None
    data2 = copy.deepcopy(data)
    if True:
        trho_solver = TRHOSolver(numofMachines=n_machines)
        trho_solver.reset(data2)
        t1 = time.time()
        trho_schedule = trho_solver.solve(data, bws=False, time_limit=time_limit/4,global_time_limit=time_limit)
        trho_schedule.fixRecord(data)
        trho_time = time.time() - t1
        trho_makespan = trho_schedule.cal_makespan()


    return {
        'instance': instance_name,
        'jobs': n_jobs,
        'machines': n_machines,
        'rho_makespan': rho_makespan,
        'rho_time_s': rho_time,
        'trho_makespan': trho_makespan,
        'trho_time_s': trho_time,
    }


def main():
    instances = [f'la{i:02d}' for i in range(1, 41)]
    results = []

    for name in instances:
        print(f'评测 {name} ...')
        try:
            res = eval_instance(name, time_limit=configs.run_time_limit)
            if res is None:
                print(f'跳过 {name}: 文件不存在')
                continue
            results.append(res)
            print(
                f"  RHO makespan={res['rho_makespan']} time={res['rho_time_s']:.3f}s"
                + (
                     f", TRHO makespan={res['trho_makespan']} time={res['trho_time_s']:.3f}s"
                )
            )
        except Exception as e:
            print(f'  评测 {name} 出错: {e}')

    if not results:
        print('没有结果可保存。')
        return

    df = pd.DataFrame(results)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    os.makedirs('testlog', exist_ok=True)
    out_csv = os.path.join('testlog', f'la_rho_trho_{timestamp}.csv')
    df.to_csv(out_csv, index=False)

    # 汇总
    print('\n评测完成:')
    print(f'  结果CSV: {out_csv}')
    print(
        '  RHO 平均makespan: '
        + f"{df['rho_makespan'].mean():.2f} | 平均时间: {df['rho_time_s'].mean():.3f}s"
    )
    if not configs.skip_trho:
        print(
            '  TRHO 平均makespan: '
            + f"{df['trho_makespan'].mean():.2f} | 平均时间: {df['trho_time_s'].mean():.3f}s"
        )


if __name__ == '__main__':
    main()


