from z3 import *
import math
import numpy as np

def SudokuSolver(board, **kwargs):
    n = len(board)
    sub_n = int(math.sqrt(len(board)))
    # Create a 2D array of integer variables
    X = [ [ Int("x_%s_%s" % (i+1, j+1)) for j in range(n) ]
          for i in range(n) ]

    # Each cell in the sudoku can only be a digit between 1 and 9
    cells_c  = [ And(1 <= X[i][j], X[i][j] <= n)
                 for i in range(n) for j in range(n) ]

    # Each row has a digit only once
    rows_c   = [ Distinct(X[i]) for i in range(n) ]

    # Each column has a digit only once
    cols_c   = [ Distinct([ X[j][i] for j in range(n) ])
                 for i in range(n) ]

    # Each 3x3 square has a digit only once
    sq_c     = [ Distinct([ X[sub_n*i0 + i][sub_n*j0 + j]
                            for i in range(sub_n) for j in range(sub_n) ])
                 for i0 in range(sub_n) for j0 in range(sub_n) ]

    sudoku_c = cells_c + rows_c + cols_c + sq_c

    # sudoku instance, we use '0' for empty cells
    instance_c = [ If(board[i][j] == 0,
                      True,
                      X[i][j] == board[i][j])
                   for i in range(n) for j in range(n) ]

    s = Solver()
    s.add(sudoku_c + instance_c)
    if s.check() == sat:
        m = s.model()
        r = [ [ m.evaluate(X[i][j]).as_long() for j in range(n) ]
              for i in range(n) ]
        r = np.array(r)
    else:
        return None

    return [r]

def MySolver():
    return SudokuSolver

if __name__ == '__main__':
    input_board = [[0, 2, 3, 1],
                   [3, 0, 4, 2],
                   [0, 0, 1, 3],
                   [0, 0, 0, 4]]
    print(SudokuSolver(input_board))