import asyncio
import sys
import time
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional,Tuple
import json
import re
from dataclasses import dataclass
import argparse
import os
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from openai import OpenAI
from tqdm import tqdm
import subprocess
import tempfile
import signal
import threading
import multiprocessing

# Import the existing sandbox for safe code execution

import sys
sys.path.append("./codegen/human-eval")
from human_eval.test_execution import check_correctness
SANDBOX_AVAILABLE = True



sys.path.append(str(Path(__file__).parent))

# Import API monitor if available
try:
    from api_monitor import APIMonitor, MonitoredGPTInterface
    MONITOR_AVAILABLE = True
except ImportError:
    MONITOR_AVAILABLE = False

# Control verbosity with environment variable or argument
log_level = os.getenv('LOG_LEVEL', 'WARNING').upper()
logging.basicConfig(level=getattr(logging, log_level, logging.WARNING))
logger = logging.getLogger(__name__)

# Disable httpx logging (OpenAI API requests)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)

@dataclass
class GPTConfig:
    """Configuration for GPT inference."""
    model_name: str = "gpt-4o-mini"  # Default model
    temperature: float = 0.0
    max_tokens: int = 256
    top_p: float = 1.0
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    stop: Optional[List[str]] = None
    api_key: Optional[str] = None
    max_workers: int = 32  # For ThreadPoolExecutor
    
class GPTInterface:
    """Interface for GPT batch inference using ThreadPoolExecutor."""
    
    def __init__(self, config: GPTConfig, monitor: Optional['APIMonitor'] = None):
        self.config = config
        self.client = None
        self.monitor = monitor
        # API call tracking
        self.api_call_count = 0
        self.total_prompt_tokens = 0
        self.total_completion_tokens = 0
        self.total_api_time = 0.0
        self._initialize_client()
    
    def _initialize_client(self):
        """Initialize OpenAI client."""
        api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
        self.client = OpenAI(api_key=api_key)
            
    
    def format_prompt(self, system_prompt: str, user_prompt: str) -> List[Dict[str, str]]:
        """Format prompt for GPT models."""
        return [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
    
    def _generate_single(self, prompt: str, system_prompt: str, task_id: Optional[str] = None) -> str:
        """Generate response for a single prompt."""
        start_time = time.time()
        messages = self.format_prompt(system_prompt, prompt)
        
        response = self.client.chat.completions.create(
            model=self.config.model_name,
            messages=messages,
            temperature=self.config.temperature,
            max_tokens=self.config.max_tokens,
            top_p=self.config.top_p,
            frequency_penalty=self.config.frequency_penalty,
            presence_penalty=self.config.presence_penalty,
            stop=self.config.stop,
            timeout=100,
        )
        
        response_text = response.choices[0].message.content.strip()
        
        # Update API call tracking
        self.api_call_count += 1
        self.total_api_time += (time.time() - start_time)
        if hasattr(response, 'usage'):
            self.total_prompt_tokens += response.usage.prompt_tokens
            self.total_completion_tokens += response.usage.completion_tokens
        
        # Record metrics if monitor is available
        if self.monitor:
            duration = time.time() - start_time
            # Use actual token counts from response if available
            if hasattr(response, 'usage'):
                self.monitor.record_call_with_usage(
                    model=self.config.model_name,
                    prompt=f"System: {system_prompt}\n\nUser: {prompt}",
                    response=response_text,
                    prompt_tokens=response.usage.prompt_tokens,
                    completion_tokens=response.usage.completion_tokens,
                    duration=duration,
                    task_id=task_id
                )
            else:
                self.monitor.record_call(
                    model=self.config.model_name,
                    prompt=f"System: {system_prompt}\n\nUser: {prompt}",
                    response=response_text,
                    duration=duration,
                    task_id=task_id
                )
        
        return response_text
            

    
    def generate_batch(self, prompts: List[str], system_prompts: List[str] = None, show_progress: bool = True) -> List[str]:
        """Generate responses for a batch of prompts using ThreadPoolExecutor."""
        if system_prompts is None:
            system_prompts = ["You are a helpful assistant."] * len(prompts)
        
        responses = [None] * len(prompts)
        
        # Use ThreadPoolExecutor for parallel API calls
        with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
            # Prepare config dict for worker
            config_dict = {
                'model_name': self.config.model_name,
                'temperature': self.config.temperature,
                'max_tokens': self.config.max_tokens,
                'top_p': self.config.top_p,
                'frequency_penalty': self.config.frequency_penalty,
                'presence_penalty': self.config.presence_penalty,
                'stop': self.config.stop
            }
            api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
            
            # Submit all tasks using the module-level worker function
        # Use ThreadPoolExecutor for parallel API calls
        with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
            # Submit all tasks
            future_to_idx = {
                executor.submit(_gpt_single_request_worker, prompt, sys_prompt, config_dict, api_key): i
                for i, (prompt, sys_prompt) in enumerate(zip(prompts, system_prompts))
            }
            
            # Collect results as they complete with progress bar
            if show_progress:
                with tqdm(total=len(prompts), desc="API calls", unit="req") as pbar:
                    for future in as_completed(future_to_idx):
                        idx = future_to_idx[future]
                        try:
                            result = future.result()
                            if isinstance(result, tuple):
                                response_text, usage_stats = result
                                responses[idx] = response_text
                                # Update API tracking
                                self.api_call_count += 1
                                self.total_prompt_tokens += usage_stats.get('prompt_tokens', 0)
                                self.total_completion_tokens += usage_stats.get('completion_tokens', 0)
                                self.total_api_time += usage_stats.get('api_time', 0)
                            else:
                                # Backward compatibility
                                responses[idx] = result
                                self.api_call_count += 1
                        except Exception as e:
                            responses[idx] = ""
                        pbar.update(1)
            else:
                for future in as_completed(future_to_idx):
                    idx = future_to_idx[future]
                    try:
                        result = future.result()
                        if isinstance(result, tuple):
                            response_text, usage_stats = result
                            responses[idx] = response_text
                            # Update API tracking
                            self.api_call_count += 1
                            self.total_prompt_tokens += usage_stats.get('prompt_tokens', 0)
                            self.total_completion_tokens += usage_stats.get('completion_tokens', 0)
                            self.total_api_time += usage_stats.get('api_time', 0)
                        else:
                            # Backward compatibility
                            responses[idx] = result
                            self.api_call_count += 1
                    except Exception as e:
                        responses[idx] = ""
        
        return responses
    
def _gpt_single_request_worker(prompt: str, system_prompt: str, config_dict: dict, api_key: str) -> Tuple[str, Dict[str, Any]]:
    """Worker function for parallel GPT API calls.
    
    This function is defined at module level to be pickle-able for ThreadPoolExecutor.
    Returns: (response_text, usage_stats)
    """
    try:
        from openai import OpenAI
        import time
        
        start_time = time.time()
        
        # Initialize client in worker process
        client = OpenAI(api_key=api_key)
        
        # Format messages
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]
        
        # Make API call
        response = client.chat.completions.create(
            model=config_dict['model_name'],
            messages=messages,
            temperature=config_dict.get('temperature', 0.1),
            max_tokens=config_dict.get('max_tokens', 1000),
            top_p=config_dict.get('top_p', 0.95),
            frequency_penalty=config_dict.get('frequency_penalty', 0.0),
            presence_penalty=config_dict.get('presence_penalty', 0.0),
            stop=config_dict.get('stop', None)
        )
        
        response_text = response.choices[0].message.content or ""
        
        # Collect usage statistics
        usage_stats = {
            'prompt_tokens': response.usage.prompt_tokens if hasattr(response, 'usage') else 0,
            'completion_tokens': response.usage.completion_tokens if hasattr(response, 'usage') else 0,
            'api_time': time.time() - start_time
        }
        
        return response_text, usage_stats
        
    except Exception as e:
        print(f"GPT worker error: {e}")
        return "", {'prompt_tokens': 0, 'completion_tokens': 0, 'api_time': 0}

