import os

import numpy as np
from tqdm import tqdm
import jax.numpy as jnp

from src.datasets.dataset_creator.utils import pad_grid
from src.datasets.dataset_creator.programs.v2_programs import TRANSFORMATIONS


def generate_dataset(B: int, N: int, grid_sizes: list, colors: int) -> tuple:
    """
    Generate a dataset of movement transformations using varying grid sizes.

    Args:
        B: Batch size
        N: Number of samples (input-output pairs) per batch
        grid_sizes: List of possible grid sizes
        colors: Number of colors in the grids
        input_generator: Function to generate input grids
        program: Function to transform input grids

    Returns:
        data: Numpy array of shape (B, N, 30, 30, 2)
        grid_shapes: Numpy array of shape (B, N, 2, 2)  # [input_shape, output_shape]
        program_ids: Numpy array of shape (B,) containing program IDs
    """
    data = np.zeros((B, N, 30, 30, 2), dtype=np.int32)
    grid_shapes = np.zeros((B, N, 2, 2), dtype=np.int32)
    program_ids = np.zeros(B, dtype=np.int32)

    for b in tqdm(range(B), desc="Generating batches"):
        while True:
            idx = np.random.randint(len(TRANSFORMATIONS))
            transformation = TRANSFORMATIONS[idx]
            input_generator = transformation["input_generator"]
            program = transformation["program"]

            batch_valid = True
            batch_data = np.zeros((N, 30, 30, 2), dtype=np.int32)
            batch_shapes = np.zeros((N, 2, 2), dtype=np.int32)

            for n in range(N):
                row_size = np.random.choice(grid_sizes)
                col_size = np.random.choice(grid_sizes)

                input_grid = input_generator(row_size, col_size, colors)
                output_grid = np.array(program(tuple(tuple(row) for row in input_grid)))

                if np.array_equal(input_grid, output_grid):
                    batch_valid = False
                    break

                batch_data[n, :, :, 0] = pad_grid(input_grid, (30, 30))
                batch_data[n, :, :, 1] = pad_grid(output_grid, (30, 30))
                batch_shapes[n] = np.array(
                    [[input_grid.shape[0], output_grid.shape[0]], [input_grid.shape[1], output_grid.shape[1]]]
                )

            if batch_valid:
                data[b] = batch_data
                grid_shapes[b] = batch_shapes
                program_ids[b] = idx
                break

    return data, grid_shapes, program_ids


if __name__ == "__main__":

    # This creates a dataset gird sizes drawn uniformly from a set of specified grid sizes
    # of arc style tasks with transformations from from v_0 programs

    # Creating the directory for dataset version 0
    os.makedirs("src/datasets/v2_program_train", exist_ok=True)
    os.makedirs("src/datasets/v2_program_test", exist_ok=True)

    # Set the parameters for the dataset
    B_train = 100000  # Number of training examples
    B_test = 1000  # Number of testing examples
    N = 3  # Number of samples per batch, set to 1 for individual examples
    grid_sizes = list(range(5, 31))
    colors = 10
    seed = 42

    print("Starting dataset generation...")

    # Seed seed to control randomness
    np.random.seed(seed)

    # Generate the training dataset
    dataset_train, grid_shapes_train, program_ids_train = generate_dataset(B_train, N, grid_sizes, colors)

    # Generate the testing dataset
    dataset_test, grid_shapes_test, program_ids_test = generate_dataset(B_test, N, grid_sizes, colors)

    # Save the datasets
    np.save(f"src/datasets/v2_program_train/grids.npy", dataset_train.astype(jnp.uint8))
    np.save(f"src/datasets/v2_program_train/shapes.npy", grid_shapes_train.astype(jnp.uint8))
    np.save(f"src/datasets/v2_program_train/program_ids.npy", program_ids_train.astype(jnp.uint8))
    np.save(f"src/datasets/v2_program_test/grids.npy", dataset_test.astype(jnp.uint8))
    np.save(f"src/datasets/v2_program_test/shapes.npy", grid_shapes_test.astype(jnp.uint8))
    np.save(f"src/datasets/v2_program_test/program_ids.npy", program_ids_test.astype(jnp.uint8))

    print("Dataset, grid shapes, and program IDs saved to 'dataset_v2_program' folder.")
    print("Train dataset of shape:", dataset_train.shape)
    print("Train shapes of shape:", grid_shapes_train.shape)
    print("Train program IDs of shape:", program_ids_train.shape)
    print("Test dataset of shape:", dataset_test.shape)
    print("Test shapes of shape:", grid_shapes_test.shape)
    print("Test program IDs of shape:", program_ids_test.shape)
