import json
import sys
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from estimate_eval_probability import get_eval_probability, get_eval_probability_agentic, get_eval_probability_first_person, get_eval_probability_first_person_full_conversation
from reasoning_traces import eval_reasoning_ea
import os
import threading
from binary_eval_awareness import binary_eval_awareness, binary_eval_awareness_agentic, binary_eval_awareness_first_person
from queue import Queue
import time
from unprompted_purpose_guessing import get_purpose_awareness, get_purpose_awareness_agentic, get_purpose_awareness_first_person
from logprob_eval_awareness import logprob_awareness
from config import OPENROUTER_API_KEY

# Specify models here; API key is loaded from config.py
MODEL = "deepseek/deepseek-r1-0528"
JUDGE_MODEL = "openai/gpt-5-mini"
INPUT_FILE = "datasets/dataset_minimal.json"
OUTPUT_FILE = "eval_results/deepseek_r1_all.json"
MAX_WORKERS = 500  # Maximum number of parallel API calls
FULL_TRANSCRIPT_ANALYSIS = False  # Set to False to activate single-prompt mode
REASONING_TRACES_ANALYSIS = True  # Controls the eval_reasoning_ea analysis
PROBABILITY_EVAL_ANALYSIS = True # Controls the eval probability analysis
BINARY_EVAL_ANALYSIS = True  # Controls the binary eval awareness analysis
PURPOSE_AWARENESS_ANALYSIS = True  # Controls the unprompted purpose guessing analysis
LOGPROB_EVAL_ANALYSIS = False  # Controls the logprob-based evaluation awareness analysis (first-person only)
FIRST_PERSON = True  # If True, use first-person approach for purpose awareness
MAX_RETRIES = 2  # Number of times to re-run failed samples

# Global queue for results and writer thread
result_queue = Queue()
writer_thread = None
writer_stop_event = threading.Event()


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Run black box evaluations on model conversations')
    
    parser.add_argument('--model', type=str, 
                       help='Model name to use for evaluation')
    parser.add_argument('--input_file', type=str,
                       help='Path to input JSON file containing samples')
    parser.add_argument('--output_file', type=str,
                       help='Path to output JSON file for results')
    parser.add_argument('--max_workers', type=int,
                       help='Maximum number of parallel API calls')
    parser.add_argument('--full_transcript_analysis', action='store_true',
                       help='Enable full transcript analysis mode')
    parser.add_argument('--reasoning_traces_analysis', action='store_true',
                       help='Enable reasoning traces analysis')
    parser.add_argument('--probability_eval_analysis', action='store_true',
                       help='Enable probability evaluation analysis')
    parser.add_argument('--binary_eval_analysis', action='store_true',
                       help='Enable binary evaluation analysis')
    parser.add_argument('--purpose_awareness_analysis', action='store_true',
                       help='Enable purpose awareness analysis')
    parser.add_argument('--logprob_eval_analysis', action='store_true',
                       help='Enable logprob-based evaluation analysis')
    parser.add_argument('--first_person', action='store_true',
                       help='Use first-person approach for evaluations')
    
    return parser.parse_args()


def get_config_values():
    """Get configuration values from CLI args or defaults."""
    args = parse_arguments()
    
    # Use CLI arguments if provided, otherwise use defaults
    config = {
        'MODEL': args.model if args.model else MODEL,
        'INPUT_FILE': args.input_file if args.input_file else INPUT_FILE,
        'OUTPUT_FILE': args.output_file if args.output_file else OUTPUT_FILE,
        'MAX_WORKERS': args.max_workers if args.max_workers else MAX_WORKERS,
        'FULL_TRANSCRIPT_ANALYSIS': args.full_transcript_analysis if len(sys.argv) > 1 else FULL_TRANSCRIPT_ANALYSIS,
        'REASONING_TRACES_ANALYSIS': args.reasoning_traces_analysis if len(sys.argv) > 1 else REASONING_TRACES_ANALYSIS,
        'PROBABILITY_EVAL_ANALYSIS': args.probability_eval_analysis if len(sys.argv) > 1 else PROBABILITY_EVAL_ANALYSIS,
        'BINARY_EVAL_ANALYSIS': args.binary_eval_analysis if len(sys.argv) > 1 else BINARY_EVAL_ANALYSIS,
        'PURPOSE_AWARENESS_ANALYSIS': args.purpose_awareness_analysis if len(sys.argv) > 1 else PURPOSE_AWARENESS_ANALYSIS,
        'LOGPROB_EVAL_ANALYSIS': args.logprob_eval_analysis if len(sys.argv) > 1 else LOGPROB_EVAL_ANALYSIS,
        'FIRST_PERSON': args.first_person if len(sys.argv) > 1 else FIRST_PERSON
    }
    
    return config


