import os
import json
import time
import argparse
import threading
from datetime import datetime
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Dict, List, Set
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from agent_utils import read_jsonl, setup_logger, get_consistency_prompt, get_input_prompt, parse_tool_call, format_tool_results_as_user_message, parse_consistency_response, batch_generate, batch_syntax_check


@dataclass
class SampleState:
    """Sample state enumeration"""
    NEED_INFERENCE = "need_inference"
    NEED_SYNTAX_CHECK = "need_syntax_check" 
    NEED_QWQ_CHECK = "need_qwq_check"
    NEED_QWEN3_CHECK = "need_qwen3_check"
    COMPLETED = "completed"
    FAILED = "failed"

def load_model_on_gpus(model_path, gpu_ids, model_name):
    """Load model on specified GPUs"""
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
    
    print(f"Loading {model_name} on GPUs {gpu_ids}...")
    model = LLM(
        model=model_path,
        tensor_parallel_size=len(gpu_ids),
        gpu_memory_utilization=0.9,
        dtype="bfloat16", 
        swap_space=16, 
        disable_custom_all_reduce=True,
        seed=12,
    )
    return model

# Add a simple incremental saver
class IncrementalSaver:
    """Incremental saver for saving completed samples"""
    
    def __init__(self, output_path: str, logger):
        self.output_path = output_path
        self.logger = logger
        self.save_lock = threading.Lock()
        
        # Ensure output directory exists
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
    def save_completed_samples(self, samples: List[dict]):
        """Save completed samples"""
        if not samples:
            return
            
        with self.save_lock:
            try:
                # Prepare data to save
                results_to_save = []
                for sample in samples:
                    results_to_save.append({
                        'index': sample['index'],
                        'data': sample['data'],
                        'prompts': sample['prompts'],
                        'responses': sample.get('responses', []),  # Add responses field
                        'tools': sample['tools'],
                        'completed': sample.get('completed', False),
                        'failed': sample.get('failed', False),
                        'iterations': sample['iteration_count']
                    })
                
                # Write to file in append mode
                with open(self.output_path, 'a', encoding='utf-8') as f:
                    for result in results_to_save:
                        f.write(json.dumps(result, ensure_ascii=False) + '\n')
                
                self.logger.info(f"✓ Incrementally saved {len(samples)} completed samples to {self.output_path}")
                
            except Exception as e:
                self.logger.error(f"✗ Incremental save failed: {e}")

