from src.datasets.dataset_creator.dsl_0 import *
from src.datasets.dataset_creator.constants import *
from src.datasets.dataset_creator.arc_types import *
import numpy as np
from typing import Dict
import random
import math


def generate_random_colored_square(max_size=30) -> np.ndarray:
    """Generate a random grid of given size with specified number of colors."""
    colors = 10
    grid_size = np.random.randint(2, max_size + 1)

    return np.random.randint(0, colors, (grid_size, grid_size)), {}


def generate_random_colored_rectangle(max_size=30) -> np.ndarray:
    """Generate a random grid of given size with specified number of colors."""
    colors = 10
    row_size = np.random.randint(2, max_size + 1)
    col_size = np.random.randint(2, max_size + 1)

    return np.random.randint(0, colors, (row_size, col_size)), {}


def generate_limited_color_grid() -> tuple[np.ndarray, dict]:
    """
    Generate a grid of specified size where each color appears at most 30 times.
    The total area of the grid is limited to 300 cells (10 colors * 30 occurrences each).

    Args:
        grid_size (int): Requested size of the grid (grid_size x grid_size)
        num_colors (int): Number of colors to use (default is 10, colors 0-9)

    Returns:
        tuple[np.ndarray, dict]: Grid with limited color occurrences and empty extra dict
    """

    grid_size = np.random.randint(1, 31)
    num_colors = 10

    # Adjust grid_size to ensure total area doesn't exceed 300
    max_area = 10 * 30  # Maximum allowed area
    adjusted_size = min(grid_size, int(math.sqrt(max_area)))

    # Initialize the grid
    grid = np.zeros((adjusted_size, adjusted_size), dtype=int)

    # Initialize color counts
    color_counts = {i: 0 for i in range(num_colors)}

    # Fill the grid cell by cell
    for i in range(adjusted_size):
        for j in range(adjusted_size):
            # Get colors that haven't reached the limit of 30
            available_colors = [color for color, count in color_counts.items() if count < 30]

            # If all colors have reached 30, use all colors
            if not available_colors:
                available_colors = list(range(num_colors))

            # Choose a random color from available colors
            color = random.choice(available_colors)

            # Assign the color to the cell and update the count
            grid[i, j] = color
            color_counts[color] += 1

    return grid, {}


def generate_grid_with_repeating_pattern() -> Tuple[np.ndarray, Dict]:
    """
    Generate a grid with a key pattern and its repetitions, ensuring no overlaps or touching shapes.

    Args:
        grid_size (int): Size of the grid (grid_size x grid_size)
        num_colors (int): Number of colors to use

    Returns:
        Tuple[np.ndarray, dict]: Grid with patterns and detailed extra information
    """

    num_colors = 10

    # why is 10 the minimum length?
    grid_size = np.random.randint(10, 31)

    grid = np.zeros((grid_size, grid_size), dtype=int)
    occupied = np.zeros((grid_size, grid_size), dtype=bool)

    pattern_size = 3  # Fixed size for simplicity

    def create_pattern():
        return np.random.randint(1, num_colors, (pattern_size, pattern_size))

    def is_space_available(x, y, size):
        if x + size + 1 > grid_size or y + size + 1 > grid_size:
            return False
        return not np.any(occupied[max(0, x - 1) : x + size + 1, max(0, y - 1) : y + size + 1])

    def place_pattern(pattern, x, y):
        grid[x : x + pattern_size, y : y + pattern_size] = pattern
        occupied[x : x + pattern_size, y : y + pattern_size] = True

    def find_available_space(size):
        attempts = 0
        while attempts < 100:
            x = random.randint(0, grid_size - size - 1)
            y = random.randint(0, grid_size - size - 1)
            if is_space_available(x, y, size):
                return x, y
            attempts += 1
        return None

    # Create and place the key pattern
    key_pattern = create_pattern()
    place_pattern(key_pattern, 0, 0)

    # Place repetitions of the key pattern
    num_repetitions = random.randint(1, 3)
    repetitions = []
    for _ in range(num_repetitions):
        position = find_available_space(pattern_size)
        if position:
            x, y = position
            place_pattern(key_pattern, x, y)
            repetitions.append((x, y))

    # Place other random patterns
    num_other_patterns = random.randint(2, 4)
    for _ in range(num_other_patterns):
        position = find_available_space(pattern_size)
        if position:
            x, y = position
            other_pattern = create_pattern()
            place_pattern(other_pattern, x, y)

    # Place small distractors
    num_distractors = random.randint(2, 4)
    for _ in range(num_distractors):
        position = find_available_space(2)
        if position:
            x, y = position
            distractor = np.random.randint(1, num_colors, (2, 2))
            grid[x : x + 2, y : y + 2] = distractor
            occupied[x : x + 2, y : y + 2] = True

    extra = {
        "pattern_size": pattern_size,
        "key_pattern": key_pattern.tolist(),
        "num_repetitions": num_repetitions,
        "repetitions": repetitions,
    }

    return grid, extra


