import numpy as np
import jax.numpy as jnp

def create_sparse_point_grid(min_distance_points: int, min_distance_right: int) -> np.ndarray:
    """
    Create a grid with randomly placed colored points, ensuring minimum distance between points
    and from the right border. All non-black points will have the same randomly chosen color.
    
    :param rows: Number of rows in the grid
    :param cols: Number of columns in the grid
    :param max_colors: Maximum number of colors to use (excluding background)
    :param min_distance_points: Minimum distance between points in both x and y dimensions
    :param min_distance_right: Minimum distance from the right border
    :return: 2D numpy array representing the grid
    """
    
    max_colors = 10  # Maximum number of colors
    
    rows = 30 # TODO: Work on getting this for smaller grids without error
    cols = 30
    
    min_distance_points += 1 # Ensure there is a gap of at least one
    
    # Initialize grid with zeros (black background)
    grid = np.zeros((rows, cols), dtype=int)
    
    # List to keep track of placed points
    placed_points = []
    
    def is_valid_position(x, y):
        # Check distance from right border
        if x > cols - min_distance_right:
            return False
        
        # Check distance from other points
        for px, py in placed_points:
            if abs(x - px) < min_distance_points and abs(y - py) < min_distance_points:
                return False
        return True
    
    # Choose a single random color for all points
    color = np.random.randint(1, max_colors)
    
    # Try to place points
    max_attempts = rows * cols  # Limit attempts to avoid infinite loop
    attempts = 0
    
    while attempts < max_attempts:
        x = np.random.randint(0, cols - min_distance_right + 1)
        y = np.random.randint(0, rows)
        
        if is_valid_position(x, y):
            grid[y, x] = color
            placed_points.append((x, y))
        
        attempts += 1
    
    return grid, {}