from solver import ThreePartitionSolver

def ThreePartitionVerifier(input_sample, output_sample, **kwargs):
    answer = ThreePartitionSolver(input_sample, **kwargs)[0]
    if answer == 'NO':
        if output_sample != 'NO':
            return {
                'result': False,
                'reason': 'Input array does not have a partition into triplets of the same sum'
            }
        else:
            return {
                'result': True,
                'reason': None
            }
    else:
        if output_sample == 'NO':
            return {
                'result': False,
                'reason': 'Input array has a partition into triplets of the same sum'
            }
        N = len(input_sample)
        assert N % 3 == 0
        number_of_triplets = N // 3
        triplet_sum =  sum(input_sample) // number_of_triplets
        for triplet in output_sample:
            if sum(triplet) != triplet_sum:
                return {
                    'result': False,
                    'reason': f"The partition: ({' '.join(list(map(str, triplet)))}) does not have the required sum"
                    
                }
            if len(triplet) != 3:
                return {
                    'result': False,
                    'reason': f"The partion: ({' '.join(list(map(str, triplet)))}) should have 3 elements"
                }
            output_elements = sorted(sum(output_sample, start=[]))
            if output_elements != sorted(input_sample):
                return {
                    'result': False,
                    'reason': f'Output Triplets should have all elements of the input array, nothing more nothing less'
                }
            return {
                'result': True,
                'reason': None
            }
            

def MyVerifier():
    return ThreePartitionVerifier

if __name__ == '__main__':
    input_sample = [1, 2, 5, 6, 7, 9]
    output_sample = [[1,5,9], [2,6,7]]
    print(ThreePartitionVerifier(input_sample, output_sample)) ## True

    input_sample = [1, 2, 5, 6, 7, 9]
    output_sample = [[1,6,9], [2,5,7]]
    print(ThreePartitionVerifier(input_sample, output_sample)) ## False

    input_sample = [1, 2, 5, 6, 7, 9]
    output_sample = "NO"
    print(ThreePartitionVerifier(input_sample, output_sample)) ## False

    input_sample = [1, 2, 5, 6, 7, 9]
    output_sample = [[1,6,8], [2,6,7]]
    print(ThreePartitionVerifier(input_sample, output_sample)) ## False

    input_sample = [1, 1, 1, 1, 1, 5]
    output_sample = [[1,1,1,1,1], [5]]
    print(ThreePartitionVerifier(input_sample, output_sample)) ## False

    input_sample = [1, 1, 1, 1, 1, 5]
    output_sample = 'NO'
    print(ThreePartitionVerifier(input_sample, output_sample)) ## True

    input_sample = [1,2,3,1,2,3]
    output_sample = [[1,2,1,2], [3,3]]
    print(ThreePartitionVerifier(input_sample, output_sample)) ## False