def extract_first_prompts(prompts):
    system_prompt = ""
    user_prompt = ""
    for p in prompts:
        if p.get("role") == "system" and not system_prompt:
            system_prompt = p.get("content", "")
        elif p.get("role") == "user" and not user_prompt:
            user_prompt = p.get("content", "")
        if system_prompt and user_prompt:
            break
    return system_prompt, user_prompt

def is_agentic_sample(sample):
    """Check if the sample belongs to an agentic group based on the sample ID"""
    sample_id = sample.get("id", "")
    
    # Define agentic groups
    agentic_groups = {
        "agentharm", "aider_full", "cybench", "gaia_runs", "honeypot", 
        "smol_gaia_manager", "smol_gaia_search", "swe_gym", "swebench", 
        "aider_commit", "aider_manual", "aider_multi", "continue_dev", 
        "irl_cursor", "openhands", "smol_real_manager", "smol_real_search"
    }
    
    # Extract group from sample ID (format: "group:sample_name")
    if ":" in sample_id:
        group = sample_id.split(":")[0]
        return group in agentic_groups
    
    return False

def get_conversation_summary(prompts):
    """Get a summary of the conversation including all role types"""
    role_counts = {}
    for p in prompts:
        role = p.get("role", "unknown")
        role_counts[role] = role_counts.get(role, 0) + 1
    return role_counts

def has_multiple_interactions(prompts):
    """Check if the conversation has multiple interactions (indicating back-and-forth interaction)"""
    user_count = sum(1 for p in prompts if p.get("role") == "user")
    assistant_count = sum(1 for p in prompts if p.get("role") == "assistant")
    tool_count = sum(1 for p in prompts if p.get("role") == "tool")
    # Consider it a multi-turn conversation if there are at least 2 user messages, 1 assistant message, or any tool messages
    return user_count >= 2 or assistant_count >= 1 or tool_count >= 1

