from z3 import *

def LatinSquareSolver(input_sample,**kwargs):

    N = len(input_sample)
    board = input_sample

    # Create Z3 solver
    solver = Solver()

    # Create integer variables for each cell of the board
    cells = [[Int('cell_%s_%s' % (i, j)) for j in range(N)] for i in range(N)]

    # Add constraints for non-zero numbers in the board
    for i in range(N):
        for j in range(N):
            if board[i][j] != 0:
                solver.add(cells[i][j] == board[i][j])

    # Add constraints for all zeros in the board
    for i in range(N):
        for j in range(N):
            if board[i][j] == 0:
                    solver.add(And(1 <= cells[i][j], cells[i][j] <= N))


    # Add constraints for each row to have each number 1,2,3,4,5 exactly once
    for i in range(N):
        solver.add(Distinct(cells[i]))

    # Add constraints for each column to have each number 1,2,3,4,5 exactly once
    for j in range(N):
        solver.add(Distinct([cells[i][j] for i in range(N)]))

    # Check satisfiability of the solver
    if solver.check() == sat:
        
        # Get the solved board from the solver
        solved_model = solver.model()
        solved_board = [[solved_model.evaluate(cells[i][j]).as_long() for j in range(N)] for i in range(N)]
        
        # Write the solved board to the output file
        return [solved_board]
    else:
        return None
    
def MySolver():
    return LatinSquareSolver
    

if __name__ == '__main__':
    input_sample = [[0, 0, 0, 0, 2, 0],
              [0, 0, 0, 0, 0, 0],
              [0, 0, 1, 0, 0, 0],
              [0, 0, 0, 0, 0, 0],
              [0, 0, 5, 3, 6, 0],
              [0, 0, 0, 0, 0, 0]]

   

    print(LatinSquareSolver(input_sample))