# %%
import os
import pickle
import numpy as np
n_var = 1000
n_ineq = 500
n_eq = 500
n_ex = 10000
type = 'simple'
filename = f"random_{type}_dataset_var{n_var}_ineq{n_ineq}_eq{n_eq}_ex{n_ex}"
filepath = os.path.join('datasets', type, filename)
with open(filepath, 'rb') as f:
    data = pickle.load(f)

data.valid_frac = 0.1024
data.test_frac = 0.1024
# Extract the data matrices
Q = np.array(data.Q).reshape(1, n_var, n_var)
p = np.array(data.p).reshape(1, n_var, 1)
A = np.array(data.A).reshape(1, n_eq, n_var)
G = np.array(data.G).reshape(1, n_ineq, n_var)
h = np.array(data.h).reshape(1, n_ineq, 1)
trainX = np.array(data.trainX).reshape(data.trainX.shape[0], n_eq, 1)
validX = np.array(data.validX).reshape(data.validX.shape[0], n_eq, 1)
testX = np.array(data.testX).reshape(data.testX.shape[0], n_eq, 1)
trainYstar = np.array(data.trainY).reshape(data.trainY.shape[0], n_var, 1)
validYstar = np.array(data.validY).reshape(data.validY.shape[0], n_var, 1)
testYstar = np.array(data.testY).reshape(data.testY.shape[0], n_var, 1)
# %%
outfile_path = os.path.join('datasets', type, f"dc3_{filename}")
np.savez(outfile_path+"train.npz", Q=Q, p=p, A=A, G=G, h=h, X=trainX, Ystar=trainYstar)
np.savez(outfile_path+"valid.npz", Q=Q, p=p, A=A, G=G, h=h, X=validX, Ystar=validYstar)
np.savez(outfile_path+"test.npz", Q=Q, p=p, A=A, G=G, h=h, X=testX, Ystar=testYstar)