class UnifiedDataPool:
    """Unified data pool supporting multi-state sample management"""
    
    def __init__(self, tokenizer, max_batch_size=256, max_prompt_length=40000):
        self.max_batch_size = max_batch_size
        self.max_prompt_length = max_prompt_length  # Add maximum prompt length limit
        self.lock = threading.RLock()
        
        # Sample queues grouped by state
        self.samples_by_state = defaultdict(deque)
        
        self.tokenizer = tokenizer

        # Sample index mapping (for fast lookup)
        self.sample_index_map = {}
        
        # Track samples being processed
        self.processing_samples = {}  # {sample_id: sample}
        
        # Completed and failed samples
        self.completed_samples = []
        self.failed_samples = []
        
        # Statistics
        self.stats = {
            'total_samples': 0,
            'completed_samples': 0,
            'failed_samples': 0,
            'processing_counts': defaultdict(int),
            'length_failed_samples': 0  # Add length limit failure statistics
        }
        
        # Add batch update lock to prevent other workers from taking samples during batch updates
        self.batch_update_lock = threading.Lock()
    
    def _check_prompt_length(self, prompt: str) -> bool:
        """Check if prompt length exceeds limit (using tokenizer to count tokens)"""
        try:
            # Use tokenizer to encode prompt and count tokens
            tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
            token_count = len(tokens)
            return token_count <= self.max_prompt_length
        except Exception as e:
            # If tokenizer encoding fails, log error and return False (conservative handling)
            print(f"Warning: Failed to tokenize prompt for length check: {e}")
            return False

    def add_samples(self, samples: List[dict]):
        """Add new samples to data pool"""
        with self.lock:
            for sample in samples:
                sample['state'] = SampleState.NEED_INFERENCE
                sample['processing'] = False
                
                # Check initial prompt length
                initial_prompt = sample['prompts'][-1] if sample['prompts'] else ""
                if not self._check_prompt_length(initial_prompt):
                    # Get actual token count for error message
                    actual_tokens = len(self.tokenizer.encode(initial_prompt, add_special_tokens=True)) if self.tokenizer else len(initial_prompt)
                    # Length exceeds limit, mark as failed directly
                    sample['failed'] = True
                    sample['failure_reason'] = f"Initial prompt token count ({actual_tokens}) exceeds maximum ({self.max_prompt_length})"
                    self.failed_samples.append(sample)
                    self.stats['failed_samples'] += 1
                    self.stats['length_failed_samples'] += 1
                else:
                    # Length is normal, add to processing queue
                    self.samples_by_state[SampleState.NEED_INFERENCE].append(sample)
                    self.sample_index_map[sample['index']] = sample
                
                self.stats['total_samples'] += 1

    def get_samples_for_processing(self, state: str, max_count: int = None) -> List[dict]:
        """Get samples of specified state for processing - later added samples have priority (LIFO)"""
        if max_count is None:
            max_count = self.max_batch_size
        
        # Wait for batch update to complete
        with self.batch_update_lock:
            pass
            
        with self.lock:
            available_samples = []
            queue = self.samples_by_state[state]
            
            # Find processable samples from queue, but keep samples in queue
            temp_samples = []
            
            # Modified: take samples from queue tail (LIFO - Last In First Out)
            while queue and len(available_samples) < max_count:
                sample = queue.pop()  # Take from queue tail
                temp_samples.append(sample)
                
                if not sample.get('processing', False):
                    # Mark as processing, but keep in queue
                    sample['processing'] = True
                    available_samples.append(sample)
                    
                    # Track in processing_samples
                    self.processing_samples[sample['index']] = sample
                    self.stats['processing_counts'][state] += 1
            
            # Put all samples back to queue tail (maintain LIFO order)
            for sample in reversed(temp_samples):
                queue.append(sample)  # Put back to queue tail
            
            return available_samples
    
    def batch_update_sample_states(self, updates: List[tuple]):
        """
        Batch update sample states, including length check
        updates: List[(sample, new_state, remove_processing)]
        """
        with self.batch_update_lock:  # Acquire batch update lock
            with self.lock:
                completed_samples = []
                failed_samples = []
                length_failed_count = 0
                
                for sample, new_state, remove_processing in updates:
                    old_state = sample.get('state')
                    
                    if remove_processing and sample.get('processing'):
                        # Clean processing state
                        sample['processing'] = False
                        
                        # Remove from processing_samples
                        self.processing_samples.pop(sample['index'], None)
                        
                        if old_state:
                            self.stats['processing_counts'][old_state] -= 1
                        
                        # Remove sample from old state queue
                        if old_state and old_state in self.samples_by_state:
                            queue = self.samples_by_state[old_state]
                            # Find and remove the sample
                            for i, s in enumerate(queue):
                                if s['index'] == sample['index']:
                                    del queue[i]
                                    break
                    
                    # If new state is NEED_INFERENCE, check prompt length
                    if new_state == SampleState.NEED_INFERENCE:
                        current_prompt = sample['prompts'][-1] if sample['prompts'] else ""
                        if not self._check_prompt_length(current_prompt):
                            # Get actual token count for error message
                            actual_tokens = len(self.tokenizer.encode(current_prompt, add_special_tokens=True)) if self.tokenizer else len(current_prompt)
                            # Length exceeds limit, force set to failed state
                            new_state = SampleState.FAILED
                            sample['failed'] = True
                            sample['failure_reason'] = f"Prompt token count ({actual_tokens}) exceeds maximum ({self.max_prompt_length})"
                            length_failed_count += 1
                    
                    if new_state == SampleState.COMPLETED:
                        # Remove from data pool, add to completed list
                        if sample['index'] in self.sample_index_map:
                            del self.sample_index_map[sample['index']]
                        completed_samples.append(sample)
                        self.stats['completed_samples'] += 1
                    elif new_state == SampleState.FAILED:
                        # Remove from data pool, add to failed list
                        if sample['index'] in self.sample_index_map:
                            del self.sample_index_map[sample['index']]
                        failed_samples.append(sample)
                        self.stats['failed_samples'] += 1
                    else:
                        # Update state and add to corresponding queue (new samples added to queue tail, will be processed first)
                        sample['state'] = new_state
                        self.samples_by_state[new_state].append(sample)
                
                # Batch add to completed and failed lists
                self.completed_samples.extend(completed_samples)
                self.failed_samples.extend(failed_samples)
                
                # Update length failure statistics
                self.stats['length_failed_samples'] += length_failed_count
    
    def update_sample_state(self, sample: dict, new_state: str, remove_processing=True):
        """Single sample state update (maintain backward compatibility)"""
        self.batch_update_sample_states([(sample, new_state, remove_processing)])
    
    def has_processing_samples(self) -> bool:
        """Check if there are samples being processed"""
        with self.lock:
            return len(self.processing_samples) > 0
    
    def get_processing_sample_count(self) -> int:
        """Get count of samples being processed"""
        with self.lock:
            return len(self.processing_samples)
    
    def get_pool_status(self) -> Dict:
        """Get data pool status"""
        with self.lock:
            finished_samples = self.stats['completed_samples'] + self.stats['failed_samples']
            
            # Separately count total samples and processable samples for each state
            samples_by_state_total = {}
            samples_by_state_available = {}
            
            for state, queue in self.samples_by_state.items():
                total_count = len(queue)
                available_count = sum(1 for sample in queue if not sample.get('processing', False))
                
                samples_by_state_total[state] = total_count
                samples_by_state_available[state] = available_count
            
            status = {
                'total_samples': self.stats['total_samples'],
                'completed_samples': self.stats['completed_samples'],
                'failed_samples': self.stats['failed_samples'],
                'length_failed_samples': self.stats['length_failed_samples'],  # Add length failure statistics
                'finished_samples': finished_samples,
                'remaining_samples': self.stats['total_samples'] - finished_samples,
                'samples_by_state_total': samples_by_state_total,      # Total samples for each state (including processing)
                'samples_by_state_available': samples_by_state_available,  # Processable samples for each state
                'processing_counts': dict(self.stats['processing_counts']),
                'processing_samples_count': len(self.processing_samples)
            }
            return status
    
    def has_work(self) -> bool:
        """Check if there is still work to do"""
        with self.lock:
            # Check if there are still samples in queues
            has_queued_samples = any(len(queue) > 0 for queue in self.samples_by_state.values())
            
            # Check if there are still samples being processed
            has_processing_samples = len(self.processing_samples) > 0
            
            # Check if there are still unfinished samples
            finished_samples = self.stats['completed_samples'] + self.stats['failed_samples']
            remaining_samples = self.stats['total_samples'] - finished_samples
            
            return has_queued_samples or has_processing_samples or remaining_samples > 0
    
    def get_all_finished_samples(self) -> List[dict]:
        """Get all finished samples (including successful and failed)"""
        with self.lock:
            return self.completed_samples.copy() + self.failed_samples.copy()

