from z3 import *
import numpy as np

def InshiNoHeyaSolver(input_sample, **kwargs):
    """
    Args:
        input_sample: Dict{
            input_board (list[list[int]]): partially filled input board, empty cells denoted by zero
            rooms (list[list[int]]): each item in this list represents a room, by the indices of the cells (given in row major order starting from 0)
            room_products (list[int]): the products corresponding to each room in rooms
        }
    Returns:
        list of output_board Union[(list[list[int]]),NoneType]: output solution, board corresponding to the input problem
    """
    input_board, rooms, room_products = input_sample['input_board'], input_sample['rooms'], input_sample['room_products']
    assert len(rooms) == len(room_products)
    input_board = np.array(input_board)
    assert input_board.shape[0] == input_board.shape[1]
    
    N = input_board.shape[0]
    
    solver = Solver()
    variables = [[Int(f"x_{i}_{j}") for j in range(N)] for i in range(N)]
    
    ### initial constraints
    for i in range(N):
        for j in range(N):
            if input_board[i][j] != 0:
                solver.add(variables[i][j] == input_board[i][j])
            else:
                solver.add(variables[i][j] >= 1, variables[i][j] <= N)
    
    ### distinct constraints
    for i in range(N):
        solver.add(Distinct([variables[i][j] for j in range(N)]))
        solver.add(Distinct([variables[j][i] for j in range(N)]))
    
    ### room and room products constraints
    for room, room_product in zip(rooms, room_products):
        solver.add(Product([variables[x//N][x%N] for x in room]) == room_product)
    
    ### get solution from z3 solver
    if solver.check() == sat:
        model = solver.model()
        output_board = [[model.evaluate(variables[i][j]).as_long() for j in range(N)] for i in range(N)]
        return [output_board]
    else:
        return None

def MySolver():
    return InshiNoHeyaSolver


if __name__ == "__main__":
    print(InshiNoHeyaSolver(input_sample={
                            'input_board':[
                                            [3,0,0,5,0],
                                            [0,1,0,0,0],
                                            [5,0,0,1,0],
                                            [0,0,0,0,0],
                                            [0,0,0,0,0],
                                          ]
                              ,
                              'rooms': [
                                        [0,5],
                                        [1,2],
                                        [3],
                                        [4,9,14],
                                        [6,7],
                                        [8,13],
                                        [10,11],
                                        [12,17,22],
                                        [15,20],
                                        [16,21],
                                        [18,19],
                                        [23,24]
                                    ],
                              'room_products':[
                                      6,
                                      4,
                                      5,
                                      40,
                                      3,
                                      4,
                                      15,
                                      40,
                                      4,
                                      10,
                                      6,
                                      3
                              ]}))

    print(InshiNoHeyaSolver(input_sample={
                                'input_board':[
                                            [3,0,0,0,4],
                                            [0,5,0,1,0],
                                            [0,0,0,0,0],
                                            [0,2,4,3,0],
                                            [2,0,0,0,5],
                                          ]
                              ,
                              'rooms':[
                                        [0,1],
                                        [2,3],
                                        [4,9],
                                        [5,6,7],
                                        [8,13,18],
                                        [10,11],
                                        [12,17,22],
                                        [14,19],
                                        [15,20],
                                        [16,21],
                                        [23,24]
                                    ],
                              'room_products':[
                                      3,
                                      10,
                                      8,
                                      60,
                                      15,
                                      4,
                                      8,
                                      3,
                                      10,
                                      6,
                                      20
                              ]}))
       
            
    
    
    