import pickle

import jax
import numpy as np

from memento.environments.cvrp.environment import MementoCVRP
from memento.environments.tsp.environment import MementoTSP
from memento.utils.data import prepare_problem_batch

envs = {"tsp": MementoTSP, "vrp": MementoCVRP}

if __name__ == "__main__":
    # # define the number of problems to create
    # num_problems = 1000
    # seed = 1235
    # env_name = "vrp"  # vrp or tsp

    # # define the env type
    # env_type = envs[env_name]

    # # instance_sizes = [175, 225, 250]
    # instance_size = 175

    # folder = "data/validation/"
    # filename = folder + f"{env_name}{instance_size}_test_small_seed{seed}.pkl"

    filename = "data/validation/test-1000-coords.npy"
    new_filename = "data/validation/test-1000-coords.pkl"
    # filename = "data/validation/tsp150_test_small_seed1235.pkl"

    # Open the pkl file in binary mode
    with open(filename, "rb") as f:
        # Load the serialized object from the file
        obj = pickle.load(f)

        # import numpy as np

        # obj = np.load(f)

    # save problems with pickle
    with open(new_filename, "wb") as f:
        pickle.dump(obj, f)

    # Print the first few lines of the serialized object
    print(obj[:3])
    print(obj.shape)

    if False:
        modified_obj0 = list(map(lambda x: x[0], obj))
        modified_obj1 = list(map(lambda x: x[1], obj))
        modified_obj2 = list(map(lambda x: x[2], obj))

        arr0 = np.array(modified_obj0)
        arr1 = np.array(modified_obj1)
        arr2 = np.array(modified_obj2)

        # concatentate zeros to arr0
        arr0 = np.concatenate((arr0, np.zeros((arr0.shape[0], 1))), axis=1)

        print(arr0[:1])
        print(arr1[:1])
        print(arr0.shape)
        print(arr1.shape)
        print(arr2.shape)

        # concatenate the arrays
        concat_array = np.concatenate((arr1, arr2[:, :, None]), axis=2)

        print(concat_array.shape)
        print(concat_array[:1])

        print("Array 0 shape: ", arr0.shape)
        print("Concate array shape: ", concat_array.shape)

        concat_array = np.concatenate((arr0[:, None, :], concat_array), axis=1)

        print(concat_array.shape)
        print(concat_array[:1])

        print("Modified ob 1: ", modified_obj1[:1])
        print("Modified obj 2: ", modified_obj2[:1])

    # modified_obj = list(map(lambda x,y:

    print(len(obj))

    # folder = "data/validation/"
    # filename = folder + f"{env_name}{instance_size}_test_small_seed{seed}_fixed.pkl"

    # # save problems with pickle
    # with open(filename, "wb") as f:
    #     pickle.dump(concat_array, f)
