import pickle

import jax
import numpy as np

from memento.environments.cvrp.environment import MementoCVRP
from memento.environments.tsp.environment import MementoTSP
from memento.utils.data import prepare_problem_batch

if __name__ == "__main__":
    env = "vrp"  # tsp or vrp
    filename = f"data/{env}100_shifted_dataset.pkl"

    # Open the pkl file in binary mode
    with open(filename, "rb") as f:
        # Load the serialized object from the file
        obj = pickle.load(f)

    num_scales = obj.shape[0]

    # loop over the scales

    for scale in range(num_scales):
        # just extract the corresponding set of problems
        problems = obj[scale]

        # create a new filename for it
        folder = "data/validation/"
        new_filename = folder + f"{env}100_shifted_scale{scale}.pkl"

        # save problems with pickle
        with open(new_filename, "wb") as f:
            pickle.dump(problems, f)
