'''
Writing a Trace program involves three simple steps:
   1. Write a normal python program.
   2. Decoderate the program with Trace primitives: node & @bundle
   3. Use a Trace-graph-compatible optimizer to optimize the program.
'''

'''
The function 'strange_sort_list' - a coding problem from OpenAI's Human-Eval dataset.
'''

# Iteration 1: Attempt at writing code to perform strange_sort_list, but the code is wrong.
# [REMEMBER TO COMMENT OUT WHEN TINKERING WITH CODE]
# def strange_sort_list(lst : list) -> list:
#     '''
#     Given a list of integers, return list in strange order.
#     Strange sorting, is when you start with the minimum value, then 
#     maximum of the remaining integers, then minimum and so on.

#     Examples:
#     strange_sort_list([1, 2, 3, 4]) == [1, 4, 2, 3]
#     strange_sort_list([5, 5, 5, 5]) = [5, 5, 5, 5]
#     '''
#     lst = sorted(lst) # first attempt at writing code
#     return lst

# Iteration 2: 'Tracing' The Program.
# We use decorators like '@bundle' to wrap over Python functions. A bundled function behaves like any 
# other Python functions.
# Even though strange_sort_list is bundled, executes NORMALLY (like a normal Python function)
# Returned value: MessageNode -> Trace automatically converts input to Node and start tracing oeprations.

from opto.trace import node, bundle

@bundle(trainable=True)
def strange_sort_list(lst : list): # Note the difference, we do not make sure the return type is list. The return type is instead a Node.
    '''
    Given a list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value, then 
    maximum of the remaining integers, then minimum and so on.

    Examples:
    strange_sort_list([1, 2, 3, 4]) == [1, 4, 2, 3]
    strange_sort_list([5, 5, 5, 5]) = [5, 5, 5, 5]
    '''
    lst = sorted(lst) # first attempt at writing code
    return lst

test_input = [1, 2, 3, 4]
test_output = strange_sort_list(test_input)
print(test_output)

correctness = test_output.eq([1, 4, 2, 3])
correctness.backward("test failed", visualize=True, print_limit=25)

# get feedback method - comparing predictions with target
# don't exactly know why this is covered in the tutorial.
def get_feedback(predict, target):
    if predict == target:
        return "test case passed!"
    else:
        return "test case failed!"


# A new snippet of code: 'Optimizing' The Program with Feedback.
# Again, we look at '@bundle'd functions such that they are modified and run in a more adaptable way.
# Even though strange_sort_list is bundled, executes NORMALLY (like a normal Python function)
# Returned value: MessageNode -> Trace automatically converts input to Node and start tracing oeprations.

# Optimize with Feedback

from opto import trace
import autogen
from opto.optimizers import OptoPrime
from opto import trace
from opto import utils

test_ground_truth = [1, 4, 2, 3] # test ground_truth
test_input = [1, 2, 3, 4] # test input

epoch = 2

# We can use Trace and the Optimizer defined in 'opto.optimizers'
# to learn a program that can pass the above test -> first, we need
# to provide some signal to the optimizer. We call this signal FEEDBACK.
#  The more intelligent the optimizer is, the less guidance we need to 
# give in the feedback. Think of this as us providing instructions to help
# LLM-based optimizers to come up with better solutions.

# 4 Steps: 
#   1. Import an optimizer from opto.optimizers
#   2. Leverage the .parameters() method of the bundled fucntion to grab all trainable parameters.
#   3. Perform a 'backward pass' on a MessageNode object, which is correctness. We also need to pass in 
#   feedback to the optimizer as well.
#   4. Call optimizer.step()

# TODO: adjust OAI_CONFIG_LIST
optimizer = OptoPrime(strange_sort_list.parameters(),
                      config_list=utils.llm.auto_construct_oai_config_list_from_env()) # constructing optimizer with oai_config_list params

for i in range(epoch): # in each epoch
    print(f"Training Epoch {i}")
    try:
        test_output = strange_sort_list(test_input) # receive test_output
        feedback = get_feedback(test_output, test_ground_truth) # check if test_output == test_ground_truths
    except trace.ExecutionError as e: # if failed
        feedback = e.exception_node.data # get data of feedback
        test_output = e.exception_node # get what the test_output was?
    
    correctness = test_output.eq(test_ground_truth) # check if the test_output equals the test_ground_truth
    
    if correctness: # if correct then we're good
        break

    optimizer.zero_feedback() # otherwise, zero out the feedback on the node
    optimizer.backward(correctness, feedback) # take a backward pass on correctness
    optimizer.step() # adjust the optimizer 