from z3 import *


def SujikoSolver(input_sample, **kwargs):
    input_board = input_sample['board']
    quadrant_sums = input_sample['quadrant_sums']

    assert len(input_board) > 1, "Board has only one row"
    m, n = len(input_board), len(input_board[0])
    assert n > 1, "Board has only one column"
    solver = Solver()
    
    # Variables for board
    x = [[Int('x_%d_%d' % (i, j)) for j in range(n)] for i in range(m)]
    
    # Each cell must be between 1 and m*n
    for i in range(m):
        for j in range(n):
            solver.add(x[i][j] >= 1, x[i][j] <= m*n)

    # Each number 1-m*n must appear once
    solver.add(Distinct([x[i][j] for i in range(m) for j in range(n)]))

    # Fill the already filled numbers from board
    for i in range(m):
        for j in range(n):
            if input_board[i][j] != 0:
                solver.add(x[i][j] == input_board[i][j])

    # Add the quadrant constraints
    idx  = 0
    assert (m-1)*(n-1) == len(quadrant_sums), f"There should be {(m-1)*(n-1)} Quadrant Constraints"
    for i in range(m-1):
        for j in range(n-1):
            solver.add(x[i][j] + x[i][j+1] + x[i+1][j] + x[i+1][j+1] == quadrant_sums[idx])
            idx += 1

    # Get solution if exists
    if solver.check() == sat:
        model = solver.model()
        solution = [[model.evaluate(x[i][j]).as_long() for j in range(n)] for i in range(m)]
        return [solution] ### return one of the possible solutions
    else:
        print("Infeasible solution!")
        return None

def MySolver():
    return SujikoSolver

if __name__ == "__main__":
    input_sample = {
        'board': [[0,0],
                [0,4]],
        'quadrant_sums': [10]
        
    }
    solved_board = SujikoSolver(input_sample)
    print(solved_board)

    input_sample = {
        'board': [[0,0,2], # 3,1,2
                [4,0,0]], # 4,6,5
        'quadrant_sums': [14,14]
    }
    solved_board = SujikoSolver(input_sample)
    print(solved_board)

    input_sample = {
        'board': [[4,0,0], 
                [0,0,0],
                [0,0,9]],
        'quadrant_sums': [17, 18, 15, 15]
    }
    solved_board = SujikoSolver(input_sample)
    print(solved_board)