from z3 import *


def SurvoSolver(input_sample, **kwargs):
    input_board, row_sums, column_sums = input_sample['input_board'], input_sample['row_sums'], input_sample['column_sums']
    m, n = len(input_board), len(input_board[0])
    # Initialize the solver
    solver = Solver()

    # Create the variables
    cells = [[Int('cell_%d_%d' % (i, j)) for j in range(n)] for i in range(m)]

    # Constraint: Set the given values
    for i in range(m):
        for j in range(n):
            if input_board[i][j] != 0:
                solver.add(cells[i][j] == input_board[i][j])

    # Constraint: Each number from 1 to m*n can only occur once
    for num in range(1, m * n + 1):
        solver.add(Sum([If(cell == num, 1, 0) for row in cells for cell in row]) == 1)


    # Constraint: Rows' sums
    for i in range(m):
        solver.add(Sum(cells[i]) == row_sums[i])

    # Constraint: Columns' sums
    for j in range(n):
        solver.add(Sum([cells[i][j] for i in range(m)]) == column_sums[j])


    # Constraint: 1 <= cells[i][j] <= m*n
    for i in range(m):
        for j in range(n):
            solver.add(And(cells[i][j] >= 1, cells[i][j] <= m * n))

    # Solve the problem
    solved_board = [[0 for col_idx in range(n)] for row_idx in range(m)]
    if solver.check() == sat:
        model = solver.model()
        for i in range(m):
            for j in range(n):
                solved_board[i][j] = model[cells[i][j]].as_long()
    else:
        print("Problem is unsolvable!")
        return [None]
    
    return [solved_board]

def MySolver():
    return SurvoSolver

if __name__ == "__main__":
    input_sample = {
        'input_board': [[0,0,3],
                    [0,6,0]],
        'row_sums': [9, 12],
        'column_sums': [9, 7, 5]
    }
    print(SurvoSolver(input_sample))