def _validate_single_oracle_worker(index: int, oracle_assertion, function_code: str, 
                                   task_description: str, timeout: float, sandbox_available: bool,
                                   dataset_name: str = None) -> Dict[str, Any]:
    """Worker function for parallel oracle validation.
    
    This function is defined at module level to be pickle-able for ThreadPoolExecutor.
    oracle_assertion can be either a string (for most datasets) or a dict (for LiveCodeBench).
    """
    import textwrap
    try:
        # Handle LiveCodeBench dict format
        if dataset_name == 'livecodebench' and isinstance(oracle_assertion, dict):
            # For LiveCodeBench, we validate stdin/stdout test cases differently
            # For now, we'll consider it valid if it has both input and output
            # Real validation would require running the code with the input
            is_valid = bool(oracle_assertion.get('input') and oracle_assertion.get('output'))
            return {
                'index': index,
                'oracle': oracle_assertion,
                'valid': is_valid,
                'error': None if is_valid else "Missing input or output"
            }
        
        # Original validation logic for string assertions
        if not isinstance(oracle_assertion, str):
            return {
                'index': index,
                'oracle': oracle_assertion,
                'valid': False,
                'error': f"Invalid assertion type: {type(oracle_assertion)}"
            }
        
        if sandbox_available:
            # Check if this is TestEval dataset which needs special handling
            if dataset_name == 'testeval' and 'class Solution' in function_code:
                # For TestEval: instantiate Solution class before running tests
                # Ensure function_code is properly dedented
                function_code = textwrap.dedent(function_code).strip()
                
                test_program = f"""from typing import List, Dict, Any, Optional, Union, Tuple, Set, Callable
import math
import re
import sys
import functools
import itertools
import collections

{function_code}

# Instantiate Solution class for TestEval
solution = Solution()

# Test the assertion
try:
    {oracle_assertion}
    # If we reach here, assertion passed
    pass
except AssertionError as e:
    raise AssertionError(f"Oracle assertion failed: {{e}}")
except Exception as e:
    raise Exception(f"Execution error in oracle validation: {{e}}")
"""
            else:
                # Standard validation for other datasets
                # Ensure function_code is properly dedented
                function_code = textwrap.dedent(function_code).strip()
                
                # Check if this is LiveCodeBench JSON format - use direct stdin/stdout execution
                if dataset_name == 'livecodebench' and oracle_assertion.strip().startswith('{'):
                    # Parse JSON test case for direct execution
                    import json
                    try:
                        test_case = json.loads(oracle_assertion.strip())
                        test_input = test_case.get('input', '')
                        expected_output = test_case.get('output', '')
                        
                        # Create test program that reads from stdin and writes to stdout
                        test_program = f"""from typing import List, Dict, Any, Optional, Union, Tuple, Set, Callable
import math
import re
import sys
import functools
import itertools
import collections
import json

{function_code}

# Execute code with stdin input and capture stdout
try:
    # The function_code should handle stdin/stdout directly
    pass
except Exception as e:
    print(f"Execution error: {{e}}", file=sys.stderr)
    sys.exit(1)
"""
                        
                        # Use direct execution approach from evaluate_livecodebench_individual_optimized.py
                        problem = {
                            'task_id': f'oracle_validation_{index}',
                            'prompt': task_description,
                            'test': '',
                            'entry_point': 'main'
                        }
                        
                        # Create temporary file and execute directly
                        import tempfile
                        import subprocess
                        
                        try:
                            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
                                f.write(test_program)
                                temp_file = f.name
                            
                            # Execute with stdin/stdout
                            process = subprocess.Popen(
                                [sys.executable, temp_file],
                                stdin=subprocess.PIPE,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                text=True
                            )
                            
                            stdout, stderr = process.communicate(input=test_input, timeout=timeout)
                            success = process.returncode == 0 and not stderr.strip()
                            
                            # Check if output matches expected
                            if success:
                                actual_output = stdout.strip()
                                expected_clean = expected_output.strip()
                                output_matches = actual_output == expected_clean
                                
                                if output_matches:
                                    return {
                                        'valid': True,
                                        'error': None,
                                        'output': f'Direct execution passed: {actual_output}',
                                        'confidence': 0.95
                                    }
                                else:
                                    print(f"DEBUG: Output mismatch for oracle {index}")
                                    print(f"  Test input: {test_input}")
                                    print(f"  Expected output: '{expected_clean}'")
                                    print(f"  Actual output: '{actual_output}'")
                                    return {
                                        'valid': False,
                                        'error': f"Output mismatch: expected '{expected_clean}', got '{actual_output}'",
                                        'output': actual_output,
                                        'confidence': 0.85
                                    }
                            else:
                                print(f"DEBUG: Validation failed for oracle {index}")
                                print(f"  Test input: {test_input}")
                                print(f"  Expected output: {expected_output}")
                                print(f"  Actual stdout: {stdout}")
                                print(f"  Stderr: {stderr}")
                                print(f"  Return code: {process.returncode}")
                                return {
                                    'valid': False,
                                    'error': f"Direct execution failed: {stderr}",
                                    'output': stdout,
                                    'confidence': 0.85
                                }
                        finally:
                            try:
                                os.unlink(temp_file)
                            except:
                                pass
                                
                    except (json.JSONDecodeError, subprocess.TimeoutExpired, Exception) as e:
                        return {
                            'valid': False,
                            'error': f"LiveCodeBench validation error: {str(e)}",
                            'output': '',
                            'confidence': 0.1
                        }
                else:
                    # Standard validation for other datasets using assert statements
                    test_program = f"""from typing import List, Dict, Any, Optional, Union, Tuple, Set, Callable
import math
import re
import sys
import functools
import itertools
import collections
import json

{function_code}

# Test the assertion
try:
    {oracle_assertion}
    # If we reach here, assertion passed
    pass
except AssertionError as e:
    raise AssertionError(f"Oracle assertion failed: {{e}}")
except Exception as e:
    raise Exception(f"Execution error in oracle validation: {{e}}")
"""
            
            # Create a problem dict for check_correctness
            problem = {
                'task_id': f'oracle_validation_{index}',
                'prompt': task_description,
                'test': '',  # We include everything in the completion
                'entry_point': 'main'  # Not used since we're validating assertions
            }
            
            # Use check_correctness to safely execute the test
            result = check_correctness(problem, test_program, timeout)
            
            if result['passed']:
                return {
                    'valid': True,
                    'error': None,
                    'output': 'Assertion passed in secure sandbox',
                    'confidence': 0.95
                }
            else:
                error_msg = result.get('result', 'Unknown error')
                print(f"Sandbox validation failed: {error_msg}")
                return {
                    'valid': False,
                    'error': f"Sandbox execution failed: {error_msg}",
                    'output': error_msg,
                    'confidence': 0.85
                }
        else:
            # Fallback to subprocess validation
            import subprocess
            import tempfile
            import os
            
            # Check if this is TestEval dataset which needs special handling
            if dataset_name == 'testeval' and 'class Solution' in function_code:
                # For TestEval: instantiate Solution class before running tests
                # Ensure function_code is properly dedented
                function_code = textwrap.dedent(function_code).strip()
                
                test_script = f"""from typing import List, Dict, Any, Optional, Union, Tuple, Set, Callable
import math
import re
import sys
import functools
import itertools
import collections
import traceback

# Function implementation
{function_code}

# Instantiate Solution class for TestEval
solution = Solution()

# Test assertion
try:
    {oracle_assertion}
    print("ASSERTION_PASSED")
except AssertionError as e:
    print(f"ASSERTION_FAILED: {{e}}")
    sys.exit(1)
except Exception as e:
    print(f"EXECUTION_ERROR: {{e}}")
    sys.exit(2)
"""
            else:
                # Standard validation for other datasets
                # Ensure function_code is properly dedented
                function_code = textwrap.dedent(function_code).strip()
                
                # Check if this is LiveCodeBench JSON format - use direct stdin/stdout execution
                if dataset_name == 'livecodebench' and oracle_assertion.strip().startswith('{'):
                    # Parse JSON test case for direct execution
                    import json
                    try:
                        test_case = json.loads(oracle_assertion.strip())
                        test_input = test_case.get('input', '')
                        expected_output = test_case.get('output', '')
                        
                        # Create test program for direct execution
                        test_script = f"""from typing import List, Dict, Any, Optional, Union, Tuple, Set, Callable
import math
import re
import sys
import functools
import itertools
import collections
import traceback
import json

# Function implementation
{function_code}

# Execute with stdin/stdout - the code should handle I/O directly
"""
                        
                        # Execute using subprocess with stdin/stdout approach
                        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
                            f.write(test_script)
                            temp_file = f.name
                        
                        try:
                            # Execute with timeout and stdin input
                            result = subprocess.run(
                                [sys.executable, temp_file],
                                input=test_input,
                                text=True,
                                capture_output=True,
                                timeout=timeout
                            )
                            
                            success = result.returncode == 0 and not result.stderr.strip()
                            
                            if success:
                                actual_output = result.stdout.strip()
                                expected_clean = expected_output.strip()
                                output_matches = actual_output == expected_clean
                                
                                if output_matches:
                                    return {
                                        'valid': True,
                                        'error': None,
                                        'output': f'Direct subprocess execution passed: {actual_output}',
                                        'confidence': 0.95
                                    }
                                else:
                                    return {
                                        'valid': False,
                                        'error': f"Output mismatch: expected '{expected_clean}', got '{actual_output}'",
                                        'output': actual_output,
                                        'confidence': 0.85
                                    }
                            else:
                                return {
                                    'valid': False,
                                    'error': f"Direct subprocess execution failed: {result.stderr}",
                                    'output': result.stdout,
                                    'confidence': 0.85
                                }
                        finally:
                            try:
                                os.unlink(temp_file)
                            except:
                                pass
                                
                    except (json.JSONDecodeError, subprocess.TimeoutExpired, Exception) as e:
                        return {
                            'valid': False,
                            'error': f"LiveCodeBench subprocess validation error: {str(e)}",
                            'output': '',
                            'confidence': 0.1
                        }
                else:
                    # Standard validation for other datasets using assert statements
                    test_script = f"""from typing import List, Dict, Any, Optional, Union, Tuple, Set, Callable
import math
import re
import sys
import functools
import itertools
import collections
import traceback
import json

# Function implementation
{function_code}

# Test assertion
try:
    {oracle_assertion}
    print("ASSERTION_PASSED")
except AssertionError as e:
    print(f"ASSERTION_FAILED: {{e}}")
    sys.exit(1)
except Exception as e:
    print(f"EXECUTION_ERROR: {{e}}")
    sys.exit(2)
"""
            
            # Execute in temporary file with timeout
            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
                f.write(test_script)
                temp_file = f.name
            
            try:
                # Execute with timeout
                result = subprocess.run(
                    [sys.executable, temp_file],
                    capture_output=True,
                    text=True,
                    timeout=timeout
                )
                
                output = result.stdout.strip()
                error_output = result.stderr.strip()
                
                if result.returncode == 0 and "ASSERTION_PASSED" in output:
                    return {
                        'valid': True,
                        'error': None,
                        'output': output,
                        'confidence': 0.85
                    }
                elif "ASSERTION_FAILED" in output:
                    return {
                        'valid': False,
                        'error': f"Assertion failed: {output}",
                        'output': output,
                        'confidence': 0.80
                    }
                else:
                    return {
                        'valid': False,
                        'error': f"Execution error: {error_output or output}",
                        'output': output,
                        'confidence': 0.75
                    }
                    
            finally:
                # Clean up temp file
                try:
                    os.unlink(temp_file)
                except:
                    pass
                    
    except Exception as e:
        return {
            'valid': False,
            'error': f'Worker validation error: {str(e)}',
            'output': '',
            'confidence': 0.60
        }



