#!/usr/bin/env python3
"""
LLM Code Selector with Fixed Event Loop Management and Enhanced Retry Logic

This module provides LLM-based code selection with robust error handling,
proper asyncio event loop management, and comprehensive retry mechanisms.
"""

import asyncio
import aiohttp
import json
import re
import time
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import concurrent.futures
import os
import threading
import logging
from tqdm import tqdm
import random

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class CodeSelectionRequest:
    """Request for code selection"""
    datapoint_id: str
    datapoint_text: str
    previous_codes: List[str]
    candidate_codes: List[str]

@dataclass
class CodeSelectionResult:
    """Result of code selection"""
    datapoint_id: str
    selected_codes: List[str]
    processing_time: float
    success: bool
    error_message: Optional[str] = None

class LLMCodeSelector:
    """
    LLM-based code selector with robust error handling and retry logic
    """
    
    def __init__(self, model_url: str = None, max_concurrency: int = 64, timeout: int = 120, max_retries: int = 3):
        """
        Initialize the LLM code selector
        
        Args:
            model_url: URL of the VLLM model server
            max_concurrency: Maximum concurrent requests
            timeout: Request timeout in seconds
            max_retries: Maximum number of retry attempts
        """
        # Load LLM URL from .env if not provided
        if model_url is None:
            try:
                from dotenv import load_dotenv
                load_dotenv('../../.env')
                model_url = os.getenv('VLLM_QWEN_32B_URL')
                logger.info(f'Using LLM URL from .env: {model_url}')
            except Exception as e:
                logger.warning(f'Failed to load .env')
        
        self.model_url = model_url.rstrip("/")
        self.max_concurrency = max_concurrency
        self.timeout = timeout
        self.max_retries = max_retries
        self._semaphore = None  # Will be created in the correct event loop
        
        # Load model name from environment variable like build_corpus.py does
        try:
            from dotenv import load_dotenv
            load_dotenv("../../.env")
            self.model_name = os.getenv("VLLM_QWEN_32B_MODEL", "Qwen/Qwen3-32B")
            logger.info(f"Using model from .env: {self.model_name}")
        except Exception as e:
            logger.warning(f"Failed to load model from .env, using default: {e}")
            self.model_name = "Qwen/Qwen3-32B"
        
        # Thread-local storage for event loops
        self._local = threading.local()
    
    def _get_or_create_loop(self):
        """Get or create a new event loop for the current thread"""
        if not hasattr(self._local, 'loop') or self._local.loop.is_closed():
            self._local.loop = asyncio.new_event_loop()
            asyncio.set_event_loop(self._local.loop)
        return self._local.loop
    
    def _sanitize_text(self, text: str) -> str:
        """Sanitize text to prevent HTTP 400 errors from problematic characters"""
        # Replace << and >> patterns that cause server issues
        sanitized = text.replace('<<', '[[').replace('>>', ']]')
        
        # Replace other problematic characters that cause HTTP 400 errors
        sanitized = sanitized.replace('"', '"').replace('"', '"')
        sanitized = sanitized.replace('—', '-').replace('–', '-')
        sanitized = sanitized.replace('…', '...').replace('•', '*')
        
        # Remove non-ASCII characters that might cause issues
        import re
        sanitized = re.sub(r'[^ -~]', ' ', sanitized)
        sanitized = re.sub(r'\s+', ' ', sanitized).strip()
        return sanitized
    
    
    def _approx_tokens(self, text: str) -> int:
        """Rough token estimate (~4 chars per token)."""
        if not text:
            return 0
        return max(1, len(text) // 4)

    def _enforce_context_budget(self, datapoint_text: str, prev_codes: List[str], cand_codes: List[str], budget_tokens: int = 32000):
        """Context budget enforcement with aggressive candidate sampling using FULL datapoint text."""
        full_dp = self._sanitize_text(datapoint_text)
        dp_chars = len(full_dp)
        dp_tokens = max(1, len(full_dp) // 4)
        available_tokens = max(0, budget_tokens - dp_tokens - 200)
        max_candidates = min(len(cand_codes or []), max(10, available_tokens // 20))

        seed = abs(hash(full_dp)) % (2**32)
        rng = random.Random(seed)
        limited_cand = rng.sample(cand_codes or [], min(max_candidates, len(cand_codes or [])))

        candidate_codes_text = "\n".join([f"{code}" for code in limited_cand])
        prompt_core = f"Datapoint:\n{full_dp}\n\nCandidate codes to choose from:\n{candidate_codes_text}\n"
        approx_tokens = self._approx_tokens(prompt_core) + 100

        logger.info(f"AGGRESSIVE: ~{approx_tokens} tokens (dp_chars={dp_chars}, cand={len(limited_cand)}, filtered=False)")
        return prompt_core, [], limited_cand, full_dp, approx_tokens

    def _format_prompt(self, request: CodeSelectionRequest) -> str:
        """Format the prompt for code selection using context budget enforcement."""
        # Enforce context budget (~32k tokens default)
        prompt_core, _, _, _, _ = self._enforce_context_budget(
            request.datapoint_text,
            request.previous_codes,
            request.candidate_codes,
            budget_tokens=32000,
        )

        prompt = (
            "You are a code selection expert. Given a datapoint and candidate codes, select the most relevant codes.\n\n"
            + prompt_core +
            "CRITICAL: Do NOT use thinking mode, step-by-step reasoning, or any hidden analysis.\n"
            "STRICT OUTPUT RULES (read carefully):\n"
            "- Respond with ONLY a JSON array of strings (no keys, no objects).\n"
            "- Do NOT include any explanations, notes, thoughts, or tags.\n"
            "- Do NOT include <think> or any hidden reasoning.\n"
            "- Do NOT wrap in code fences.\n"
            "- Output must begin with '[' and end with ']'.\n"
            "- Select as many DISTINCT codes as needed to comprehensively cover the datapoint's key themes, up to a maximum of 20.\n"
            "- Prefer breadth over redundancy; avoid near-duplicate codes.\n"
            "- Each code MUST be a clear, unambiguous phrase of 5–15 words (no single words, avoid vague terms).\n"
            "- Aim to select 10–20 codes if relevant; at least 5, and at most 20.\n\n"
            'Example valid output: ["Systematic delegation to maximize time value", "Building personal authority via consistent public expertise sharing"]'
        )
        return prompt

    def _extract_json_from_response(self, content: str) -> List[str]:
        """Extract JSON array from LLM response, handling thinking mode and various formats"""
        if not content:
            return []
        
        logger.warning(f"🔍 Extracting JSON from response of length {len(content)}")
        
        # Strategy 1: Remove <think> tags and extract JSON from cleaned content
        cleaned_content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
        logger.warning(f"🔍 Cleaned content length: {len(cleaned_content)}")
        
        # Try to find JSON array in cleaned content
        json_match = re.search(r'\[\s*"[^"]*"(?:\s*,\s*"[^"]*")*\s*\]', cleaned_content, re.DOTALL)
        if json_match:
            logger.warning(f"🔍 Found JSON match in cleaned content")
            try:
                selected_codes = json.loads(json_match.group())
                if isinstance(selected_codes, list) and all(isinstance(item, str) for item in selected_codes):
                    logger.warning(f"🔍 Successfully extracted {len(selected_codes)} codes from cleaned content")
                    return selected_codes
            except json.JSONDecodeError as e:
                logger.warning(f"🔍 JSON parsing failed in cleaned content: {e}")
                pass
        
        # Strategy 2: Look for JSON array anywhere in the original response
        json_match = re.search(r'\[\s*"[^"]*"(?:\s*,\s*"[^"]*")*\s*\]', content, re.DOTALL)
        if json_match:
            try:
                selected_codes = json.loads(json_match.group())
                if isinstance(selected_codes, list) and all(isinstance(item, str) for item in selected_codes):
                    return selected_codes
            except json.JSONDecodeError:
                pass
        
        # Strategy 3: Extract individual quoted strings as fallback
        string_matches = re.findall(r'"[^"]*"', content)
        if string_matches:
            # Remove quotes and return as list
            return [match.strip('"') for match in string_matches if match.strip('"')]
        
        return []

    
    async def _make_llm_request(self, session: aiohttp.ClientSession, request: CodeSelectionRequest) -> CodeSelectionResult:
        """
        Make a single LLM request for code selection with enhanced retry logic
        
        Args:
            session: aiohttp session
            request: Code selection request
            
        Returns:
            Code selection result
        """
        # Create semaphore in the current event loop if it doesn't exist
        if self._semaphore is None:
            self._semaphore = asyncio.Semaphore(self.max_concurrency)
        
        async with self._semaphore:
            start_time = time.time()
            
            formatted_prompt = self._format_prompt(request)
            
            payload = {
                "model": self.model_name,  # Use environment variable instead of hardcoded
                "messages": [
                    {
                        "role": "user",
                        "content": formatted_prompt
                    }
                ],
                "temperature": 0.0,
                "max_tokens": 1024,
                "stop": ["```json", "```"],  # Stop on code fences; allow think, we strip it later
                "thinking_mode": False,  # Explicitly disable thinking mode
                # Remove "stream": False - this might cause 400 errors
            }
            
            last_exception = None
            
            for attempt in range(self.max_retries):
                try:
                    logger.debug(f"Making LLM request for {request.datapoint_id[:50]}... (attempt {attempt + 1}/{self.max_retries})")
                    
                    async with session.post(
                        self.model_url + "/v1/chat/completions",
                        json=payload,
                        timeout=aiohttp.ClientTimeout(total=self.timeout)
                    ) as response:
                        if response.status == 200:
                            result = await response.json()
                            content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
                            
                            # Handle empty content as retryable case
                            if not content or not content.strip():
                                logger.warning(f"⚠️  Empty content for {request.datapoint_id[:50]}... prompt_len={len(formatted_prompt)}, cand={len(request.candidate_codes)}")
                                if attempt < self.max_retries - 1:
                                    wait_time = min(2 ** attempt, 10)
                                    logger.warning(f"Empty content, retrying in {wait_time}s (attempt {attempt + 1}/{self.max_retries})")
                                    await asyncio.sleep(wait_time)
                                    continue
                                processing_time = time.time() - start_time
                                return CodeSelectionResult(
                                    datapoint_id=request.datapoint_id,
                                    selected_codes=[],
                                    processing_time=processing_time,
                                    success=False,
                                    error_message="Empty content from model"
                                )
                            
                            # Parse JSON response with detailed debugging
                            logger.warning(f"🔍 Parsing response for {request.datapoint_id[:50]}...")
                            logger.debug(f"   Response length: {len(content)} characters")
                            logger.debug(f"   Response preview: {content[:200]}...")
                            
                            try:
                                selected_codes = json.loads(content)
                                if not isinstance(selected_codes, list):
                                    logger.warning(f"⚠️  Response is not a list: {type(selected_codes)}")
                                    selected_codes = []
                                
                                # Ensure we don't exceed 20 codes
                                selected_codes = selected_codes[:20]
                                
                                processing_time = time.time() - start_time
                                logger.warning(f"✅ Successfully parsed {len(selected_codes)} codes for {request.datapoint_id[:50]}...")
                                
                                return CodeSelectionResult(
                                    datapoint_id=request.datapoint_id,
                                    selected_codes=selected_codes,
                                    processing_time=processing_time,
                                    success=True
                                )
                            except json.JSONDecodeError as e:
                                logger.warning(f"⚠️  JSON parsing failed for {request.datapoint_id[:50]}...: {e}")
                                logger.warning(f"   Raw response: {content[:500]}...")
                                
                                # Try to extract JSON from thinking mode response
                                logger.warning(f"🔍 Calling _extract_json_from_response for {request.datapoint_id[:50]}...")
                                selected_codes = self._extract_json_from_response(content)
                                logger.warning(f"🔍 Extraction result: {len(selected_codes) if selected_codes else 0} codes")
                                if selected_codes:
                                    selected_codes = selected_codes[:20]  # Limit to 20
                                    processing_time = time.time() - start_time
                                    logger.warning(f"✅ Successfully extracted {len(selected_codes)} codes from thinking mode for {request.datapoint_id[:50]}...")
                                    
                                    return CodeSelectionResult(
                                        datapoint_id=request.datapoint_id,
                                        selected_codes=selected_codes,
                                        processing_time=processing_time,
                                        success=True
                                    )
                                else:
                                    logger.warning(f"⚠️  All JSON extraction strategies failed for {request.datapoint_id[:50]}...")
                                
                                # JSON parsing failed - treat as retryable error
                                if attempt < self.max_retries - 1:
                                    wait_time = min(2 ** attempt, 10)
                                    logger.warning(f"JSON parsing failed for {request.datapoint_id[:50]}... (response: {content[:100]}), retrying in {wait_time}s (attempt {attempt + 1}/{self.max_retries})")
                                    await asyncio.sleep(wait_time)
                                    continue
                                
                                # Final attempt failed
                                processing_time = time.time() - start_time
                                return CodeSelectionResult(
                                    datapoint_id=request.datapoint_id,
                                    selected_codes=[],
                                    processing_time=processing_time,
                                    success=False,
                                    error_message=f"Failed to parse JSON from response after {self.max_retries} attempts: {content[:200]}"
                                )
                        else:
                            error_text = await response.text()
                            last_exception = Exception(f"HTTP {response.status}: {error_text[:200]}")
                            
                            if attempt < self.max_retries - 1:
                                wait_time = min(2 ** attempt, 10)  # Cap at 10 seconds
                                logger.warning(f"HTTP {response.status} for {request.datapoint_id[:50]}..., retrying in {wait_time}s (attempt {attempt + 1}/{self.max_retries})")
                                await asyncio.sleep(wait_time)
                                continue
                            
                            processing_time = time.time() - start_time
                            return CodeSelectionResult(
                                datapoint_id=request.datapoint_id,
                                selected_codes=[],
                                processing_time=processing_time,
                                success=False,
                                error_message=str(last_exception)
                            )
                            
                except asyncio.TimeoutError as e:
                    last_exception = e
                    if attempt < self.max_retries - 1:
                        wait_time = min(2 ** attempt, 10)
                        logger.warning(f"Timeout for {request.datapoint_id[:50]}..., retrying in {wait_time}s (attempt {attempt + 1}/{self.max_retries})")
                        await asyncio.sleep(wait_time)
                        continue
                    
                    processing_time = time.time() - start_time
                    return CodeSelectionResult(
                        datapoint_id=request.datapoint_id,
                        selected_codes=[],
                        processing_time=processing_time,
                        success=False,
                        error_message="Request timeout"
                    )
                    
                except Exception as e:
                    last_exception = e
                    # Check if it's a connection error
                    if 'Cannot connect to host' in str(e) or 'Connection refused' in str(e):
                        logger.error(f"Connection error for {request.datapoint_id[:50]}...: {str(e)[:100]}")
                        # Don't retry connection errors immediately
                        if attempt < self.max_retries - 1:
                            wait_time = min(5 ** attempt, 30)  # Longer wait for connection issues
                            logger.warning(f"Connection failed, waiting {wait_time}s before retry (attempt {attempt + 1}/{self.max_retries})")
                            await asyncio.sleep(wait_time)
                            continue
                    elif attempt < self.max_retries - 1:
                        wait_time = min(2 ** attempt, 10)
                        logger.warning(f"Exception for {request.datapoint_id[:50]}...: {str(e)[:100]}, retrying in {wait_time}s (attempt {attempt + 1}/{self.max_retries})")
                        await asyncio.sleep(wait_time)
                        continue
                    
                    processing_time = time.time() - start_time
                    return CodeSelectionResult(
                        datapoint_id=request.datapoint_id,
                        selected_codes=[],
                        processing_time=processing_time,
                        success=False,
                        error_message=f"Unexpected error: {str(e)}"
                    )
            
            # If we get here, all retries failed
            processing_time = time.time() - start_time
            logger.error(f"All retry attempts failed for {request.datapoint_id[:50]}...: {str(last_exception)}")
            return CodeSelectionResult(
                datapoint_id=request.datapoint_id,
                selected_codes=[],
                processing_time=processing_time,
                success=False,
                error_message=f"All retry attempts failed: {str(last_exception)}"
            )
    
    async def select_codes_batch(self, requests: List[CodeSelectionRequest]) -> List[CodeSelectionResult]:
        """
        Process a batch of code selection requests with intelligent batching for large datasets
        
        Args:
            requests: List of code selection requests
            
        Returns:
            List of code selection results
        """
        if not requests:
            return []
        
        logger.info(f"Processing {len(requests)} requests with max_concurrency={self.max_concurrency}")
        
        # Intelligent batching for large datasets
        if len(requests) > 50:
            batch_size = min(32, len(requests) // 4)  # Smaller batches for large datasets
            logger.info(f"Large dataset detected ({len(requests)} requests), using batch size: {batch_size}")
            return await self._process_large_batch(requests, batch_size)
        else:
            # Use original method for small datasets
            return await self._process_small_batch(requests)
    
    async def _process_small_batch(self, requests: List[CodeSelectionRequest]) -> List[CodeSelectionResult]:
        """Process small batches using the original method"""
        # Create a new connector for this batch
        connector = aiohttp.TCPConnector(
            limit=self.max_concurrency * 2, 
            limit_per_host=self.max_concurrency,
            ttl_dns_cache=300,  # DNS cache for 5 minutes
            use_dns_cache=True
        )
        timeout = aiohttp.ClientTimeout(total=self.timeout)
        
        async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
            # Create tasks for all requests
            tasks = [self._make_llm_request(session, request) for request in requests]
            
            # Use asyncio.gather with return_exceptions to handle individual failures
            results = await asyncio.gather(*tasks, return_exceptions=True)
            
            # Process results and handle exceptions
            processed_results = []
            for i, result in enumerate(results):
                if isinstance(result, Exception):
                    logger.error(f"Request {i} failed with exception: {result}")
                    processed_results.append(CodeSelectionResult(
                        datapoint_id=requests[i].datapoint_id,
                        selected_codes=[],
                        processing_time=0.0,
                        success=False,
                        error_message=str(result)
                    ))
                else:
                    processed_results.append(result)
            
            return processed_results
            
    async def _process_large_batch(self, requests: List[CodeSelectionRequest], batch_size: int) -> List[CodeSelectionResult]:
        """Process large datasets in smaller batches to prevent server overload"""
        all_results = []
        total_batches = (len(requests) + batch_size - 1) // batch_size
        
        logger.info(f"Processing {len(requests)} requests in {total_batches} batches of {batch_size}")
        
        for i in range(0, len(requests), batch_size):
            batch = requests[i:i + batch_size]
            batch_num = (i // batch_size) + 1
            
            logger.info(f"  🔄 Processing batch {batch_num}/{total_batches} ({len(batch)} requests)...")
            
            # Process this batch
            batch_results = await self._process_small_batch(batch)
            all_results.extend(batch_results)
            
            # Small delay between batches to prevent server overload
            if batch_num < total_batches:
                await asyncio.sleep(1.0)  # 1 second delay between batches
        
        logger.info(f"✅ Completed all {total_batches} batches")
        return all_results
    
    def select_codes_batch_sync(self, requests: List[CodeSelectionRequest]) -> List[CodeSelectionResult]:
        """
        Synchronous wrapper for batch code selection with robust event loop management
        
        Args:
            requests: List of code selection requests
            
        Returns:
            List of code selection results
        """
        if not requests:
            return []
        
        logger.info(f"Starting synchronous batch processing of {len(requests)} requests")
        
        try:
            # Try to get the current event loop
            loop = asyncio.get_running_loop()
            logger.warning("🔍 Running in existing event loop, using ThreadPoolExecutor")
            
            # We're in an async context, use ThreadPoolExecutor to run in a separate thread
            with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
                future = executor.submit(self._run_in_new_loop, requests)
                result = future.result(timeout=self.timeout * 10)  # Much longer timeout for large batches
                logger.warning(f"🔍 ThreadPoolExecutor completed successfully: {len(result)} results")
                return result
                
        except RuntimeError:
            # No event loop is running, we can create a new one
            logger.warning("🔍 No event loop running, creating new one")
            result = self._run_in_new_loop(requests)
            logger.warning(f"🔍 New event loop completed successfully: {len(result)} results")
            return result
        except Exception as e:
            logger.error(f"🔍 Exception in select_codes_batch_sync: {e}")
            import traceback
            traceback.print_exc()
            return []
    def _run_in_new_loop(self, requests: List[CodeSelectionRequest]) -> List[CodeSelectionResult]:
        """
        Run the batch processing in a new event loop
        
        Args:
            requests: List of code selection requests
            
        Returns:
            List of code selection results
        """
        try:
            # Create a new event loop for this thread
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            
            try:
                return loop.run_until_complete(self.select_codes_batch(requests))
            finally:
                loop.close()
                
        except Exception as e:
            logger.error(f"Error in _run_in_new_loop: {str(e)}")
            # Return failed results for all requests
            return [
                CodeSelectionResult(
                    datapoint_id=request.datapoint_id,
                    selected_codes=[],
                    processing_time=0.0,
                    success=False,
                    error_message=f"Event loop error: {str(e)}"
                )
                for request in requests
            ]

    def _process_llm_requests_with_retries(self, requests: List[CodeSelectionRequest], max_retries: int = 3) -> List[CodeSelectionResult]:
        """Process LLM requests with retries and exponential backoff, retrying only failures."""
        if not requests:
            return []

        pending: List[CodeSelectionRequest] = list(requests)
        final_results: Dict[str, CodeSelectionResult] = {}

        backoff_seconds = 1.0
        for attempt in range(1, max_retries + 1):
            try:
                batch_results = self.select_codes_batch_sync(pending)
            except Exception as e:
                logger.error(f"Batch attempt {attempt}/{max_retries} failed with exception: {e}")
                batch_results = []

            # Partition successes and failures
            failures: List[CodeSelectionRequest] = []
            seen_ids: set = set()

            # Map request by id for quick lookup
            req_by_id = {r.datapoint_id: r for r in pending}

            for res in batch_results:
                seen_ids.add(res.datapoint_id)
                if res.success:
                    final_results[res.datapoint_id] = res
                else:
                    # Keep for retry
                    if res.datapoint_id in req_by_id:
                        failures.append(req_by_id[res.datapoint_id])

            # Any request that didn't return a result at all -> treat as failure
            for req in pending:
                if req.datapoint_id not in seen_ids:
                    failures.append(req)

            if not failures:
                logger.info(f"All requests succeeded on attempt {attempt}")
                break

            if attempt < max_retries:
                logger.warning(f"{len(failures)} requests failed on attempt {attempt}. Retrying after {backoff_seconds:.1f}s...")
                time.sleep(backoff_seconds)
                backoff_seconds = min(backoff_seconds * 2, 8.0)
                pending = failures
            else:
                logger.error(f"Exhausted retries. {len(failures)} requests failed after {max_retries} attempts.")
                # Record placeholder failure results for remaining
                for req in failures:
                    if req.datapoint_id not in final_results:
                        final_results[req.datapoint_id] = CodeSelectionResult(
                            datapoint_id=req.datapoint_id,
                            selected_codes=[],
                            processing_time=0.0,
                            success=False,
                            error_message=f"Failed after {max_retries} attempts"
                        )

        # Return results in original order
        ordered: List[CodeSelectionResult] = []
        for req in requests:
            if req.datapoint_id in final_results:
                ordered.append(final_results[req.datapoint_id])
        return ordered


class RefinementPipeline:
    """
    Main refinement pipeline that combines datapoint retrieval with LLM code selection
    """
    
    def __init__(self,
                 embeddings_path: str,
                 mapping_dir: str,
                 cliques_dir: str,
                 model_url: str = None,
                 max_concurrency: int = 64):
        """
        Initialize the refinement pipeline
        
        Args:
            embeddings_path: Path to embeddings.parquet
            mapping_dir: Path to datapoint_code_mapping directory
            cliques_dir: Path to cliques directory
            model_url: URL for the VLLM model
            max_concurrency: Maximum concurrent LLM requests
        """
        try:
            from retrieval.datapoint_retrieval import DatapointRetriever
            
            self.datapoint_retriever = DatapointRetriever(
                embeddings_path=embeddings_path,
                mapping_dir=mapping_dir,
                cliques_dir=cliques_dir,
                sample_size=20,
                total_codes_per_original=20,  # 20 codes per original code
                max_workers=4
            )
        except ImportError as e:
            logger.warning(f"Failed to import DatapointRetriever: {e}")
            logger.info("Using mock datapoint retriever for testing")
            self.datapoint_retriever = None
        
        self.llm_selector = LLMCodeSelector(
            model_url=model_url,
            max_concurrency=max_concurrency
        )
        self.model_url = model_url
    
    def process_datapoints_batch(self, 
                                datapoint_ids: List[str],
                                previous_codes: Optional[List[str]] = None,
                                corpus_path: Optional[str] = None,
                                chunk_size: int = 100) -> List[Dict[str, Any]]:
        """
        Process a batch of datapoints through the full refinement pipeline
        
        Args:
            datapoint_ids: List of datapoint IDs to process
            previous_codes: Previous codebook codes (optional)
            corpus_path: Path to corpus for getting datapoint text (optional)
            chunk_size: Number of datapoints per chunk (default 100)
            
        Returns:
            List of results with selected codes for each datapoint
        """
        if not datapoint_ids:
            logger.warning("No datapoint IDs provided")
            return []
            
        if previous_codes is None:
            previous_codes = []
        
        logger.info(f"Processing {len(datapoint_ids)} datapoints in chunks of {chunk_size}...")
        start_time = time.time()
        
        # Load corpus for datapoint text if path provided
        datapoint_texts = {}
        if corpus_path and os.path.exists(corpus_path):
            try:
                import pandas as pd
                corpus_df = pd.read_parquet(corpus_path)
                if 'datapoint' in corpus_df.columns and 'datapoint_text' in corpus_df.columns:
                    datapoint_texts = dict(zip(corpus_df['datapoint'], corpus_df['datapoint_text']))
                    logger.info(f"Loaded {len(datapoint_texts)} datapoint texts from corpus")
                elif 'chunk_text' in corpus_df.columns:
                    # Use chunk_text as datapoint text
                    for i, text in enumerate(corpus_df['chunk_text']):
                        datapoint_texts[text] = text
                    logger.info(f"Loaded {len(datapoint_texts)} datapoint texts from corpus")
            except Exception as e:
                logger.warning(f"Failed to load corpus: {e}")
        
        # Process datapoints in chunks
        all_results: List[Dict[str, Any]] = []
        total_chunks = (len(datapoint_ids) + chunk_size - 1) // chunk_size
        for i in range(0, len(datapoint_ids), chunk_size):
            chunk_datapoints = datapoint_ids[i:i + chunk_size]
            chunk_num = (i // chunk_size) + 1
            logger.info(f"📦 Processing chunk {chunk_num}/{total_chunks} ({len(chunk_datapoints)} datapoints)")
            chunk_start = time.time()

            chunk_results = self._process_datapoint_chunk(
                chunk_datapoints,
                previous_codes,
                datapoint_texts
            )
            all_results.extend(chunk_results)
            logger.info(f"   ✅ Chunk {chunk_num} completed in {time.time() - chunk_start:.2f}s: {len(chunk_results)} results")

        logger.info(f"Completed {len(datapoint_ids)} datapoints in {time.time() - start_time:.2f}s")
        return all_results
    
    def save_results(self, results: List[Dict[str, Any]], output_path: str):
        """
        Save refinement results to parquet file
        
        Args:
            results: List of refinement results
            output_path: Path to save the results
        """
        import pandas as pd
        import os
        
        # Create output directory if it doesn't exist
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Convert results to DataFrame
        data = []
        for result in results:
            if result['success'] and result['selected_codes']:
                for code in result['selected_codes']:
                    data.append({
                        'datapoint': result['datapoint_id'],
                        'code': code,
                        'datapoint_text': result['datapoint_text']
                    })
        
        df = pd.DataFrame(data)
        df.to_parquet(output_path, index=False)
        logger.info(f"Saved {len(df)} refined codes to {output_path}")

    def _process_datapoint_chunk(self,
                                 chunk_datapoints: List[str],
                                 previous_codes: List[str],
                                 datapoint_texts: Dict[str, str]) -> List[Dict[str, Any]]:
        """Process a chunk of datapoints end-to-end and return results."""
        results: List[Dict[str, Any]] = []
        requests: List[CodeSelectionRequest] = []

        # Build requests for this chunk
        for datapoint_id in chunk_datapoints:
            # Resolve datapoint text
            dp_text = datapoint_texts.get(datapoint_id)
            if not dp_text:
                # Fallback: use the ID itself if corpus mapping missing
                dp_text = str(datapoint_id)

            candidate_codes: List[str] = []
            try:
                if getattr(self, 'datapoint_retriever', None) is not None:
                    retrieval = self.datapoint_retriever.retrieve_for_datapoint(datapoint_id, use_parallel=True)
                    candidate_codes = retrieval.get('candidate_codes', []) or []
            except Exception as e:
                logger.warning(f"Candidate retrieval failed for {datapoint_id}: {e}")

            requests.append(CodeSelectionRequest(
                datapoint_id=datapoint_id,
                datapoint_text=dp_text,
                previous_codes=previous_codes or [],
                candidate_codes=candidate_codes,
            ))

        # Execute requests
        try:
            # Create mapping from datapoint_id to datapoint_text
            id_to_text = {req.datapoint_id: req.datapoint_text for req in requests}
            
            selection_results = self.llm_selector._process_llm_requests_with_retries(requests, max_retries=self.llm_selector.max_retries)
            for r in selection_results:
                results.append({
                    'datapoint_id': r.datapoint_id,
                    'selected_codes': r.selected_codes,
                    'processing_time': r.processing_time,
                    'success': r.success,
                    'error_message': r.error_message,
                    'datapoint_text': id_to_text.get(r.datapoint_id, ''),
                })
        except Exception as e:
            logger.error(f"Chunk processing failure: {e}")

        return results


if __name__ == "__main__":
    # Test the fixed implementation
    import sys
    import os
    
    # Add the current directory to the path
    sys.path.append(os.path.dirname(__file__))
    
    # Test with a simple request
    selector = LLMCodeSelector(
        max_concurrency=64,
        timeout=120,
        max_retries=3
    )
    
    test_request = CodeSelectionRequest(
        datapoint_id="test_datapoint",
        datapoint_text="This is a test datapoint for code selection.",
        previous_codes=[],
        candidate_codes=["test_code_1", "test_code_2", "test_code_3"]
    )
    
    print("Testing fixed LLM code selector...")
    results = selector.select_codes_batch_sync([test_request])
    
    for result in results:
        print(f"Result: {result}")