import random
import copy
import numpy as np
from typing import overload, Literal
from .impl.sudoku_utils import Board, empty_board, is_valid_fill_number, repr_board


def find_empty(board: Board):
    for row in range(9):
        for col in range(9):
            if board[row, col] == 0:
                return (row, col)
    assert False


def fill_board(board: Board):
    """Fill the board without conflicts, and ensures that the board has a solution"""

    if np.count_nonzero(board == 0) == 0:
        return True

    row, col = find_empty(board)
    candidates = [num for num in range(1, 10) if is_valid_fill_number(board, row, col, num)]
    random.shuffle(candidates)
    for num in candidates:
        board[row][col] = num
        if fill_board(board):
            return True
        board[row][col] = 0

    return False

def remove_numbers(board: Board, k: int):
    out = board.copy()
    for i in np.random.choice(81, size=k, replace=False):
        row = i // 9
        col = i % 9
        out[row, col] = 0
    return out


@overload
def generate_sudoku_puzzle(guess_num: int) -> str: ...
@overload
def generate_sudoku_puzzle(guess_num: int, require_solution: Literal[True]) -> tuple[str, str]: ...
def generate_sudoku_puzzle(guess_num: int, require_solution: bool = False) -> str | tuple[str, str]:    
    board = empty_board()
    if not fill_board(board):
        assert False
    puzzle = remove_numbers(board, guess_num)
    puzzle_str = repr_board(puzzle, rowsep='', colsep='')
    if require_solution:
        solution = repr_board(board)
        return puzzle_str, solution
    else:
        return puzzle_str


def pprint_board(board: Board):
    print(pstr_board(board))


def pstr_board(board: Board):
    lines = []
    for r in range(9):
        line = ""
        for c in range(9):
            val = board[r, c]
            if val == 0:
                line += ". "
            else:
                line += str(val) + " "
            if (c+1) % 3 == 0 and c < 8:
                line += "| "
        lines.append(line)
        if (r+1) % 3 == 0 and r < 8:
            lines.append("- " * 11)
    return '\n'.join(lines)


if __name__ == "__main__":
    s = generate_sudoku_puzzle(20)
    print(s)
