from z3 import *

def ThreePartitionSolver(input_sample, **kwargs):
    multiset = input_sample
    N = len(multiset)

    assert N % 3 == 0
    number_of_triplets = N // 3
    triplet_sum = sum(multiset) // number_of_triplets
    
    ### indicates the triplet number to which the ith element belongs
    variables = [Int(f"x_{i}") for i in range(N)] 
    solver = Solver()
    
    ### initial constraints
    for i in range(N):
        solver.add(variables[i] >= 0, variables[i] < number_of_triplets)
        
    ### each triplet should have 3 elements, and sum to the required sum
    for i in range(number_of_triplets):
        solver.add(Sum(*[If(variables[j] == i, 1, 0) for j in range(N)]) == 3)
        solver.add(Sum(*[If(variables[j] == i, multiset[j], 0) for j in range(N)]) == triplet_sum)

    if solver.check() == sat:
        triplets = [[] for i in range(number_of_triplets)]
        model = solver.model()
        for i in range(N):
            triplet_number = model[variables[i]].as_long()
            triplets[triplet_number].append(multiset[i])
        return [triplets] ### returns one of multiple solutions
    else:
        return ["NO"]

def MySolver():
    return ThreePartitionSolver

if __name__ == "__main__":
    input_sample = [1, 2, 5, 6, 7, 9]
    ### triplets = [(1,9,5), (2,6,7)]
    print(ThreePartitionSolver(input_sample)) ## YES

    input_sample = [4, 4, 4, 6, 6, 6]
    print(ThreePartitionSolver(input_sample)) ## NO

    input_sample = [4, 5, 5, 5, 5, 6]
    ### triplets = [(5,5,5),(4,5,6)]
    print(ThreePartitionSolver(input_sample)) ## YES
    
    input_sample = [1, 2, 3 ]
    ### triplets = ([1,2,3])
    print(ThreePartitionSolver(input_sample)) ## YES
    
        