from opto import trace
import autogen
from opto.optimizers import OptoPrime
from opto import trace
from opto import utils
import missing_bits_agent_pt1
from resources import secret


test_passed = False
test_ground_truth = 'c70881814053b5a4b252f76ba9f80011db91767cd16ca73012178cc70203010001028201000a9a25ba23342c18ab29339f4167dacc4cb21a383e14e13d0717efd006a92c55371be258ad865c44cfed2b27bd61a64571c846747a2463a1c8755e7530d86713803a197ce4592c55061551aca07fd1297a48dafef4a3e4e13c34870074406bdda67a81d2f18beab56b92a1d67126b05e858981a9b0f90deb7a2f6e8c3f08ab7837b51026a644430dfcd2b340d2c3907a79b35d114ea5703c3d115b1ba8ad19188fd6007db697509d68204399fec7cff693e9370915baff120ff5d3005598768d5a266de9cbc609e14022416d8691e08f01d33cd8a97a857de1c230d97ea499a698572fb36644ef57eaca346277a643a92e065ed4ae45033450becb510541698102818100e4188b37b163c203ae4f814ac457738b376afede66bd192c6e604ffc95a4defc5061e0a63223d0c6b414a9d1c61b11bc88c2301fb76784d68f6bcd65acff5f08aca28ff71a11b7b8966126f9574cff165017e3e209cd5fce2617a61cfdc2e01cf3efa4bcac4cd846a20def05db99ec5d6f856b13685ca6c9834cd340bcb32a2102818100f122e285b030a36974cda97c18c21e9b3afe00fc4bc3e7e67786abc9500ca2ff003641bc283330766bef927deb2fb59a2b29b97da92abd7478e7f4063def27895cf1ea869619249f8a37956a7dbba46fdbcd5ab2ca614764ff5c4611075b81cde7b84ea57cb491416a55ac49582b3eb611f42d36684e801ea60facafcd8569e7028181009c99e943346c1c3130c4ef2ca8fead9a8052da67513c3503574892c204efac14a90f7a1040ce6c317043db8fe5a68217e20108cf4db929efb18a3efbb6b9a1c2afdbf061a3a9546ccbfa29aaf8f0291665a8f151288fb35e32b2fbde5daf24f55bd9454e3031962f7065869c8bf7f2daf78d2f45a4da2fbc9c83eb6a3fe3f66102818100d28bba2c2ece1930f02aceecab157bf1e04d7e80cc8acec63de119c7da1c5e89b71ea8c0e87ce1b8d4bfc73229136c46a39f77f5cfaaa0e7bcd6130d0c30f05d91092dd4bbe45244f7435734377a8ab69cbe2e8e640512788e3b7c764bed4b66850839ab673aa685ee99c5a7f8d3319b7a0bb11a8b72d78d6b43ce6d664f324b0281806275b9649c00af1229c351007c76f125cd13aa974654317f6140eb12ae0d23bb7b26ba7f2aba37133c942b59cb68370d195e74230e4573689c6883517742429483a075467d90d1162e08078d0e7f8aa819134f5f03bbbcb0f8653fe14e84330fae2fb74d8dab03d638f1f74ae8e919830c19860d0ba5b9d3b41bb87d5d0ce5f9'
print(test_ground_truth)

epoch = 15
agent = missing_bits_agent_pt1.SimpleDecryptorAgent()
# print(agent.parameters()[0].data)
agent.save("agent.pkl")
agent.load("agent.pkl")
optimizer = OptoPrime(agent.parameters()) 

def get_feedback(predict, target):
    global test_passed
    print("Prediction:")
    print(str(predict))
    print("Target:")
    print(str(target))

    return_statement = None
    if predict == target:
        return_statement = "test case passed!"
        test_passed = True
    else:
        return_statement = "test case failed! As always, make sure the format of your output is accurate. "
    print("FEEDBACK: " + return_statement)
    return return_statement

def train_and_test(epoch : int):
    for i in range(epoch): # in each epoch
        print(f"Training Epoch {i}")
        print(agent.parameters()[0].data)

        try:
            test_output = agent.run() # receive test_output
            print(test_output)
            feedback = get_feedback(test_output.data, test_ground_truth) # check if test_output == test_ground_truths
        except trace.ExecutionError as e: # if failed
            feedback = e.exception_node.data # get data of feedback
            test_output = e.exception_node # get what the test_output was?
        if (test_passed):
            break
        print("here")
        optimizer.zero_feedback() # otherwise, zero out the feedback on the node
        optimizer.backward(test_output, feedback) # take a backward pass on correctness
        optimizer.step() # adjust the optimizer 

        print(test_output.data)

train_and_test(epoch) # number of epochs
agent.save("/Users/anonymous/trace-security/security/crypto/missing-bits/pt1_final.pkl")
# print(test_ground_truth)