class CodeExecutionSandboxAgent:
    """Code Execution & Sandbox Agent (CESA) for empirical oracle validation using secure sandbox."""
    
    def __init__(self, timeout: float = 5.0, max_workers: int = None):
        """Initialize CESA with execution timeout and parallel processing.
        
        Args:
            timeout: Execution timeout in seconds
            max_workers: Maximum number of parallel workers for validation (None = CPU count)
        """
        self.timeout = timeout
        self.max_workers = max_workers or min(64, multiprocessing.cpu_count() * 2)
        self.sandbox_available = SANDBOX_AVAILABLE
    

    def validate_single_oracle(self, oracle: str, function_code: str, 
                               task_description: str, dataset_name: str = None) -> Dict[str, Any]:
        """Validate a single oracle assertion.
        
        Args:
            oracle: Oracle assertion to validate
            function_code: Function implementation code
            task_description: Task description for context
            dataset_name: Name of the dataset (e.g., 'testeval', 'humaneval')
            
        Returns:
            Validation result dict with 'valid', 'error', etc.
        """
        if not oracle or not oracle.strip():
            return {
                'valid': False,
                'error': 'Empty oracle assertion',
                'output': '',
                'confidence': 0.0
            }
        
        # Use the existing validation worker function
        result = _validate_single_oracle_worker(
            0, oracle, function_code, task_description, self.timeout, self.sandbox_available, dataset_name
        )
        return result

