import autogen
import opto.trace as trace
from opto.trace import node, bundle, model
from opto.optimizers import OptoPrime, OPRO
from trace_mapper.trace_template import DSLMapperGenerator
from trace_mapper.test_mapper import TestMapper
import json
import os
import sys
import random

def feedback_for_program(program, app_info, iter, repeat_idx):
    test_mapper = TestMapper(program, app_info, iter, repeat_idx)
    feedback = test_mapper.gen_feedback()
    return feedback

def optimize_step(app_info, iter, repeat_idx):
    app = app_info['application']
    # Initialize or load the mapper
    load_from_file = f"result/{app}/repeat{repeat_idx}t_ckpt{iter-1}.pkl"
    save_to_file = f"result/{app}/repeat{repeat_idx}t_ckpt{iter}.pkl"
    dsl_dump_file = f"result/{app}/repeat{repeat_idx}t_code{iter}.py"
    try:
        mapper = DSLMapperGenerator(tasks=app_info['tasks'], regions=app_info['regions'], index_tasks=app_info['index_tasks'], 
                            single_tasks=app_info['single_tasks'], index_task_specification=app_info['index_task_specification'])
        if iter > 0 and os.path.exists(load_from_file):
            mapper.load(load_from_file)
        optimizer = OptoPrime(mapper.parameters(), config_list=autogen.config_list_from_json("OAI_CONFIG_LIST"), memory_size=3, max_tokens=16383)
        # optimizer = OPRO(mapper.parameters(), config_list=autogen.config_list_from_json("OAI_CONFIG_LIST"), memory_size=3, max_tokens=16383)
        # Generate the mapping program
        program = mapper.generate_mapper()
        feedback = feedback_for_program(program.data, app_info, iter, repeat_idx)

    except trace.ExecutionError as e:
        print("Error trace.ExecutionError during execution, generating feedback from the error node.")
        program = e.exception_node  # Use exception_node if it's an ExecutionError
        feedback = e.exception_node.create_feedback()
        feedback = f"Please do not modify function signatures (the function name, arguments, and return types). Error message is {feedback}."
        print(feedback)

    optimizer.zero_feedback()
    optimizer.backward(program, feedback)
    optimizer.parameters = [p for p in optimizer.parameters if p.trainable and len(p.feedback) > 0]
    optimizer.step(verbose=False)

    # Save the mapper for checkpointing
    mapper.save(save_to_file)
    # dump the generated DSLMapperGenerator object
    with open(dsl_dump_file, 'w') as f:
        for param in mapper.parameters_dict().values():
            f.write(f"{param.data}\n")

def optimize(app_info, num_iterations, repeat_idx):
    iter = 0
    while iter < num_iterations:
        try:
            optimize_step(app_info, iter, repeat_idx)
            iter += 1  # Move to the next iteration if successful
        except Exception as e:
            if iter > 0:
                print(f"Error at iteration {iter}: {e}. Falling back to iteration {iter-1}.")
                try:
                    optimize_step(app_info, iter - 1, repeat_idx)  # Retry the previous iteration
                except Exception as fallback_error:
                    if "file is not a database" in str(fallback_error):
                        print("Checkpointing file is wrong!")
                        exit(1)
                    print(f"Error during fallback to iteration {iter-1}: {fallback_error}. Retrying...")
            else:
                print(f"Error at iteration {iter}: {e}. Retrying the first iteration until it succeeds.")
                # Keep retrying the first iteration until it succeeds

def run(application_name, config_name, repeat_idx, num_iterations):
    # Load application-specific info from a JSON file
    with open("app_config.json", "r") as f:
        json_dict = json.load(f)

    # Extract application-specific information
    app_info = {
        'application': application_name,
        'config': json_dict[application_name][config_name],
        'tasks': json_dict[application_name]['tasks'],
        'regions': json_dict[application_name]['regions'],
        'index_tasks': json_dict[application_name]['index_tasks'],
        'single_tasks': json_dict[application_name]['single_tasks'],
        'index_task_specification': json_dict[application_name].get('index_task_specification', "")
    }

    # Start the iterative optimization process
    optimize(app_info, num_iterations, repeat_idx)

if __name__ == '__main__':
    random.seed(42)
    app_name = sys.argv[1]
    repeat_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0
    steps = int(sys.argv[3]) if len(sys.argv) > 3 else 10
    conf_idx = int(sys.argv[4]) if len(sys.argv) > 4 else 0

    run(app_name, f"config{conf_idx}", repeat_idx, steps)
