from benchmarks.GPQA_Diamond.loader import GPQADataset

import argparse
import json
import os

from workflow.qwq.workflow_manager import WorkflowManager
from memory.memory_manager import MemoryManager


# MAX_TEST = 10

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--subject", type=str, default="All")
    parser.add_argument("--max_test", type=int, default=10)
    args = parser.parse_args()

    # load the dataset
    if args.subject == "All":
        dataset = GPQADataset().get_full_set()
    else:
        dataset = GPQADataset().get_by_subject(args.subject)

    test_count = 0
    results = []
    memory_manager = MemoryManager()
    for example in dataset:
        test_count += 1
        if test_count > args.max_test:
            break
        
        # Retry logic - attempt up to 3 times
        max_attempts = 3
        success = False
        last_error = None
        
        for attempt in range(1, max_attempts + 1):
            try:
                print(f"Processing question {test_count} (attempt {attempt}/{max_attempts}): {example['question'][:100]}...")
                workflow_manager = WorkflowManager(example["question"], record_memory=True, base_workspace=f"qwq_gpqa_workspace_{args.subject}", memory_manager=memory_manager)
                result = workflow_manager.run()
                answer = result.get('final_answer', result.get('current_summary', 'No answer found'))
                results.append({
                    "id": example["id"],
                    "question": example["question"],
                    "subject": example["subject"],
                    "answer": answer,
                    "attempt": attempt
                })
                success = True
                print(f"✅ Success on attempt {attempt}")
                break
            except Exception as e:
                last_error = e
                print(f"❌ Attempt {attempt} failed: {e}")
                if attempt < max_attempts:
                    print(f"🔄 Retrying... ({attempt + 1}/{max_attempts})")
                else:
                    print(f"💥 All {max_attempts} attempts failed for question: {example['question'][:100]}")
        
        if not success:
            # Record the failed attempt
            results.append({
                "id": example["id"],
                "question": example["question"],
                "subject": example["subject"],
                "answer": "FAILED_ALL_ATTEMPTS",
                "error": str(last_error),
                "attempts": max_attempts
            })

    # save the memory manager
    memory_manager.save_workflow_execution_memory(workflow_execution_memory_path=os.path.join(f"qwq_gpqa_workspace_{args.subject}", "workflow_execution_memory.json"))

    # save the results to a json file
    with open(f"qwq_gpqa_results_{args.subject}_with_memory.json", "w") as f:
        json.dump(results, f)