import numpy as np
from typing import Tuple, Dict

import numpy as np
from typing import Tuple, Dict

def create_zigzag_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    for i in range(n):
        pattern[i, :] = colors[(i % 2)]
        if i % 2 == 1:
            pattern[i, :] = pattern[i, ::-1]
    return pattern

def create_diamond_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=3, replace=False)
    for i in range(n):
        for j in range(n):
            if i + j < n // 2 or i + j >= n * 3 // 2 or abs(i - j) >= n // 2:
                pattern[i, j] = colors[0]
            elif i + j == n // 2 or i + j == n * 3 // 2 - 1 or abs(i - j) == n // 2 - 1:
                pattern[i, j] = colors[1]
            else:
                pattern[i, j] = colors[2]
    return pattern

def create_spiral_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    dx, dy = [0, 1, 0, -1], [1, 0, -1, 0]
    x, y, d = 0, -1, 0
    for i in range(n * n):
        x, y = x + dx[d], y + dy[d]
        if 0 <= x < n and 0 <= y < n and pattern[x][y] == 0:
            pattern[x][y] = colors[i % 2]
        else:
            x, y = x - dx[d], y - dy[d]
            d = (d + 1) % 4
            x, y = x + dx[d], y + dy[d]
            pattern[x][y] = colors[i % 2]
    return pattern

def create_concentric_squares_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    num_colors = min(n//2, 9)  # Use at most 9 colors
    colors = np.random.choice(range(1, 10), size=num_colors, replace=False)
    for i in range(n//2):
        color_index = i % num_colors
        pattern[i:n-i, i:n-i] = colors[color_index]
    return pattern

def create_checkerboard_gradient_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=min(n, 9), replace=False)
    for i in range(n):
        for j in range(n):
            pattern[i, j] = colors[(i + j) % len(colors)]
    return pattern

def generate_interesting_pattern(n):
    pattern_functions = [
        create_zigzag_pattern,
        create_diamond_pattern,
        create_spiral_pattern,
        create_concentric_squares_pattern,
        create_checkerboard_gradient_pattern,
        create_radial_pattern,
        create_maze_pattern,
        create_wave_interference_pattern,
        create_voronoi_pattern,
        create_cellular_automaton_pattern,
        create_mandelbrot_pattern,
    ]
    chosen_function = np.random.choice(pattern_functions)
    return chosen_function(n)

def missing_pattern_input_generator() -> Tuple[np.ndarray, Dict]:
    n = np.random.randint(10, 21)  # Generate patterns of size 10x10 to 20x20
    original_pattern = generate_interesting_pattern(n)
    
    # Create a mask to obscure part of the pattern
    mask = np.ones_like(original_pattern, dtype=bool)
    num_black_cells = np.random.randint(n*n // 4, n*n // 2)  # Obscure 25% to 50% of cells
    black_indices = np.random.choice(n*n, num_black_cells, replace=False)
    mask.flat[black_indices] = False
    
    # Apply the mask to create the input pattern
    input_pattern = np.where(mask, original_pattern, 0)
    
    extra = {
        'original_pattern': original_pattern,
        'mask': mask
    }
    
    return input_pattern, extra




def create_spiral_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    dx, dy = [0, 1, 0, -1], [1, 0, -1, 0]
    x, y, d = 0, -1, 0
    for i in range(n * n):
        x, y = x + dx[d], y + dy[d]
        if 0 <= x < n and 0 <= y < n and pattern[x][y] == 0:
            pattern[x][y] = colors[i % 2]
        else:
            x, y = x - dx[d], y - dy[d]
            d = (d + 1) % 4
            x, y = x + dx[d], y + dy[d]
            pattern[x][y] = colors[i % 2]
    return pattern

def create_concentric_squares_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    num_colors = min(n//2, 9)
    colors = np.random.choice(range(1, 10), size=num_colors, replace=False)
    for i in range(n//2):
        color_index = i % num_colors
        pattern[i:n-i, i:n-i] = colors[color_index]
    return pattern

def create_checkerboard_gradient_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=min(n, 9), replace=False)
    for i in range(n):
        for j in range(n):
            pattern[i, j] = colors[(i + j) % len(colors)]
    return pattern

def create_radial_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=3, replace=False)
    center = n // 2
    for i in range(n):
        for j in range(n):
            distance = np.sqrt((i - center)**2 + (j - center)**2)
            if distance < n / 4:
                pattern[i, j] = colors[0]
            elif distance < n / 2:
                pattern[i, j] = colors[1]
            else:
                pattern[i, j] = colors[2]
    return pattern

def create_sierpinski_triangle(n):
    def sierpinski(n):
        if n == 1:
            return np.array([[1]])
        smaller = sierpinski(n // 2)
        top = np.hstack([smaller, smaller])
        bottom = np.hstack([smaller, np.zeros_like(smaller)])
        return np.vstack([top, bottom])
    
    base = sierpinski(n)
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    pattern = np.zeros((n, n), dtype=int)
    pattern[base == 1] = colors[0]
    pattern[base == 0] = colors[1]
    return pattern

def create_maze_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    
    def recursive_division(x, y, width, height, horizontal):
        if width < 3 or height < 3:
            return
        
        if horizontal:
            divide_at = np.random.randint(y + 1, y + height - 1)
            passage_at = np.random.randint(x, x + width)
            pattern[divide_at, x:x+width] = colors[0]
            pattern[divide_at, passage_at] = colors[1]
            recursive_division(x, y, width, divide_at - y, not horizontal)
            recursive_division(x, divide_at + 1, width, y + height - divide_at - 1, not horizontal)
        else:
            divide_at = np.random.randint(x + 1, x + width - 1)
            passage_at = np.random.randint(y, y + height)
            pattern[y:y+height, divide_at] = colors[0]
            pattern[passage_at, divide_at] = colors[1]
            recursive_division(x, y, divide_at - x, height, not horizontal)
            recursive_division(divide_at + 1, y, x + width - divide_at - 1, height, not horizontal)
    
    recursive_division(0, 0, n, n, np.random.choice([True, False]))
    pattern[pattern == 0] = colors[1]
    return pattern

def create_mosaic_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=4, replace=False)
    tile_size = n // 4
    for i in range(0, n, tile_size):
        for j in range(0, n, tile_size):
            tile_type = np.random.randint(4)
            if tile_type == 0:  # Solid color
                pattern[i:i+tile_size, j:j+tile_size] = colors[0]
            elif tile_type == 1:  # Diagonal split
                for x in range(tile_size):
                    for y in range(tile_size):
                        if x < y:
                            pattern[i+x, j+y] = colors[1]
                        else:
                            pattern[i+x, j+y] = colors[2]
            elif tile_type == 2:  # Concentric squares
                for k in range(tile_size):
                    pattern[i+k:i+tile_size-k, j+k:j+tile_size-k] = colors[k % 2 + 2]
            else:  # Checkerboard
                pattern[i:i+tile_size:2, j:j+tile_size:2] = colors[0]
                pattern[i+1:i+tile_size:2, j+1:j+tile_size:2] = colors[0]
                pattern[i+1:i+tile_size:2, j:j+tile_size:2] = colors[3]
                pattern[i:i+tile_size:2, j+1:j+tile_size:2] = colors[3]
    return pattern

def create_wave_interference_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=5, replace=False)
    x = np.linspace(0, 2*np.pi, n)
    y = np.linspace(0, 2*np.pi, n)
    X, Y = np.meshgrid(x, y)
    Z1 = np.sin(X)
    Z2 = np.sin(Y)
    Z3 = np.sin(X + Y)
    Z = Z1 + Z2 + Z3
    for i in range(5):
        pattern[(Z >= -3 + 1.2*i) & (Z < -3 + 1.2*(i+1))] = colors[i]
    return pattern

def create_voronoi_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=5, replace=False)
    points = np.random.rand(5, 2) * n
    for i in range(n):
        for j in range(n):
            distances = np.sqrt(((points - [i, j])**2).sum(axis=1))
            pattern[i, j] = colors[np.argmin(distances)]
    return pattern

