import numpy as np
from z3 import *

def MagicSquareVerifier(input_board, output_board, **kwargs):
    input_board = np.array(input_board)
    output_board = np.array(output_board)
    N_in, N_out = input_board.shape[0], output_board.shape[0]
    
    ### shape constraints
    if input_board.shape != output_board.shape:
        return {
            'result': False,
            'reason': "Input and Output board sizes don't match"
        }
    if output_board.shape[0] != output_board.shape[1]:
        return {
            'result': False,
            'reason': "Output Board must be a square"
        }
    
    ### initial constraints
    for i in range(N_in):
        for j in range(N_out):
            if str(input_board[i][j]) != '0':
                if str(output_board[i][j]) != str(input_board[i][j]):
                    return {
                        'result': False,
                        'reason': f"Original Numbers must remain on the board, Cell at ({i}, {j}) was changed"
                    }
    
    containts_all_elements = np.all(np.isin(output_board.flatten(), np.arange(1, N_in**2+1)))
    unique_elements = len(np.unique(output_board.flatten()))
    if not containts_all_elements or unique_elements != N_in**2:
        return {
            'result': False,
            'reason': f"Output Board Does not contain all elements from 1 to {N_in**2} exactly once"
        }
    magic_number = (N_in * (N_in**2 + 1))/2
    diag1_sum, diag2_sum = 0, 0
    for i in range(N_in):
        row_sum, col_sum = 0, 0
        diag1_sum += output_board[i][i]
        diag2_sum += output_board[i][N_in-i-1]
        for j in range(N_out):
            row_sum += output_board[i][j]
            col_sum += output_board[j][i]
        if row_sum != magic_number or col_sum != magic_number:
            return {
                'result': False,
                'reason': f"All Rows and All Columns must sum to the same number"
            }
    if diag1_sum != magic_number or diag2_sum != magic_number:
        return {
            'result': False,
            'reason': "Both the diagonals don't sum to expected number"
        }
    
    return {
        'result': True,
        'reason': None
    }



def MyVerifier():
    return MagicSquareVerifier
    

if __name__ == "__main__":
    print(MagicSquareVerifier(input_board=[[4, 0, 2],
                                         [0,0,7],
                                         [0,0,0]],
                            output_board=[[4,9,2],
                                          [3,5,7],
                                          [8,1,6]]))    
    print(MagicSquareVerifier(input_board=[[4,0,2],
                                         [0,0,7],
                                         [0,0,0]],
                            output_board=[[5,5,5],
                                          [5,5,5],
                                          [5,5,5]]))    
    print(MagicSquareVerifier(input_board=[[4,0,2],
                                         [0,0,7],
                                         [0,0,0]],
                            output_board=[[4,9,2],
                                          [-1,9,7],
                                          [4,-3,6]]))    
    print(MagicSquareVerifier(input_board=[[1, 0, 4, 0],
                                           [8, 11, 0 ,0],
                                           [0, 0, 0, 0],
                                           [0, 7, 0, 0]],
                            output_board=[[1, 14, 4, 15],
                                           [8, 11, 5, 10],
                                           [13, 2, 16, 3],
                                           [12, 7, 9, 6]]))