class AsyncStageWorker:
    """Asynchronous stage worker base class"""
    
    def __init__(self, stage_name: str, data_pool: UnifiedDataPool, logger, 
                 check_interval: float = 0.1):
        self.stage_name = stage_name
        self.data_pool = data_pool
        self.logger = logger
        self.check_interval = check_interval
        self.running = False
        self.worker_thread = None
        
    def start(self):
        """Start worker thread"""
        self.running = True
        self.worker_thread = threading.Thread(target=self._work_loop, 
                                            name=f"{self.stage_name}_worker")
        self.worker_thread.start()
        self.logger.info(f"[{self.stage_name}] Worker thread started")
    
    def stop(self):
        """Stop worker thread"""
        self.running = False
        if self.worker_thread:
            self.worker_thread.join()
        self.logger.info(f"[{self.stage_name}] Worker thread stopped")
    
    def _work_loop(self):
        """Work loop"""
        while self.running:
            try:
                # Check if there is work to do
                samples = self.get_samples_to_process()
                
                if samples:
                    self.logger.info(f"[{self.stage_name}] Starting to process {len(samples)} samples")
                    start_time = time.time()
                    
                    # Process samples
                    self.process_samples(samples)
                    
                    end_time = time.time()
                    self.logger.info(f"[{self.stage_name}] Completed processing, time taken: {end_time-start_time:.2f}s")
                else:
                    # Sleep briefly when no work
                    time.sleep(self.check_interval)
                    
            except Exception as e:
                self.logger.error(f"[{self.stage_name}] Processing error: {e}")
                time.sleep(1)  # Wait longer on error
    
    def get_samples_to_process(self) -> List[dict]:
        """Get samples to process (implemented by subclass)"""
        raise NotImplementedError
    
    def process_samples(self, samples: List[dict]):
        """Process samples (implemented by subclass)"""
        raise NotImplementedError