def create_cellular_automaton_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    pattern[0] = np.random.choice(colors, n)
    for i in range(1, n):
        prev = np.pad(pattern[i-1], (1, 1), mode='wrap')
        pattern[i] = colors[np.sum([prev[:-2], prev[1:-1], prev[2:]], axis=0) % 2]
    return pattern

def create_fractal_tree_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    
    def draw_branch(x, y, length, angle, depth):
        if depth == 0 or length < 1:
            return
        x2 = int(x + length * np.cos(angle))
        y2 = int(y + length * np.sin(angle))
        rr, cc = line(x, y, x2, y2)
        pattern[rr[rr < n], cc[cc < n]] = colors[0]
        draw_branch(x2, y2, length * 0.7, angle - np.pi/4, depth - 1)
        draw_branch(x2, y2, length * 0.7, angle + np.pi/4, depth - 1)
    
    from skimage.draw import line
    draw_branch(n//2, n-1, n//3, -np.pi/2, 7)
    pattern[pattern == 0] = colors[1]
    return pattern

def create_spiral_galaxy_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=3, replace=False)
    center = n // 2
    for i in range(n):
        for j in range(n):
            dx, dy = i - center, j - center
            r = np.sqrt(dx**2 + dy**2)
            theta = np.arctan2(dy, dx) + 0.5 * np.log(r)
            if r < n/3:
                pattern[i, j] = colors[0]
            elif np.sin(3*theta) > 0:
                pattern[i, j] = colors[1]
            else:
                pattern[i, j] = colors[2]
    return pattern