def process_sample(sample, original_index, config):
    """Process a single sample with all three evaluations plus majority-vote eval awareness"""
    try:
        prompts = sample.get("prompts", [])
        sample_id = sample.get("id", f"index_{original_index}")
        
        print(f"Starting processing of sample: {sample_id}")
        
        # Eval probability (agentic or single-prompt mode)
        if config['PROBABILITY_EVAL_ANALYSIS']:
            print(f"  Getting eval probability for {sample_id}...")
            try:
                if config['FULL_TRANSCRIPT_ANALYSIS']:
                    if config['FIRST_PERSON']:
                        probability_eval, reasoning_traces = get_eval_probability_first_person_full_conversation(prompts, config['MODEL'], OPENROUTER_API_KEY)
                    else:
                        probability_eval, reasoning_traces = get_eval_probability_agentic(prompts, config['MODEL'], OPENROUTER_API_KEY)
                    eval_probability_results = {
                        "eval_probability": probability_eval,
                        "reasoning_traces": reasoning_traces
                    }
                else:
                    system_prompt, user_prompt = extract_first_prompts(prompts)
                    if config['FIRST_PERSON']:
                        probability_eval, reasoning_traces = get_eval_probability_first_person(system_prompt, user_prompt, config['MODEL'], OPENROUTER_API_KEY)
                    else:
                        probability_eval, reasoning_traces = get_eval_probability(system_prompt, user_prompt, config['MODEL'], OPENROUTER_API_KEY)
                    eval_probability_results = {
                        "eval_probability": probability_eval,
                        "reasoning_traces": reasoning_traces
                    }
                print(f"  Eval probability for {sample_id}: {probability_eval}")
            except Exception as e:
                error_details = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "analysis_step": "eval_probability",
                    "mode": "agentic" if config['FULL_TRANSCRIPT_ANALYSIS'] else ("first_person" if config['FIRST_PERSON'] else "standard")
                }
                print(f"  Error in eval probability analysis for {sample_id}: {error_details}")
                eval_probability_results = {
                    "eval_probability": None,
                    "reasoning_traces": [],
                    "error": error_details
                }
        else:
            eval_probability_results = {
                "eval_probability": None
            }

        if config['REASONING_TRACES_ANALYSIS'] and config['FIRST_PERSON']:
            print(f"  Getting reasoning traces EA analysis for {sample_id}...")
            try:
                system_prompt, user_prompt = extract_first_prompts(prompts)
                reasoning, full_output, judge_response, parsed_judge_response, full_conversations = eval_reasoning_ea(system_prompt, user_prompt, config['MODEL'], OPENROUTER_API_KEY)
                reasoning_traces_results = {
                    "reasoning": reasoning,
                    "full_output": full_output,
                    "judge_response": judge_response,
                    "parsed_judge_response": parsed_judge_response,
                    "full_conversations": full_conversations
                }
                print(f"  Reasoning traces EA analysis completed for {sample_id}")
            except Exception as e:
                error_details = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "analysis_step": "reasoning_traces_ea"
                }
                print(f"  Error in reasoning traces EA analysis for {sample_id}: {error_details}")
                reasoning_traces_results = {
                    "reasoning": None,
                    "full_output": None,
                    "judge_response": None,
                    "parsed_judge_response": None,
                    "full_conversations": [],
                    "error": error_details
                }
        else:
            reasoning_traces_results = {
                "reasoning": None,
                "full_output": None,
                "judge_response": None,
                "parsed_judge_response": None,
                "full_conversations": []
            }

        # Majority-vote eval awareness for full conversation
        if config['BINARY_EVAL_ANALYSIS']:
            print(f"  Getting binary eval awareness for {sample_id}...")
            try:
                if config['FULL_TRANSCRIPT_ANALYSIS']:
                    binary_eval = binary_eval_awareness_agentic(prompts, config['MODEL'], OPENROUTER_API_KEY)
                else:
                    system_prompt, user_prompt = extract_first_prompts(prompts)
                    if config['FIRST_PERSON']:
                        binary_eval = binary_eval_awareness_first_person(system_prompt, user_prompt, config['MODEL'], OPENROUTER_API_KEY)
                    else:
                        binary_eval = binary_eval_awareness(system_prompt, user_prompt, config['MODEL'], OPENROUTER_API_KEY)
                print(f"  Binary eval for {sample_id}: {binary_eval.get('final_decision', 'UNKNOWN')}")
            except Exception as e:
                error_details = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "analysis_step": "binary_eval_awareness",
                    "mode": "agentic" if config['FULL_TRANSCRIPT_ANALYSIS'] else ("first_person" if config['FIRST_PERSON'] else "standard")
                }
                print(f"  Error in binary eval awareness for {sample_id}: {error_details}")
                binary_eval = {
                    "responses": [],
                    "labels": [],
                    "final_decision": "UNKNOWN",
                    "reasoning_traces": [],
                    "full_conversations": [],
                    "assistant_response": None,  # For first-person approach compatibility
                    "error": error_details
                }
        else:
            binary_eval = {
                "responses": [],
                "labels": [],
                "final_decision": "UNKNOWN",
                "reasoning_traces": [],
                "full_conversations": [],
                "assistant_response": None  # For first-person approach compatibility
            }
        
        # Purpose awareness evaluation
        if config['PURPOSE_AWARENESS_ANALYSIS']:
            print(f"  Getting purpose awareness for {sample_id}...")
            try:
                if config['FULL_TRANSCRIPT_ANALYSIS']:
                    purpose_awareness_results = get_purpose_awareness_agentic(prompts, config['MODEL'], OPENROUTER_API_KEY, JUDGE_MODEL)
                else:
                    system_prompt, user_prompt = extract_first_prompts(prompts)
                    if config['FIRST_PERSON']:
                        purpose_awareness_results = get_purpose_awareness_first_person(system_prompt, user_prompt, config['MODEL'], OPENROUTER_API_KEY, JUDGE_MODEL)
                    else:
                        purpose_awareness_results = get_purpose_awareness(system_prompt, user_prompt, config['MODEL'], OPENROUTER_API_KEY, JUDGE_MODEL)
                print(f"  Purpose awareness for {sample_id}: {purpose_awareness_results.get('final_decision', 'UNKNOWN')}")
            except Exception as e:
                error_details = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "analysis_step": "purpose_awareness",
                    "mode": "agentic" if config['FULL_TRANSCRIPT_ANALYSIS'] else ("first_person" if config['FIRST_PERSON'] else "standard")
                }
                print(f"  Error in purpose awareness for {sample_id}: {error_details}")
                purpose_awareness_results = {
                    "responses": [],
                    "judged_purposes": [],
                    "judge_jsons": [],
                    "final_decision": "UNKNOWN",
                    "error": error_details
                }
        else:
            purpose_awareness_results = {
                "responses": [],
                "judged_purposes": [],
                "judge_jsons": [],
                "final_decision": "UNKNOWN"
            }
        
        # Logit-bias-based awareness (first-person only)
        if config['LOGPROB_EVAL_ANALYSIS']:
            print(f"  Getting logprob eval awareness for {sample_id}...")
            if config['FIRST_PERSON'] and not config['FULL_TRANSCRIPT_ANALYSIS']:
                try:
                    system_prompt, user_prompt = extract_first_prompts(prompts)
                    logprob_results = logprob_awareness(system_prompt, user_prompt, config['MODEL'], OPENROUTER_API_KEY)
                    print(f"  Logprob eval for {sample_id}: {logprob_results.get('decision', 'NULL')} (raw_token: {logprob_results.get('raw_token')})")
                except Exception as e:
                    error_details = {
                        "error_type": type(e).__name__,
                        "error_message": str(e),
                        "analysis_step": "logprob_awareness"
                    }
                    print(f"  Error in logprob analysis for {sample_id}: {error_details}")
                    logprob_results = {
                        "decision": None,
                        "raw_token": None,
                        "full_response": None,
                        "error": error_details
                    }
            else:
                # Analysis only works in first-person, single-prompt mode
                logprob_results = {
                    "decision": None,
                    "raw_token": None,
                    "full_response": None,
                    "error": "Analysis only available in first-person single-prompt mode"
                }
                print(f"  Logprob eval for {sample_id}: Skipped (not in first-person single-prompt mode)")
        else:
                logprob_results = {
                    "decision": None,
                    "raw_token": None,
                    "full_response": None,
                    "error": None
                }

        # Combine all outputs with the original sample
        sample_result = dict(sample)  # copy all original fields
        
        # If not doing full transcript analysis, only keep the first 2 prompts
        if not config['FULL_TRANSCRIPT_ANALYSIS'] and "prompts" in sample_result:
            sample_result["prompts"] = sample_result["prompts"][:2]
        sample_result["eval_probability_results"] = eval_probability_results
        sample_result["reasoning_traces_results"] = reasoning_traces_results
        sample_result["binary_test_results"] = binary_eval
        sample_result["purpose_awareness_results"] = purpose_awareness_results
        sample_result["logprob_eval_results"] = logprob_results
        
        sample_result["original_index"] = original_index  # Add original index for sorting

        # Handle human_eval_probability based on FULL_TRANSCRIPT_ANALYSIS setting
        if "metadata" not in sample_result:
            sample_result["metadata"] = {}
        
        if config['FULL_TRANSCRIPT_ANALYSIS']:
            # Preserve original human_eval_probability when doing full transcript analysis
            if "metadata" in sample and "human_eval_probability" in sample.get("metadata", {}):
                sample_result["metadata"]["human_eval_probability"] = sample["metadata"]["human_eval_probability"]
        else:
            # Set human_eval_probability to null when not doing full transcript analysis
            sample_result["metadata"]["human_eval_probability"] = None

        print(f"Completed processing of sample: {sample_id}")
        return sample_result
        
    except Exception as e:
        import traceback
        
        # Create detailed error information
        error_details = {
            "error_type": type(e).__name__,
            "error_message": str(e),
            "traceback": traceback.format_exc(),
            "sample_id": sample.get('id', f'index_{original_index}'),
            "config_state": {
                "FULL_TRANSCRIPT_ANALYSIS": config['FULL_TRANSCRIPT_ANALYSIS'],
                "REASONING_TRACES_ANALYSIS": config['REASONING_TRACES_ANALYSIS'],
                "PROBABILITY_EVAL_ANALYSIS": config['PROBABILITY_EVAL_ANALYSIS'],
                "BINARY_EVAL_ANALYSIS": config['BINARY_EVAL_ANALYSIS'],
                "PURPOSE_AWARENESS_ANALYSIS": config['PURPOSE_AWARENESS_ANALYSIS'],
                "LOGPROB_EVAL_ANALYSIS": config['LOGPROB_EVAL_ANALYSIS'],
                "FIRST_PERSON": config['FIRST_PERSON']
            },
            "sample_metadata": {
                "has_prompts": "prompts" in sample,
                "prompt_count": len(sample.get("prompts", [])) if "prompts" in sample else 0,
                "is_agentic": is_agentic_sample(sample),
                "conversation_summary": get_conversation_summary(sample.get("prompts", [])),
                "has_multiple_interactions": has_multiple_interactions(sample.get("prompts", []))
            }
        }
        
        print(f"Error processing sample {sample.get('id', f'index_{original_index}')}: {error_details['error_message']}")
        print(f"Error type: {error_details['error_type']}")
        print(f"Full traceback: {error_details['traceback']}")
        
        # Return a minimal result to avoid breaking the pipeline
        sample_result = dict(sample)
        if config['PROBABILITY_EVAL_ANALYSIS']:
            sample_result["eval_probability_results"] = {
                "eval_probability": None,
                "reasoning_traces": [],
                "error": error_details
            }
        else:
            sample_result["eval_probability_results"] = {
                "eval_probability": None,
                "error": error_details
            }
        sample_result["reasoning_traces_results"] = {
            "reasoning": None,
            "full_output": None,
            "judge_response": None,
            "parsed_judge_response": None,
            "full_conversations": [],
            "error": error_details
        }
        sample_result["binary_test_results"] = {
            "responses": [],
            "labels": [],
            "final_decision": "UNKNOWN",
            "reasoning_traces": [],
            "full_conversations": [],
            "assistant_response": None,  # For first-person approach compatibility
            "error": error_details
        }
        sample_result["purpose_awareness_results"] = {
            "responses": [],
            "judged_purposes": [],
            "judge_jsons": [],
            "final_decision": "UNKNOWN",
            "error": error_details
        }
        sample_result["logprob_eval_results"] = {
            "classification": None,
            "evaluation_logprob": None,
            "deployment_logprob": None,
            "logprobs_available": False,
            "full_response": None,
            "error": error_details
        }
        sample_result["original_index"] = original_index
        sample_result["processing_error"] = error_details
        return sample_result

