import numpy as np
import os
import jax.numpy as jnp
from functools import partial
from src.datasets.dataset_creator.programs.v1_geometry_programs import TRANSFORMATIONS
from jax import random
from src.datasets.dataset_creator.creators.utils import create_dataset, create_data

create_dataset = partial(create_dataset, transformations=TRANSFORMATIONS)

from src.data_utils import DATASETS_BASE_PATH

if __name__ == "__main__":

    # Creating the directory for dataset version 0
    os.makedirs(os.path.join(DATASETS_BASE_PATH, "storage/v1_geometry_train"), exist_ok=True)
    os.makedirs(os.path.join(DATASETS_BASE_PATH, "storage/v1_geometry_test"), exist_ok=True)

    # Set the parameters for the dataset
    B_train = 1_000  # Number of training examples
    B_test = 1_000  # Number of testing examples
    N = 4  # Number of samples per batch, set to 1 for individual examples
    seed = 420

    print("Starting dataset generation...")

    # Seed seed to control randomness
    np.random.seed(seed)

    # Generate the training dataset
    dataset_train, grid_shapes_train, program_ids_train = create_data(B_train, N, seed, TRANSFORMATIONS)

    # Generate the testing dataset
    dataset_test, grid_shapes_test, program_ids_test = create_data(B_test, N, seed + 1, TRANSFORMATIONS)

    # Check we generated valid grids
    assert np.max(dataset_train) <= 9
    assert np.min(dataset_train) >= 0

    assert np.max(dataset_test) <= 9
    assert np.min(dataset_test) >= 0

    assert np.max(grid_shapes_train) <= 30
    assert np.min(grid_shapes_train) >= 1

    assert np.max(grid_shapes_test) <= 30
    assert np.min(grid_shapes_test) >= 1

    assert np.min(program_ids_train) >= 0
    assert np.min(program_ids_test) >= 0

    # Save the datasets
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_geometry_train/grids.npy", dataset_train.astype(jnp.uint8))
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_geometry_train/shapes.npy", grid_shapes_train.astype(jnp.uint8))
    np.save(
        f"{DATASETS_BASE_PATH}/storage/v1_geometry_train/program_ids.npy", program_ids_train.astype(jnp.uint8)
    )
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_geometry_test/grids.npy", dataset_test.astype(jnp.uint8))
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_geometry_test/shapes.npy", grid_shapes_test.astype(jnp.uint8))
    np.save(
        f"{DATASETS_BASE_PATH}/storage/v1_geometry_test/program_ids.npy", program_ids_test.astype(jnp.uint8)
    )

    print("Grids, shapes, and program IDs saved to 'v1_geometry_train' and 'v1_geometry_test' folders.")
    print("Train dataset of shape:", dataset_train.shape)
    print("Train shapes of shape:", grid_shapes_train.shape)
    print("Train program IDs of shape:", program_ids_train.shape)
    print("Test dataset of shape:", dataset_test.shape)
    print("Test shapes of shape:", grid_shapes_test.shape)
    print("Test program IDs of shape:", program_ids_test.shape)
