from z3 import *


def ShinroSolver(input_sample, **kwargs):
    input_board = input_sample['input_board']
    row_counts = input_sample['row_counts']
    column_counts = input_sample['column_counts']
    
    solver = Solver()
    N = len(input_board)
    cells = [[Bool(f'cell_{i}_{j}') for j in range(N)] for i in range(N)]

    # Handle cells with 1s and Arrows
    for i in range(N):
        for j in range(N):
            if str(input_board[i][j]) == '1':
                solver.add(cells[i][j])
            elif str(input_board[i][j]) in ('N', 'S', 'E', 'W', 'NE', 'NW', 'SE', 'SW'):
                solver.add(Not(cells[i][j]))

    # Row and Column counts
    for i in range(N):
        solver.add(Sum([If(cells[i][j], 1, 0) for j in range(N)]) == row_counts[i])
        solver.add(Sum([If(cells[j][i], 1, 0) for j in range(N)]) == column_counts[i])

    # Arrow constraints
    directions = {
        'N': (-1, 0), 'S': (1, 0), 'E': (0, 1), 'W': (0, -1),
        'NE': (-1, 1), 'NW': (-1, -1), 'SE': (1, 1), 'SW': (1, -1)
    }

    for i in range(N):
        for j in range(N):
            if str(input_board[i][j]) in directions:
                dx, dy = directions[str(input_board[i][j])]
                x, y = i + dx, j + dy
                arrow_constraint = []
                while 0 <= x < N and 0 <= y < N:
                    if str(input_board[x][y]) not in directions:
                        arrow_constraint.append(cells[x][y])
                    x += dx
                    y += dy
                solver.add(Or(*arrow_constraint))

    if not solver.check() == sat:
        print("Constraints are unsatisfiable!")
        return None
    else:
        model = solver.model()
        solved_board = []
        for i in range(N):
            row = []
            for j in range(N):
                if str(input_board[i][j]) == '0' and model.evaluate(Bool(f'cell_{i}_{j}')):
                    row.append(1)
                else:
                    row.append(input_board[i][j])
            solved_board.append(row)
        return [solved_board]

def MySolver():
    return ShinroSolver


if __name__ == "__main__":
    input_sample = {
        'column_counts' :    [2, 2, 1, 2, 1, 1, 1, 2],
        'row_counts' : [1, 1, 2, 1, 2, 2, 2, 1],
        'input_board' : [
                        [0, 0, 0, 'E', 0, 0, 0, 0],
                        [0, 0, 0, 0, 'W', 0, 0, 0],
                        [0, 0, 0, 0, 0, 'SW', 0, 'NW'],
                        [0, 0, 'S', 0, 'NW', 0, 0, 0],
                        ['NE', 0, 0, 'S', 'W', 'N', 0, 0],
                        [0, 0, 0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 'NE', 0, 0, 'NW', 0],
                        [0, 0, 0, 0, 0, 0, 0, 0],
                    ]
    }
    kwargs = {
        'n': 8
    }
    solved_board = ShinroSolver(input_sample=input_sample, **kwargs)
    print(solved_board)
    for row in solved_board:
        for ele in row:
            print(str(ele) +"\t", end="")
        print(end="\n\n")