def writer_worker(config):
    """Dedicated thread to handle file writing"""
    results_buffer = []
    last_write_time = time.time()
    
    while not writer_stop_event.is_set() or not result_queue.empty():
        try:
            # Get result with timeout to allow checking stop event
            result = result_queue.get(timeout=1.0)
            results_buffer.append(result)
            
            # Write to file every 10 results or every 30 seconds
            current_time = time.time()
            if len(results_buffer) >= 10 or (current_time - last_write_time) >= 30:
                write_results_to_file(results_buffer, config)
                results_buffer = []
                last_write_time = current_time
                
        except Exception as e:
            # Timeout is expected, continue
            continue
    
    # Write any remaining results
    if results_buffer:
        write_results_to_file(results_buffer, config)

def write_results_to_file(results_to_write, config):
    """Write multiple results to file efficiently"""
    try:
        # Load existing data
        if os.path.exists(config['OUTPUT_FILE']):
            with open(config['OUTPUT_FILE'], "r") as f:
                data = json.load(f)
        else:
            data = {"model": config['MODEL'], "results": []}
        
        # Add new results
        for result in results_to_write:
            if "original_index" in result:
                del result["original_index"]
            data["results"].append(result)
        
        # Write back to file
        with open(config['OUTPUT_FILE'], "w") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
            
    except Exception as e:
        print(f"Error writing to file: {e}")