class SelfRefineAgent:
    """Self-Refine Agent (SRA) for debugging and fixing failed test oracles."""
    
    def __init__(self, llm_interface: GPTInterface, cesa: 'CodeExecutionSandboxAgent', max_iterations: int = 3):
        """Initialize SRA with language model interface and sandbox agent.
        
        Args:
            llm_interface: LLM interface for generating fixes
            cesa: Code execution sandbox agent for validation
            max_iterations: Maximum refinement iterations per oracle
        """
        self.llm = llm_interface
        self.cesa = cesa
        self.max_iterations = max_iterations
    
    async def refine_multiple_tasks_parallel(self, 
                                            tasks_with_failures: List[Dict[str, Any]]) -> Dict[int, List[str]]:
        """Refine failed oracles for multiple tasks in parallel across all iterations.
        
        This method implements true parallel processing inspired by NexusOracle's approach:
        - All tasks are processed in parallel within each iteration
        - All LLM calls within an iteration are batched
        - All validations within an iteration are done in parallel
        
        Args:
            tasks_with_failures: List of dicts containing:
                - task_idx: Task index
                - task_data: Task information  
                - failed_oracles: List of (oracle, error) tuples
                - candidate_code: Function implementation
                - passed_oracles: List of passing oracles (optional)
                
        Returns:
            Dict mapping task_idx to list of refined oracles
        """
        if not tasks_with_failures:
            return {}
        
        print(f"\n🚀 SelfRefineAgent: Starting PARALLEL refinement for {len(tasks_with_failures)} tasks")
        print(f"   Max iterations: {self.max_iterations}")
        
        # Initialize tracking for all tasks
        task_oracle_status = {}  # task_idx -> list of oracle status dicts
        active_tasks = {}  # task_idx -> current task data
        
        for task_info in tasks_with_failures:
            task_idx = task_info['task_idx']
            failed_oracles = task_info['failed_oracles']
            
            # Initialize status for each oracle in this task
            task_oracle_status[task_idx] = []
            for i, (oracle, error) in enumerate(failed_oracles):
                task_oracle_status[task_idx].append({
                    'original_oracle': oracle,
                    'current_oracle': oracle,
                    'original_error': error,
                    'current_error': error,
                    'fixed': False,
                    'fixed_oracle': None
                })
            
            active_tasks[task_idx] = task_info
        
        # Perform iterative refinement with full parallelization
        for iteration in range(self.max_iterations):
            print(f"\n📍 Iteration {iteration + 1}/{self.max_iterations}")
            
            # Collect all tasks that still need refinement
            tasks_to_refine = []
            for task_idx, task_info in active_tasks.items():
                # Check if this task has unfixed oracles
                unfixed_count = sum(1 for status in task_oracle_status[task_idx] if not status['fixed'])
                if unfixed_count > 0:
                    tasks_to_refine.append((task_idx, task_info, unfixed_count))
            
            if not tasks_to_refine:
                print(f"✅ All oracles fixed after {iteration} iterations!")
                break
            
            print(f"   Processing {len(tasks_to_refine)} tasks with {sum(count for _, _, count in tasks_to_refine)} total failed oracles")
            
            # Step 1: Generate fixes for ALL tasks in parallel using batch LLM calls
            all_prompts = []
            all_system_prompts = []
            prompt_task_mapping = []  # Maps prompt index to (task_idx, oracle_indices)
            
            for task_idx, task_info, _ in tasks_to_refine:
                # Collect currently failed oracles for this task
                current_failed = []
                oracle_indices = []
                
                for i, status in enumerate(task_oracle_status[task_idx]):
                    if not status['fixed']:
                        current_failed.append((status['current_oracle'], status['current_error']))
                        oracle_indices.append(i)
                
                if current_failed:
                    # Create refinement prompt for this task
                    prompt = self._create_combined_fix_prompt_with_iteration(
                        task_info['task_data'].get('task_description', ''),
                        task_info['task_data'].get('function_name', 'func'),
                        current_failed,
                        task_info['candidate_code'],
                        task_info.get('passed_oracles'),
                        iteration,
                        task_info['task_data'].get('dataset_name', '')
                    )
                    
                    dataset_name = task_info['task_data'].get('dataset_name', '')

                    system_prompt = """You are an expert debugger specializing in test oracle correction. 
Analyze ALL failed assertions together, understand the errors, and provide precise fixes for each.
Always provide syntactically correct Python assertions."""
                    
                    all_prompts.append(prompt)
                    all_system_prompts.append(system_prompt)
                    prompt_task_mapping.append((task_idx, oracle_indices))
            
            # Batch LLM call for all tasks
            print(f"   🤖 Sending {len(all_prompts)} refinement requests to LLM in batch...")
            refine_responses = self.llm.generate_batch(all_prompts, all_system_prompts, show_progress=False)
            
            # Step 2: Parse all responses and prepare validation tasks
            all_validation_tasks = []
            validation_task_mapping = []  # Maps validation index to (task_idx, oracle_idx, refined_oracle)
            
            for (prompt_idx, response) in enumerate(refine_responses):
                task_idx, oracle_indices = prompt_task_mapping[prompt_idx]
                task_info = active_tasks[task_idx]
                
                # Parse refined oracles based on dataset format
                dataset_name = task_info['task_data'].get('dataset_name', '')
                refined_oracles = self._parse_multiple_assertions(response, dataset_name)
                
                # Fallback parsing if needed
                if len(refined_oracles) != len(oracle_indices):
                    refined_oracles = self._parse_multiple_assertions_fallback(response, len(oracle_indices), dataset_name)
                
                # Prepare validation for each refined oracle
                for i, oracle_idx in enumerate(oracle_indices):
                    if i < len(refined_oracles):
                        refined_oracle = refined_oracles[i]
                        dataset_name = task_info['task_data'].get('dataset_name')
                        
                        # Add to validation queue
                        validation_idx = len(all_validation_tasks)
                        all_validation_tasks.append((
                            validation_idx,
                            refined_oracle,
                            task_info['candidate_code'],
                            task_info['task_data'].get('task_description', ''),
                            self.cesa.timeout,
                            self.cesa.sandbox_available,
                            dataset_name
                        ))
                        validation_task_mapping.append((task_idx, oracle_idx, refined_oracle))
            
            # Step 3: Validate ALL refined oracles in parallel using ProcessPoolExecutor
            print(f"   🔍 Validating {len(all_validation_tasks)} refined oracles in parallel...")
            validation_results = {}
            
            if all_validation_tasks:
                try:
                    with ProcessPoolExecutor(max_workers=self.cesa.max_workers) as executor:
                        future_to_idx = {}
                        for task in all_validation_tasks:
                            future = executor.submit(_validate_single_oracle_worker, *task)
                            future_to_idx[future] = task[0]
                        
                        # Collect results
                        for future in as_completed(future_to_idx):
                            val_idx = future_to_idx[future]
                            try:
                                result = future.result(timeout=self.cesa.timeout + 2)
                                validation_results[val_idx] = result
                            except Exception as e:
                                validation_results[val_idx] = {
                                    'valid': False,
                                    'error': f'Validation error: {str(e)}',
                                    'output': '',
                                    'confidence': 0.0
                                }
                except Exception as e:
                    print(f"   ⚠️ Error in parallel validation: {e}")
                    validation_results = {}
            
            # Step 4: Update oracle status based on validation results
            fixed_count = 0
            still_failed_count = 0
            
            for val_idx, (task_idx, oracle_idx, refined_oracle) in enumerate(validation_task_mapping):
                if val_idx in validation_results:
                    val_result = validation_results[val_idx]
                    
                    if val_result['valid']:
                        # Mark as fixed
                        task_oracle_status[task_idx][oracle_idx]['fixed'] = True
                        task_oracle_status[task_idx][oracle_idx]['fixed_oracle'] = refined_oracle
                        fixed_count += 1
                    else:
                        # Update error for next iteration
                        task_oracle_status[task_idx][oracle_idx]['current_oracle'] = refined_oracle
                        task_oracle_status[task_idx][oracle_idx]['current_error'] = val_result.get('error', 'Validation failed')
                        still_failed_count += 1
            
            print(f"   📊 Iteration {iteration + 1} results: ✅ Fixed {fixed_count}, ❌ Still failed {still_failed_count}")
        
        # Build final results
        refined_results = {}
        total_fixed = 0
        total_failed = 0
        
        for task_idx in task_oracle_status:
            refined_oracles = []
            for status in task_oracle_status[task_idx]:
                if status['fixed']:
                    refined_oracles.append(status['fixed_oracle'])
                    total_fixed += 1
                else:
                    # Keep original oracle without comment
                    refined_oracles.append(status['original_oracle'])
                    total_failed += 1
            refined_results[task_idx] = refined_oracles
        
        print(f"\n🎯 SelfRefineAgent Complete:")
        print(f"   ✅ Total fixed: {total_fixed}")
        print(f"   ❌ Total failed: {total_failed}")
        print(f"   📈 Success rate: {total_fixed/(total_fixed + total_failed)*100:.1f}%")
        
        return refined_results
    
    async def refine_failed_oracles_batch(self, task_data: Dict[str, Any], 
                                         failed_oracles: List[Tuple[str, str]], 
                                         candidate_code: str,
                                         passed_oracles: List[str] = None,
                                         regenerate_code_callback=None) -> Tuple[List[str], str]:
        """Refine multiple failed oracles through iterative refinement.
        
        Args:
            task_data: Task information
            failed_oracles: List of (oracle, error) tuples
            candidate_code: Function implementation for testing
            passed_oracles: List of oracles that passed validation (for context)
            regenerate_code_callback: Optional callback to regenerate code
            
        Returns:
            Tuple of (List of refined oracles, potentially updated candidate_code)
        """
        if not failed_oracles:
            return [], candidate_code
        
        task_description = task_data.get('task_description', '')
        function_name = task_data.get('function_name', 'func')
        dataset_name = task_data.get('dataset_name', None)
        
        # Initialize with original failed oracles
        current_failed_oracles = failed_oracles.copy()
        oracle_status = {}  # Track status of each oracle by index
        for i in range(len(failed_oracles)):
            oracle_status[i] = {'fixed': False, 'oracle': failed_oracles[i][0], 'fixed_oracle': None}
        
        # Track if we regenerated code
        current_candidate_code = candidate_code
        code_regenerated = False
        
        # Perform iterative refinement
        for iteration in range(self.max_iterations):
            if not current_failed_oracles:
                break  # All oracles fixed
                
            print(f"  🔧 Self-Refine: Iteration {iteration + 1}/{self.max_iterations} - Fixing {len(current_failed_oracles)} failed oracles")
            
            # Create prompt for current iteration (use current_candidate_code which may have been regenerated)
            combined_prompt = self._create_combined_fix_prompt_with_iteration(
                task_description, function_name, current_failed_oracles, 
                current_candidate_code, passed_oracles, iteration, dataset_name
            )
            

            system_prompt = """You are an expert debugger specializing in test oracle correction. 
Analyze ALL failed assertions together, understand the errors, and provide precise fixes for each.
Always provide syntactically correct Python assertions."""
            
            # API call to fix oracles
            response = self.llm._generate_single(combined_prompt, system_prompt)
            
            # Parse fixed oracles from response
            fixed_oracles_raw = self._parse_multiple_assertions(response, dataset_name)
            
            # If parsing fails or returns wrong number, try alternative parsing
            if len(fixed_oracles_raw) != len(current_failed_oracles):
                print(f"    ⚠️ Expected {len(current_failed_oracles)} fixes, got {len(fixed_oracles_raw)}. Attempting re-parse...")
                fixed_oracles_raw = self._parse_multiple_assertions_fallback(response, len(current_failed_oracles), dataset_name)
            
            # Validate and collect results for this iteration
            still_failed_oracles = []
            iteration_fixed_count = 0
            
            # Map current failed oracles to their original indices
            current_to_original_idx = {}
            idx_counter = 0
            for orig_idx, status in oracle_status.items():
                if not status['fixed']:
                    current_to_original_idx[idx_counter] = orig_idx
                    idx_counter += 1
            
            for i, (original_oracle, error) in enumerate(current_failed_oracles):
                orig_idx = current_to_original_idx.get(i)
                
                if i < len(fixed_oracles_raw):
                    fixed_oracle = fixed_oracles_raw[i]
                    
                    # Validate the fixed oracle (use current_candidate_code which may have been regenerated)
                    validation_result = self.cesa.validate_single_oracle(
                        fixed_oracle, current_candidate_code, task_description, dataset_name
                    )
                    
                    if validation_result['valid']:
                        # Mark as fixed in oracle_status
                        if orig_idx is not None:
                            oracle_status[orig_idx]['fixed'] = True
                            oracle_status[orig_idx]['fixed_oracle'] = fixed_oracle
                        iteration_fixed_count += 1
                        print(f"    ✅ Fixed: {fixed_oracle[:60]}...")
                    else:
                        # Still failed, add to next iteration
                        still_failed_oracles.append((original_oracle, validation_result.get('error', 'Validation failed')))
                        print(f"    ❌ Still failing: {fixed_oracle[:60]}... Error: {validation_result.get('error', 'Unknown')[:50]}")
                else:
                    # No fix generated for this oracle
                    still_failed_oracles.append((original_oracle, error))
                    print(f"    ⚠️ No fix generated for: {original_oracle[:60]}...")
            
            print(f"  📊 Iteration {iteration + 1} Result: Fixed {iteration_fixed_count}/{len(current_failed_oracles)} oracles")
            
            # NEW: After second iteration, analyze if errors might be from code, not oracles
            if iteration == 1 and still_failed_oracles and not code_regenerated:
                # Analyze error patterns
                total_oracles = len(failed_oracles)
                failed_count = len(still_failed_oracles)
                failed_ratio = failed_count / total_oracles if total_oracles > 0 else 0
                
                # Count how many failed oracles contain "== None"
                none_count = sum(1 for oracle, _ in still_failed_oracles if '== None' in oracle)
                none_ratio = none_count / failed_count if failed_count > 0 else 0
                
                print(f"\n  📈 Error Analysis after iteration 2:")
                print(f"    - Failed oracles: {failed_count}/{total_oracles} ({failed_ratio*100:.1f}%)")
                print(f"    - Contains '== None': {none_count}/{failed_count} ({none_ratio*100:.1f}%)")
                
                # Decision logic: regenerate code if >70% failures AND <50% contain "== None"
                if failed_ratio > 0.7 and none_ratio < 0.5:
                    print(f"  🔄 High failure rate ({failed_ratio*100:.1f}%) suggests code might be incorrect.")
                    print(f"     Regenerating both code and failed oracles...")
                    
                    # Regenerate code using callback if provided
                    if regenerate_code_callback:
                        new_code = await regenerate_code_callback(task_data)
                        if new_code and new_code != current_candidate_code:
                            current_candidate_code = new_code
                            code_regenerated = True
                            print(f"  ✅ Successfully regenerated code implementation")
                            
                            # Update candidate_code for validation
                            candidate_code = current_candidate_code
                        else:
                            print(f"  ⚠️ Failed to regenerate code, continuing with oracle refinement only")
                    else:
                        print(f"  ⚠️ No code regeneration callback provided, continuing with oracle refinement only")
                else:
                    print(f"  📝 Error pattern suggests oracle issues, continuing refinement...")
                print("")
            
            # Update for next iteration
            current_failed_oracles = still_failed_oracles
            
            # Early termination if all fixed
            if not current_failed_oracles:
                print(f"  🎉 All oracles successfully fixed after {iteration + 1} iteration(s)!")
                break
        
        # Build final result list maintaining original order
        refined_oracles = []
        total_fixed = 0
        
        # Build result list using oracle_status
        for idx in range(len(failed_oracles)):
            if oracle_status[idx]['fixed']:
                refined_oracles.append(oracle_status[idx]['fixed_oracle'])
                total_fixed += 1
            else:
                # Oracle couldn't be fixed after all iterations
                refined_oracles.append(f"{oracle_status[idx]['oracle']}  # FAILED: Could not fix after {self.max_iterations} iterations")
        print(f"  🔧 Self-Refine Complete: Fixed {total_fixed}/{len(failed_oracles)} oracles after {min(iteration + 1, self.max_iterations)} iteration(s)")
        return refined_oracles, current_candidate_code
    
    def _create_combined_fix_prompt(self, task_description: str, function_name: str,
                                   failed_oracles: List[Tuple[str, str]], 
                                   candidate_code: str, passed_oracles: List[str]) -> str:
        """Create a single prompt to fix all failed oracles at once."""
        
        # Format passed examples for context
        passed_examples = ""
        if passed_oracles and len(passed_oracles) > 0:
            # examples = passed_oracles[:min(3, len(passed_oracles))]
            examples = passed_oracles
            passed_examples = "\n".join([f"✓ {oracle}" for oracle in examples])
            passed_examples = f"\n\nExamples of CORRECT oracles that passed validation:\n{passed_examples}"
        
        # Format all failed oracles with their errors
        failed_list = []
        for i, (oracle, error) in enumerate(failed_oracles, 1):
            # Handle both dict (LiveCodeBench) and string (other datasets) formats
            if isinstance(oracle, dict):
                oracle_str = json.dumps(oracle)
            else:
                oracle_str = str(oracle)
            failed_list.append(f"{i}. Failed Oracle: {oracle_str}")
            failed_list.append(f"   Error: {error}")
            failed_list.append("")
        failed_oracles_formatted = "\n".join(failed_list)
        
        return f"""Fix ALL the following failed test oracles for the given function.

Task Description:
{task_description}

Function Implementation Being Tested:
```python
{candidate_code}
```
{passed_examples}

ALL Failed Oracles to Fix:
{failed_oracles_formatted}

Analyze why each oracle failed and provide corrected versions for ALL of them.
Consider:
1. The expected output type and format based on the function implementation
2. Edge cases and boundary conditions
3. The error messages for each failed oracle
4. Common patterns from the passing oracles

IMPORTANT: Provide EXACTLY {len(failed_oracles)} corrected assertions, one for each failed oracle above.
Format your response as a numbered list of assertions:
1. assert ...
2. assert ...
etc.

Provide ONLY the corrected assertions:"""
    
    def _create_combined_fix_prompt_with_iteration(self, task_description: str, function_name: str,
                                                  failed_oracles: List[Tuple[str, str]], 
                                                  candidate_code: str, passed_oracles: List[str],
                                                  iteration: int, dataset_name: str = '') -> str:
        """Create a single prompt to fix all failed oracles with iteration context."""
        
        # Format passed examples for context
        passed_examples = ""
        if passed_oracles and len(passed_oracles) > 0:
            examples = passed_oracles
            formatted_examples = []
            for oracle in examples:
                # Handle both dict (LiveCodeBench) and string (other datasets) formats
                if isinstance(oracle, dict):
                    oracle_str = json.dumps(oracle)
                else:
                    oracle_str = str(oracle)
                formatted_examples.append(f"✓ {oracle_str}")
            passed_examples = "\n".join(formatted_examples)
            passed_examples = f"\n\nExamples of CORRECT oracles that passed validation:\n{passed_examples}"
        
        # Add iteration context
        iteration_note = ""
        if iteration > 0:
            iteration_note = f"\n\n⚠️ IMPORTANT: This is refinement iteration {iteration + 1}/{self.max_iterations}. The previous fixes still failed validation.\nPlease try different approaches or check for subtle issues in the assertions."
        
        # Format all failed oracles with their errors
        failed_list = []
        for i, (oracle, error) in enumerate(failed_oracles, 1):
            # Handle both dict (LiveCodeBench) and string (other datasets) formats
            if isinstance(oracle, dict):
                oracle_str = json.dumps(oracle)
            else:
                oracle_str = str(oracle)
            failed_list.append(f"{i}. Failed Oracle: {oracle_str}")
            failed_list.append(f"   Error: {error}")
            failed_list.append("")
        failed_oracles_formatted = "\n".join(failed_list)
        
        return f"""Fix ALL the following failed test oracles for the given function.{iteration_note}

Task Description:
{task_description}

Function Implementation Being Tested:
```python
{candidate_code}
```
{passed_examples}

ALL Failed Oracles to Fix:
{failed_oracles_formatted}

Analyze why each oracle failed and provide corrected versions for ALL of them.
Consider:
1. The expected output type and format based on the function implementation
2. Edge cases and boundary conditions
3. The error messages for each failed oracle
4. Common patterns from the passing oracles
{f"5. Why the previous iteration's fixes didn't work (iteration {iteration + 1})" if iteration > 0 else ""}

Output your final {len(failed_oracles)} test assertions in this exact format:

```python
assert function_name(input1) == expected_output1
assert function_name(input2) == expected_output2
```
Each assertion on one line, no additional code."""
    
    
    def _parse_multiple_assertions(self, response: str, dataset_name: str = '') -> List:
        """Parse multiple assertions from LLM response.
        Returns List[Dict] for LiveCodeBench, List[str] for other datasets.
        """
        if not response:
            return []
        
        if dataset_name == 'livecodebench':
            # Parse JSON test cases and return as dicts
            test_cases = []
            lines = response.strip().split('\n')
            
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                
                try:
                    # Try to parse as JSON
                    if line.startswith('{') and line.endswith('}'):
                        test_case = json.loads(line)
                        if isinstance(test_case, dict) and 'input' in test_case and 'output' in test_case:
                            test_cases.append(test_case)  # Return dict, not JSON string
                except json.JSONDecodeError:
                    continue
            
            return test_cases
        else:
            # Parse assert statements
            assertions = []
            lines = response.strip().split('\n')
            
            for line in lines:
                line = line.strip()
                # Remove numbering if present (e.g., "1. assert ..." -> "assert ...")
                if '. assert ' in line:
                    line = 'assert ' + line.split('. assert ', 1)[1]
                elif line.startswith('assert '):
                    pass  # Already in correct format
                else:
                    continue  # Skip non-assertion lines
                
                if line.startswith('assert '):
                    assertions.append(line)
            
            return assertions
    
    def _parse_multiple_assertions_fallback(self, response: str, expected_count: int, dataset_name: str = '') -> List:
        """Fallback parser when primary parsing fails.
        Returns List[Dict] for LiveCodeBench, List[str] for other datasets.
        """
        if not response:
            return []
        
        if dataset_name == 'livecodebench':
            # Try to extract JSON test cases and return as dicts
            test_cases = []
            
            # Split by common separators
            parts = response.replace('\n\n', '\n').split('\n')
            
            for part in parts:
                part = part.strip()
                if '{' in part and '}' in part:
                    # Try to find JSON objects in the text
                    start = part.find('{')
                    end = part.rfind('}') + 1
                    if start >= 0 and end > start:
                        json_candidate = part[start:end]
                        try:
                            test_case = json.loads(json_candidate)
                            if isinstance(test_case, dict) and 'input' in test_case and 'output' in test_case:
                                test_cases.append(test_case)  # Return dict, not JSON string
                        except json.JSONDecodeError:
                            continue
            
            # If we still don't have enough, generate placeholder test cases
            while len(test_cases) < expected_count:
                test_cases.append({"input": "", "output": ""})
            
            return test_cases[:expected_count]
        else:
            # Try to extract any assert statements, even if not numbered properly
            assertions = []
            
            # Split by common separators
            parts = response.replace('\n\n', '\n').split('\n')
            
            for part in parts:
                part = part.strip()
                # Try to find assert statements in various formats
                if 'assert' in part.lower():
                    # Extract the assertion part
                    if 'assert ' in part:
                        start = part.find('assert ')
                        assertion = part[start:]
                        # Clean up any trailing comments or explanations
                        if '#' in assertion:
                            assertion = assertion.split('#')[0].strip()
                        if assertion.startswith('assert '):
                            assertions.append(assertion)
            
            # If we still don't have enough, generate placeholder assertions
            while len(assertions) < expected_count:
                assertions.append("assert True  # Could not parse fix from response")
            
            return assertions[:expected_count]
    
    def _create_fix_prompt(self, task_description: str, function_name: str,
                          failed_oracle: str, error: str, 
                          candidate_code: str, passed_oracles: List[str],
                          iteration: int) -> str:
        """Create prompt for fixing a failed oracle."""
        
        passed_examples = ""
        if passed_oracles and len(passed_oracles) > 0:
            # Show up to 3 examples of passing oracles
            examples = passed_oracles[:min(3, len(passed_oracles))]
            passed_examples = "\n".join([f"✓ {oracle}" for oracle in examples])
            passed_examples = f"\n\nExamples of CORRECT oracles that passed:\n{passed_examples}"
        
        iteration_note = ""
        if iteration > 0:
            iteration_note = f"\n\nNote: This is refinement attempt {iteration + 1}. The previous fix still failed."
        
        return f"""Fix the following failed test oracle.

Task Description:
{task_description}

Function Implementation Being Tested:
```python
{candidate_code}
```

Failed Oracle:
{failed_oracle}

Error Message:
{error}{passed_examples}{iteration_note}

Analyze why the oracle failed and provide a CORRECTED version.
Consider:
1. The expected output type and format
2. The actual function behavior based on the implementation
3. Edge cases and boundary conditions
4. Common assertion patterns that work

Provide ONLY the corrected assertion statement:"""
    
    def _extract_assertion(self, response: str) -> str:
        """Extract assertion from LLM response."""
        if not response:
            return None
            
        lines = response.strip().split('\n')
        for line in lines:
            line = line.strip()
            if line.startswith('assert '):
                return line
        
        # If no assert found, try to construct one from the response
        response_stripped = response.strip()
        if response_stripped and not response_stripped.startswith('assert '):
            # Only add 'assert' if it looks like an assertion expression
            if '==' in response_stripped or '!=' in response_stripped or '<' in response_stripped or '>' in response_stripped:
                return f"assert {response_stripped}"
        
        return None