def create_penrose_tiling_pattern(n):
    def rotate(x, y, angle):
        return x * np.cos(angle) - y * np.sin(angle), x * np.sin(angle) + y * np.cos(angle)

    def draw_kite(x, y, size, angle, depth):
        if depth == 0 or size < 1:
            return
        points = [(0, 0), (1, 0), (0.5, 0.3), (0.2, 0.6), (0.8, 0.6)]
        for i in range(len(points)):
            x1, y1 = rotate(points[i][0] * size, points[i][1] * size, angle)
            x2, y2 = rotate(points[(i+1)%len(points)][0] * size, points[(i+1)%len(points)][1] * size, angle)
            rr, cc = line(int(x+x1), int(y+y1), int(x+x2), int(y+y2))
            rr = rr[(rr >= 0) & (rr < n)]
            cc = cc[(cc >= 0) & (cc < n)]
            pattern[rr, cc] = colors[0]
        draw_kite(x + rotate(0.2 * size, 0.6 * size, angle)[0],
                  y + rotate(0.2 * size, 0.6 * size, angle)[1],
                  size * 0.6, angle + np.pi/5, depth-1)
        draw_kite(x + rotate(0.8 * size, 0.6 * size, angle)[0],
                  y + rotate(0.8 * size, 0.6 * size, angle)[1],
                  size * 0.6, angle - np.pi/5, depth-1)

    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    from skimage.draw import line
    draw_kite(n//2, n//2, n//2, 0, 5)
    pattern[pattern == 0] = colors[1]
    return pattern

def create_mandelbrot_pattern(n):
    pattern = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=5, replace=False)
    x = np.linspace(-2, 1, n)
    y = np.linspace(-1, 1, n)
    c = x[:, np.newaxis] + 1j * y[np.newaxis, :]
    z = np.zeros_like(c)
    divtime = np.zeros_like(z, dtype=int)
    for i in range(20):
        z = z**2 + c
        diverge = np.abs(z) > 2
        div_now = diverge & (divtime == 0)
        divtime[div_now] = i
        z[diverge] = 2
    pattern = colors[divtime % 5]
    return pattern




