"""
Code modified from https://www.kaggle.com/code/jamesmcguigan/cryptarithmetic-solver
"""
from z3 import *

"""
input_words when combined with operator gives target word
eg. 
input_words = [SEND, MORE] (in our examples the size of this list will be limited to two)
target_word = MONEY
operator = +
SEND + MORE = MONEY
limit determines the maximum number of solutions to find for this input,
unique determines whether the letters must correspond to unique digits
"""


def CryptarithmeticSolver(input_sample, **kwargs):
    input_words = list(input_sample['operands'])
    target_word = input_sample['target']
    operator = input_sample['operator']
    limit = kwargs['limit'] if 'limit' in kwargs else None
    unique= kwargs['unique'] if 'unique' in kwargs else True
    solver = Solver()

    all_words = input_words + [target_word] ### all words in the equation
    letters = { l: Int(l) for l in list("".join(all_words))} ### all letters in the equation
    words   = { w: Int(w) for w in all_words} ### words mapped to their Z3 symbol 

    # Constraint: convert letters to numbers
    for l,s in letters.items(): solver.add(0 <= s, s <= 9)

    # Constraint: letters must be unique (optional)
    if unique and len(letters) <= 10:
        solver.add(Distinct(*letters.values()))

    # Constraint: first letter of words must not be zero
    for word in words.keys():
        solver.add(letters[word[0]] != 0)

    # Constraint: convert words to decimal values
    for word, word_symbol in words.items():
        solver.add(word_symbol == Sum(*[
            letter_symbol * 10**index
            for index,letter_symbol in enumerate(reversed([
                letters[l] for l in list(word)
                ]))
            ]))

    # Constraint: problem definition as defined by input
    eval_string = f"{operator.join(input_words)}=={target_word}"
    solver.add(eval(eval_string, None, words))

    solutions = []
    while str(solver.check()) == 'sat':
        # solutions.append({ str(s): solver.model()[s] for w,s in words.items() })
        solutions.append({k: solver.model()[v] for k, v in letters.items()})
        solver.add(Or(*[ s != solver.model()[s] for w,s in words.items() ]))
        if limit and len(solutions) >= limit: break

    return solutions

def MySolver():
    return CryptarithmeticSolver

if __name__ == '__main__':
    print(CryptarithmeticSolver(input_sample={'operands': ["SEND", "MORE"], 'target': "MONEY", 'operator': '+'},unique=True, limit=1))
    print(CryptarithmeticSolver(input_sample={'operands': ["ONE", "TWO"], 'target': "THREE", 'operator': '*'},unique=True, limit=10))
    