def generate_grid_with_unique_pattern() -> Tuple[np.ndarray, Dict]:
    """
    Generate a grid with exactly one unique 3x3 pattern and multiple instances of other patterns and 2x2 distractors.
    The unique pattern is guaranteed to be present.

    Args:
        grid_size (int): Size of the grid (grid_size x grid_size)
        num_colors (int): Number of colors to use

    Returns:
        Tuple[np.ndarray, dict]: Grid with patterns and detailed extra information
    """

    num_colors = 10

    grid_size = np.random.randint(10, 31)

    while True:  # Keep trying until we successfully place the unique pattern
        grid = np.zeros((grid_size, grid_size), dtype=int)
        occupied = np.zeros((grid_size, grid_size), dtype=bool)

        pattern_size = 3  # Size for main patterns
        distractor_size = 2  # Size for distractor patterns

        def create_pattern(size):
            return np.random.randint(1, num_colors, (size, size))

        def is_space_available(x, y, size):
            if x + size + 1 > grid_size or y + size + 1 > grid_size:
                return False
            return not np.any(occupied[max(0, x - 1) : x + size + 1, max(0, y - 1) : y + size + 1])

        def place_pattern(pattern, x, y):
            size = pattern.shape[0]
            grid[x : x + size, y : y + size] = pattern
            occupied[x : x + size, y : y + size] = True

        def find_available_space(size):
            attempts = 0
            while attempts < 100:
                x = random.randint(0, grid_size - size - 1)
                y = random.randint(0, grid_size - size - 1)
                if is_space_available(x, y, size):
                    return x, y
                attempts += 1
            return None

        # Create main patterns (3x3)
        main_patterns = [create_pattern(pattern_size) for _ in range(4)]
        unique_pattern_index = random.randint(0, 3)
        unique_pattern = main_patterns[unique_pattern_index]

        # Try to place the unique pattern first
        unique_position = find_available_space(pattern_size)
        if unique_position is None:
            continue  # If we can't place the unique pattern, start over

        x, y = unique_position
        place_pattern(unique_pattern, x, y)

        # Create distractor patterns (2x2)
        distractor_patterns = [create_pattern(distractor_size) for _ in range(3)]

        all_patterns: List[Tuple[np.ndarray, int, List[Tuple[int, int]]]] = [
            (pattern, pattern_size, [])
            for i, pattern in enumerate(main_patterns)
            if i != unique_pattern_index
        ] + [(pattern, distractor_size, []) for pattern in distractor_patterns]

        # Place other patterns
        for pattern, size, positions in all_patterns:
            placements = 0
            while placements < 2:
                position = find_available_space(size)
                if position:
                    x, y = position
                    place_pattern(pattern, x, y)
                    positions.append((x, y))
                    placements += 1
                else:
                    break  # If no space, move to next pattern

        # If we've made it here, we've successfully placed the unique pattern and others
        break

    extra = {
        "pattern_size": pattern_size,
        "unique_pattern": unique_pattern.tolist(),
        "unique_position": unique_position,
    }

    return grid, extra


def generate_two_binary_bitmaps() -> Tuple[np.ndarray, Dict]:
    """
    Generate two random binary bitmaps separated by a black line.

    Args:
        max_size (int): Maximum size of each bitmap (NxN)

    Returns:
        Tuple[np.ndarray, Dict]: The combined grid and extra information
    """
    max_size = 14
    N = np.random.randint(2, max_size + 1)

    # Create two random binary bitmaps
    bitmap1 = np.random.randint(0, 2, (N, N), dtype=int)
    bitmap2 = np.random.randint(0, 2, (N, N), dtype=int)

    # Create the combined grid with a black separator
    combined_grid = np.zeros((N, 2 * N + 1), dtype=int)
    combined_grid[:, :N] = bitmap1
    combined_grid[:, N + 1 :] = bitmap2

    extra = {"bitmap_size": N}

    return combined_grid, extra
