from z3 import *

def makeIntVar(sol,name,min_val, max_val):
    v = Int(name)
    sol.add(v >= min_val, v <= max_val)
    return v

def element(sol,ix,x,v,n):
  for i in range(n):
    sol.add(Implies(i==ix, v == x[i]))

def HidatoSolver(input_sample, **kwargs):
    input_board = input_sample
    s = Solver()
    num_rows = len(input_board)
    num_cols = len(input_board[0])

    x = {}
    for i in range(num_rows):
        for j in range(num_cols):
            x[(i, j)] = makeIntVar(s, f"x({i},{j})", 1, num_rows * num_cols)
    x_flat = [x[(i, j)] for i in range(num_rows) for j in range(num_cols)]
    s.add(Distinct(x_flat))
    
    for i in range(num_rows):
        for j in range(num_cols):
            if input_board[i][j] > 0:
                s.add(x[(i, j)] == input_board[i][j])

    # From the numbers k = 1 to r*c-1, find this position,
    # and then the position of k+1
    cc = 0
    for k in range(1, num_cols * num_rows):
            i = makeIntVar(s, f"i_tmp_{k}_{cc}" , 0, num_rows-1)
            j = makeIntVar(s,f"j_tmp_{k}_{cc}" , 0, num_cols-1)
            a = makeIntVar(s,f"a_tmp_{k}_{cc}" , -1, 1)
            b = makeIntVar(s,f"b_tmp_{k}_{cc}" , -1, 1)
            cc += 1

            # 1) First: fix "this" k
            # s.add(k == x[(i,j)])
            element(s, i * num_cols + j, x_flat, k, num_rows * num_cols)

            # 2) and then find the position of the next value (k+1)
            # sver.add(k + 1 == x[(i+a,j+b)])
            element(s,(i + a) * num_cols + (j + b),x_flat, k + 1, num_rows * num_cols)

            s.add(i + a >= 0)
            s.add(j + b >= 0)
            s.add(i + a < num_rows)
            s.add(j + b < num_cols)

            s.add(Or(a != 0, b != 0))
    


    if s.check() == sat:
        m = s.model()
        xx_flat =  [m.eval(x_flat[i*num_cols+j]).as_long() for i in range(num_rows) for j in range(num_cols)]
        output_board = [xx_flat[i*num_cols: (i+1)*num_cols] for i in range(num_rows)]
        return [output_board]
        
        
        
def MySolver():
    return HidatoSolver
    
    
if __name__ == '__main__':
    input_sample =[[6, 0, 9],
                    [0, 2, 8],
                    [1, 0, 0]]
    print(HidatoSolver(input_sample))