import os
import json
import time
import random
from tqdm import tqdm

# Import the core workflow object and dependencies from your ProAgent definition file
# Ensure that pa_main.py is in Python's search path
try:
    from pa_main import workflow
    from langgraph.checkpoint.sqlite import SqliteSaver
except ImportError:
    print("Error: Failed to import 'workflow' from 'pa_main.py'.")
    print("Please ensure that 'evaluate_proagent.py' and 'pa_main.py' are in the same directory, or that the directory containing pa_main.py is in PYTHONPATH.")
    exit()

# ================================
# User Configuration
# ================================

# 1. Define the tasks to be evaluated and their corresponding test file paths
# Please modify according to your actual paths
TASKS = {
    'PQA': 'YOUR_PQA_TEST_DIR',
    'LABQA': 'YOUR_LASBQA_TEST_DIR', 
    'ORD': 'YOUR_ORD_TEST_DIR',
    'ERR': 'YOUR_ERR_TEST_DIR',
    'GEN': 'YOUR_GEN_TEST_DIR',
}

# 2. Define the output directory for evaluation results
OUTPUT_DIR = 'evaluation_results/YOUR_TEST_NAME'

# 3. Define an identifier for ProAgent to be used in filenames
AGENT_NAME = "proagent_v2_hyperbird" 

# 4. Set the checkpoint saving frequency (save once every N samples processed)
CHECKPOINT_INTERVAL = 10

# ================================
# Helper Functions (borrowed from your rag_inference.py)
# ================================

def get_test_data(file_path):
    """
    Load test data from a JSON file.
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        test_data = json.load(f)
    return test_data

def save_checkpoint(data, filename):
    """
    Save intermediate results to a JSON file.
    If the directory does not exist, it will be created automatically.
    """
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

# ================================
# Core ProAgent Processing Function
# ================================

def process_sample_with_proagent(sample, app, max_retries=3, initial_delay=5):
    """
    Process a single sample using the compiled ProAgent application.
    Skip if the sample already has a 'generated_response'.
    """
    if 'generated_response' in sample:
        return sample
    
    last_exception = None
    for attempt in range(max_retries):
        try:
            # Prepare inputs and configuration for langgraph invocation
            # Use the sample ID as the thread ID to ensure traceability for each run
            thread_id = sample.get("id", f"thread-{hash(str(sample))}")
            config = {"configurable": {"thread_id": thread_id}}
            inputs = {"sample": sample}
            
            # Invoke ProAgent
            final_state = app.invoke(inputs, config=config)
            
            # Extract generated content from the final state
            generation = final_state.get('generation', 'Error: No generation found in final state.')
            sample['generated_response'] = generation
            return sample

        except Exception as e:
            last_exception = e
            if '429' in str(e):
                # Exponential backoff + jitter
                delay = initial_delay * (2 ** attempt) + random.uniform(0, 1) 
                print(f"\nReceived API rate limit (429). Will attempt {attempt + 2}/{max_retries} in {delay:.1f} seconds...")
                time.sleep(delay)
            else:
                print(f"\nError: Unexpected exception occurred while processing sample {sample.get('id')}: {e}")
                sample['generated_response'] = f"PROAGENT_EXECUTION_ERROR: {str(e)}"
                return sample 
    print(f"\nError: Sample {sample.get('id')} still failed after {max_retries} attempts.")
    sample['generated_response'] = f"PROAGENT_FAILED_AFTER_RETRIES: {str(last_exception)}"
    return sample

# ================================
# Main Processing Function
# ================================

def main():
    """
    Main function for batch processing all defined tasks.
    """
    # Ensure the output directory exists
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # --- Compile ProAgent ---
    # Use an in-memory SQLite as the checkpointer, so each run is fresh
    # This is suitable for batch evaluation as we don't rely on LangGraph's persistence memory
    with SqliteSaver.from_conn_string(":memory:") as memory:
        print("Compiling ProAgent graph...")
        # Compile the app within the 'with' block to ensure it gets the correct memory object
        app = workflow.compile(checkpointer=memory)
        print("ProAgent graph compiled successfully.")

        # --- Iterate through each task ---
        # The entire task processing loop is now within the context of the 'with' statement
        for task_name, test_file_path in TASKS.items():
            print(f"\n{'='*25} Starting evaluation for task: {task_name} {'='*25}")

            output_file = os.path.join(OUTPUT_DIR, f"{task_name}_{AGENT_NAME}_evaluation_delay1s.json")

            if not os.path.exists(test_file_path):
                print(f"Warning: Cannot find test file for task '{task_name}', skipped. Path: {test_file_path}")
                continue

            test_set = get_test_data(test_file_path)

            # --- Checkpoint ---
            processed_set = []
            if os.path.exists(output_file):
                print(f"Found existing checkpoint file, loading: {output_file}")
                with open(output_file, 'r', encoding='utf-8') as f:
                    processed_set = json.load(f)

            processed_ids = set(s['id'] for s in processed_set if 'id' in s and 'generated_response' in s)
            remaining_samples = [s for s in test_set if s.get('id') not in processed_ids]

            if not remaining_samples:
                print(f"Task '{task_name}' has been fully processed.")
                continue

            print(f"Total: {len(test_set)} | Processed: {len(processed_ids)} | Remaining: {len(remaining_samples)}")

            # --- Start processing ---
            count_since_last_save = 0
            REQUEST_DELAY_SECONDS = 1.25
            for sample in tqdm(remaining_samples, desc=f"Processing {task_name}"):
                # 'app' can be used safely here because it was compiled within the 'with' block
                processed_sample = process_sample_with_proagent(sample, app)
                processed_set.append(processed_sample)
                
                count_since_last_save += 1
                if count_since_last_save >= CHECKPOINT_INTERVAL:
                    save_checkpoint(processed_set, output_file)
                    count_since_last_save = 0
                
                time.sleep(REQUEST_DELAY_SECONDS)

            # --- Save Final Outcome ---
            save_checkpoint(processed_set, output_file)
            print(f"\nTask '{task_name}' processing completed.")
            print(f"All results have been saved to: {output_file}")

if __name__ == "__main__":
    main()