def create_partial_stripe_pattern(n: int) -> Tuple[np.ndarray, Dict]:
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    stripe_width = np.random.randint(1, 4)
    pattern = np.zeros((n, n), dtype=int)
    visible_height = min(n, max(stripe_width * 2, n // 2))  # Ensure at least one full iteration
    for i in range(visible_height):
        pattern[i] = colors[(i // stripe_width) % 2]
    
    extra = {
        'pattern_type': 'stripe',
        'stripe_width': stripe_width,
        'colors': colors.tolist(),
        'visible_height': visible_height
    }
    return pattern, extra

def create_partial_checkerboard_pattern(n: int) -> Tuple[np.ndarray, Dict]:
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    square_size = np.random.randint(1, 4)
    pattern = np.zeros((n, n), dtype=int)
    visible_size = min(n, max(square_size * 2, n // 2))  # Ensure at least one full iteration
    for i in range(visible_size):
        for j in range(visible_size):
            pattern[i, j] = colors[((i // square_size) + (j // square_size)) % 2]
    
    extra = {
        'pattern_type': 'checkerboard',
        'square_size': square_size,
        'colors': colors.tolist(),
        'visible_size': visible_size
    }
    return pattern, extra

def create_partial_wave_pattern(n: int) -> Tuple[np.ndarray, Dict]:
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    wave_length = np.random.randint(4, 8)
    amplitude = np.random.randint(1, 4)
    pattern = np.zeros((n, n), dtype=int)
    visible_height = min(n, max(wave_length, n // 2))  # Ensure at least one full iteration
    for i in range(visible_height):
        wave = int(amplitude * np.sin(2 * np.pi * i / wave_length))
        for j in range(n):
            if j <= n//2 + wave:
                pattern[i, j] = colors[0]
            else:
                pattern[i, j] = colors[1]
    
    extra = {
        'pattern_type': 'wave',
        'wave_length': wave_length,
        'amplitude': amplitude,
        'colors': colors.tolist(),
        'visible_height': visible_height
    }
    return pattern, extra

def create_partial_zigzag_pattern(n: int) -> Tuple[np.ndarray, Dict]:
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    zigzag_width = np.random.randint(2, 5)
    pattern = np.zeros((n, n), dtype=int)
    visible_height = min(n, max(zigzag_width * 2, n // 2))  # Ensure at least one full iteration
    for i in range(visible_height):
        for j in range(n):
            if (i // zigzag_width) % 2 == 0:
                pattern[i, j] = colors[j % 2]
            else:
                pattern[i, j] = colors[(j + 1) % 2]
    
    extra = {
        'pattern_type': 'zigzag',
        'zigzag_width': zigzag_width,
        'colors': colors.tolist(),
        'visible_height': visible_height
    }
    return pattern, extra

def create_partial_diamond_pattern(n: int) -> Tuple[np.ndarray, Dict]:
    colors = np.random.choice(range(1, 10), size=2, replace=False)
    diamond_size = np.random.randint(3, 6)
    pattern = np.zeros((n, n), dtype=int)
    visible_size = min(n, max(diamond_size, n // 2))  # Ensure at least one full iteration
    for i in range(visible_size):
        for j in range(visible_size):
            if (i % diamond_size) + (j % diamond_size) < diamond_size:
                pattern[i, j] = colors[0]
            else:
                pattern[i, j] = colors[1]
    
    extra = {
        'pattern_type': 'diamond',
        'diamond_size': diamond_size,
        'colors': colors.tolist(),
        'visible_size': visible_size
    }
    return pattern, extra

def generate_partial_pattern() -> Tuple[np.ndarray, Dict]:
    n = np.random.randint(10, 31)  # Random size between 10x10 and 30x30
    
    pattern_functions = [
        create_partial_stripe_pattern,
        create_partial_checkerboard_pattern,
        create_partial_wave_pattern,
        create_partial_zigzag_pattern,
        create_partial_diamond_pattern
    ]
    chosen_function = np.random.choice(pattern_functions)
    return chosen_function(n)


def generate_scattered_points(n: int, num_points: int) -> Tuple[np.ndarray, Dict]:
    grid = np.zeros((n, n), dtype=int)
    colors = np.random.choice(range(1, 10), size=num_points, replace=True)
    points = []
    
    for i in range(num_points):
        y, x = np.random.randint(0, n), np.random.randint(0, n)
        while grid[y, x] != 0:
            y, x = np.random.randint(0, n), np.random.randint(0, n)
        grid[y, x] = colors[i]
        points.append((y, x, colors[i]))
    
    extra = {
        'points': points
    }
    return grid, extra

def generate_beam_extension_input() -> Tuple[np.ndarray, Dict]:
    n = np.random.randint(10, 31)  # Random size between 10x10 and 30x30
    num_points = np.random.randint(3, 10)  # Random number of points between 3 and 9
    return generate_scattered_points(n, num_points)



import numpy as np
from typing import Tuple, Dict

def generate_complex_input(n: int, num_points: int) -> Tuple[np.ndarray, Dict]:
    grid = np.zeros((n, n), dtype=int)
    beam_color = np.random.randint(1, 10)  # Single color for beam points
    points = []
    
    # Generate beam points
    for _ in range(num_points):
        y, x = np.random.randint(0, n), np.random.randint(0, n)
        while grid[y, x] != 0:
            y, x = np.random.randint(0, n), np.random.randint(0, n)
        grid[y, x] = beam_color
        points.append((y, x, beam_color))
    
    # Add irrelevant patterns
    add_checkerboard_pattern(grid)
    add_diagonal_stripes(grid)
    add_random_noise(grid)
    add_circular_pattern(grid)
    
    extra = {
        'points': points,
        'beam_color': beam_color
    }
    return grid, extra

def add_checkerboard_pattern(grid):
    n = grid.shape[0]
    checkerboard_color = np.random.randint(1, 10)
    for i in range(0, n, 2):
        for j in range(0, n, 2):
            if grid[i, j] == 0:
                grid[i, j] = checkerboard_color

def add_diagonal_stripes(grid):
    n = grid.shape[0]
    stripe_color = np.random.randint(1, 10)
    for i in range(n):
        for j in range(n):
            if grid[i, j] == 0 and (i + j) % 5 == 0:
                grid[i, j] = stripe_color

def add_random_noise(grid):
    n = grid.shape[0]
    noise_color = np.random.randint(1, 10)
    noise_points = np.random.choice(n*n, size=n*n//10, replace=False)
    for point in noise_points:
        i, j = point // n, point % n
        if grid[i, j] == 0:
            grid[i, j] = noise_color

def add_circular_pattern(grid):
    n = grid.shape[0]
    circle_color = np.random.randint(1, 10)
    center = n // 2
    radius = n // 4
    for i in range(n):
        for j in range(n):
            if grid[i, j] == 0 and (i - center)**2 + (j - center)**2 <= radius**2:
                grid[i, j] = circle_color



def generate_complex_beam_input() -> Tuple[np.ndarray, Dict]:
    n = np.random.randint(20, 31)  # Random size between 20x20 and 30x30
    num_points = np.random.randint(3, 7)  # Random number of beam points between 3 and 6
    return generate_complex_input(n, num_points)