class NexusOraclePredictor:
    """Oracle Nexus with empirical validation, adversarial testing, and adaptive routing."""
    
    def __init__(self, gpt_config=None, num_panelists: int = 4, llm_interface=None, self_refine_iterations: int = 3):
        """Initialize Nexus Oracle  predictor.
        
        Args:
            gpt_config: Configuration for GPT
            num_panelists: Number of panelists
            llm_interface: Optional pre-initialized GPTInterface to reuse
            self_refine_iterations: Maximum number of iterations for self-refinement (default: 3)
        """
        self.gpt_config = gpt_config or GPTConfig()
        self.llm = llm_interface if llm_interface is not None else GPTInterface(self.gpt_config)
        self.num_panelists = num_panelists
        self.self_refine_iterations = self_refine_iterations
        
        # Initialize new  agents with parallel processing
        self.cesa = CodeExecutionSandboxAgent(timeout=5.0, max_workers=self.gpt_config.max_workers)
        self.sra = SelfRefineAgent(self.llm, self.cesa, max_iterations=self_refine_iterations)  # NEW: Self-refine agent with configurable iterations
        
        # Control flags for optimization
        self.enable_batch_self_refine = True  # Self-refine is core contribution - enabled by default
    
    async def predict_assertions_for_tasks_batch(self, tasks_data: List[Dict[str, Any]]) -> List[List[str]]:
        """ oracle prediction with empirical validation - all tasks use full pipeline."""
        # Process all tasks through the complete pipeline
        all_assertions = await self._process_thorough_path(tasks_data)
        
        return all_assertions
    
    async def _process_thorough_path(self, tasks_data: List[Dict[str, Any]]) -> List[List[str]]:
        """Full pipeline processing for all tasks."""
        
        # Generate initial tentative oracles
        tentative_oracles = await self._generate_tentative_oracles(tasks_data)
        
        # Phase 1: Requirement Engineering
        requirements = await self._run_requirement_engineer(tasks_data)
        
        # Phase 2: Panelist Analysis
        panelist_outputs = await self._run_panelists(tasks_data, tentative_oracles, requirements)
        
        # Phase 3: Interpreter Summarization
        interpreted_outputs = await self._run_interpreters(panelist_outputs, tasks_data, tentative_oracles)
        
        # Phase 4: Initial Curator Judgment
        initial_assertions = await self._run_curator(interpreted_outputs, tasks_data, tentative_oracles)
        
        # Phase 5: Validation with Self-Refinement
        validated_assertions = await self._run_empirical_validation(tasks_data, initial_assertions)
        
        return validated_assertions
    
    # Reuse existing methods from v1 with same signatures
    async def _generate_tentative_oracles(self, tasks_data: List[Dict[str, Any]]) -> List[List[str]]:
        """Generate simple initial oracles (reused from v1)."""
        all_prompts = []
        all_system_prompts = []
        
        for task_data in tasks_data:
            task_description = task_data['task_description']
            test_inputs = task_data['test_inputs']
            dataset_name = task_data.get('dataset_name', '')
            
            # Create appropriate prompt based on dataset
            if dataset_name == 'livecodebench':
                # LiveCodeBench format: JSON test cases with input/output pairs
                test_inputs_formatted = "\n".join([f"{i+1}. {test_input}" for i, test_input in enumerate(test_inputs)])
                
                # Create expected format examples
                test_examples = "\n".join([f'{{"input": "{test_input}", "output": "<expected_output_{i+1}>"}}' 
                                           for i, test_input in enumerate(test_inputs)])
                
                tentative_prompt = f"""You are a test oracle that generates expected outputs for code functions.

Task Description:
{task_description}

Generate expected outputs for ALL the following test inputs:
{test_inputs_formatted}

You MUST generate EXACTLY {len(test_inputs)} test cases, one for each input above.

Return test cases in JSON format within a Python code block, one per line:

```python
{test_examples}
```

Generate the test cases now:"""
                system_prompt = "You are generating initial test oracles. Be quick but may not be perfect. Generate ONLY the Python code block with JSON test cases."
            else:
                # Original format for other datasets
                test_inputs_formatted = "\n".join([f"{i+1}. {test_input}" for i, test_input in enumerate(test_inputs)])

                tentative_prompt = f"""Generate test assertions for the following function:

Task Description:
{task_description}

Test inputs:
{test_inputs_formatted}

You MUST provide EXACTLY {len(test_inputs)} assertions:"""
                system_prompt = "You are generating initial test oracles. Be quick but may not be perfect."
            
            all_prompts.append(tentative_prompt)
            all_system_prompts.append(system_prompt)
        
        responses = self.llm.generate_batch(all_prompts, all_system_prompts)
        all_tentative_oracles = []
        for response, task_data in zip(responses, tasks_data):
            test_inputs = task_data['test_inputs']
            function_name = task_data.get('function_name', 'func')
            dataset_name = task_data.get('dataset_name', '')
            tentative_assertions = self._parse_multiple_assertions(response, test_inputs, function_name, dataset_name)
            all_tentative_oracles.append(tentative_assertions)
        
        return all_tentative_oracles
    
    async def _run_requirement_engineer(self, tasks_data: List[Dict[str, Any]]) -> List[str]:
        """Phase 1: Extract formal requirements (reused from v1)."""
        prompts = []
        system_prompts = []
        
        for task_data in tasks_data:
            task_description = task_data['task_description']
            
            req_prompt = f"""Analyze the following function specification and extract key functional requirements:
{task_description}
"""
            
            prompts.append(req_prompt)
            system_prompts.append("You are an expert software engineer and requirement analyst. Extract requirements and generate clear specifications in predicate logic when possible.")

        return self.llm.generate_batch(prompts, system_prompts)
    
    async def _run_panelists(self, tasks_data: List[Dict[str, Any]], 
                            tentative_oracles: List[List[str]], 
                            requirements: List[str]) -> List[List[str]]:
        """Phase 2: Multiple panelists analyze tentative oracles (reused from v1)."""
        all_prompts = []
        all_system_prompts = []
        prompt_mapping = []
        
        for task_idx, task_data in enumerate(tasks_data):
            task_description = task_data['task_description']
            test_inputs = task_data['test_inputs']
            dataset_name = task_data.get('dataset_name', '')
            tentative_formatted = "\n".join([f"{i+1}. {oracle}" for i, oracle in enumerate(tentative_oracles[task_idx])])
            test_inputs_formatted = "\n".join([f"{i+1}. {test_input}" for i, test_input in enumerate(test_inputs)])
            
            for panelist_idx in range(self.num_panelists):
                # Each panelist has a slightly different perspective (matching NexusOraclePredictor)
                panelist_roles = [
                    ("SPECIFICATION EXPERT", "Focus on adherence to documented specifications and requirements"),
                    ("EDGE CASE SPECIALIST", "Focus on boundary conditions, corner cases, and error scenarios"),
                    ("FUNCTIONAL VALIDATOR", "Focus on core functionality and expected input-output relationships"),
                    ("ALGORITHMIC ORACLE", "Focus on step-by-step algorithm execution and correctness")
                ]
                
                role_name, role_focus = panelist_roles[panelist_idx % len(panelist_roles)]
                

                format_instruction = f"""Output your final {len(test_inputs)} test assertions in this exact format:

```python
assert function_name(input1) == expected_output1
assert function_name(input2) == expected_output2
```
Each assertion on one line, no additional code."""
                system_message = "You are a Senior Software Engineer specializing in testing. Analyze the test oracles generated by test generation tools, which may not be perfect. Identify incorrect test oracles and provide corrected versions based on the requirements."
            
                panelist_prompt = f"""You are analyzing test oracles generated by an automated tool. Your role: {role_name} - {role_focus}

Description: 
{task_description}

Requirements: 
{requirements[task_idx]}

Test Cases and Current Oracles:
{tentative_formatted}

Test Inputs:
{test_inputs_formatted}

{format_instruction}
"""
                
                all_prompts.append(panelist_prompt)
                all_system_prompts.append(system_message)
                prompt_mapping.append((task_idx, panelist_idx))
        
        responses = self.llm.generate_batch(all_prompts, all_system_prompts)
        panelist_outputs = [['' for _ in range(self.num_panelists)] for _ in tasks_data]
        for response, (task_idx, panelist_idx) in zip(responses, prompt_mapping):
            panelist_outputs[task_idx][panelist_idx] = response
            
        return panelist_outputs
    
    async def _run_interpreters(self, panelist_outputs: List[List[str]], 
                               tasks_data: List[Dict[str, Any]], 
                               tentative_oracles: List[List[str]]) -> List[List[str]]:
        """Phase 3: Interpreters summarize panelist discussions (reused from v1)."""
        all_prompts = []
        all_system_prompts = []
        prompt_mapping = []
        
        for task_idx, panelist_responses in enumerate(panelist_outputs):
            for panelist_idx, panelist_output in enumerate(panelist_responses):
                
                test_code_formatted = "\n".join([f"{i+1}. {oracle}" for i, oracle in enumerate(tentative_oracles[task_idx])])
                
                interpreter_prompt = f"""A software tester has analyzed test oracles and provided detailed thoughts. The tester is excellent but lacks confidence and tends to overthink. Your job is to extract the key insights and conclusions.

Tester's Thoughts:
{panelist_output}

Test Code Being Analyzed:
{test_code_formatted}
"""

                all_prompts.append(interpreter_prompt)
                all_system_prompts.append("You work with an excellent software tester who is trying to check if test cases have correct test oracles. The tester always gets correct oracles but lacks confidence and overthinks. Your task is to summarize the tester's thoughts and extract the correct oracles from the analysis.")
                prompt_mapping.append((task_idx, panelist_idx))
        
        responses = self.llm.generate_batch(all_prompts, all_system_prompts)
        interpreted_outputs = [['' for _ in range(self.num_panelists)] for _ in tasks_data]
        for response, (task_idx, panelist_idx) in zip(responses, prompt_mapping):
            interpreted_outputs[task_idx][panelist_idx] = response
            
        return interpreted_outputs
    
    async def _run_curator(self, interpreted_outputs: List[List[str]], 
                          tasks_data: List[Dict[str, Any]], 
                          tentative_oracles: List[List[str]]) -> List[List[str]]:
        """Phase 4: Curator makes initial judgment (reused from v1)."""
        curator_prompts = []
        system_prompts = []
        
        for task_idx, task_data in enumerate(tasks_data):
            task_description = task_data['task_description']
            test_inputs = task_data['test_inputs']
            
            panel_discussion = ""
            for panelist_idx, interpreted_summary in enumerate(interpreted_outputs[task_idx]):
                panel_discussion += f"\n--- Team Member {panelist_idx + 1} Analysis ---\n{interpreted_summary}\n"
            
            test_inputs_formatted = "\n".join([f"{i+1}. {test_input}" for i, test_input in enumerate(test_inputs)])
            tentative_formatted = "\n".join([f"{i+1}. {oracle}" for i, oracle in enumerate(tentative_oracles[task_idx])])
            
            dataset_name = task_data.get('dataset_name', '')
            

            format_instruction = f"""Output your final {len(test_inputs)} test assertions in this exact format:

```python
assert function_name(input1) == expected_output1
assert function_name(input2) == expected_output2
```
Each assertion on one line, no additional code."""
            system_message = "You are a senior software engineer managing a team of software testers. Your team has analyzed test cases and provided analysis reports. Summarize their analysis and provide final judgment on the test oracles, making corrections where necessary."
        
            curator_prompt = f"""You are managing a team of software testers analyzing test cases generated by a competitor. Three team members have analyzed the test oracles and provided their reports.

Task Description:
{task_description}

Current Test Oracles:
{tentative_formatted}

Test Inputs:
{test_inputs_formatted}

Team Members' Analysis Reports:
{panel_discussion}

{format_instruction}"""

            curator_prompts.append(curator_prompt)
            system_prompts.append(system_message)

        responses = self.llm.generate_batch(curator_prompts, system_prompts)
        all_final_assertions = []
        for task_idx, (response, task_data) in enumerate(zip(responses, tasks_data)):
            test_inputs = task_data['test_inputs']
            function_name = task_data.get('function_name', 'func')
            dataset_name = task_data.get('dataset_name', '')
            final_assertions = self._parse_multiple_assertions(response, test_inputs, function_name, dataset_name)
            all_final_assertions.append(final_assertions)
        
        return all_final_assertions

    async def _run_empirical_validation(self, tasks_data: List[Dict[str, Any]], 
                                       assertions: List[List[str]]) -> List[List[str]]:
        """NEW Phase 5: Empirical validation using CESA with GLOBAL parallel processing."""
        import time
        phase5_start = time.time()
        print(f"Phase 5: Starting OPTIMIZED empirical validation at {time.strftime('%H:%M:%S')}")
        print(f"Phase 5: Processing {len(tasks_data)} tasks with {sum(len(a) for a in assertions)} total assertions")
        
        # Step 1: Generate candidate implementations for all tasks (in parallel)
        step1_start = time.time()
        print("Phase 5 Step 1: Generating candidate implementations for empirical validation")
        candidate_codes = await self._generate_candidate_implementations_batch(tasks_data)
        step1_time = time.time() - step1_start
        print(f"Phase 5 Step 1: Completed in {step1_time:.2f}s - Generated {len([c for c in candidate_codes if c])} valid implementations")
        
        # Step 2: GLOBAL parallel validation - collect ALL validation tasks first
        step2_start = time.time()
        print(f"Phase 5 Step 2: Starting GLOBAL parallel validation at {time.strftime('%H:%M:%S')}")
        print(f"Phase 5: Using ProcessPoolExecutor with {self.cesa.max_workers} workers for CPU-intensive validation")
        
        # Collect all validation tasks across all tasks
        all_validation_tasks = []
        task_mapping = []  # Maps (global_idx) -> (task_idx, assertion_idx)
        global_idx = 0
        
        for task_idx, (task_data, task_assertions, candidate_code) in enumerate(zip(tasks_data, assertions, candidate_codes)):
            if not candidate_code:
                print(f"Phase 5: Task {task_idx+1} - No code generated, skipping {len(task_assertions)} assertions")
                continue
            
            dataset_name = task_data.get('dataset_name', None)
            for assertion_idx, assertion in enumerate(task_assertions):
                # Handle both dict (LiveCodeBench) and string (other datasets) formats
                if dataset_name == 'livecodebench' and isinstance(assertion, dict):
                    # For LiveCodeBench, skip validation if no output
                    if assertion and assertion.get('output'):
                        # Add to global validation queue (pass the dict)
                        all_validation_tasks.append((
                            global_idx, assertion, candidate_code, 
                            task_data['task_description'], self.cesa.timeout, 
                            self.cesa.sandbox_available, dataset_name
                        ))
                        task_mapping.append((task_idx, assertion_idx))
                        global_idx += 1
                elif isinstance(assertion, str) and assertion and assertion.strip():
                    # For other datasets with string assertions
                    all_validation_tasks.append((
                        global_idx, assertion, candidate_code, 
                        task_data['task_description'], self.cesa.timeout, 
                        self.cesa.sandbox_available, dataset_name
                    ))
                    task_mapping.append((task_idx, assertion_idx))
                    global_idx += 1
        
        print(f"Phase 5: Collected {len(all_validation_tasks)} total validation tasks from {len(tasks_data)} tasks")
        
        # Execute ALL validations in parallel using ProcessPoolExecutor
        validation_results = {}
        if all_validation_tasks:
            try:
                print(f"Phase 5: Starting ProcessPoolExecutor with {self.cesa.max_workers} workers for {len(all_validation_tasks)} validations")
                with ProcessPoolExecutor(max_workers=self.cesa.max_workers) as executor:
                    # Submit all validation tasks at once
                    future_to_idx = {}
                    for task in all_validation_tasks:
                        future = executor.submit(_validate_single_oracle_worker, *task)
                        future_to_idx[future] = task[0]  # global_idx
                    
                    # Collect results with progress bar
                    from tqdm import tqdm
                    completed = 0
                    with tqdm(total=len(all_validation_tasks), desc="Validating oracles globally") as pbar:
                        for future in as_completed(future_to_idx):
                            global_idx = future_to_idx[future]
                            try:
                                result = future.result(timeout=self.cesa.timeout + 2)
                                validation_results[global_idx] = result
                            except Exception as e:
                                validation_results[global_idx] = {
                                    'valid': False,
                                    'error': f'Validation error: {str(e)}',
                                    'output': '',
                                    'confidence': 0.0
                                }
                            completed += 1
                            pbar.update(1)
                            if completed % 100 == 0:
                                print(f"Phase 5: Completed {completed}/{len(all_validation_tasks)} validations")
                
                print(f"Phase 5: All {len(all_validation_tasks)} validations completed")
            except Exception as e:
                print(f"Phase 5: Error in global parallel validation: {e}")
                # Fallback to empty results
                validation_results = {}
        
        # Step 3: Reorganize results by task and collect failed oracles
        validated_assertions = []
        all_failed_oracles = []  # Batch collection of failed oracles
        
        for task_idx, (task_data, task_assertions, candidate_code) in enumerate(zip(tasks_data, assertions, candidate_codes)):
            if not candidate_code:
                # No code for this task, keep original assertions
                validated_assertions.append(task_assertions)
                continue
            
            # Process this task's validation results
            refined_assertions = []
            valid_count = 0
            failed_count = 0
            task_failed_oracles = []
            
            for assertion_idx, assertion in enumerate(task_assertions):
                # Find the global index for this assertion
                global_idx_for_assertion = None
                for idx, (t_idx, a_idx) in enumerate(task_mapping):
                    if t_idx == task_idx and a_idx == assertion_idx:
                        global_idx_for_assertion = idx
                        break
                
                if global_idx_for_assertion is not None and global_idx_for_assertion in validation_results:
                    validation = validation_results[global_idx_for_assertion]
                    if validation['valid']:
                        refined_assertions.append(assertion)
                        valid_count += 1
                    else:
                        # Collect failed oracle for batch processing later
                        failed_count += 1
                        task_failed_oracles.append({
                            'task_idx': task_idx,
                            'assertion': assertion,
                            'error': validation.get('error', 'Unknown error'),
                            'candidate_code': candidate_code,
                            'task_data': task_data
                        })
                        # Still keep it for now (will be refined in batch later)
                        refined_assertions.append(assertion)
                else:
                    # No validation result, keep original
                    refined_assertions.append(assertion)
            
            # Print task summary
            success_rate = (valid_count / len(task_assertions) * 100) if task_assertions else 0
            print(f"Phase 5: Task {task_idx+1}/{len(tasks_data)} - ✅ {valid_count}/{len(task_assertions)} passed ({success_rate:.1f}%)")
            
            validated_assertions.append(refined_assertions)
            if task_failed_oracles:
                all_failed_oracles.extend(task_failed_oracles)
        
        # Step 4: Batch process all failed oracles using SelfRefineAgent
        if all_failed_oracles:
            print(f"\nPhase 5: Collected {len(all_failed_oracles)} failed oracles")
            
            # Self-refine is enabled by default as it's the core contribution
            if self.enable_batch_self_refine and self.sra:
                print(f"Phase 5: Starting BATCH self-refine for {len(all_failed_oracles)} failed oracles using SelfRefineAgent")
                refine_start = time.time()
                
                # Group failed oracles by task for efficient batch processing
                failed_by_task = {}
                for failed_oracle in all_failed_oracles:
                    task_idx = failed_oracle['task_idx']
                    if task_idx not in failed_by_task:
                        failed_by_task[task_idx] = []
                    failed_by_task[task_idx].append(failed_oracle)
                
                # Process each task's failed oracles through SelfRefineAgent
                refined_oracles_by_task = {}
                
                # Prepare tasks for parallel processing with new method
                tasks_with_failures = []
                
                for task_idx, task_failed_oracles in failed_by_task.items():
                    if len(task_failed_oracles) > 0:
                        # Extract data for this task
                        task_data = task_failed_oracles[0]['task_data']
                        candidate_code = task_failed_oracles[0]['candidate_code']
                        failed_assertions = [(fo['assertion'], fo['error']) for fo in task_failed_oracles]
                        
                        # Get passed oracles for context
                        passed_oracles = []
                        for assertion_idx, assertion in enumerate(assertions[task_idx]):
                            # Check if this oracle passed
                            was_failed = any(
                                fo['assertion'] == assertion and fo['task_idx'] == task_idx
                                for fo in all_failed_oracles
                            )
                            if not was_failed:
                                passed_oracles.append(assertion)
                        
                        tasks_with_failures.append({
                            'task_idx': task_idx,
                            'task_data': task_data,
                            'failed_oracles': failed_assertions,
                            'candidate_code': candidate_code,
                            'passed_oracles': passed_oracles if passed_oracles else None
                        })
                
                # Use the new parallel refinement method
                print(f"Phase 5: Using SelfRefineAgent's PARALLEL refinement for {len(tasks_with_failures)} tasks")
                refined_oracles_by_task = await self.sra.refine_multiple_tasks_parallel(tasks_with_failures)
                
                # Process results
                total_refined = sum(len(oracles) for oracles in refined_oracles_by_task.values())
                
                # Update validated_assertions with refined oracles
                for task_idx, refined in refined_oracles_by_task.items():
                    if task_idx < len(validated_assertions) and refined:
                        # Replace failed oracles with refined ones
                        original = validated_assertions[task_idx]
                        # Keep passed oracles and replace failed ones
                        final_assertions = []
                        refined_idx = 0
                        
                        for orig_assertion in original:
                            # Check if this was a failed oracle
                            was_failed = any(
                                fo['assertion'] == orig_assertion and fo['task_idx'] == task_idx
                                for fo in all_failed_oracles
                            )
                            if was_failed and refined_idx < len(refined):
                                final_assertions.append(refined[refined_idx])
                                refined_idx += 1
                            else:
                                final_assertions.append(orig_assertion)
                        
                        validated_assertions[task_idx] = final_assertions
                        
        return validated_assertions
    
    async def _generate_candidate_implementations_batch(self, tasks_data: List[Dict[str, Any]]) -> List[str]:
        """Generate candidate implementations for multiple tasks in parallel."""
        all_prompts = []
        all_system_prompts = []
        
        for task_data in tasks_data:
            task_description = task_data['task_description']
            function_name = task_data.get('function_name', 'func')
            test_inputs = task_data['test_inputs']
            dataset_name = task_data.get('dataset_name', '')
            
            # Format test inputs as examples
            test_examples = "\n".join([f"# Example: {inp}" for inp in test_inputs[:3]])
            
            # Special handling for LiveCodeBench dataset
            if dataset_name == 'livecodebench':
                code_prompt = f"""Generate a complete Python program that reads from stdin and writes to stdout.

Task Description:
{task_description}

Example Test Cases:
{test_examples}

Requirements:
1. Read input from stdin
2. Process the input according to the task description
3. Write the output to stdout
4. Handle all test cases correctly
5. Use standard input/output (no function definitions needed)

Provide ONLY the complete Python program (no explanations):"""
                system_prompt = "You are an expert Python programmer. Generate a complete stdin/stdout program for competitive programming problems."
            # Special handling for TestEval dataset
            elif dataset_name == 'testeval':
                code_prompt = f"""Generate a Python Solution class implementation based on the following specification:

Task Description:
{task_description}

Example Usage:
{test_examples}

Requirements:
1. MUST implement a class named 'Solution' (not a function)
2. The Solution class should contain the required method(s)
3. Follow the exact method signature from the examples
4. Handle edge cases appropriately
5. Return ONLY the Solution class implementation, no additional code

Provide ONLY the Python Solution class implementation:"""
                system_prompt = "You are an expert Python programmer. Generate a complete Solution class implementation for TestEval problems."
            else:
                # Standard function generation for other datasets
                code_prompt = f"""Generate a Python function implementation based on the following specification:

Task Description:
{task_description}

Function Name: {function_name}

Example Usage:
{test_examples}

Requirements:
1. Implement ONLY the function, no additional code
2. Function should be complete and runnable
3. Handle edge cases appropriately
4. Follow the exact function signature from the examples

Provide ONLY the Python function implementation:"""
                system_prompt = "You are an expert Python programmer. Generate clean, correct, and efficient function implementations based on specifications."

            all_prompts.append(code_prompt)
            all_system_prompts.append(system_prompt)
        
        # Generate all implementations in parallel using the existing LLM batch processing
        llm_start = time.time()
        print(f"Phase 5: Generating {len(tasks_data)} candidate implementations in parallel")
        responses = self.llm.generate_batch(all_prompts, all_system_prompts)
        llm_time = time.time() - llm_start
        print(f"Phase 5: LLM batch generation completed in {llm_time:.2f}s")
        
        # Extract function code from responses
        candidate_codes = []
        for response, task_data in zip(responses, tasks_data):
            function_name = task_data.get('function_name', 'func')
            dataset_name = task_data.get('dataset_name', '')
            
            # For LiveCodeBench, extract the entire program (not just a function)
            if dataset_name == 'livecodebench':
                candidate_code = self._extract_program_code(response)
            else:
                candidate_code = self._extract_function_code(response, function_name)
            
            candidate_codes.append(candidate_code)
        
        return candidate_codes
    
    def _extract_program_code(self, response: str) -> str:
        """Extract complete program code from LLM response (for LiveCodeBench)."""
        if not response:
            return None
            
        response = response.strip()
        
        # Try to extract from Python code block first
        code_block_pattern = r'```python\s*(.*?)\s*```'
        code_match = re.search(code_block_pattern, response, re.DOTALL)
        
        if code_match:
            code_content = code_match.group(1).strip()
        else:
            # If no code block, use the entire response
            code_content = response
        
        # For LiveCodeBench, we want the entire program
        if code_content:
            return code_content
        
        return None
    
    def _extract_function_code(self, response: str, function_name: str) -> str:
        """Extract function code from LLM response."""
        if not response:
            return None
            
        response = response.strip()
        
        # Try to extract from Python code block first
        code_block_pattern = r'```python\s*(.*?)\s*```'
        code_match = re.search(code_block_pattern, response, re.DOTALL)
        
        if code_match:
            code_content = code_match.group(1).strip()
        else:
            # If no code block, use the entire response
            code_content = response
        
        if not code_content:
            return None
        
        # Check if this is a class definition (common for TestEval)
        if 'class Solution' in code_content:
            # Extract the entire class definition
            lines = code_content.split('\n')
            class_lines = []
            in_class = False
            class_indent = 0
            
            for line in lines:
                if line.strip().startswith('class Solution'):
                    in_class = True
                    class_indent = len(line) - len(line.lstrip())
                    class_lines.append(line)
                elif in_class:
                    current_indent = len(line) - len(line.lstrip())
                    # Continue until we hit a non-indented line (except empty lines)
                    if line.strip() and current_indent <= class_indent and not line.strip().startswith('#'):
                        break
                    class_lines.append(line)
                elif not in_class and line.strip():
                    # Include any imports or helper functions before the class
                    class_lines.append(line)
            
            if class_lines:
                extracted_code = '\n'.join(class_lines)
                return extracted_code
        
        # Look for function definition (for non-class based code)
        lines = code_content.split('\n')
        function_lines = []
        in_function = False
        indent_level = 0
        
        for line in lines:
            # Check if this line starts a function definition
            if line.strip().startswith(f'def {function_name}(') or line.strip().startswith('def '):
                in_function = True
                indent_level = len(line) - len(line.lstrip())
                function_lines.append(line)
            elif in_function:
                # Continue collecting lines that are part of the function
                current_indent = len(line) - len(line.lstrip())
                
                # If we hit a non-indented line (except empty lines), function is complete
                if line.strip() and current_indent <= indent_level and not line.strip().startswith('#'):
                    break
                
                function_lines.append(line)
        
        if function_lines:
            extracted_code = '\n'.join(function_lines)
            return extracted_code
        else:
            # Fallback: return entire code content if it looks like valid Python
            if 'def ' in code_content or 'class ' in code_content:
                return code_content
            else:
                return None
    
    def _parse_multiple_assertions(self, response: str, test_inputs: List[str], function_name: str = 'func', dataset_name: str = None) -> List[str]:
        """Parse response to extract multiple assertions (reused from v1)."""
        response = response.strip()
        
        if dataset_name == 'livecodebench':
            # Parse LiveCodeBench JSON format and convert to test case format
            test_cases = []
            
            # First try to extract from Python code block
            code_block_pattern = r'```python\s*(.*?)\s*```'
            code_match = re.search(code_block_pattern, response, re.DOTALL)
            
            if code_match:
                code_content = code_match.group(1).strip()
                lines = code_content.split('\n')
            else:
                lines = response.split('\n')
            
            # Extract JSON objects from lines
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                try:
                    # Try to parse as JSON
                    test_case = json.loads(line)
                    if isinstance(test_case, dict) and 'input' in test_case and 'output' in test_case:
                        test_cases.append(test_case)
                except:
                    # Try to extract JSON from the line
                    json_pattern = r'\{[^{}]*"input"[^{}]*"output"[^{}]*\}'
                    matches = re.findall(json_pattern, line)
                    for match in matches:
                        try:
                            test_case = json.loads(match)
                            if isinstance(test_case, dict) and 'input' in test_case and 'output' in test_case:
                                test_cases.append(test_case)
                        except:
                            continue
            
            # Match test cases with inputs and format as JSON strings for consistency
            matched_test_cases = []
            for test_input in test_inputs:
                found = False
                for test_case in test_cases:
                    # Compare both raw input and with newlines normalized
                    input_normalized = test_input.replace('\n', '\\n').replace('\r', '\\r')
                    case_input = test_case['input']
                    
                    # Check both exact match and normalized match
                    if case_input == test_input or case_input == input_normalized:
                        # Ensure the output format is consistent (raw string, not escaped)
                        output = test_case['output']
                        if isinstance(output, str) and '\\n' in output and '\n' not in output:
                            # This looks like an escaped string, unescape it
                            try:
                                output = output.encode().decode('unicode_escape')
                            except:
                                pass  # Keep original if decode fails
                        
                        # Store as dict for downstream processing
                        matched_test_cases.append({
                            "input": test_input,  # Use the original input format
                            "output": output
                        })
                        found = True
                        break
                
                if not found:
                    # If no exact match, try to use test case by position if available
                    if len(matched_test_cases) < len(test_cases):
                        test_case = test_cases[len(matched_test_cases)]
                        output = test_case.get('output', '')
                        # Unescape output if needed
                        if isinstance(output, str) and '\\n' in output and '\n' not in output:
                            try:
                                output = output.encode().decode('unicode_escape')
                            except:
                                pass
                        matched_test_cases.append({
                            "input": test_input,
                            "output": output
                        })
                    else:
                        # Generate default test case for missing ones
                        matched_test_cases.append({"input": test_input, "output": ""})
            
            return matched_test_cases
        
        # Original logic for non-LiveCodeBench datasets
        # First try to extract from Python code block
        code_block_pattern = r'```python\s*(.*?)\s*```'
        code_match = re.search(code_block_pattern, response, re.DOTALL)
        
        if code_match:
            code_content = code_match.group(1).strip()
            lines = code_content.split('\n')
        else:
            lines = response.split('\n')
        
        # Extract all assertion lines
        assertions = []
        for line in lines:
            line = line.strip()
            if line.startswith('assert '):
                assertions.append(line)
        
        # Try to match assertions with test inputs
        matched_assertions = []
        for i, test_input in enumerate(test_inputs):
            found = False
            for assertion in assertions:
                if (test_input in assertion or 
                    f"{function_name}(" in assertion and len(assertions) == len(test_inputs)):
                    if len(assertions) == len(test_inputs):
                        if assertions.index(assertion) == i:
                            matched_assertions.append(assertion)
                            found = True
                            break
                    else:
                        matched_assertions.append(assertion)
                        found = True
                        break
            
            if not found:
                if i < len(assertions):
                    matched_assertions.append(assertions[i])
                else:
                    matched_assertions.append(f"assert {test_input} == None  # Failed to resolve")
        
        return matched_assertions
