import numpy as np
from typing import Tuple, List, Callable
from functools import partial
from src.datasets.dataset_creator.input_generation.v1_objectness import (
    object_generator,
    object_generator_with_holes,
    diverse_shapes_generator,
)
from src.datasets.dataset_creator.input_generation.v1_arc import create_advanced_arc_input
from src.datasets.dataset_creator.input_generation.v1_counting import (
    generate_random_colored_square,
    generate_random_colored_rectangle,
    generate_grid_with_repeating_pattern,
    generate_two_binary_bitmaps,
)
from src.datasets.dataset_creator.input_generation.v1_agent import (
    create_snake_input,
    create_maze_input,
    create_city_network_input,
    create_connect_dots_input,
    create_flood_fill_input,
    create_checkers_input,
    create_light_bulb_input,
    create_radio_coverage_input,
)
from src.datasets.dataset_creator.input_generation.v1_patterns import (
    missing_pattern_input_generator,
    generate_partial_pattern,
    generate_beam_extension_input,
    generate_complex_beam_input,
)
import numpy as np
from typing import Tuple, Dict
from jax import random


def v0_arc_general_input() -> Tuple[np.ndarray, dict]:
    """
    Generate a random input grid for ARC (Abstraction and Reasoning Corpus) tasks.

    This function serves as a generic input generator for ARC-style problems that don't require
    program-specific inputs. It randomly selects from a variety of input generation methods,
    each designed to create different types of patterns or structures commonly found in ARC tasks.

    The available input generators include:
    - Object generation with minimum distance
    - Advanced ARC input creation
    - Random colored squares and rectangles
    - Grids with repeating patterns
    - Game-like structures (snake, maze, city network)
    - Puzzle-like inputs (connect dots, flood fill, checkers)
    - Problem-specific inputs (light bulb placement, radio coverage)

    Returns:
    Tuple[np.ndarray, dict]: A tuple containing two elements:
        1. np.ndarray: A 2D numpy array representing the generated grid. The dimensions and
           content of this grid will vary based on the randomly selected generator.
        2. dict: A metadata dictionary containing additional information about the generated
           input. The contents of this dictionary will vary depending on the selected generator.

    Note:
    The specific dimensions, colors, and patterns in the output will depend on the randomly
    selected generator and its internal parameters. Users of this function should be prepared
    to handle variable-sized outputs and different types of patterns.
    """
    # List of input generation functions
    input_generators: List[Callable] = [
        partial(object_generator, min_distance=2),
        create_advanced_arc_input,
        generate_random_colored_square,
        generate_random_colored_rectangle,
        partial(generate_random_colored_square, max_size=6),
        partial(generate_random_colored_rectangle, max_size=6),
        generate_grid_with_repeating_pattern,
        create_snake_input,
        create_maze_input,
        create_city_network_input,
        create_connect_dots_input,
        create_flood_fill_input,
        create_checkers_input,
        create_light_bulb_input,
        create_radio_coverage_input,
        partial(object_generator_with_holes, min_distance=2),
        diverse_shapes_generator,
        binary_mask_pattern_generator,
        missing_pattern_input_generator,
        generate_partial_pattern,
        generate_beam_extension_input,
        generate_complex_beam_input,
        generate_two_binary_bitmaps,
    ]

    # Randomly select a generator
    selected_generator = np.random.choice(input_generators)

    # Apply the selected generator
    grid, metadata = selected_generator()

    return grid, metadata


def binary_mask_pattern_generator() -> Tuple[np.ndarray, Dict]:
    # Choose the size of the binary mask (2x2, 3x3, 4x4, or 5x5)
    mask_size = np.random.randint(2, 6)

    # Generate the binary mask
    mask = np.random.choice([0, 1], size=(mask_size, mask_size))

    # Choose a color for the mask (1-9)
    mask_color = np.random.randint(1, 10)

    # Apply the color to the mask
    colored_mask = np.where(mask == 1, mask_color, 0)

    extra = {"mask_size": mask_size, "mask_color": mask_color}

    return colored_mask, extra