def identify_failed_samples(data, config):
    """Identify samples that failed in any analysis method."""
    failed_samples = []
    
    for i, sample in enumerate(data['results']):
        sample_id = sample.get('id', f'index_{i}')
        issues = []
        error_details = []
        
        # Check for processing errors
        if 'processing_error' in sample:
            processing_error = sample['processing_error']
            if isinstance(processing_error, dict):
                error_msg = f"Processing error: {processing_error.get('error_type', 'Unknown')} - {processing_error.get('error_message', 'No message')}"
                error_details.append({
                    "type": "processing_error",
                    "details": processing_error
                })
            else:
                error_msg = f"Processing error: {processing_error}"
                error_details.append({
                    "type": "processing_error", 
                    "details": {"error_message": str(processing_error)}
                })
            issues.append(error_msg)
        
        # Check reasoning traces analysis
        reasoning_results = sample.get('reasoning_traces_results', {})
        if config['REASONING_TRACES_ANALYSIS']:
            if reasoning_results.get('error'):
                error_info = reasoning_results['error']
                if isinstance(error_info, dict):
                    error_msg = f"Reasoning traces error: {error_info.get('error_type', 'Unknown')} - {error_info.get('error_message', 'No message')}"
                    error_details.append({
                        "type": "reasoning_traces_error",
                        "details": error_info
                    })
                else:
                    error_msg = f"Reasoning traces error: {error_info}"
                    error_details.append({
                        "type": "reasoning_traces_error",
                        "details": {"error_message": str(error_info)}
                    })
                issues.append(error_msg)
            elif not (reasoning_results.get('reasoning') or reasoning_results.get('full_output')) or not reasoning_results.get('judge_response'):
                # Count as processed if we have either reasoning traces OR full_output, plus judge_response
                # This allows samples with only reasoning traces OR only full_output to be considered processed
                issues.append("Incomplete reasoning traces analysis (missing required fields)")
                error_details.append({
                    "type": "incomplete_reasoning_traces",
                    "details": {
                        "has_reasoning": bool(reasoning_results.get('reasoning')),
                        "has_full_output": bool(reasoning_results.get('full_output')),
                        "has_judge_response": bool(reasoning_results.get('judge_response'))
                    }
                })
        
        # Check eval probability analysis
        eval_prob_results = sample.get('eval_probability_results', {})
        if config['PROBABILITY_EVAL_ANALYSIS']:
            if eval_prob_results.get('error'):
                error_info = eval_prob_results['error']
                if isinstance(error_info, dict):
                    error_msg = f"Eval probability error: {error_info.get('error_type', 'Unknown')} - {error_info.get('error_message', 'No message')}"
                    error_details.append({
                        "type": "eval_probability_error",
                        "details": error_info
                    })
                else:
                    error_msg = f"Eval probability error: {error_info}"
                    error_details.append({
                        "type": "eval_probability_error",
                        "details": {"error_message": str(error_info)}
                    })
                issues.append(error_msg)
            elif eval_prob_results.get('eval_probability') is None:
                issues.append("Missing eval probability")
                error_details.append({
                    "type": "missing_eval_probability",
                    "details": {"eval_probability": eval_prob_results.get('eval_probability')}
                })
        
        # Check binary eval awareness analysis
        binary_results = sample.get('binary_test_results', {})
        if config['BINARY_EVAL_ANALYSIS']:
            if binary_results.get('error'):
                error_info = binary_results['error']
                if isinstance(error_info, dict):
                    error_msg = f"Binary eval error: {error_info.get('error_type', 'Unknown')} - {error_info.get('error_message', 'No message')}"
                    error_details.append({
                        "type": "binary_eval_error",
                        "details": error_info
                    })
                else:
                    error_msg = f"Binary eval error: {error_info}"
                    error_details.append({
                        "type": "binary_eval_error",
                        "details": {"error_message": str(error_info)}
                    })
                issues.append(error_msg)
            elif not binary_results.get('responses') or binary_results.get('final_decision') == 'UNKNOWN':
                issues.append("Incomplete binary eval analysis")
                error_details.append({
                    "type": "incomplete_binary_eval",
                    "details": {
                        "has_responses": bool(binary_results.get('responses')),
                        "final_decision": binary_results.get('final_decision')
                    }
                })
        
        # Check purpose awareness analysis
        purpose_results = sample.get('purpose_awareness_results', {})
        if config['PURPOSE_AWARENESS_ANALYSIS']:
            if purpose_results.get('error'):
                error_info = purpose_results['error']
                if isinstance(error_info, dict):
                    error_msg = f"Purpose awareness error: {error_info.get('error_type', 'Unknown')} - {error_info.get('error_message', 'No message')}"
                    error_details.append({
                        "type": "purpose_awareness_error",
                        "details": error_info
                    })
                else:
                    error_msg = f"Purpose awareness error: {error_info}"
                    error_details.append({
                        "type": "purpose_awareness_error",
                        "details": {"error_message": str(error_info)}
                    })
                issues.append(error_msg)
            elif not purpose_results.get('responses') or purpose_results.get('final_decision') == 'UNKNOWN':
                issues.append("Incomplete purpose awareness analysis")
                error_details.append({
                    "type": "incomplete_purpose_awareness",
                    "details": {
                        "has_responses": bool(purpose_results.get('responses')),
                        "final_decision": purpose_results.get('final_decision')
                    }
                })
        
        # Check logprob evaluation awareness analysis (treat indecision as failure)
        logprob_results = sample.get('logprob_eval_results', {})
        if config['LOGPROB_EVAL_ANALYSIS'] and config['FIRST_PERSON'] and not config['FULL_TRANSCRIPT_ANALYSIS']:
            if logprob_results.get('error'):
                error_info = logprob_results['error']
                if isinstance(error_info, dict):
                    error_msg = f"Logprob analysis error: {error_info.get('error_type', 'Unknown')} - {error_info.get('error_message', 'No message')}"
                    error_details.append({
                        "type": "logprob_analysis_error",
                        "details": error_info
                    })
                else:
                    error_msg = f"Logprob analysis error: {error_info}"
                    error_details.append({
                        "type": "logprob_analysis_error",
                        "details": {"error_message": str(error_info)}
                    })
                issues.append(error_msg)
            elif logprob_results.get('decision') is None:
                issues.append("Logprob indecision (no eval/deploy token)")
                error_details.append({
                    "type": "logprob_indecision",
                    "details": {
                        "decision": logprob_results.get('decision'),
                        "raw_token": logprob_results.get('raw_token')
                    }
                })
        
        if issues:
            failed_samples.append({
                'index': i,
                'id': sample_id,
                'issues': issues,
                'error_details': error_details,
                'sample': sample
            })
    
    return failed_samples