class LLMInferenceWorker(AsyncStageWorker):
    """LLM inference worker"""
    
    def __init__(self, llm, sampling_params, data_pool, logger, max_iterations=8, batch_size=256, name = 'LLM_Inference 1'):
        super().__init__("LLM_Inference", data_pool, logger, check_interval=0.05)
        self.llm = llm
        self.sampling_params = sampling_params
        self.max_iterations = max_iterations
        self.batch_size = batch_size
    
    def get_samples_to_process(self) -> List[dict]:
        return self.data_pool.get_samples_for_processing(SampleState.NEED_INFERENCE, max_count=self.batch_size)
    
    def process_samples(self, samples: List[dict]):
        if not samples:
            return
        
        # Batch inference
        prompts = [sample['prompts'][-1] for sample in samples]
        responses = batch_generate(self.llm, prompts, self.sampling_params)
        
        # Prepare batch updates
        updates = []
        
        # Parse results and prepare updates
        for sample, response in zip(samples, responses):
            tool_call = parse_tool_call(response)
            sample['chat_count'] += 1
            sample['last_response'] = response
            sample['last_tool_call'] = tool_call
            
            # Save current round response
            sample['responses'].append(response)
            
            if sample['chat_count'] > 2 * self.max_iterations:
                sample['failed'] = True
                next_state = SampleState.FAILED
            else:
                # Decide next state based on tool call type
                tool_name = tool_call.get('name', '')
                if tool_name == 'syntax_check':
                    next_state = SampleState.NEED_SYNTAX_CHECK
                elif tool_name == 'consistency_check':
                    next_state = SampleState.NEED_QWQ_CHECK
                else:
                    # Invalid tool call, proceed to next inference round or fail
                    next_state = self._handle_invalid_tool_call(sample)
            
            updates.append((sample, next_state, True))
        
        # Batch update all sample states
        self.data_pool.batch_update_sample_states(updates)

    def _handle_invalid_tool_call(self, sample) -> str:
        """Handle invalid tool call"""
        # Invalid tool call counts as an error, increment iteration count
        sample['iteration_count'] += 1
        
        if sample['iteration_count'] >= self.max_iterations:
            sample['failed'] = True  # Mark as failed
            return SampleState.FAILED
        
        # Add error feedback, prepare for next inference round
        tool_feedback = format_tool_results_as_user_message('', {"error": "No valid tool call"})
        next_prompt = sample['prompts'][-1] + sample['last_response'] + '\n\n' + tool_feedback + '\n'
        sample['prompts'].append(next_prompt)
        sample['tools'].append({"error": "No valid tool call"})
        
        # Clean temporary data
        sample.pop('last_response', None)
        sample.pop('last_tool_call', None)
        
        return SampleState.NEED_INFERENCE
    
