import numpy as np
from z3 import *

def NumberLinkSolver(input_sample, **kwargs):
    input_board = np.array(input_sample)
    assert input_board.shape[0] == input_board.shape[1]
    N = input_board.shape[0]
    
    solver = Solver()
    variables = [[Int(f"x_{i}_{j}") for j in range(N)] for i in range(N)]
    
    initial_cells = list()
    initial_numbers = set()
    for i in range(N):
        for j in range(N):
            if input_board[i, j] != 0:
                number = input_board[i, j]
                initial_numbers.add(number)
                initial_cells.append((i, j))
                solver.add(variables[i][j] == number)
    max_number = max(initial_numbers)
    for i in range(N):
        for j in range(N):
            solver.add(variables[i][j] >= 1, variables[i][j] <= max_number)

    for x in range(N):
        for y in range(N):
            steps = ((0,-1), (0,+1), (-1,0), (+1, 0))
            neighbours = list()
            for step in steps:
                dx, dy = step
                newx, newy = x + dx, y + dy
                if 0 <= newx < N and 0 <= newy < N:
                    neighbours.append((newx, newy))
            if (x, y) in initial_cells:
                solver.add(Sum(*[If(variables[nbr_x][nbr_y] == input_board[x, y], 1, 0) for nbr_x, nbr_y in neighbours]) == 1)
            else:
                solver.add(Sum(*[If(variables[nbr_x][nbr_y] == variables[x][y], 1, 0) for nbr_x, nbr_y in neighbours]) == 2)

    if solver.check() == sat:
        model = solver.model()
        output_board = [[model.evaluate(variables[i][j]).as_long() for j in range(N)] for i in range(N)]
        return [output_board]
    else:
        return None
    
def MySolver():
    return NumberLinkSolver


if __name__ == '__main__':
    ## TC-1
    input_sample = [[0, 0, 0, 3, 0],
                    [0, 4, 0, 1, 0],
                    [0, 0, 2, 0, 0],
                    [0, 4, 0, 3, 2],
                    [0, 1, 0, 0, 0]]
    expected_solution = [[1, 1, 1, 3, 3],
                         [1, 4, 1, 1, 3],
                         [1, 4, 2, 3, 3],
                         [1, 4, 2, 3, 2],
                         [1, 1, 2, 2, 2]]
    output_solution =  NumberLinkSolver(input_sample=input_sample)
    assert(np.all(np.array(expected_solution) == np.array(output_solution)))
    
    ## TC-2
    input_sample = [[0, 0, 0, 4, 0, 0, 0],
                    [0, 3, 0, 0, 2, 5, 0],
                    [0, 0, 0, 3, 1, 0, 0],
                    [0, 0, 0, 5, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0, 0],
                    [0, 0, 1, 0, 0, 0, 0],
                    [2, 0, 0, 0, 4, 0, 0]]
    expected_solution =    [[2, 2, 2, 4, 4, 4, 4],
                            [2, 3, 2, 2, 2, 5, 4],
                            [2, 3, 3, 3, 1, 5, 4],
                            [2, 5, 5, 5, 1, 5, 4],
                            [2, 5, 1, 1, 1, 5, 4],
                            [2, 5, 1, 5, 5, 5, 4],
                            [2, 5, 5, 5, 4, 4, 4]]
    output_solution =  NumberLinkSolver(input_sample=input_sample)
    assert(np.all(np.array(expected_solution) == np.array(output_solution)))


    ## TC-3
    input_sample = [[1, 2, 3, 0, 0],
                    [0, 0, 0, 4, 0],
                    [0, 4, 2, 0, 0],
                    [0, 0, 0, 0, 0],
                    [0, 0, 1, 3, 0]]
    expected_solution = [[1, 2, 3, 3, 3],
                         [1, 2, 2, 4, 3],
                         [1, 4, 2, 4, 3],
                         [1, 4, 4, 4, 3],
                         [1, 1, 1, 3, 3]]
    output_solution =  NumberLinkSolver(input_sample=input_sample)
    assert(np.all(np.array(expected_solution) == np.array(output_solution)))
    
    
    
    print("Test Cases Passed")