import numpy as np
from z3 import *

EMPTY = 0
FILLED = 1
BALLOON = 2
IRON_BALL = 3

def DosunFuwariSolver(input_sample, **kwargs):
    input_board, subgrids = input_sample['input_board'], input_sample['subgrids']
    input_board = np.array(input_board)
    
    assert input_board.shape[0] == input_board.shape[1]
    N = input_board.shape[0]

    solver = Solver()
    variables = [[Int(f"x_{i}_{j}") for j in range(N)] for i in range(N)]
    
    ### initial constraints
    filled_cells_count = 0
    for i in range(N):
        for j in range(N):
            if input_board[i][j] != 0:
                solver.add(variables[i][j] == input_board[i][j])
            else:
                solver.add(variables[i][j] >= 0, variables[i][j] <= 3)
            if input_board[i][j] == 1:
                filled_cells_count += 1
    
    ### subgrid constraints
    is_filled = []
    for subgrid_idx, subgrid in enumerate(subgrids):
        is_ballon = []
        is_iron_ball = []
        for cell in subgrid:
            x, y = cell//N, cell % N
            is_ballon.append(If(variables[x][y] == BALLOON, 1, 0))
            is_iron_ball.append(If(variables[x][y] == IRON_BALL, 1, 0))
            is_filled.append(If(variables[x][y] == FILLED, 1, 0))
        ### each subgrid must have one ballon and one iron ball
        solver.add(Sum(*is_ballon) == 1)
        solver.add(Sum(*is_iron_ball) == 1)
    solver.add(Sum(*is_filled) == filled_cells_count) ### filled cells cannot be added to the board
    
    ### ballon and iron ball constraints
    for i in range(N):
        for j in range(N):
            if i != 0:
                solver.add(Implies(variables[i][j] == BALLOON, Or(variables[i-1][j] == BALLOON, variables[i-1][j] == FILLED)))
            if i != N-1:
                solver.add(Implies(variables[i][j] == IRON_BALL, Or(variables[i+1][j] == IRON_BALL, variables[i+1][j] == FILLED)))

    if solver.check() == sat:
        model = solver.model()
        output_board = [[model.evaluate(variables[i][j]).as_long() for j in range(N)] for i in range(N)]
        return [output_board]
    else:
        return None

def MySolver():
    return DosunFuwariSolver

if __name__ == '__main__':
    print(DosunFuwariSolver(input_sample={
                                    'input_board': [
                                                        [0,0,0,0],
                                                        [0,1,0,0],
                                                        [0,0,0,0],
                                                        [0,0,0,1]
                                                    ],
                                    'subgrids': [
                                                    [0, 1],
                                                    [2, 3, 6, 10],
                                                    [4, 8, 12],
                                                    [5, 9, 13, 14, 15],
                                                    [7, 11]
                                                ]}))
    print(DosunFuwariSolver(input_sample={
                                    'input_board': [
                                                        [0,0,0],
                                                        [0,0,1],
                                                        [0,0,0]
                                                    ],
                                    'subgrids': [
                                                    [0, 1, 2, 5],
                                                    [3, 4],
                                                    [6, 7, 8],
                                                ]}))