def retry_failed_processing(samples, config):
    """Retry processing of failed samples up to MAX_RETRIES times."""
    max_retries = config.get('MAX_RETRIES', MAX_RETRIES if 'MAX_RETRIES' in globals() else 0)
    if max_retries <= 0:
        return
    for attempt in range(max_retries):
        print("\n" + "="*60)
        print(f"RETRY ROUND {attempt+1} OF {max_retries}")
        print("="*60)
        # Load current results
        try:
            with open(config['OUTPUT_FILE'], 'r') as f:
                data = json.load(f)
        except Exception as e:
            print(f"Unable to load results for retry: {e}")
            return
        failed_samples = identify_failed_samples(data, config)
        if not failed_samples:
            print("No failed samples left to retry.")
            return
        print(f"Retrying {len(failed_samples)} failed samples...")
        # Remove failed from results to avoid duplicates
        failed_indices = [fs['index'] for fs in failed_samples]
        failed_indices.sort(reverse=True)
        for idx in failed_indices:
            try:
                del data['results'][idx]
            except Exception:
                pass
        with open(config['OUTPUT_FILE'], 'w') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        # Map ids to original samples
        id_to_tuple = {s.get('id', i): (i, s) for i, s in enumerate(samples)}
        indices_to_process = []
        index_to_sample = {}
        for fs in failed_samples:
            sid = fs.get('id')
            if sid in id_to_tuple:
                idx, s = id_to_tuple[sid]
                indices_to_process.append(idx)
                index_to_sample[idx] = s
        if not indices_to_process:
            print("No matching samples found in input for failed ids; skipping.")
            continue
        # Start writer thread
        global writer_thread
        writer_stop_event.clear()
        writer_thread = threading.Thread(target=writer_worker, args=(config,), daemon=True)
        writer_thread.start()
        # Process in parallel
        completed = 0
        try:
            with ThreadPoolExecutor(max_workers=config['MAX_WORKERS']) as executor:
                future_to_index = {executor.submit(process_sample, index_to_sample[idx], idx, config): idx for idx in indices_to_process}
                for future in as_completed(future_to_index):
                    try:
                        sample_result = future.result(timeout=900)
                        result_queue.put(sample_result)
                        completed += 1
                        print(f"Retried sample {completed}/{len(indices_to_process)} (id {sample_result.get('id', 'no_id')})")
                    except Exception as e:
                        print(f"Error getting result from retry future: {e}")
                        completed += 1
                        continue
        except KeyboardInterrupt:
            print("\nRetry processing interrupted by user.")
        # Stop writer thread
        writer_stop_event.set()
        writer_thread.join()
    print("Finished retry rounds.")