class SyntaxCheckWorker(AsyncStageWorker):
    """Syntax check worker"""
    
    def __init__(self, data_pool, logger, max_iterations=8, batch_size=256):
        super().__init__("Syntax_Check", data_pool, logger, check_interval=0.1)
        self.max_iterations = max_iterations
        self.batch_size = batch_size
    
    def get_samples_to_process(self) -> List[dict]:
        return self.data_pool.get_samples_for_processing(SampleState.NEED_SYNTAX_CHECK, max_count=self.batch_size)
    
    def process_samples(self, samples: List[dict]):
        if not samples:
            return
            
        # Batch syntax check
        lean4_codes = []
        for sample in samples:
            tool_call = sample.get('last_tool_call', {})
            code = tool_call.get('arguments', {}).get('lean4_code', '')
            lean4_codes.append(code)
        
        syntax_results = batch_syntax_check(lean4_codes)
        
        # Prepare batch updates
        updates = []
        
        # Update results and decide next step
        for sample, result in zip(samples, syntax_results):
            sample['tools'].append(result)
            next_state = self._determine_next_state(sample, result)
            updates.append((sample, next_state, True))
        
        # Batch update all sample states
        self.data_pool.batch_update_sample_states(updates)
    
    def _determine_next_state(self, sample, result) -> str:
        """Decide next state based on syntax check result"""
        
        if not result.get('pass'):
            sample['iteration_count'] += 1

        if sample['iteration_count'] >= self.max_iterations:
            sample['failed'] = True  # Mark as failed
            return SampleState.FAILED

        # Prepare for next inference round
        response = sample.get('last_response', '')
        tool_call = sample.get('last_tool_call', {})
        tool_feedback = format_tool_results_as_user_message(tool_call.get('name', ''), result)
        next_prompt = sample['prompts'][-1] + response + '\n\n' + tool_feedback + '\n'
        sample['prompts'].append(next_prompt)
        
        # Clean temporary data
        sample.pop('last_response', None)
        sample.pop('last_tool_call', None)
        
        return SampleState.NEED_INFERENCE

class QWQConsistencyWorker(AsyncStageWorker):
    """QWQ consistency check worker"""
    
    def __init__(self, qwq_model, tokenizer, consistency_params, data_pool, logger, max_iterations=8, batch_size=256):
        super().__init__("QWQ_Consistency", data_pool, logger, check_interval=0.1)
        self.qwq_model = qwq_model
        self.tokenizer = tokenizer
        self.consistency_params = consistency_params
        self.max_iterations = max_iterations
        self.batch_size = batch_size
    
    def get_samples_to_process(self) -> List[dict]:
        return self.data_pool.get_samples_for_processing(SampleState.NEED_QWQ_CHECK, max_count=self.batch_size)
    
    def process_samples(self, samples: List[dict]):
        if not samples:
            return
            
        # Build consistency check prompts
        prompts = []
        for sample in samples:
            tool_call = sample.get('last_tool_call', {})
            informal_statement = sample['data']['informal_statement']
            lean4_code = tool_call.get('arguments', {}).get('lean4_code', '')
            consistency_prompt = get_consistency_prompt(self.tokenizer, informal_statement, lean4_code)
            prompts.append(consistency_prompt)
        
        # Batch QWQ check
        qwq_responses = batch_generate(self.qwq_model, prompts, self.consistency_params)
        
        # Prepare batch updates
        updates = []
        
        # Parse results and decide next step
        for sample, response in zip(samples, qwq_responses):
            qwq_result = parse_consistency_response(response)
            sample['qwq_result'] = qwq_result
            
            if qwq_result.get('pass'):
                # QWQ passed, proceed to QWen3 check
                next_state = SampleState.NEED_QWEN3_CHECK
            else:
                # QWQ failed, record result and decide next step
                sample['tools'].append(qwq_result)
                next_state = self._determine_next_state(sample, qwq_result)
            
            updates.append((sample, next_state, True))
        
        # Batch update all sample states
        self.data_pool.batch_update_sample_states(updates)
    
    def _determine_next_state(self, sample, result) -> str:
        """Decide next state based on QWQ result"""
        # Only increment iteration count when QWQ check fails
        sample['iteration_count'] += 1
        
        if sample['iteration_count'] >= self.max_iterations:
            sample['failed'] = True  # Mark as failed
            return SampleState.FAILED
        
        # QWQ failed, prepare for next inference round
        response = sample.get('last_response', '')
        tool_call = sample.get('last_tool_call', {})
        tool_feedback = format_tool_results_as_user_message(tool_call.get('name', ''), result)
        next_prompt = sample['prompts'][-1] + response + '\n\n' + tool_feedback + '\n'
        sample['prompts'].append(next_prompt)
        
        # Clean temporary data
        sample.pop('last_response', None)
        sample.pop('last_tool_call', None)
        sample.pop('qwq_result', None)
        
        return SampleState.NEED_INFERENCE

