import numpy as np
from z3 import *

def KenKenSolver(input_sample):
    input_board, cages, operations, targets = input_sample['input_board'], input_sample['subgrids'], input_sample['operations'], input_sample['targets']
    input_board = np.array(input_board)
    assert input_board.shape[0] == input_board.shape[1]
    N = input_board.shape[0]
    
    solver = Solver()
    variables = [[Int(f"x_{i}_{j}") for j in range(N)] for i in range(N)]

    ### initial constraints
    for i in range(N):
        for j in range(N):
            if input_board[i][j] != 0:
                solver.add(variables[i][j] == input_board[i][j])
            else:
                solver.add(variables[i][j] >= 1, variables[i][j] <= N)
    
    ### row constraints
    for i in range(N):
        solver.add(Distinct([variables[i][j] for j in range(N)]))
    
    ### column constraints
    for i in range(N):
        solver.add(Distinct([variables[j][i] for j in range(N)]))
    
    for cage, operation, target in zip(cages, operations, targets):
        numbers = [variables[ele//N][ele%N] for ele in cage]
        if operation == '?':
            solver.add(numbers[0] == target)
        elif operation == '-':
            solver.add(Or(*[numbers[1] - numbers[0] == target, numbers[0] - numbers[1] == target]))
        elif operation == '%':
            solver.add(Or(*[numbers[1] / numbers[0] == target, numbers[0] / numbers[1] == target]))
        elif operation == '+':
            solver.add(Sum(numbers) == target)
        elif operation == '*':
            solver.add(Product(numbers) == target)
    
    if solver.check() == sat:
        model = solver.model()
        output_board = [[model.evaluate(variables[i][j]).as_long() for j in range(N)] for i in range(N)]
        return [output_board]
    else:
        return None

def MySolver():
    return KenKenSolver

if __name__ == '__main__':
    print(KenKenSolver(input_sample={
                        'input_board':[
                                      [1,0,0],
                                      [3,0,0],
                                      [0,0,0]
                                     ],
                         'subgrids': [
                             [0, 1],
                             [2],
                             [3],
                             [4,5,8],
                             [6,7]
                         ],
                         'operations':
                         [
                            '+',
                            '?',
                            '?',
                            '+',
                            '+'
                         ],
                         'targets':
                         [
                            3,
                            3,
                            3,
                            4,
                            5
                         ]}))


    print(KenKenSolver(input_sample={
                        'input_board':[
                                      [0,2,0,0],
                                      [0,1,0,0],
                                      [3,0,2,0],
                                      [0,3,0,0],
                                     ]
                         ,
                         'subgrids': [
                             [0, 1],
                             [2, 6],
                             [3],
                             [4,8],
                             [5,9],
                             [7,11],
                             [10,14,15],
                             [12,13]
                         ],
                         'operations':
                         [
                            '%',
                            '+',
                            '?',
                            '-',
                            '-',
                            '-',
                            '*',
                            '-'
                         ],
                         'targets':
                         [
                            2,
                            7,
                            4,
                            1,
                            3,
                            2,
                            4,
                            1
                         ]}))
        