import numpy as np
import os
import jax.numpy as jnp
from functools import partial
from src.datasets.dataset_creator.creators.big_four.create_object_v1 import (
    create_dataset as create_dataset_object,
)
from src.datasets.dataset_creator.creators.big_four.create_geometry_v1 import (
    create_dataset as create_dataset_geometry,
)
from src.datasets.dataset_creator.creators.big_four.create_counting_v1 import (
    create_dataset as create_dataset_counting,
)
from src.datasets.dataset_creator.creators.big_four.create_agency_v1 import (
    create_dataset as create_dataset_agency,
)

from src.datasets.dataset_creator.creators.big_four.create_object_v1 import (
    TRANSFORMATIONS as TRANSFORMATIONS_object,
)
from src.datasets.dataset_creator.creators.big_four.create_geometry_v1 import (
    TRANSFORMATIONS as TRANSFORMATIONS_geometry,
)
from src.datasets.dataset_creator.creators.big_four.create_counting_v1 import (
    TRANSFORMATIONS as TRANSFORMATIONS_counting,
)
from src.datasets.dataset_creator.creators.big_four.create_agency_v1 import (
    TRANSFORMATIONS as TRANSFORMATIONS_agency,
)

from src.data_utils import DATASETS_BASE_PATH

if __name__ == "__main__":

    # Creating the directory for dataset version 0_objectness
    os.makedirs("{DATASETS_BASE_PATH}/storage/v1_main_train", exist_ok=True)
    os.makedirs("{DATASETS_BASE_PATH}/storage/v1_main_test", exist_ok=True)

    # Set the parameters for the dataset
    B_train = 100_000  # Number of training examples
    B_test = 1_000  # Number of testing examples
    N = 4  # Number of samples per batch
    seed = 42

    print("Starting main dataset generation...")

    # Seed to control randomness
    np.random.seed(seed)

    def concatenate_datasets(*create_dataset_funcs):
        datasets_train, datasets_test = zip(
            *(func(B_train // 4, B_test // 4, N, seed) for func in create_dataset_funcs)
        )
        return (
            [np.concatenate(data) for data in zip(*datasets_train)],
            [np.concatenate(data) for data in zip(*datasets_test)],
        )

    # Generate and concatenate datasetsgrid_sizes
    train_sets, test_sets = concatenate_datasets(
        partial(create_dataset_object, baseline_program_id=0),
        partial(create_dataset_geometry, baseline_program_id=len(TRANSFORMATIONS_object)),
        partial(
            create_dataset_counting,
            baseline_program_id=len(TRANSFORMATIONS_object) + len(TRANSFORMATIONS_geometry),
        ),
        partial(
            create_dataset_agency,
            baseline_program_id=len(TRANSFORMATIONS_object)
            + len(TRANSFORMATIONS_geometry)
            + len(TRANSFORMATIONS_counting),
        ),
    )
    dataset_train, grid_shapes_train, program_ids_train = train_sets
    dataset_test, grid_shapes_test, program_ids_test = test_sets

    # Check we generated valid grids
    assert np.max(dataset_train) <= 9
    assert np.min(dataset_train) >= 0

    assert np.max(dataset_test) <= 9
    assert np.min(dataset_test) >= 0

    assert np.max(grid_shapes_train) <= 30
    assert np.min(grid_shapes_train) >= 1

    assert np.max(grid_shapes_test) <= 30
    assert np.min(grid_shapes_test) >= 1

    assert np.min(program_ids_train) >= 0
    assert np.min(program_ids_test) >= 0

    # Save the datasets
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_main_train/grids.npy", dataset_train.astype(jnp.uint8))
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_main_train/shapes.npy", grid_shapes_train.astype(jnp.uint8))
    np.save(
        f"{DATASETS_BASE_PATH}/storage/v1_main_train/program_ids.npy", program_ids_train.astype(jnp.uint8)
    )
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_main_test/grids.npy", dataset_test.astype(jnp.uint8))
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_main_test/shapes.npy", grid_shapes_test.astype(jnp.uint8))
    np.save(f"{DATASETS_BASE_PATH}/storage/v1_main_test/program_ids.npy", program_ids_test.astype(jnp.uint8))

    print("Train dataset saved to 'v1_main_train' folder.")
    print("Test dataset saved to 'v1_main_test' folder.")
    print("Train dataset of shape:", dataset_train.shape)
    print("Train shapes of shape:", grid_shapes_train.shape)
    print("Train program IDs of shape:", program_ids_train.shape)
    print("Test dataset of shape:", dataset_test.shape)
    print("Test shapes of shape:", grid_shapes_test.shape)
    print("Test program IDs of shape:", program_ids_test.shape)
