import numpy as np
from typing import List, Tuple
from tqdm import tqdm
import os
import jax.numpy as jnp
from utils import generate_random_grid, pad_grid, apply_transformation

from v0_programs import TRANSFORMATIONS


def generate_dataset(
    B: int, N: int, grid_sizes: List[int], colors: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generate a dataset of transformations using varying grid sizes.

    Args:
        B: Batch size
        N: Number of samples (input-output pairs) per batch
        grid_sizes: List of possible grid sizes
        colors: Number of colors in the grids

    Returns:
        data: Numpy array of shape (B, N, 30, 30, 2)
        grid_shapes: Numpy array of shape (B, N, 2, 2)  # [input_shape, output_shape]
        program_ids: Numpy array of shape (B,) containing program IDs
    """
    data = np.zeros((B, N, 30, 30, 2), dtype=np.int32)
    grid_shapes = np.zeros((B, N, 2, 2), dtype=np.int32)
    program_ids = np.zeros(B, dtype=np.int32)

    for b in tqdm(range(B), desc="Generating batches"):
        idx = np.random.randint(len(TRANSFORMATIONS))
        transformation = TRANSFORMATIONS[idx]
        program_ids[b] = idx  # Store the program ID

        for n in range(N):
            idx = np.random.randint(len(grid_sizes))
            grid_size = grid_sizes[idx]

            input_grid = generate_random_grid(grid_size, colors)
            output_grid = apply_transformation(input_grid, transformation).astype(np.int32)

            data[b, n, :, :, 0] = pad_grid(input_grid, (30, 30))
            data[b, n, :, :, 1] = pad_grid(output_grid, (30, 30))
            grid_shapes[b, n] = np.array(
                [[input_grid.shape[0], output_grid.shape[0]], [input_grid.shape[1], output_grid.shape[1]]]
            )

    return data, grid_shapes, program_ids


if __name__ == "__main__":

    # This creates a dataset gird sizes drawn uniformly from a set of specified grid sizes
    # of arc style tasks with transformations from from v_0 programs

    # List of all transformation functions to use in generation

    # Creating the directory for dataset version 0
    os.makedirs("src/datasets/v0_program_train", exist_ok=True)
    os.makedirs("src/datasets/v0_program_test", exist_ok=True)

    # Set the parameters for the dataset
    B_train = 100000  # Number of training examples
    B_test = 1000  # Number of testing examples
    N = 3  # Number of samples per batch, set to 1 for individual examples
    grid_sizes = [3, 4, 5, 6, 7, 8, 9, 10]
    colors = 10
    seed = 42

    print("Starting dataset generation...")

    # Seed seed to control randomness
    np.random.seed(seed)

    # Generate the training dataset
    dataset_train, grid_shapes_train, program_ids_train = generate_dataset(B_train, N, grid_sizes, colors)
    # Generate the testing dataset
    dataset_test, grid_shapes_test, program_ids_test = generate_dataset(B_test, N, grid_sizes, colors)

    # Save the datasets
    np.save(f"src/datasets/v0_program_train/grids.npy", dataset_train.astype(jnp.uint8))
    np.save(f"src/datasets/v0_program_train/shapes.npy", grid_shapes_train.astype(jnp.uint8))
    np.save(f"src/datasets/v0_program_train/program_ids.npy", program_ids_train.astype(jnp.uint8))
    np.save(f"src/datasets/v0_program_test/grids.npy", dataset_test.astype(jnp.uint8))
    np.save(f"src/datasets/v0_program_test/shapes.npy", grid_shapes_test.astype(jnp.uint8))
    np.save(f"src/datasets/v0_program_test/program_ids.npy", program_ids_test.astype(jnp.uint8))

    print("Dataset, grid shapes, and program IDs saved to 'dataset_v0_program' folder.")
    print("Train dataset of shape:", dataset_train.shape)
    print("Train shapes of shape:", grid_shapes_train.shape)
    print("Train program IDs of shape:", program_ids_train.shape)
    print("Test dataset of shape:", dataset_test.shape)
    print("Test shapes of shape:", grid_shapes_test.shape)
    print("Test program IDs of shape:", program_ids_test.shape)
