import pickle

import jax
import numpy as np

from memento.environments.cvrp.environment import MementoCVRP
from memento.environments.tsp.environment import MementoTSP
from memento.utils.data import prepare_problem_batch

envs = {"tsp": MementoTSP, "vrp": MementoCVRP}

if __name__ == "__main__":
    # define the number of problems to create
    num_problems = 64  # 512
    seed = 1235
    env_name = "tsp"  # vrp or tsp

    # define the env type
    env_type = envs[env_name]

    instance_sizes = [500]  # [100]

    # define a key
    key = jax.random.PRNGKey(seed)

    # loop over the instance sizes
    for instance_size in instance_sizes:
        # create the environment
        environment = env_type(instance_size)

        # get the number of devices
        num_devices = len(jax.local_devices())

        _, start_key = jax.random.split(key, 2)

        # create the problems
        problems, start_positions, acting_keys = prepare_problem_batch(
            key,
            start_key,
            environment,
            num_problems,
            1,
            1,
        )

        print(problems.shape)

        # convert problems to numpy array
        problems = np.array(problems)

        folder = "data/validation/"
        filename = folder + f"{env_name}{instance_size}_test_small_seed{seed}.pkl"

        # save problems with pickle
        with open(filename, "wb") as f:
            pickle.dump(problems, f)
