import numpy as np
import os
import jax.numpy as jnp
from src.datasets.dataset_creator.programs.v1_agency_programs import TRANSFORMATIONS
from functools import partial
from src.datasets.dataset_creator.creators.utils import create_dataset, create_data


# Define dataset creator
create_dataset = partial(create_dataset, transformations=TRANSFORMATIONS)

from src.data_utils import DATASETS_BASE_PATH

if __name__ == "__main__":

    # Creating the directory for dataset
    os.makedirs("{DATASETS_BASE_PATH}/storage/v1_agency_train", exist_ok=True)
    os.makedirs("{DATASETS_BASE_PATH}/storage/v1_agency_test", exist_ok=True)

    # Set the parameters for the dataset
    B_train = 1_000  # Number of training examples
    B_test = 1_000  # Number of testing examples
    N = 4  # Number of samples per batch

    seed = 420

    print("Starting dataset generation...")

    # Seed to control randomness
    np.random.seed(seed)

    # Generate the training dataset
    dataset_train, grid_shapes_train, program_ids_train = create_data(B_train, N, seed, TRANSFORMATIONS)

    # Generate the testing dataset
    dataset_test, grid_shapes_test, program_ids_test = create_data(B_test, N, seed + 1, TRANSFORMATIONS)

    # Save the datasets
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_agency_train/grids.npy", dataset_train.astype(jnp.uint8))
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_agency_train/shapes.npy", grid_shapes_train.astype(jnp.uint8))
    np.save(
        f"{DATASETS_BASE_PATH}/storage/v1_agency_train/program_ids.npy", program_ids_train.astype(jnp.uint8)
    )
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_agency_test/grids.npy", dataset_test.astype(jnp.uint8))
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_agency_test/shapes.npy", grid_shapes_test.astype(jnp.uint8))
    np.save(
        f"{DATASETS_BASE_PATH}/storage/v1_agency_test/program_ids.npy", program_ids_test.astype(jnp.uint8)
    )

    print("Grids, shapes, and program IDs saved to 'v1_agency_train' and 'v1_agency_test' folders.")
    print("Train dataset of shape:", dataset_train.shape)
    print("Train shapes of shape:", grid_shapes_train.shape)
    print("Test dataset of shape:", dataset_test.shape)
    print("Test shapes of shape:", grid_shapes_test.shape)
