from gurobipy import Model, GRB
import time

import os
import pickle

test_data_path = 'test_dataset_01.pkl'
if os.path.exists(test_data_path):
    with open(test_data_path, 'rb') as f:
        test_data = pickle.load(f)
    print("Test data loaded successfully.")

import numpy as np
import time
import ray
from tqdm import tqdm

ray.init(num_cpus=64)

@ray.remote
def solve_ilp(A, b, c, i, j):
    A = np.array(A)
    b = np.array(b)
    c = np.array(c)

    m_dim, n_dim = A.shape
    m = Model()
    m.setParam('OutputFlag', 0)
    m.setParam('MIPGap', 0)

    x_vars = [m.addVar(vtype=GRB.INTEGER, lb=0) for _ in range(n_dim)]
    m.setObjective(sum(c[k] * x_vars[k] for k in range(n_dim)), GRB.MINIMIZE)
    for row in range(m_dim):
        m.addConstr(sum(A[row][k] * x_vars[k] for k in range(n_dim)) <= b[row])

    m.optimize()

    if m.status == GRB.OPTIMAL:
        x_val = [int(x_vars[k].X) for k in range(n_dim)]
        return (i, j, m.objVal, x_val)
    else:
        return (i, j, None, None)

futures = []
for i, item in tqdm(enumerate(test_data)):
    A = item['A']
    b = item['b']
    for j, c in enumerate(item['c_all']):
        futures.append(solve_ilp.remote(A, b, c, i, j))

start = time.time()
results = ray.get(futures)
print("Total time:", time.time() - start, "seconds")

for item in test_data:
    item['f_all_IP'] = [None] * len(item['c_all'])
    item['x_all_IP'] = [None] * len(item['c_all'])

for i, j, obj, x in results:
    test_data[i]['f_all_IP'][j] = obj
    test_data[i]['x_all_IP'][j] = x

ray.shutdown()

# save
import pickle

with open(test_data_path, 'wb') as f:
    pickle.dump(test_data, f)

print("Test data saved successfully to", test_data_path)