class QWen3ConsistencyWorker(AsyncStageWorker):
    """QWen3 consistency check worker"""
    
    def __init__(self, qwen3_model, tokenizer, consistency_params, data_pool, logger, max_iterations=8, batch_size=256, saver=None):
        super().__init__("QWen3_Consistency", data_pool, logger, check_interval=0.1)
        self.qwen3_model = qwen3_model
        self.tokenizer = tokenizer
        self.consistency_params = consistency_params
        self.max_iterations = max_iterations
        self.batch_size = batch_size
        self.saver = saver  # Incremental saver
    
    def get_samples_to_process(self) -> List[dict]:
        return self.data_pool.get_samples_for_processing(SampleState.NEED_QWEN3_CHECK, max_count=self.batch_size)
    
    def process_samples(self, samples: List[dict]):
        if not samples:
            return
            
        # Build consistency check prompts
        prompts = []
        for sample in samples:
            tool_call = sample.get('last_tool_call', {})
            informal_statement = sample['data']['informal_statement']
            lean4_code = tool_call.get('arguments', {}).get('lean4_code', '')
            consistency_prompt = get_consistency_prompt(self.tokenizer, informal_statement, lean4_code)
            prompts.append(consistency_prompt)
        
        # Batch QWen3 check
        qwen3_responses = batch_generate(self.qwen3_model, prompts, self.consistency_params)
        
        # Prepare batch updates
        updates = []
        completed_samples = []  # Collect completed samples
        
        # Parse results and decide next step
        for sample, response in zip(samples, qwen3_responses):
            qwen3_result = parse_consistency_response(response)
            sample['tools'].append(qwen3_result)
            
            next_state = self._determine_next_state(sample, qwen3_result)
            
            # If sample is completed, add to completed list
            if next_state == SampleState.COMPLETED:
                completed_samples.append(sample)
            
            updates.append((sample, next_state, True))
        
        # Batch update all sample states
        self.data_pool.batch_update_sample_states(updates)
        
        # Batch save completed samples
        if completed_samples and self.saver:
            self.saver.save_completed_samples(completed_samples)
    
    def _determine_next_state(self, sample, result) -> str:
        """Decide next state based on QWen3 result"""
        
        if result.get('pass'):
            # QWen3 passed, task completed
            sample['completed'] = True
            return SampleState.COMPLETED
        else:
            # Only increment iteration count when QWen3 check fails
            sample['iteration_count'] += 1
            
            if sample['iteration_count'] >= self.max_iterations:
                sample['failed'] = True  # Mark as failed
                return SampleState.FAILED
            
            # QWen3 failed, prepare for next inference round
            response = sample.get('last_response', '')
            tool_call = sample.get('last_tool_call', {})
            tool_feedback = format_tool_results_as_user_message(tool_call.get('name', ''), result)
            next_prompt = sample['prompts'][-1] + response + '\n\n' + tool_feedback + '\n'
            sample['prompts'].append(next_prompt)
            
            # Clean temporary data
            sample.pop('last_response', None)
            sample.pop('last_tool_call', None)
            sample.pop('qwq_result', None)
            
            return SampleState.NEED_INFERENCE

