from typing import List, Optional
from z3 import ArithRef, Int, Solver, Distinct, And, sat, IntVal


def make_9x9_z3_board(board_text: str, solver: Solver) -> List[List[ArithRef]]:
    """
    Creates a board of z3 variables from a string representation of a board.
    For unknown cells, make the value be 0, and for known cells, make the value
    be a number from 1-9.
    """
    board = []
    for line_counter, line in enumerate(board_text.splitlines()):
        row = []
        for char_counter, character in enumerate(line.strip()):
            if character.isdigit():
                num = int(character)
                # 0 is unknown
                cell = Int(f"cell_{line_counter}_{char_counter}")
                if num == 0:
                    solver.add(And(cell >= 1, cell <= 9))
                    row.append(cell)
                elif 0 < num < 10:
                    solver.add(cell == IntVal(num))
                    row.append(cell)
        if len(row) != 9:
            raise ValueError(
                f"Invalid column count of board, must be 9, got {len(row)}")
        board.append(row)

    if len(board) != 9:
        raise ValueError(
            f"Invalid row count of board, must be 9, got {len(board)}")

    return board


def assert_uniq(solver: Solver, z3_board: List[List[ArithRef]]):
    # Assert rows unique
    for row in z3_board:
        solver.add(Distinct(row))

    # Assert columns unique
    for col in zip(*z3_board):
        solver.add(Distinct(col))

    # Assert 3x3 squares unique
    for i in range(0, 9, 3):
        for j in range(0, 9, 3):
            square = [z3_board[x][y]
                      for x in range(i, i+3) for y in range(j, j+3)]
            solver.add(Distinct(square))


def print_board(board: List[List[int]]):
    for row in board:
        print(row)


def check_valid(board: List[List[int]]) -> bool:
    for row in board:
        if len(set(row)) != 9:
            return False

    for col in zip(*board):
        if len(set(col)) != 9:
            return False

    for i in range(0, 9, 3):
        for j in range(0, 9, 3):
            square = [board[x][y]
                      for x in range(i, i+3) for y in range(j, j+3)]
            if len(set(square)) != 9:
                return False
    return True


def solve(board_text: str) -> Optional[List[List[int]]]:
    solver = Solver()
    z3_board = make_9x9_z3_board(board_text, solver)
    board: List[List[int]] = [[] for _ in range(9)]
    assert_uniq(solver, z3_board)
    if solver.check() == sat:
        model = solver.model()
        for i, row in enumerate(z3_board):
            row = [model.evaluate(cell).as_long()  # type: ignore
                   for cell in row]
            board[i] = row
        return board
    else:
        return None