import os
import json
import time
import argparse
from typing import Dict, Any, List, Optional

from crypto_agent import OptimizedCryptoAgent
from feedback import binary_feedback
from critic_agent import CriticAgent
from test_data import CHALLENGE_DATA
from opto.optimizers import OptoPrime
from opto import trace
from opto.trace.utils import render_opt_step

def past_history(agent_code, prev_critic_feedback):
    fb = f"After one optimization loop, this is how the agent decided to change the code: {agent_code} based off this feedback that you gave previously: {prev_critic_feedback}."
    fb += "Based on this result, tailor the system prompt for the critic agent."
    return fb

def run_docstring_experiment(challenge_key, max_iterations=30):
    print(f"\nRunning experiment for: {challenge_key}")

    challenge_data = CHALLENGE_DATA[challenge_key]
    ciphertext = challenge_data["ciphertext"]
    source_code = challenge_data["source_code"]
    solution = challenge_data["solution"]
    solution_code = challenge_data["solution_code"]
    task_directions = challenge_data['task_directions']
    
    agent = OptimizedCryptoAgent(task_name=challenge_key)
    agent_optimizer = OptoPrime(agent.parameters(), memory_size=5)
    critic = CriticAgent()
    critic_optimizer = OptoPrime(critic.parameters(), memory_size=5)
    
    results = {
        "challenge": challenge_key,
        "feedback_type": "actor_critic",
        "success": False,
        "iterations_completed": 0,
        "iterations_to_success": None,
        "final_flag": "",
        "final_agent_code": ""
    }
    
    iteration = 0
    prev_critic_feedback = ""
    while iteration < max_iterations:
        print(f"\nIteration {iteration+1}/{max_iterations}")
        try:
            print("test decryption")
            decrypted_node = agent.decrypt(ciphertext, source_code)
            decrypted_flag = str(decrypted_node)
            if hasattr(decrypted_node, 'data'):
                decrypted_flag = str(decrypted_node.data)
            print(f"Flag: {decrypted_flag}")

            agent_code = agent.decrypt.parameters()[0]
            print("\nagent code:")
            print(agent_code)

            print("check correct")
            correctness = binary_feedback(decrypted_flag, solution, source_code, iteration)
            
            if prev_critic_feedback:
                print("critic update")
                critic_fb = past_history(agent_code, prev_critic_feedback)
                critic_optimizer.zero_feedback()
                critic_optimizer.backward(prev_critic_feedback, critic_fb)
                critic_optimizer.step(verbose=False)
            
            print("asses the solution")
            solution_assessment = critic.assess_solution(solution, solution_code, task_directions)
            print(f"\nsolution assess:")
            print(solution_assessment)
            
            agent_assessment = critic.assess_agent(decrypted_flag, agent_code, task_directions)
            critic_feedback = critic.master_critic(solution_assessment, agent_assessment, task_directions)
            prev_critic_feedback = critic_feedback
            
            print("update agent with feedback")
            agent_optimizer.zero_feedback()
            agent_optimizer.backward(decrypted_node, critic_feedback.data, visualize=True)
            agent_optimizer.step(verbose=False)
            
        except trace.ExecutionError as e:
            print(f"\nerror encountered: {e}")
            decrypted_node = e.exception_node
            agent_code = agent.parameters_dict()["decrypt"]
            
            if prev_critic_feedback:
                critic_fb = past_history(agent_code, prev_critic_feedback)
                critic_optimizer.zero_feedback()
                critic_optimizer.backward(prev_critic_feedback, critic_fb)
                critic_optimizer.step(verbose=False)
            
            solution_assessment = critic.assess_solution(solution, solution_code, task_directions)
            agent_assessment = critic.assess_agent(str(e), agent_code, task_directions)
            critic_feedback = critic.master_critic(solution_assessment, agent_assessment, task_directions, str(e))
            prev_critic_feedback = critic_feedback
            
            agent_optimizer.zero_feedback()
            agent_optimizer.backward(decrypted_node, critic_feedback, visualize=True)
            agent_optimizer.step(verbose=False)
        
        results["iterations_completed"] = iteration + 1
      
        if correctness:
            results["success"] = True
            results["iterations_to_success"] = iteration + 1
            print(f"\nsolved!")
            break
        
        iteration += 1
    
    if results["success"]:
        results["final_flag"] = decrypted_flag
    
    results["final_agent_code"] = str(agent.decrypt.parameters()[0])
    
    return results

def save_experiment_result(result, output_dir, challenge_key):
    os.makedirs(output_dir, exist_ok=True)
    timestamp = result.get("timestamp", int(time.time()))
    filename = f"{challenge_key}_{timestamp}.json"
    
    serializable_result = {
        "challenge": result.get("challenge"),
        "feedback_type": result.get("feedback_type"),
        "success": result.get("success", False),
        "iterations_completed": result.get("iterations_completed", 0),
        "iterations_to_success": result.get("iterations_to_success"),
        "timestamp": timestamp,
        "final_flag": result.get("final_flag", ""),
        "final_agent_code": result.get("final_agent_code", "")
    }
    
    with open(os.path.join(output_dir, filename), 'w') as f:
        json.dump(serializable_result, f, indent=2)
    
    print(f"Results saved to {os.path.join(output_dir, filename)}")
    return filename

def run_all_experiments():
    output_dir = "results_docstring"
    os.makedirs(output_dir, exist_ok=True)
    all_results = []
    
    for challenge_key in CHALLENGE_DATA.keys():
        print(f"\nRunning experiment for challenge: {challenge_key}")
        
        result = run_docstring_experiment(challenge_key)
        print(result)
        all_results.append(result)
        
        save_experiment_result(result, output_dir, challenge_key)
    
    return all_results

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test crypto agent with docstring-based hints and binary feedback")
    parser.add_argument("--challenge", "-c", type=str, help="Specify which challenge to run")
    parser.add_argument("--iterations", "-i", type=int, default=30, help="Maximum iterations to run")
    parser.add_argument("--output-dir", "-o", type=str, default="results_docstring", help="Directory to save results")
    args = parser.parse_args()
    
    if args.challenge:
        if args.challenge not in CHALLENGE_DATA:
            print(f"Error: Challenge '{args.challenge}' not found.")
            print(f"These are the challenges to choose from: {', '.join(CHALLENGE_DATA.keys())}")
            exit(1)
            
        print(f"running experiment for: {args.challenge}")
        result = run_docstring_experiment(args.challenge, args.iterations)
        save_experiment_result(result, args.output_dir, args.challenge)
    else:
        print("Running all experiments with docstring-based hints")
        results = run_all_experiments()
        print("\nExperiment completed!")
        print(f"Results saved to '{args.output_dir}' directory") 