from opto import trace
import autogen
from opto.optimizers import OptoPrime
from opto import trace
from opto import utils
import missing_bits_agent_pt3
from resources import secret


test_passed = False
p = "0x00e4188b37b163c203ae4f814ac457738b376afede66bd192c6e604ffc95a4defc5061e0a63223d0c6b414a9d1c61b11bc88c2301fb76784d68f6bcd65acff5f08aca28ff71a11b7b8966126f9574cff165017e3e209cd5fce2617a61cfdc2e01cf3efa4bcac4cd846a20def05db99ec5d6f856b13685ca6c9834cd340bcb32a21"
q = "0x00f122e285b030a36974cda97c18c21e9b3afe00fc4bc3e7e67786abc9500ca2ff003641bc283330766bef927deb2fb59a2b29b97da92abd7478e7f4063def27895cf1ea869619249f8a37956a7dbba46fdbcd5ab2ca614764ff5c4611075b81cde7b84ea57cb491416a55ac49582b3eb611f42d36684e801ea60facafcd8569e7"
test_ground_truth = p
print(test_ground_truth)

epoch = 15
agent = missing_bits_agent_pt3.SimpleDecryptorAgent()
# print(agent.parameters()[0].data)
agent.save("agent.pkl")
agent.load("agent.pkl")
optimizer = OptoPrime(agent.parameters()) 

def get_feedback(predict, target):
    global test_passed
    print("Prediction:")
    print(str(predict))
    print("Target:")
    print(str(target))

    return_statement = None
    if predict == p or predict == q:
        return_statement = "test case passed!"
        test_passed = True
    else:
        return_statement = "test case failed! As always, make sure the format of your output is accurate. Remember to strip the delimeters that refer to the type (integer) and the indicator of the long-form length - just output the content. It should be something starting with 0x00..."
    print("FEEDBACK: " + return_statement)
    return return_statement

def train_and_test(epoch : int):
    for i in range(epoch): # in each epoch
        print(f"Training Epoch {i}")
        print(agent.parameters()[0].data)

        try:
            test_output = agent.run() # receive test_output
            print(test_output)
            feedback = get_feedback(test_output.data, test_ground_truth) # check if test_output == test_ground_truths
        except trace.ExecutionError as e: # if failed
            feedback = e.exception_node.data # get data of feedback
            test_output = e.exception_node # get what the test_output was?
        if (test_passed):
            break
        print("here")
        optimizer.zero_feedback() # otherwise, zero out the feedback on the node
        optimizer.backward(test_output, feedback) # take a backward pass on correctness
        optimizer.step() # adjust the optimizer 

        print(test_output.data)

train_and_test(epoch) # number of epochs
agent.save("/Users/anonymous/trace-security/security/crypto/missing-bits/pt3_final.pkl")
# print(test_ground_truth)