# test.py
import torch
import ray
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
from aux.seed import set_seed
from aux.ppo import PPOBinaryActorCritic, ppo_update
import numpy as np
from scipy.optimize import linprog
from aux.CPD import CPD
from aux.solver import solve_ilp

def test_run(idx, data, c, gamma=0.9, dim=100, con_num=50, model_path='result.pth', seed=42):
    set_seed(seed)
    T = data['f_all'].shape[0]
    # print(T)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = PPOBinaryActorCritic(3 + 3 * dim + con_num).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()



    A = torch.tensor(data['A'], dtype=torch.float32, device=device)
    b = torch.tensor(data['b'], dtype=torch.float32, device=device)
    c_all = torch.tensor(data['c_all'], dtype=torch.float32, device=device)
    f_all = torch.tensor(data['f_all_IP'], dtype=torch.float32, device=device)

    sol = linprog(np.ones(dim), A_ub=A.cpu().numpy(), b_ub=b.cpu().numpy(), bounds=[(0, None)] * dim, method='highs')
    objVal, x_val = solve_ilp(A.cpu().numpy(), b.cpu().numpy(), np.ones(dim))
    if not sol.success:
        return None

    x_sol = torch.tensor(x_val, dtype=torch.float32, device=device)
    mu_sol = torch.tensor(sol.lower.marginals, dtype=torch.float32, device=device)
    lambda_sol = torch.tensor(sol.ineqlin.marginals, dtype=torch.float32, device=device)

    model_time = 0
    model_sample_size = 0
    state = torch.cat([
        torch.tensor([0.0], device=device),
        torch.tensor([0.0], device=device),
        torch.tensor([0.0], device=device),
        (torch.ones(dim, device=device) - lambda_sol @ A - mu_sol),
        x_sol,
        mu_sol,
        lambda_sol,
    ])
    cost_sum = torch.tensor(0.0, device=device)
    retrain_num = 0
    retrain_list = []
    denominator = torch.tensor(1.0, device=device)
    for t in tqdm(range(T)):
        if t == 0:
            cost_sum = torch.dot(c_all[t], x_sol) - f_all[t]
            continue
        cp = CPD(c_all[:t].cpu().numpy().astype(np.float64))
        start = cp[-1] if cp else 0
        c_avg = c_all[start:t].mean(dim=0)

        state[0] = float(model_time - t)
        state[1] = float(start - t)
        state[2] = float(model_sample_size)
        state[3:3 + dim] = c_avg - lambda_sol @ A - mu_sol
        state[3 + dim:3 + 2 * dim] = x_sol
        state[3 + 2 * dim:3 + 3 * dim] = mu_sol
        state[3 + 3 * dim:3 + 3 * dim + con_num] = lambda_sol

        action, log_prob, value = model.act(state.clone())

        if action == 1:
            sol = linprog(c_avg.cpu().numpy(), A_ub=A.cpu().numpy(), b_ub=b.cpu().numpy(), bounds=[(0, None)] * dim, method='highs')
            objVal, x_val = solve_ilp(A.cpu().numpy(), b.cpu().numpy(), c_avg.cpu().numpy())
            if sol.success:
                x_sol = torch.tensor(x_val, dtype=torch.float32, device=device)
                mu_sol = torch.tensor(sol.lower.marginals, dtype=torch.float32, device=device)
                lambda_sol = torch.tensor(sol.ineqlin.marginals, dtype=torch.float32, device=device)
            model_time = t
            model_sample_size = t - start
            retrain_num += 1
            retrain_list.append(t)

        reg = torch.dot(c_all[t], x_sol) - f_all[t] + c * action
        cost_sum += reg

    return {
        "run_id": idx,
        "total_cost": cost_sum.item(),
        "retrain_num": retrain_num,
        "retrain_steps": retrain_list,
    }

test_run_ray = ray.remote(num_gpus=0.05)(test_run)

if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)
    c = 10.0
    gamma = 0.9

    # === 加载测试数据 ===
    with open('test_dataset_01.pkl', 'rb') as f:
        dataset = pickle.load(f)  
    futures = [
        test_run_ray.remote(i, data, c, gamma)
        for i, data in enumerate(dataset)
    ]
    results = ray.get(futures)

    avg_cost = sum(r['total_cost'] for r in results) / len(results)
    print(f"✅ Average Total Cost over {len(results)} test cases: {avg_cost:.2f}")

    