class EventDrivenPipelineManager:
    """Event-driven pipeline manager"""
    
    def __init__(self, llm1, llm2, qwq_model, qwen3_model, tokenizer, 
                 sampling_params, consistency_params, logger, 
                 max_batch_size=256, max_iterations=8, output_path=None, max_prompt_length=40000):
        
        # Create unified data pool, pass length limit
        self.data_pool = UnifiedDataPool(tokenizer, max_batch_size, max_prompt_length)
        self.logger = logger
        self.max_iterations = max_iterations
        
        # Create incremental saver
        self.saver = IncrementalSaver(output_path, logger) if output_path else None
        
        # Create stage workers
        self.workers = [
            LLMInferenceWorker(llm1, sampling_params, self.data_pool, logger, max_iterations, max_batch_size, name = 'LLM_Inference 1'),
            LLMInferenceWorker(llm2, sampling_params, self.data_pool, logger, max_iterations, max_batch_size, name = 'LLM_Inference 2'),
            SyntaxCheckWorker(self.data_pool, logger, max_iterations, max_batch_size),
            QWQConsistencyWorker(qwq_model, tokenizer, consistency_params, self.data_pool, logger, max_iterations, max_batch_size),
            QWen3ConsistencyWorker(qwen3_model, tokenizer, consistency_params, self.data_pool, logger, max_iterations, max_batch_size, self.saver)
        ]
        
        # Monitor thread
        self.monitor_thread = None
        self.running = False

    def start_pipeline(self, samples):
        """Start pipeline processing"""
        
        self.logger.info(f"Starting event-driven pipeline, sample count: {len(samples)}")
        self.data_pool.add_samples(samples)
        
        # Start all workers
        for worker in self.workers:
            worker.start()
        
        # Start monitor thread
        self.running = True
        self.monitor_thread = threading.Thread(target=self._monitor_progress, 
                                             name="progress_monitor")
        self.monitor_thread.start()
        
        # Wait for all samples to complete
        self._wait_for_completion()
        
        # Stop all workers
        self.stop_pipeline()
        
        return self.data_pool.get_all_finished_samples()
    
    def _monitor_progress(self):
        """Monitor processing progress"""
        while self.running:
            status = self.data_pool.get_pool_status()
            self.logger.info(f"Progress monitor - Total samples: {status['total_samples']}, "
                        f"Completed: {status['completed_samples']}, "
                        f"Failed: {status['failed_samples']}, "
                        f"Length limit failures: {status['length_failed_samples']}, "  # Add length failure statistics
                        f"Remaining: {status['remaining_samples']}")
            self.logger.info(f"Total samples by state: {status['samples_by_state_total']}")
            self.logger.info(f"Available samples by state: {status['samples_by_state_available']}")
            self.logger.info(f"Processing sample counts: {status['processing_counts']}")
            
            time.sleep(60)  # Report progress every 60 seconds
    
    def _wait_for_completion(self):
        """Wait for all samples to complete processing"""
        while self.data_pool.has_work():
            time.sleep(5)
        
        # Additional wait to ensure all processing samples complete
        time.sleep(10)
    
    def stop_pipeline(self):
        """Stop pipeline"""
        self.running = False
        
        # Stop all workers
        for worker in self.workers:
            worker.stop()
        
        # Stop monitor thread
        if self.monitor_thread:
            self.monitor_thread.join()
        
        self.logger.info("Event-driven pipeline stopped")
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--log_dir", type=str, required=True)
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--max_length", type=int, default=2048*20)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--max_iterations", type=int, default=8)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--max_prompt_length", type=int, default=40000)  # Add maximum prompt length parameter
    args = parser.parse_args()

    # Setup logger
    logger, log_path = setup_logger(args.log_dir)
    logger.info("=" * 60)
    logger.info("Starting event-driven asynchronous pipeline inference task")
    logger.info(f"Input path: {args.input_path}")
    logger.info(f"Output path: {args.output_path}")
    logger.info(f"Batch size: {args.batch_size}")
    logger.info(f"Max iterations: {args.max_iterations}")
    logger.info("=" * 60)

    # Initialize models
    try:
        logger.info("Initializing models...")
        qwen3_model = load_model_on_gpus('xxx/QWQ-32B', [0], "qwen3")
        qwq_model = load_model_on_gpus('xxx/Qwen3-32B', [1], "qwq")
        llm1 = load_model_on_gpus(args.model, [2], "llm1")
        llm2 = load_model_on_gpus(args.model, [3], "llm2")
        tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
        logger.info("✓ All models initialized successfully")
    except Exception as e:
        logger.error(f"✗ Model initialization failed: {e}")
        raise

    # Set sampling parameters
    sampling_params = SamplingParams(
        temperature=args.temperature, 
        max_tokens=args.max_length, 
        n=1, 
        stop='</tool_calls>', 
        include_stop_str_in_output=True,
    )

    consistency_params = SamplingParams(
        temperature=0.0,
        # top_p=0.95, 
        max_tokens=16000, 
        n=1, 
    )

    # Load data
    data = read_jsonl(args.input_path)
    logger.info(f"Successfully loaded data, total {len(data)} records")

    # Initialize samples
    samples = []
    for i, d in enumerate(data):
        samples.append({
            'index': i,
            'data': d,
            'prompts': [get_input_prompt(tokenizer, d['informal_statement'])],
            'responses': [],  # Add responses list to save each round's response
            'tools': [],
            'completed': False,
            'failed': False,
            'iteration_count': 0,
            'chat_count': 0
        })

    # Create event-driven pipeline manager (pass output path for incremental saving)
    pipeline_manager = EventDrivenPipelineManager(
        llm1, llm2, qwq_model, qwen3_model, tokenizer,
        sampling_params, consistency_params, logger, 
        args.batch_size, args.max_iterations, args.output_path, args.max_prompt_length
    )

    # Run pipeline
    start_time = datetime.now()
    try:
        finished_samples = pipeline_manager.start_pipeline(samples)
    except KeyboardInterrupt:
        logger.info("Received interrupt signal, stopping pipeline...")
        pipeline_manager.stop_pipeline()
        raise
    
    end_time = datetime.now()

    # Final save of all remaining results (if any)
    results_to_save = []
    for sample in finished_samples:
        results_to_save.append({
            'index': sample['index'],
            'data': sample['data'],
            'prompts': sample['prompts'],
            'responses': sample.get('responses', []),  # Add responses field
            'tools': sample['tools'],
            'completed': sample.get('completed', False),
            'failed': sample.get('failed', False),
            'iterations': sample['iteration_count']
        })


    # Since incremental saving has been done, we can choose not to write here, or use as backup
    # write_jsonl(results_to_save, args.output_path, 'w')

    # Statistics (add length failure statistics)
    total_time = (end_time - start_time).total_seconds()
    completed_count = sum(1 for sample in finished_samples if sample.get('completed', False))
    failed_count = sum(1 for sample in finished_samples if sample.get('failed', False))
    length_failed_count = sum(1 for sample in finished_samples if sample.get('failure_reason', '').startswith('Prompt length') or sample.get('failure_reason', '').startswith('Initial prompt length'))
    avg_iterations = sum(sample['iteration_count'] for sample in finished_samples) / len(finished_samples) if finished_samples else 0
    success_rate = (completed_count / len(samples)) * 100 if samples else 0

    logger.info(f"\n{'='*60}")
    logger.info("Event-driven asynchronous pipeline processing completed!")
    logger.info(f"Total processing time: {total_time:.2f} seconds")
    logger.info(f"Total processed samples: {len(samples)}")
    logger.info(f"Successfully completed samples: {completed_count}")
    logger.info(f"Failed samples: {failed_count}")
    logger.info(f"Length limit failed samples: {length_failed_count}")  # Add length failure statistics
    logger.info(f"Success rate: {success_rate:.2f}%")
    logger.info(f"Average iterations: {avg_iterations:.2f}")
    logger.info(f"Average time per sample: {total_time/len(samples):.2f} seconds")
    logger.info(f"Output file: {args.output_path}")
    logger.info("=" * 60)
