from datasets import load_dataset
from workflow.qwq.workflow_manager import WorkflowManager
from memory.memory_manager import MemoryManager

import argparse
import json
import os

# MAX_TEST = 10

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--level", type=int, default=1)
    parser.add_argument("--split", type=str, default="validation")
    parser.add_argument("--max_test", type=int, default=10)
    args = parser.parse_args()

    # load the dataset
    dataset = load_dataset("./benchmarks/GAIA/GAIA.py", f"2023_level{args.level}", split=args.split)
    results = []
    # load the memory manager
    memory_manager = MemoryManager()

    test_count = 0
    for example in dataset:
        if example['file_path'] == "":
            test_count += 1
            if test_count > args.max_test:
                break
            
            # Retry logic - attempt up to 3 times
            max_attempts = 3
            success = False
            last_error = None
            
            for attempt in range(1, max_attempts + 1):
                try:
                    print(f"Processing question {test_count} (attempt {attempt}/{max_attempts}): {example['Question'][:100]}...")
                    workflow_manager = WorkflowManager(first_user_message=example["Question"], record_memory=True, base_workspace=f"qwq_gaia_workspace_level{args.level}", memory_manager=memory_manager)
                    result = workflow_manager.run()
                    answer = result.get('final_answer', result.get('current_summary', 'No answer found'))
                    results.append({
                        "id": example["task_id"],
                        "question": example["Question"],
                        "answer": answer,
                        "attempt": attempt
                    })
                    success = True
                    print(f"✅ Success on attempt {attempt}")
                    break
                except Exception as e:
                    last_error = e
                    print(f"❌ Attempt {attempt} failed: {e}")
                    if attempt < max_attempts:
                        print(f"🔄 Retrying... ({attempt + 1}/{max_attempts})")
                    else:
                        print(f"💥 All {max_attempts} attempts failed for question: {example['Question'][:100]}")
            
            if not success:
                # Record the failed attempt
                results.append({
                    "id": example["task_id"],
                    "question": example["Question"],
                    "answer": "FAILED_ALL_ATTEMPTS",
                    "error": str(last_error),
                    "attempts": max_attempts
                })

    # save the memory manager
    memory_manager.save_workflow_execution_memory(workflow_execution_memory_path=os.path.join(f"qwq_gaia_workspace_level{args.level}", "workflow_execution_memory.json"))

    # save the results to a json file
    with open(f"qwq_gaia_results_level{args.level}_with_memory.json", "w") as f:
        json.dump(results, f)