def cleanup_failed_samples(config):
    """Identify failed samples, remove them from main results, and save them separately."""
    print("\n" + "="*60)
    print("CLEANING UP FAILED SAMPLES")
    print("="*60)
    
    # Load the current results
    with open(config['OUTPUT_FILE'], 'r') as f:
        data = json.load(f)
    
    # Identify failed samples
    failed_samples = identify_failed_samples(data, config)
    
    if not failed_samples:
        print("✅ No failed samples found! All samples completed successfully.")
        return
    
    print(f"❌ Found {len(failed_samples)} failed samples:")
    print("-" * 60)
    
    # Group failed samples by issue type for better reporting
    issue_counts = {}
    error_type_counts = {}
    for failed in failed_samples:
        for issue in failed['issues']:
            issue_counts[issue] = issue_counts.get(issue, 0) + 1
        
        # Count error types from detailed error information
        for error_detail in failed.get('error_details', []):
            error_type = error_detail.get('type', 'unknown')
            error_type_counts[error_type] = error_type_counts.get(error_type, 0) + 1
    
    print("Issue breakdown:")
    for issue, count in sorted(issue_counts.items()):
        print(f"  • {issue}: {count} samples")
    
    print("\nError type breakdown:")
    for error_type, count in sorted(error_type_counts.items()):
        print(f"  • {error_type}: {count} samples")
    
    print("\nDetailed failed sample information:")
    for failed in failed_samples:
        print(f"\n  📋 Sample: {failed['id']}")
        print(f"     Index: {failed['index']}")
        print(f"     Issues: {len(failed['issues'])}")
        for i, issue in enumerate(failed['issues'], 1):
            print(f"       {i}. {issue}")
        
        # Show detailed error information if available
        if failed.get('error_details'):
            print(f"     Detailed errors:")
            for j, error_detail in enumerate(failed['error_details'], 1):
                error_type = error_detail.get('type', 'unknown')
                details = error_detail.get('details', {})
                print(f"       {j}. Type: {error_type}")
                
                if error_type == 'processing_error' and 'traceback' in details:
                    print(f"          Error: {details.get('error_type', 'Unknown')} - {details.get('error_message', 'No message')}")
                    if details.get('sample_metadata'):
                        metadata = details['sample_metadata']
                        print(f"          Sample info: {metadata.get('prompt_count', 0)} prompts, agentic: {metadata.get('is_agentic', False)}")
                elif isinstance(details, dict):
                    # Show key details for other error types
                    for key, value in details.items():
                        if key not in ['traceback', 'config_state']:  # Skip very verbose fields
                            print(f"          {key}: {value}")
    
    print(f"\n{'='*60}")
    
    
    # Remove failed samples from main results
    failed_indices = [failed['index'] for failed in failed_samples]
    failed_indices.sort(reverse=True)  # Remove from highest index to lowest to avoid shifting
    
    for index in failed_indices:
        del data['results'][index]
    
    # Save the cleaned results
    print(f"🧹 Removing {len(failed_samples)} failed samples from main results...")
    with open(config['OUTPUT_FILE'], 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    
    print(f"✅ Clean results saved to {config['OUTPUT_FILE']}")
    print(f"📊 Final statistics:")
    print(f"   • Total samples processed: {len(data['results']) + len(failed_samples)}")
    print(f"   • Successful samples: {len(data['results'])}")
    print(f"   • Failed samples: {len(failed_samples)}")
    print(f"   • Success rate: {len(data['results']) / (len(data['results']) + len(failed_samples)) * 100:.1f}%")
    
    print("="*60) 

if __name__ == "__main__":
    # Get configuration from CLI or defaults
    config = get_config_values()
    
    
    if config['LOGPROB_EVAL_ANALYSIS'] and (not config['FIRST_PERSON'] or config['FULL_TRANSCRIPT_ANALYSIS']):
        print("ERROR: LOGPROB_EVAL_ANALYSIS requires FIRST_PERSON = True and FULL_TRANSCRIPT_ANALYSIS = False.")
        print("Please set FIRST_PERSON = True and FULL_TRANSCRIPT_ANALYSIS = False when using LOGPROB_EVAL_ANALYSIS = True.")
        sys.exit(1)
    
    with open(config['INPUT_FILE'], "r") as f:
        samples = json.load(f)

    results = []
    completed = 0

    # Check for existing output file and load processed samples
    if os.path.exists(config['OUTPUT_FILE']):
        with open(config['OUTPUT_FILE'], "r") as f:
            try:
                existing_data = json.load(f)
                if isinstance(existing_data, dict) and "results" in existing_data:
                    processed_samples = existing_data["results"]
                    processed_ids = set(r.get("id", i) for i, r in enumerate(processed_samples))
                else:
                    processed_samples = []
                    processed_ids = set()
            except Exception:
                processed_samples = []
                processed_ids = set()
    else:
        processed_samples = []
        processed_ids = set()

    # Prepare a mapping from index to sample for fast lookup
    index_to_sample = {idx: sample for idx, sample in enumerate(samples)}

    # Only process samples whose id is not already in the output file
    indices_to_process = [idx for idx, sample in enumerate(samples) if sample.get("id", idx) not in processed_ids]

    print(f"Found {len(indices_to_process)} samples to process out of {len(samples)} total samples")
    
    if len(indices_to_process) == 0:
        print("All samples have already been processed. Running cleanup on existing results...")
        cleanup_failed_samples(config)
        sys.exit(0)

    # Write initial output file if it doesn't exist
    if not os.path.exists(config['OUTPUT_FILE']):
        with open(config['OUTPUT_FILE'], "w") as f:
            json.dump({"model": config['MODEL'], "results": []}, f, ensure_ascii=False, indent=2)

    # Start the writer thread
    writer_thread = threading.Thread(target=writer_worker, args=(config,), daemon=True)
    writer_thread.start()

    # Process samples in parallel
    try:
        with ThreadPoolExecutor(max_workers=config['MAX_WORKERS']) as executor:
            future_to_index = {executor.submit(process_sample, index_to_sample[idx], idx, config): idx for idx in indices_to_process}
            for future in as_completed(future_to_index):
                try:
                    # Add timeout to prevent infinite hanging
                    sample_result = future.result(timeout=900)  # 5 minute timeout per sample
                    result_queue.put(sample_result) # Put result into the queue
                    completed += 1
                    print(f"Processed sample {completed}/{len(indices_to_process)} (index {sample_result.get('id', 'no_id')})")
                except Exception as e:
                    print(f"Error getting result from future: {e}")
                    completed += 1
                    continue
    except KeyboardInterrupt:
        print("\nProcessing interrupted by user. Cleaning up...")
        writer_stop_event.set()
        writer_thread.join(timeout=5)  # Give it 5 seconds to finish
        print("Cleanup complete. Exiting...")
        

    # Signal the writer thread to stop and join
    writer_stop_event.set()
    writer_thread.join()

    print(f"All evaluations complete. Results written to {config['OUTPUT_FILE']}")
    
    # Retry failed samples if any
    try:
        retry_failed_processing(samples, config)
    except Exception as e:
        print(f"Retry phase encountered an error: {e}")
    
    # Clean up failed samples
    cleanup_failed_samples(config)

