import os
import time
import json
import copy
import numpy as np
import pandas as pd
from datetime import datetime
import torch

from RHOSolver import RHOSolver
from TRHOSolver import TRHOSolver
from network.JobShopSetTransformer import JobShopSetTransformer
from params import configs


def parse_jsplib_instance(file_path):
    """按 JSPLIB (参考 test_benchmark) 解析 TA 实例文件到 [time_mat, machine_mat]。

    - 首行: n m
    - 后续 n 行: 每行 2*m 个整数，按 (machine, time) 成对给出
    - machine 下标在原文件中从 0 开始，这里与 test_benchmark 一致，+1 存储
    """
    with open(file_path, 'r') as f:
        content = f.read()

    lines = content.strip().split('\n')
    n, m = map(int, lines[0].split())
    machine_mat = np.zeros((n, m), dtype=int)
    time_mat = np.zeros((n, m), dtype=int)

    for i in range(n):
        data = list(map(int, lines[i + 1].split()))
        for j in range(m):
            machine_mat[i][j] = data[j * 2] + 1  # 与 test_benchmark 对齐: 机器号 +1
            time_mat[i][j] = data[j * 2 + 1]

    return [time_mat, machine_mat]


def build_rho_solver_for_instance(machines: int, jobs: int) -> RHOSolver:
    modelname = (
        'jsst'
        + str(configs.gen_instance_num)
        + '_'
        + str(machines)
        + '_'
        + str(machines)
        + '_'
        + str(configs.gen_time_limit)
        + '_AU'
        + str(configs.augment_times)
    )
    model_path = os.path.join('model', modelname + '.pth')

    model = JobShopSetTransformer(m=machines)
    if os.path.exists(model_path):
        state = torch.load(model_path, map_location='cpu')
        model.load_state_dict(state)
    else:
        print(f"[警告] 未找到模型权重 {model_path}，将使用随机初始化模型进行RHO测试。")

    # RHO 的 k，简单设置为 min(machines+5, jobs-1)（与 la 评测保持一致风格）
    k = min(machines + 5, jobs - 1) if jobs > 1 else 1
    rho = RHOSolver(numofMachines=machines, k=k, net=model)
    return rho


def eval_instance_from_path(instance_name: str, file_path: str, time_limit: int):
    data = parse_jsplib_instance(file_path)
    time_mat, machine_mat = data
    n_jobs, n_machines = time_mat.shape

    # RHO
    rho_solver = build_rho_solver_for_instance(n_machines, n_jobs)
    rho_solver.reset(data)
    t0 = time.time()
    rho_schedule = rho_solver.solve(
        data,
        model='net',
        bws=False,
        initChoose='GRGC',
        time_limit=time_limit/n_jobs,
        initsolver=False,
        detail=False,
        returnSchedule=True,
    )
    rho_schedule.fixRecord(copy.deepcopy(data))
    rho_time = time.time() - t0
    rho_makespan = rho_schedule.cal_makespan()

    # TRHO（可选）
    trho_makespan = None
    trho_time = None
    if True:
        trho_solver = TRHOSolver(numofMachines=n_machines)
        trho_solver.reset(copy.deepcopy(data))
        t1 = time.time()
        trho_schedule = trho_solver.solve(copy.deepcopy(data), bws=False, time_limit=time_limit/4, global_time_limit=time_limit)
        trho_schedule.fixRecord(copy.deepcopy(data))
        trho_time = time.time() - t1
        trho_makespan = trho_schedule.cal_makespan()

    return {
        'instance': instance_name,
        'jobs': n_jobs,
        'machines': n_machines,
        'rho_makespan': rho_makespan,
        'rho_time_s': rho_time,
        'trho_makespan': trho_makespan,
        'trho_time_s': trho_time,
    }


def main():
    # 读取 JSPLIB 元数据，定位 TA 实例文件路径
    meta_path = os.path.join('benchmark', 'JSPLIB', 'instances.json')
    if not os.path.exists(meta_path):
        print('未找到 benchmark/JSPLIB/instances.json')
        return

    with open(meta_path, 'r') as f:
        inst = json.load(f)

    ta_instances = [x for x in inst if x['name'].startswith('ta')]
    if not ta_instances:
        print('未在 instances.json 中找到 TA 实例')
        return

    results = []
    for meta in ta_instances:
        name = meta['name']
        file_path = os.path.join('benchmark', 'JSPLIB', meta['path'])
        if not os.path.exists(file_path):
            print(f"跳过 {name}: 文件不存在 {file_path}")
            continue

        print(f"评测 {name} ...")
        try:
            res = eval_instance_from_path(name, file_path, time_limit=configs.run_time_limit)
            results.append(res)
            print(
                    f"  RHO makespan={res['rho_makespan']} time={res['rho_time_s']:.3f}s, "
                    f"TRHO makespan={res['trho_makespan']} time={res['trho_time_s']:.3f}s"
                )
        except Exception as e:
            print(f"  评测 {name} 出错: {e}")

    if not results:
        print('没有结果可保存。')
        return

    df = pd.DataFrame(results)
    os.makedirs('testlog', exist_ok=True)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    out_csv = os.path.join('testlog', f'ta_rho_trho_{timestamp}.csv')
    df.to_csv(out_csv, index=False)

    print('\n评测完成:')
    print(f'  结果CSV: {out_csv}')
    print(
        '  RHO 平均makespan: '
        + f"{df['rho_makespan'].mean():.2f} | 平均时间: {df['rho_time_s'].mean():.3f}s"
    )
    if not configs.skip_trho:
        print(
            '  TRHO 平均makespan: '
            + f"{df['trho_makespan'].mean():.2f} | 平均时间: {df['trho_time_s'].mean():.3f}s"
        )


if __name__ == '__main__':
    main()


