from z3 import *
from collections import Counter

def SummleSolver(input_sample, **kwargs):
    n = input_sample["n"] ## number of equations
    operators = input_sample["operators"] ## operators to be used
    target = input_sample["target"] ## final target, last result
    numbers  = input_sample["numbers"] ## numbers to be used
    
    # Solver
    solver = Solver()
    
    # Create equation components
    operands = [[Int(f"op_{i}_{j}") for j in range(2)] for i in range(n)]
    results = [Int(f"res_{i}") for i in range(n)]
    ops = [Int(f"ops_{i}") for i in range(n)]

    ## where is the ith number used, cells are marked as 0, 1, 2, 3 .... 2n-1
    used_position = [Int(f"number_pos_{i}") for i in range(len(numbers))]
    
    for i in range(len(numbers)): ## all input numbers
        ## each number can be used in only one position
        solver.add(And([used_position[i] >= 0, used_position[i] < 2*n]))
        ## assert that the operand must match the number if it is used at that location
        for j in range(2*n): ## all positions
            solver.add(Implies(used_position[i] == j, operands[j//2][j%2] == numbers[i]))
    
    ## positions where the ith result is used
    result_used_position = [Int(f"result_pos_{i}") for i in range(n-1)] ## ignore the final result
    
    for i in range(n-1):
        solver.add(And([result_used_position[i] >=2*(i+1), result_used_position[i] < 2*n])) ## result can only be used in the future
        for j in range(2*(i+1), 2*n): ## all positions where result[i] can be be used
            solver.add(Implies(result_used_position[i] == j, operands[j//2][j%2] == results[i]))
    
    solver.add(Distinct(result_used_position + used_position))


    # Map operators to integers for Z3 usage
    operator_dict = { "+": 1, "-": 2, '*': 3, '/': 4}
    operator_inv = { 1: "+", 2: '-', 3: '*', 4: '/'}

    # Constraints for each equation
    for i in range(n):
        left, right = operands[i]
        result = results[i]
        op = ops[i]

        # Operator must be one of the allowed operators
        solver.add(Or([op == operator_dict[o] for o in operators]))
        
        # Apply operation based on operator value
        conditions = []
        if "+" in operators:
            conditions.append(And(op == operator_dict['+'], result == left + right))
        if "*" in operators:
            conditions.append(And(op == operator_dict['*'], result == left * right))
        if "-" in operators:
            conditions.append(And(op == operator_dict['-'], result == left - right))
        if "/" in operators:
            conditions.append(And(op == operator_dict['/'], result * right == left, right != 0))
        
        solver.add(Or(conditions))
    

    # constraints on use of how many times an operator is used
    operator_counts = Counter(operators)
    for operator, count in operator_counts.items():
        solver.add(Sum([If(ops[i] == operator_dict[operator], 1, 0) for i in range(n)]) == count)
        

    # last result must be the target:
    solver.add(results[-1] == target)

    # Solve the equations
    if solver.check() == sat:
        model = solver.model()
        equations = []
        for i in range(n):
            left = model.eval(operands[i][0]).as_long()
            right = model.eval(operands[i][1]).as_long()
            result = model.eval(results[i]).as_long()
            op = operator_inv[model.eval(ops[i]).as_long()]
            equations.append(f"{left} {op} {right} = {result}")
        return [equations]
    else:
        return None
    
def MySolver():
    return SummleSolver

if __name__ == "__main__":
    input_sample_1 = {
        'numbers' : [4, 3, 5],
        'target' : 35,
        'n' : 2,
        'operators' : ['+', '*']
    }
    solution = SummleSolver(input_sample_1)
    print("Solution:", solution)
    
    input_sample_2 = {
        'numbers' : [10, 2, 7],
        'target' : 35,
        'n' : 2,
        'operators' : ['/', '*']
    }
    solution = SummleSolver(input_sample_2)
    print("Solution:", solution)
    
    input_sample_3 = {
        'numbers' : [3, 3, 3],
        'target' : 18,
        'n' : 2,
        'operators' : ['+', '*']
    }
    solution = SummleSolver(input_sample_3)
    print("Solution:", solution)
    
    input_sample_4 = {
        'numbers' : [4, 3, 5],
        'target' : 35,
        'n' : 2,
        'operators' : ['+', '*']
    }
    solution = SummleSolver(input_sample_4)
    print("Solution:", solution)

    input_sample_5 = {
        'numbers' : [25, 3, 6, 23, 2],
        'target' : 29,
        'n' : 4,
        'operators' : ['*','+', '-', '/']
    }
    solution = SummleSolver(input_sample_5)
    print("Solution:", solution)
