#!/usr/bin/env python3
"""
Build a mixed text corpus from heterogeneous inputs (txt/json/jsonl/csv/md),
chunk into smaller pieces, randomly sample 5–10 chunks to generate initial
multi-level codebooks with respect to a question (few-shot seed), then apply
few-shot prompting to the rest to generate codebooks, flatten all tags, and
save the final corpus as Parquet in temp_files.

Outputs:
- temp_files/corpus.parquet (columns: source_path, chunk_index, level, tag)

Usage examples:
  python build_corpus.py \
    --question "How to improve productivity?" \
    --chunk-size 128 --overlap 24 --sample-size 0

If --input is omitted, the first valid data file under the data directory is used automatically.
"""

import os
import json
import csv
import argparse
import random
import time
import pandas as pd
from typing import List, Dict, Iterable, Tuple, Optional

# Increase CSV field size limit to handle large text fields
csv.field_size_limit(1000000)  # 1MB field size limit

try:
    import pyarrow as pa
    import pyarrow.parquet as pq
except Exception as e:
    print("❌ pyarrow is required. Please install pyarrow.")
    raise

# Load .env for environment variables if present
try:
    from dotenv import load_dotenv
    load_dotenv()
except Exception:
    pass

THIS_DIR = os.path.abspath(os.path.dirname(__file__))
TEMP_FILES_DIR = os.path.abspath(os.path.join(THIS_DIR, "..", "..", "temp_files"))
DATA_DIR = os.path.join(THIS_DIR, "..", "..", "data")
DEFAULT_CORPUS = os.path.join(TEMP_FILES_DIR, "corpus.parquet")
DEFAULT_SAMPLE = os.path.join(TEMP_FILES_DIR, "corpus_sample.parquet")  # unused in new flow, kept for CLI compat

SUPPORTED_EXTS = {".txt", ".md", ".json", ".jsonl", ".csv"}

# Import existing pipeline components for reuse
# These will be imported when needed to avoid import issues
# from .embeddings import build_embeddings_parquet
# from .cluster import cluster_fast
# from .cosine_sim import CosineSimilarity
# from .nli_classify import classify_similarities_optimized, NLIClassifier
# from .topological_graph import build_topological_graph_fast
# from .flip_label_processing import process_nli_results_for_conflict_detection
# from .conflict_relationship_detection.conflict_detection_resolver import detect_and_resolve_conflicts_advanced
# from .temp_build_corpus import TempBuildCorpus, RefinementContext, AsyncVLLMClient, VLLM_BASE_URL
# from .prompts import (
#     get_prompt, fill_template, STRATEGY_1_DIRECT_PROMPT, 
#     STRATEGY_2_LOW_LEVEL_PROMPT, STRATEGY_2_MEDIUM_LEVEL_PROMPT, STRATEGY_2_HIGH_LEVEL_PROMPT,
#     STRICT_JSON_SUFFIX, INITIAL_CODE_GENERATION_PROMPT, CODEBOOK_GENERATION_PROMPT
# )

# --------------------------- vLLM async clients -----------------------------
import asyncio
import aiohttp

# Use separate URLs for embeddings and chat completions
VLLM_EMBEDDING_URL = os.getenv("VLLM_EMBEDDING_URL")
VLLM_TEXT_URL = os.getenv("VLLM_QWEN_32B_URL")  # Default chat completion server
VLLM_BASE_URL = VLLM_TEXT_URL or VLLM_EMBEDDING_URL

# Use correct models for each purpose
DEFAULT_TEXT_MODEL = os.getenv("VLLM_QWEN_32B_MODEL")  # Chat completion model
DEFAULT_EMBEDDING_MODEL = os.getenv("DEFAULT_EMBEDDING_MODEL")  # Embedding model
MAX_CONCURRENCY = int(os.getenv("VLLM_MAX_CONCURRENCY", "128"))
REQUEST_TIMEOUT = int(os.getenv("VLLM_TIMEOUT", "120"))

class AsyncVLLMClient:
    def __init__(self, base_url: str, timeout: int = 120):
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self.session: Optional[aiohttp.ClientSession] = None
        # Use separate URLs for embeddings and chat completions
        self.embedding_url = VLLM_EMBEDDING_URL.rstrip("/") if VLLM_EMBEDDING_URL else base_url.rstrip("/")
        self.chat_url = VLLM_TEXT_URL.rstrip("/") if VLLM_TEXT_URL else base_url.rstrip("/")

    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        if self.session:
            await self.session.close()

    async def embeddings(self, model: str, inputs: List[str]) -> Optional[Dict]:
        if not self.session:
            raise RuntimeError("Client session not initialized")
        payload = {"model": model, "input": inputs}
        try:
            async with self.session.post(f"{self.embedding_url}/v1/embeddings", json=payload, timeout=REQUEST_TIMEOUT) as resp:
                if resp.status == 200:
                    return await resp.json()
                else:
                    txt = await resp.text()
                    print(f"⚠️ embeddings HTTP {resp.status}: {txt}")
                    return None
        except Exception as e:
            print(f"⚠️ embeddings request failed: {e}")
            return None

    async def chat_completion(self, model: str, messages: List[Dict[str, str]], temperature: float = 0.2, max_tokens: int = 1024) -> Optional[Dict]:
        if not self.session:
            raise RuntimeError("Client session not initialized")
        payload = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
        try:
            async with self.session.post(f"{self.chat_url}/v1/chat/completions", json=payload, timeout=REQUEST_TIMEOUT) as resp:
                if resp.status == 200:
                    return await resp.json()
                else:
                    txt = await resp.text()
                    print(f"⚠️ chat HTTP {resp.status}: {txt}")
                    return None
        except Exception as e:
            print(f"⚠️ chat request failed: {e}")
            return None

    async def chat_completion_batch(self, model: str, messages_list: List[List[Dict[str, str]]], 
                                  temperature: float = 0.2, max_tokens: int = 1024) -> List[Optional[Dict]]:
        """Batch multiple chat completions for better throughput and KV cache utilization"""
        if not self.session:
            raise RuntimeError("Client session not initialized")
        
        # Create batch payload
        batch_payload = []
        for messages in messages_list:
            batch_payload.append({
                "model": model,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens,
            })
        
        try:
            # Use batch endpoint if available, otherwise fall back to individual calls
            async with self.session.post(f"{self.chat_url}/v1/chat/completions/batch", json=batch_payload, timeout=REQUEST_TIMEOUT) as resp:
                if resp.status == 200:
                    return await resp.json()
                else:
                    # Fall back to individual calls
                    print(f"⚠️ Batch endpoint not available, falling back to individual calls")
                    return await self._chat_completion_individual_batch(model, messages_list, temperature, max_tokens)
        except Exception as e:
            print(f"⚠️ Batch chat request failed, falling back to individual calls: {e}")
            return await self._chat_completion_individual_batch(model, messages_list, temperature, max_tokens)
    
    async def _chat_completion_individual_batch(self, model: str, messages_list: List[List[Dict[str, str]]], 
                                              temperature: float = 0.2, max_tokens: int = 1024) -> List[Optional[Dict]]:
        """Fallback: process batch as individual requests with concurrency"""
        tasks = []
        for messages in messages_list:
            task = self.chat_completion(model, messages, temperature, max_tokens)
            tasks.append(task)
        
        return await asyncio.gather(*tasks, return_exceptions=True)


# --------------------------- Question-Chunk Similarity Filtering -----------------------------

async def filter_chunks_by_similarity(
    client: AsyncVLLMClient, 
    question: str, 
    chunks: List[Tuple[str, int, str]], 
    similarity_threshold: float = 0.6
) -> List[Tuple[str, int, str]]:
    """
    Filter chunks based on cosine similarity with the question.
    
    Args:
        client: AsyncVLLMClient for embeddings
        question: The research question
        chunks: List of (source_path, chunk_idx, text) tuples
        similarity_threshold: Minimum similarity to keep chunks (default 0.6)
    
    Returns:
        Filtered list of chunks with similarity >= threshold
    """
    import numpy as np
    print(f"🔍 Filtering {len(chunks)} chunks by similarity with question (threshold: {similarity_threshold:.1%})")
    
    # Extract chunk texts
    chunk_texts = [chunk_text for _, _, chunk_text in chunks]
    
    # Prepare texts for embedding: question + all chunks
    all_texts = [question] + chunk_texts
    
    # Get embeddings for all texts
    print("📊 Generating embeddings for question and chunks...")
    
    # Use embedding model
    embedding_model = DEFAULT_EMBEDDING_MODEL or "text-embedding-ada-002"
    
    # Batch embedding requests for efficiency
    batch_size = 50  # Process in batches to avoid memory issues
    all_embeddings = []
    
    for i in range(0, len(all_texts), batch_size):
        batch_texts = all_texts[i:i + batch_size]
        
        try:
            response = await client.embeddings(embedding_model, batch_texts)
            if response and 'data' in response:
                batch_embeddings = [item['embedding'] for item in response['data']]
                all_embeddings.extend(batch_embeddings)
                print(f"  ✅ Embedded batch {i//batch_size + 1}/{(len(all_texts) + batch_size - 1)//batch_size}")
            else:
                print(f"  ❌ Failed to get embeddings for batch {i//batch_size + 1}")
                # Use zero embeddings as fallback
                embedding_dim = 1536  # Default for text-embedding models
                batch_embeddings = [[0.0] * embedding_dim] * len(batch_texts)
                all_embeddings.extend(batch_embeddings)
        except Exception as e:
            print(f"  ⚠️ Error getting embeddings for batch {i//batch_size + 1}: {e}")
            # Use zero embeddings as fallback
            embedding_dim = 1536
            batch_embeddings = [[0.0] * embedding_dim] * len(batch_texts)
            all_embeddings.extend(batch_embeddings)
    
    if len(all_embeddings) != len(all_texts):
        print(f"⚠️ Warning: Expected {len(all_texts)} embeddings, got {len(all_embeddings)}")
        return chunks  # Return all chunks if embedding failed
    
    # Convert to numpy arrays
    embeddings = np.array(all_embeddings)
    
    # Extract question embedding (first one)
    question_embedding = embeddings[0:1]  # Shape: (1, embedding_dim)
    chunk_embeddings = embeddings[1:]     # Shape: (num_chunks, embedding_dim)
    
    # Compute cosine similarity between question and each chunk
    # Normalize embeddings
    question_norm = question_embedding / np.linalg.norm(question_embedding, axis=1, keepdims=True)
    chunks_norm = chunk_embeddings / np.linalg.norm(chunk_embeddings, axis=1, keepdims=True)
    
    # Compute cosine similarities
    similarities = np.dot(chunks_norm, question_norm.T).flatten()
    
    # Filter chunks based on similarity threshold
    filtered_chunks = []
    filtered_similarities = []
    
    for i, (chunk, similarity) in enumerate(zip(chunks, similarities)):
        if similarity >= similarity_threshold:
            filtered_chunks.append(chunk)
            filtered_similarities.append(similarity)
    
    # Report results
    print(f"✅ Similarity filtering complete:")
    print(f"   📊 Original chunks: {len(chunks)}")
    print(f"   ✅ Kept chunks: {len(filtered_chunks)} ({len(filtered_chunks)/len(chunks)*100:.1f}%)")
    print(f"   ❌ Filtered out: {len(chunks) - len(filtered_chunks)} ({(len(chunks) - len(filtered_chunks))/len(chunks)*100:.1f}%)")
    
    if filtered_similarities:
        print(f"   📈 Similarity range: {min(filtered_similarities):.3f} - {max(filtered_similarities):.3f}")
        print(f"   📊 Average similarity: {np.mean(filtered_similarities):.3f}")
    
    return filtered_chunks



# --------------------------- Codebook quality evaluation -----------------------------

def evaluate_codebook_quality(codebook: Dict[str, List[str]]) -> Dict[str, float]:
    """Evaluate the quality of a generated codebook"""
    quality_metrics = {}
    
    # Check if codebook has the expected structure
    expected_levels = ['low-level', 'mid-level', 'high-level']
    has_structure = all(level in codebook for level in expected_levels)
    quality_metrics['structure_completeness'] = 1.0 if has_structure else 0.0
    
    # Count total tags
    total_tags = sum(len(tags) for tags in codebook.values())
    quality_metrics['total_tags'] = total_tags
    
    # Calculate tag distribution
    if has_structure:
        low_count = len(codebook.get('low-level', []))
        mid_count = len(codebook.get('mid-level', []))
        high_count = len(codebook.get('high-level', []))
        
        # Ideal distribution: more low-level, fewer high-level
        if total_tags > 0:
            quality_metrics['low_level_ratio'] = low_count / total_tags
            quality_metrics['mid_level_ratio'] = mid_count / total_tags
            quality_metrics['high_level_ratio'] = high_count / total_tags
            
            # Distribution score (prefer more low-level tags)
            distribution_score = (low_count * 0.5 + mid_count * 0.3 + high_count * 0.2) / total_tags
            quality_metrics['distribution_score'] = distribution_score
        else:
            quality_metrics['low_level_ratio'] = 0.0
            quality_metrics['mid_level_ratio'] = 0.0
            quality_metrics['high_level_ratio'] = 0.0
            quality_metrics['distribution_score'] = 0.0
    else:
        quality_metrics['low_level_ratio'] = 0.0
        quality_metrics['mid_level_ratio'] = 0.0
        quality_metrics['high_level_ratio'] = 0.0
        quality_metrics['distribution_score'] = 0.0
    
    # Calculate tag length diversity
    all_tags = []
    for tags in codebook.values():
        all_tags.extend(tags)
    
    if all_tags:
        tag_lengths = [len(tag) for tag in all_tags]
        quality_metrics['avg_tag_length'] = sum(tag_lengths) / len(tag_lengths)
        quality_metrics['tag_length_std'] = (sum((l - quality_metrics['avg_tag_length']) ** 2 for l in tag_lengths) / len(tag_lengths)) ** 0.5
    else:
        quality_metrics['avg_tag_length'] = 0.0
        quality_metrics['tag_length_std'] = 0.0
    
    # Calculate overall quality score
    structure_weight = 0.3
    distribution_weight = 0.4
    diversity_weight = 0.3
    
    overall_score = (
        quality_metrics['structure_completeness'] * structure_weight +
        quality_metrics['distribution_score'] * distribution_weight +
        min(quality_metrics['total_tags'] / 20.0, 1.0) * diversity_weight  # Normalize to max 20 tags
    )
    
    quality_metrics['overall_quality_score'] = overall_score
    
    return quality_metrics

def evaluate_corpus_quality(all_records: List[Tuple[str, int, str, str]]) -> Dict[str, float]:
    """Evaluate the overall quality of the generated corpus"""
    if not all_records:
        return {'overall_quality': 0.0, 'total_tags': 0, 'unique_tags': 0}
    
    # Extract all tags
    all_tags = [record[3] for record in all_records]  # record[3] is the tag
    
    # Calculate basic metrics
    total_tags = len(all_tags)
    unique_tags = len(set(all_tags))
    tag_diversity = unique_tags / total_tags if total_tags > 0 else 0.0
    
    # Calculate tag length statistics
    tag_lengths = [len(tag) for tag in all_tags]
    avg_tag_length = sum(tag_lengths) / len(tag_lengths) if tag_lengths else 0.0
    
    # Calculate level distribution
    levels = [record[2] for record in all_records]  # record[2] is the level
    level_counts = {}
    for level in levels:
        level_counts[level] = level_counts.get(level, 0) + 1
    
    # Calculate level distribution score
    total_levels = sum(level_counts.values())
    if total_levels > 0:
        low_ratio = level_counts.get('low-level', 0) / total_levels
        mid_ratio = level_counts.get('mid-level', 0) / total_levels
        high_ratio = level_counts.get('high-level', 0) / total_levels
        
        # Ideal distribution: more low-level, fewer high-level
        distribution_score = low_ratio * 0.5 + mid_ratio * 0.3 + high_ratio * 0.2
    else:
        distribution_score = 0.0
    
    # Calculate overall quality
    overall_quality = (
        tag_diversity * 0.3 +
        distribution_score * 0.4 +
        min(total_tags / 1000.0, 1.0) * 0.3  # Normalize to max 1000 tags
    )
    
    return {
        'overall_quality': overall_quality,
        'total_tags': total_tags,
        'unique_tags': unique_tags,
        'tag_diversity': tag_diversity,
        'avg_tag_length': avg_tag_length,
        'distribution_score': distribution_score,
        'level_distribution': level_counts
    }

# --------------------------- Chunk size testing -----------------------------

async def test_chunk_sizes(client: AsyncVLLMClient, input_path: str, question: str, chunk_sizes: List[int] = [256], overlap_ratio: float = 0.2) -> Dict[int, Dict]:
    """Test different chunk sizes and evaluate quality and performance"""
    # Import required modules
    import sys
    import os
    sys.path.append(os.path.dirname(__file__))
    from .prompts import fill_template, STRATEGY_1_DIRECT_PROMPT
    
    results = {}
    
    for chunk_size in chunk_sizes:
        overlap = int(chunk_size * overlap_ratio)
        print(f"\n🧪 Testing chunk size: {chunk_size} (overlap: {overlap})")
        
        # Build chunks with this size
        chunks = build_text_chunks(input_path, chunk_size, overlap, seed=42)
        print(f"  📊 Generated {len(chunks)} chunks")
        
        if len(chunks) == 0:
            print(f"  ⚠️ No chunks generated for size {chunk_size}")
            continue
        
        # Take a sample for testing (first 10 chunks)
        test_chunks = chunks[:10]
        
        # Load prompts
        initial_tmpl = STRATEGY_1_DIRECT_PROMPT
        initial_prompt = fill_template(initial_tmpl, {"QUESTION": question}) if "{QUESTION}" in initial_tmpl else (f"Question: {question}\n\n" + initial_tmpl)
        
        # Test with optimal concurrency (64)
        semaphore = asyncio.Semaphore(64)
        
        start_time = time.time()
        tasks = [
            generate_codebook_for_chunk_base(client, initial_prompt, chunk_text, semaphore, "strategy_1") 
            for (_, _, chunk_text) in test_chunks
        ]
        results_list = await asyncio.gather(*tasks)
        end_time = time.time()
        
        # Calculate performance metrics
        successful = sum(1 for r in results_list if r is not None)
        total_time = end_time - start_time
        throughput = successful / total_time if total_time > 0 else 0
        
        # Calculate quality metrics
        quality_scores = []
        total_tags = 0
        unique_tags = set()
        
        for codebook in results_list:
            if codebook:
                quality = evaluate_codebook_quality(codebook)
                quality_scores.append(quality['overall_quality_score'])
                
                # Count tags
                for tags in codebook.values():
                    total_tags += len(tags)
                    unique_tags.update(tags)
        
        avg_quality = sum(quality_scores) / len(quality_scores) if quality_scores else 0.0
        
        results[chunk_size] = {
            'total_chunks': len(test_chunks),
            'successful_chunks': successful,
            'success_rate': successful / len(test_chunks) * 100,
            'total_time': total_time,
            'throughput': throughput,
            'avg_time_per_chunk': total_time / len(test_chunks) if len(test_chunks) > 0 else 0,
            'avg_quality_score': avg_quality,
            'total_tags': total_tags,
            'unique_tags': len(unique_tags),
            'tag_diversity': len(unique_tags) / total_tags if total_tags > 0 else 0.0,
            'avg_tags_per_chunk': total_tags / successful if successful > 0 else 0.0
        }
        
        print(f"  ⏱️ Time: {total_time:.2f}s")
        print(f"  ✅ Success: {successful}/{len(test_chunks)} ({successful/len(test_chunks)*100:.1f}%)")
        print(f"  ⚡ Throughput: {throughput:.2f} chunks/second")
        print(f"  🎯 Quality Score: {avg_quality:.3f}")
        print(f"  📝 Tags: {total_tags} total, {len(unique_tags)} unique ({len(unique_tags)/total_tags*100:.1f}% diversity)")
        print(f"  📊 Avg tags per chunk: {total_tags/successful:.1f}" if successful > 0 else "  📊 Avg tags per chunk: 0.0")
    
    return results

def print_chunk_size_analysis(results: Dict[int, Dict]):
    """Print a formatted analysis of chunk size test results"""
    print("\n" + "="*80)
    print("📊 CHUNK SIZE PERFORMANCE & QUALITY ANALYSIS")
    print("="*80)
    
    # Find best performing chunk size (balanced score)
    best_score = 0
    best_chunk_size = 0
    
    print(f"{'Chunk Size':<12} {'Time (s)':<10} {'Success Rate':<12} {'Throughput':<12} {'Quality':<10} {'Tags/Chunk':<12} {'Diversity':<10}")
    print("-" * 90)
    
    for chunk_size in sorted(results.keys()):
        result = results[chunk_size]
        
        # Calculate balanced score (performance + quality)
        performance_score = result['throughput'] / max(r['throughput'] for r in results.values()) if max(r['throughput'] for r in results.values()) > 0 else 0
        quality_score = result['avg_quality_score']
        balanced_score = (performance_score * 0.4 + quality_score * 0.6)
        
        print(f"{chunk_size:<12} {result['total_time']:<10.2f} {result['success_rate']:<12.1f}% {result['throughput']:<12.2f} {result['avg_quality_score']:<10.3f} {result['avg_tags_per_chunk']:<12.1f} {result['tag_diversity']:<10.1%}")
        
        if balanced_score > best_score:
            best_score = balanced_score
            best_chunk_size = chunk_size
    
    print("-" * 90)
    print(f"🏆 Best balanced performance: {best_chunk_size} chunk size (score: {best_score:.3f})")
    
    return best_chunk_size

# --------------------------- Concurrency testing -----------------------------

async def test_concurrency_levels(client: AsyncVLLMClient, test_chunks: List[Tuple[str, int, str]], prompt: str, max_concurrency: int = 128) -> Dict[int, Dict]:
    """Test different concurrency levels and measure performance"""
    results = {}
    
    # Test concurrency levels: 8, 16, 32, 64, 128
    concurrency_levels = [8, 16, 32, 64, 128] if max_concurrency >= 128 else [8, 16, 32, 64]
    
    for concurrency in concurrency_levels:
        if concurrency > max_concurrency:
            continue
            
        print(f"\n🧪 Testing concurrency level: {concurrency}")
        semaphore = asyncio.Semaphore(concurrency)
        
        # Use a subset of chunks for testing (first 20)
        test_subset = test_chunks[:20]
        
        start_time = time.time()
        tasks = [
            generate_codebook_for_chunk_base(client, prompt, chunk_text, semaphore, "strategy_1") 
            for (_, _, chunk_text) in test_subset
        ]
        results_list = await asyncio.gather(*tasks)
        end_time = time.time()
        
        successful = sum(1 for r in results_list if r is not None)
        total_time = end_time - start_time
        throughput = successful / total_time if total_time > 0 else 0
        
        results[concurrency] = {
            'total_time': total_time,
            'successful': successful,
            'total_chunks': len(test_subset),
            'success_rate': successful / len(test_subset) * 100,
            'throughput': throughput,
            'avg_time_per_chunk': total_time / len(test_subset) if len(test_subset) > 0 else 0
        }
        
        print(f"  ⏱️ Time: {total_time:.2f}s")
        print(f"  ✅ Success: {successful}/{len(test_subset)} ({successful/len(test_subset)*100:.1f}%)")
        print(f"  ⚡ Throughput: {throughput:.2f} chunks/second")
        print(f"  🎯 Avg time per chunk: {total_time/len(test_subset):.2f}s")
    
    return results

def print_concurrency_analysis(results: Dict[int, Dict]):
    """Print a formatted analysis of concurrency test results"""
    print("\n" + "="*80)
    print("🎯 CONCURRENCY PERFORMANCE ANALYSIS")
    print("="*80)
    
    # Find best performing concurrency level
    best_throughput = 0
    best_concurrency = 0
    
    print(f"{'Concurrency':<12} {'Time (s)':<10} {'Success Rate':<12} {'Throughput':<12} {'Avg/Chunk':<12}")
    print("-" * 70)
    
    for concurrency in sorted(results.keys()):
        result = results[concurrency]
        print(f"{concurrency:<12} {result['total_time']:<10.2f} {result['success_rate']:<12.1f}% {result['throughput']:<12.2f} {result['avg_time_per_chunk']:<12.2f}")
        
        if result['throughput'] > best_throughput:
            best_throughput = result['throughput']
            best_concurrency = concurrency
    
    print("-" * 70)
    print(f"🏆 Best performance: {best_concurrency} concurrent requests ({best_throughput:.2f} chunks/second)")
    
    # Calculate improvement percentages
    if len(results) > 1:
        min_concurrency = min(results.keys())
        max_concurrency = max(results.keys())
        min_throughput = results[min_concurrency]['throughput']
        max_throughput = results[max_concurrency]['throughput']
        
        if min_throughput > 0:
            improvement = ((max_throughput - min_throughput) / min_throughput) * 100
            print(f"📈 Improvement from {min_concurrency} to {max_concurrency}: {improvement:.1f}%")
    
    return best_concurrency

# --------------------------- Prompt utilities ------------------------------

# Function to fill template with variables
def fill_template(template: str, mapping: Dict[str, str]) -> str:
    out = template
    for k, v in mapping.items():
        out = out.replace("{" + k + "}", v)
    return out

STRICT_JSON_SUFFIX = (
    "\n\nIMPORTANT: Respond with JSON ONLY. Do not include explanations, prose, or code fences. "
    'Return exactly this schema: {"low-level": [...], "mid-level": [...], "high-level": [...]}'
)

# --------------------------- File ingestion --------------------------------

def iter_paths(input_path: str) -> Iterable[str]:
    if os.path.isdir(input_path):
        for root, _, files in os.walk(input_path):
            for f in files:
                ext = os.path.splitext(f)[1].lower()
                if ext in SUPPORTED_EXTS:
                    yield os.path.join(root, f)
    else:
        ext = os.path.splitext(input_path)[1].lower()
        if ext in SUPPORTED_EXTS:
            yield input_path


def read_txt_like(path: str) -> List[str]:
    try:
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            txt = f.read()
        return [txt]
    except Exception as e:
        print(f"⚠️ Failed to read text file {path}: {e}")
        return []


def read_csv_file(path: str) -> List[str]:
    rows: List[str] = []
    try:
        with open(path, "r", encoding="utf-8", errors="ignore", newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                parts = []
                for _, v in row.items():
                    if v is None:
                        continue
                    vs = str(v).strip()
                    if vs:
                        parts.append(vs)
                if parts:
                    rows.append(" \n".join(parts))
    except Exception as e:
        print(f"⚠️ Failed to read csv {path}: {e}")
    return rows


def read_json_file(path: str) -> List[str]:
    items: List[str] = []
    try:
        if path.endswith(".jsonl"):
            with open(path, "r", encoding="utf-8", errors="ignore") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    try:
                        obj = json.loads(line)
                        items.extend(extract_strings_from_json(obj))
                    except Exception:
                        continue
        else:
            with open(path, "r", encoding="utf-8", errors="ignore") as f:
                data = json.load(f)
            if isinstance(data, list):
                for obj in data:
                    items.extend(extract_strings_from_json(obj))
            else:
                items.extend(extract_strings_from_json(data))
    except Exception as e:
        print(f"⚠️ Failed to read json {path}: {e}")
    return items


def extract_strings_from_json(obj) -> List[str]:
    out: List[str] = []
    if obj is None:
        return out
    if isinstance(obj, str):
        s = obj.strip()
        if s:
            out.append(s)
        return out
    if isinstance(obj, dict):
        parts: List[str] = []
        for v in obj.values():
            if v is None:
                continue
            if isinstance(v, (str, int, float)):
                parts.append(str(v))
            elif isinstance(v, (list, tuple, dict)):
                nested = extract_strings_from_json(v)
                if nested:
                    parts.append(" ".join(nested))
        if parts:
            out.append(" \n".join(parts))
        return out
    if isinstance(obj, (list, tuple)):
        parts: List[str] = []
        for v in obj:
            if v is None:
                continue
            if isinstance(v, (str, int, float)):
                parts.append(str(v))
            else:
                nested = extract_strings_from_json(v)
                if nested:
                    parts.append(" ".join(nested))
        if parts:
            out.append(" \n".join(parts))
        return out
    s = str(obj).strip()
    if s:
        out.append(s)
    return out

# --------------------------- Chunking --------------------------------------

def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
    if not text:
        return []
    tokens = text.split()
    if chunk_size <= 0:
        return [text]
    chunks: List[str] = []
    step = max(1, chunk_size - max(0, overlap))
    for start in range(0, len(tokens), step):
        end = min(start + chunk_size, len(tokens))
        if start >= end:
            break
        chunk = " ".join(tokens[start:end]).strip()
        if chunk:
            chunks.append(chunk)
        if end == len(tokens):
            break
    return chunks


def build_text_chunks(input_path: str, chunk_size: int, overlap: int, seed: int) -> List[Tuple[str, int, str]]:
    random.seed(seed)
    corpus: List[Tuple[str, int, str]] = []  # (source_path, chunk_idx, text)
    for path in iter_paths(input_path):
        ext = os.path.splitext(path)[1].lower()
        if ext in {".txt", ".md"}:
            texts = read_txt_like(path)
        elif ext in {".csv"}:
            texts = read_csv_file(path)
        elif ext in {".json", ".jsonl"}:
            texts = read_json_file(path)
        else:
            continue
        chunk_idx = 0
        for t in texts:
            chunks = chunk_text(t, chunk_size=chunk_size, overlap=overlap)
            for c in chunks:
                corpus.append((path, chunk_idx, c))
                chunk_idx += 1
    random.shuffle(corpus)
    return corpus

# --------------------------- Codebook helpers ------------------------------

def pick_seed_samples(corpus: List[Tuple[str, int, str]], low: int = 5, high: int = 10, seed: int = 42) -> List[Tuple[str, int, str]]:
    if not corpus:
        return []
    random.seed(seed)
    k = random.randint(low, min(high, len(corpus)))
    return random.sample(corpus, k)




def parse_codebook_json(text: str) -> Optional[Dict[str, List[str]]]:
    # Strip common code fences
    if text.strip().startswith("```"):
        # remove leading and trailing fenced blocks
        stripped = text.strip().strip("`")
        # try to find first { ... } or [ ... ]
        text = stripped
    
    # Fast path - try to parse as JSON
    try:
        obj = json.loads(text)
        
        # Handle new format: simple list of codes
        if isinstance(obj, list):
            # Convert list to the expected dictionary format
            return {
                "low-level": obj,  # All codes go to low-level for now
                "mid-level": [],
                "high-level": []
            }
        
        # Handle old format: dictionary with levels
        if isinstance(obj, dict):
            return obj
            
        return None
    except Exception:
        pass
    
    # Extract JSON object substring (try both {} and [])
    start_brace = text.find("{")
    end_brace = text.rfind("}")
    start_bracket = text.find("[")
    end_bracket = text.rfind("]")
    
    # Try dictionary format first
    if 0 <= start_brace < end_brace:
        candidate = text[start_brace:end_brace+1]
        candidate = candidate.replace("```json", "").replace("```", "").strip()
        try:
            obj = json.loads(candidate)
            if isinstance(obj, dict):
                return obj
        except Exception:
            pass
    
    # Try list format
    if 0 <= start_bracket < end_bracket:
        candidate = text[start_bracket:end_bracket+1]
        candidate = candidate.replace("```json", "").replace("```", "").strip()
        try:
            obj = json.loads(candidate)
            if isinstance(obj, list):
                return {
                    "low-level": obj,  # All codes go to low-level for now
                    "mid-level": [],
                    "high-level": []
                }
        except Exception:
            pass
    
    return None


def flatten_codebook(source_path: str, chunk_idx: int, codebook: Dict[str, List[str]], chunk_text: str = "", question_id: str = "") -> List[Tuple[str, int, str, str, str]]:
    records: List[Tuple[str, int, str, str, str]] = []  # (source_path, chunk_index, level, tag, chunk_text)
    for level_key in ["low-level", "mid-level", "high-level"]:
        tags = codebook.get(level_key) or codebook.get(level_key.replace("-", "_")) or []
        for tag in tags:
            if tag is None:
                continue
            # Add question_id to source_path if provided
            source_path_with_question = f"{source_path}_{question_id}" if question_id else source_path
            records.append((source_path_with_question, chunk_idx, level_key, str(tag), chunk_text))
    return records

# --------------------------- Defaults --------------------------------------

def find_first_valid_data_file() -> Optional[str]:
    if not os.path.isdir(DATA_DIR):
        return None
    for root, _, files in os.walk(DATA_DIR):
        for f in files:
            ext = os.path.splitext(f)[1].lower()
            if ext in SUPPORTED_EXTS:
                return os.path.join(root, f)
    return None

# --------------------------- Output ----------------------------------------

def write_parquet(records: List[Tuple[str, int, str, str, str]], out_path: str) -> None:
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    src = pa.array([r[0] for r in records], type=pa.string())
    idx = pa.array([r[1] for r in records], type=pa.int32())
    lvl = pa.array([r[2] for r in records], type=pa.string())
    tag = pa.array([r[3] for r in records], type=pa.string())
    chunk_text = pa.array([r[4] for r in records], type=pa.string())
    table = pa.Table.from_arrays([src, idx, lvl, tag, chunk_text], names=["source_path", "chunk_index", "level", "tag", "chunk_text"])
    pq.write_table(table, out_path)

# --------------------------- Async processing ------------------------------

async def generate_codebook_for_chunk_base(client: AsyncVLLMClient, prompt: str, chunk_text: str, semaphore: asyncio.Semaphore, 
                                          strategy: str = "strategy_1") -> Optional[Dict[str, List[str]]]:
    """
    Base function for codebook generation with different strategies.
    
    Args:
        client: AsyncVLLMClient instance
        prompt: The prompt template to use
        chunk_text: The text chunk to process
        semaphore: Semaphore for concurrency control
        strategy: Strategy to use ("strategy_1", "strategy_2", "strategy_3")
    
    Returns:
        Parsed codebook dictionary or None if failed
    """
    # Route to specific strategy implementation
    if strategy == "strategy_1":
        return await generate_codebook_for_chunk_1(client, prompt, chunk_text, semaphore)
    elif strategy == "strategy_2":
        return await generate_codebook_for_chunk_2(client, prompt, chunk_text, semaphore)
    elif strategy == "strategy_3":
        return await generate_codebook_for_chunk_3(client, prompt, chunk_text, semaphore)
    else:
        raise ValueError(f"Unknown strategy: {strategy}")


async def generate_codebook_for_chunk_1(client: AsyncVLLMClient, prompt: str, chunk_text: str, semaphore: asyncio.Semaphore) -> Optional[Dict[str, List[str]]]:
    """
    Strategy 1: Original implementation with embeddings pre-flight and retry logic.
    - Embeds prompt and chunk first
    - Uses retry logic with strict JSON instruction
    - Temperature 0.0 for deterministic output
    """
    # Import required modules
    import sys
    import os
    sys.path.append(os.path.dirname(__file__))
    from .prompts import fill_template, STRICT_JSON_SUFFIX
    
    # Fill the prompt with the chunk text
    filled_prompt = prompt.replace("{chunk_text}", chunk_text)
    
    # Chat with up to 2 attempts (second attempt adds strict JSON instruction)
    attempts = [filled_prompt, filled_prompt + STRICT_JSON_SUFFIX]
    for attempt_prompt in attempts:
        async with semaphore:
            messages = [
                {"role": "user", "content": attempt_prompt},
            ]
            # Use client's default model if available, otherwise fall back to DEFAULT_TEXT_MODEL
            model_to_use = getattr(client, 'default_text_model', DEFAULT_TEXT_MODEL)
            resp = await client.chat_completion(model_to_use, messages, temperature=0.0, max_tokens=256)  # Reduced further for vLLM efficiency

        if not resp:
            continue
        try:
            content = resp["choices"][0]["message"]["content"]
        except Exception:
            continue
        parsed = parse_codebook_json(content)
        if parsed:
            return parsed
    return None


async def generate_codebook_for_chunk_2(client: AsyncVLLMClient, prompt: str, chunk_text: str, semaphore: asyncio.Semaphore) -> Optional[Dict[str, List[str]]]:
    """
    Strategy 2: Hierarchical code generation approach.
    - Generate 5 low-level codes per chunk
    - Cluster low-level codes and generate 1 medium-level code per cluster
    - Cluster medium-level codes and generate 1 high-level code per cluster
    
    Note: This strategy requires access to all chunks for clustering, so it's handled specially in async_main.
    """
    # This strategy requires access to all chunks for clustering, so we'll implement it differently
    # For now, return None to indicate this strategy needs special handling
    return None


async def generate_codebook_hierarchical_strategy_2(client: AsyncVLLMClient, chunks: List[str], question: str, semaphore: asyncio.Semaphore) -> List[Dict[str, List[str]]]:
    """
    Strategy 2: Hierarchical code generation with relationship tracking.
    
    Process:
    1. Generate 5 low-level codes for each chunk (keep original)
    2. Cluster all low-level codes, K = number_of_low_level_codes / 3
    3. Generate 1 medium-level code per cluster, track which low-level codes it represents
    4. Cluster all medium-level codes, K = number_of_medium_level_codes / 5
    5. Generate 1 high-level code per cluster, track which medium-level codes it represents
    6. Distribute codes back to chunks maintaining relationships
    """
    # Import required modules
    import sys
    import os
    sys.path.append(os.path.dirname(__file__))
    from .embeddings import build_embeddings_parquet
    from .prompts import (
        fill_template, STRATEGY_2_LOW_LEVEL_PROMPT, 
        STRATEGY_2_MEDIUM_LEVEL_PROMPT, STRATEGY_2_HIGH_LEVEL_PROMPT
    )
    from sklearn.cluster import KMeans
    
    print("🔄 Strategy 2: Hierarchical code generation...")
    
    # Step 1: Generate 5 low-level codes for each chunk
    print("   📝 Step 1: Generating low-level codes...")
    
    low_level_tasks = []
    for chunk_text in chunks:
        low_level_prompt = fill_template(STRATEGY_2_LOW_LEVEL_PROMPT, {"question": question, "chunk_text": chunk_text})
        task = generate_codebook_for_chunk_1(client, low_level_prompt, chunk_text, semaphore)
        low_level_tasks.append(task)
    
    low_level_results = await asyncio.gather(*low_level_tasks)
    
    # Extract all low-level codes
    all_low_level_codes = []
    for result in low_level_results:
        if result and "low-level" in result:
            all_low_level_codes.extend(result["low-level"])
    
    print(f"   ✅ Generated {len(all_low_level_codes)} low-level codes from {len(chunks)} chunks")
    
    if len(all_low_level_codes) == 0:
        print("   ❌ No low-level codes generated")
        return []
    
    # Step 2: Cluster low-level codes and generate medium-level codes
    print("   🎯 Step 2: Clustering low-level codes and generating medium-level codes...")
    
    # Calculate K for low-level clustering
    k_low = max(1, len(all_low_level_codes) // 3)
    print(f"   📊 Clustering {len(all_low_level_codes)} low-level codes into {k_low} clusters")
    
    # Create a temporary DataFrame for embeddings
    low_level_df = pd.DataFrame({
        'tag': all_low_level_codes,
        'level': ['low-level'] * len(all_low_level_codes),
        'chunk_text': [''] * len(all_low_level_codes)  # Empty for embedding purposes
    })
    
    # Generate embeddings for low-level codes
    embeddings_path = os.path.join(TEMP_FILES_DIR, "low_level_embeddings.parquet")
    embeddings_path, embeddings, codes, _ = build_embeddings_parquet(
        corpus_df=low_level_df,
        output_parquet=embeddings_path
    )
    
    # Cluster low-level codes
    kmeans = KMeans(n_clusters=k_low, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(embeddings)
    
    # Group codes by cluster
    low_level_clusters = {}
    for i, (code, label) in enumerate(zip(all_low_level_codes, cluster_labels)):
        if label not in low_level_clusters:
            low_level_clusters[label] = []
        low_level_clusters[label].append(code)
    
    # Generate 1 medium-level code per cluster
    medium_level_tasks = []
    for cluster_id, cluster_codes in low_level_clusters.items():
        cluster_text = "\n".join([f"- {code}" for code in cluster_codes])
        medium_level_prompt = fill_template(STRATEGY_2_MEDIUM_LEVEL_PROMPT, {"question": question, "cluster_codes": cluster_text})
        task = generate_codebook_for_chunk_1(client, medium_level_prompt, "", semaphore)
        medium_level_tasks.append(task)
    
    medium_level_results = await asyncio.gather(*medium_level_tasks)
    
    # Extract all medium-level codes
    all_medium_level_codes = []
    for result in medium_level_results:
        if result and "mid-level" in result:
            all_medium_level_codes.extend(result["mid-level"])
    
    print(f"   ✅ Generated {len(all_medium_level_codes)} medium-level codes from {k_low} clusters")
    
    if len(all_medium_level_codes) == 0:
        print("   ❌ No medium-level codes generated")
        return []
    
    # Step 3: Cluster medium-level codes and generate high-level codes
    print("   🎯 Step 3: Clustering medium-level codes and generating high-level codes...")
    
    # Calculate K for medium-level clustering
    k_medium = max(1, len(all_medium_level_codes) // 5)
    print(f"   📊 Clustering {len(all_medium_level_codes)} medium-level codes into {k_medium} clusters")
    
    # Create a temporary DataFrame for medium-level embeddings
    medium_level_df = pd.DataFrame({
        'tag': all_medium_level_codes,
        'level': ['mid-level'] * len(all_medium_level_codes),
        'chunk_text': [''] * len(all_medium_level_codes)
    })
    
    # Generate embeddings for medium-level codes
    embeddings_path = os.path.join(TEMP_FILES_DIR, "medium_level_embeddings.parquet")
    embeddings_path, embeddings, codes, _ = build_embeddings_parquet(
        corpus_df=medium_level_df,
        output_parquet=embeddings_path
    )
    
    # Cluster medium-level codes
    kmeans = KMeans(n_clusters=k_medium, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(embeddings)
    
    # Group codes by cluster
    medium_level_clusters = {}
    for i, (code, label) in enumerate(zip(all_medium_level_codes, cluster_labels)):
        if label not in medium_level_clusters:
            medium_level_clusters[label] = []
        medium_level_clusters[label].append(code)
    
    # Generate 1 high-level code per cluster
    high_level_tasks = []
    for cluster_id, cluster_codes in medium_level_clusters.items():
        cluster_text = "\n".join([f"- {code}" for code in cluster_codes])
        high_level_prompt = fill_template(STRATEGY_2_HIGH_LEVEL_PROMPT, {"question": question, "cluster_codes": cluster_text})
        task = generate_codebook_for_chunk_1(client, high_level_prompt, "", semaphore)
        high_level_tasks.append(task)
    
    high_level_results = await asyncio.gather(*high_level_tasks)
    
    # Extract all high-level codes
    all_high_level_codes = []
    for result in high_level_results:
        if result and "high-level" in result:
            all_high_level_codes.extend(result["high-level"])
    
    print(f"   ✅ Generated {len(all_high_level_codes)} high-level codes from {k_medium} clusters")
    
    # Step 4: Track relationships and distribute codes back to chunks
    print("   📊 Step 4: Tracking relationships and distributing codes back to chunks...")
    
    # Create relationship mappings
    low_to_medium = {}  # Maps low-level code index to medium-level code
    medium_to_high = {}  # Maps medium-level code index to high-level code
    
    # Map low-level codes to medium-level codes based on clustering
    for i, (code, label) in enumerate(zip(all_low_level_codes, cluster_labels)):
        # Find which medium-level code was generated for this cluster
        if label in low_level_clusters:
            # Find the medium-level code index for this cluster
            cluster_indices = [j for j, (cluster_id, _) in enumerate(low_level_clusters.items()) if cluster_id == label]
            if cluster_indices and cluster_indices[0] < len(all_medium_level_codes):
                low_to_medium[i] = cluster_indices[0]
    
    # Map medium-level codes to high-level codes based on clustering
    for i, (code, label) in enumerate(zip(all_medium_level_codes, cluster_labels)):
        # Find which high-level code was generated for this cluster
        if label in medium_level_clusters:
            # Find the high-level code index for this cluster
            cluster_indices = [j for j, (cluster_id, _) in enumerate(medium_level_clusters.items()) if cluster_id == label]
            if cluster_indices and cluster_indices[0] < len(all_high_level_codes):
                medium_to_high[i] = cluster_indices[0]
    
    # Distribute codes back to chunks maintaining relationships
    final_codebooks = []
    
    # Get the original low-level codes for each chunk
    original_low_level_codes_per_chunk = []
    for i in range(0, len(all_low_level_codes), 5):
        chunk_low_codes = all_low_level_codes[i:i+5]
        original_low_level_codes_per_chunk.append(chunk_low_codes)
    
    for chunk_idx, chunk_text in tqdm(enumerate(chunks), total=len(chunks), desc="Generating candidate codes"):
        # Get original 5 low-level codes for this chunk
        if chunk_idx < len(original_low_level_codes_per_chunk):
            low_codes = original_low_level_codes_per_chunk[chunk_idx]
        else:
            low_codes = []
        
        # Find related medium-level codes
        medium_codes = []
        for low_idx, low_code in enumerate(low_codes):
            global_low_idx = chunk_idx * 5 + low_idx
            if global_low_idx in low_to_medium:
                medium_idx = low_to_medium[global_low_idx]
                if medium_idx < len(all_medium_level_codes):
                    medium_codes.append(all_medium_level_codes[medium_idx])
        
        # Find related high-level codes
        high_codes = []
        for medium_idx, medium_code in enumerate(medium_codes):
            if medium_idx in medium_to_high:
                high_idx = medium_to_high[medium_idx]
                if high_idx < len(all_high_level_codes):
                    high_codes.append(all_high_level_codes[high_idx])
        
        # Ensure we have at least some codes at each level
        if not low_codes:
            low_codes = all_low_level_codes[:5] if len(all_low_level_codes) >= 5 else all_low_level_codes
        if not medium_codes:
            medium_codes = all_medium_level_codes[:5] if len(all_medium_level_codes) >= 5 else all_medium_level_codes
        if not high_codes:
            high_codes = all_high_level_codes[:5] if len(all_high_level_codes) >= 5 else all_high_level_codes
        
        # Limit to 5 codes per level for consistency
        low_codes = low_codes[:5]
        medium_codes = medium_codes[:5]
        high_codes = high_codes[:5]
        
        codebook = {
            "low-level": low_codes,
            "mid-level": medium_codes,
            "high-level": high_codes
        }
        final_codebooks.append(codebook)
    
    print(f"   ✅ Strategy 2 completed: {len(final_codebooks)} codebooks generated")
    print(f"      - Low-level codes: {len(all_low_level_codes)}")
    print(f"      - Medium-level codes: {len(all_medium_level_codes)}")
    print(f"      - High-level codes: {len(all_high_level_codes)}")
    
    return final_codebooks


async def generate_codebook_for_chunk_3(client: AsyncVLLMClient, prompt: str, chunk_text: str, semaphore: asyncio.Semaphore) -> Optional[Dict[str, List[str]]]:
    """
    Strategy 3: Template implementation for future strategy.
    - Placeholder for alternative code generation approach
    - Can be customized based on specific requirements
    """
    # TODO: Implement specific strategy 3 logic
    # For now, using the same logic as strategy 1
    return await generate_codebook_for_chunk_1(client, prompt, chunk_text, semaphore)

async def generate_codebook_two_stage(client: AsyncVLLMClient, question: str, chunk_text: str, 
                                    semaphore: asyncio.Semaphore) -> Optional[Dict[str, List[str]]]:
    """
    Two-stage codebook generation for optimal quality/performance balance.
    
    Stage 1: Generate basic tags quickly
    Stage 2: Refine and organize tags for quality
    """
    from .prompts import STRATEGY_1_DIRECT_PROMPT
    
    # Stage 1: Quick generation
    stage1_prompt = f"""Question: {question}

Generate 30 tags for this text chunk quickly:
- 10 low-level (specific details)
- 10 mid-level (themes)  
- 10 high-level (patterns)

Text: {chunk_text}

JSON: {{"low-level": [...], "mid-level": [...], "high-level": [...]}}"""

    async with semaphore:
        model_to_use = getattr(client, 'default_text_model', DEFAULT_TEXT_MODEL)
        
        # Stage 1: Quick generation
        stage1_messages = [{"role": "user", "content": stage1_prompt}]
        stage1_resp = await client.chat_completion(model_to_use, stage1_messages, temperature=0.1, max_tokens=512)
        
        if not stage1_resp:
            return None
        
        try:
            stage1_content = stage1_resp["choices"][0]["message"]["content"]
            stage1_parsed = parse_codebook_json(stage1_content)
            
            if not stage1_parsed:
                return None
            
            # Stage 2: Quality refinement (only if needed)
            stage2_prompt = f"""Refine these tags for better quality:

Original: {stage1_content}

Improve: descriptive, ≤5 words, no redundancy, proper abstraction levels.

JSON: {{"low-level": [...], "mid-level": [...], "high-level": [...]}}"""
            
            stage2_messages = [{"role": "user", "content": stage2_prompt}]
            stage2_resp = await client.chat_completion(model_to_use, stage2_messages, temperature=0.0, max_tokens=512)
            
            if stage2_resp:
                stage2_content = stage2_resp["choices"][0]["message"]["content"]
                stage2_parsed = parse_codebook_json(stage2_content)
                return stage2_parsed if stage2_parsed else stage1_parsed
            else:
                return stage1_parsed
                
        except Exception as e:
            print(f"      ⚠️ Two-stage processing failed: {e}")
            return None

async def generate_codebook_adaptive(client: AsyncVLLMClient, question: str, chunk_text: str, 
                                   semaphore: asyncio.Semaphore, quality_mode: str = "balanced") -> Optional[Dict[str, List[str]]]:
    """
    Adaptive codebook generation that balances quality and performance.
    
    quality_mode: "fast", "balanced", "high_quality"
    """
    
    if quality_mode == "fast":
        # Fast mode: minimal prompt, quick generation
        prompt = f"""Q: {question}
Generate 30 tags for: {chunk_text[:500]}...
{{"low-level": [...], "mid-level": [...], "high-level": [...]}}"""
        
    elif quality_mode == "high_quality":
        # High quality: full detailed prompt
        from .prompts import STRATEGY_1_DIRECT_PROMPT
        prompt = fill_template(STRATEGY_1_DIRECT_PROMPT, {"QUESTION": question})
        prompt = prompt.replace("Read the supplied document in full.", "Read the supplied document in full.\n\nText chunk:\n{chunk_text}")
        prompt = prompt.format(chunk_text=chunk_text)
        
    else:  # balanced
        # Balanced: optimized prompt
        from .prompts import STRATEGY_1_DIRECT_PROMPT
        prompt = fill_template(STRATEGY_1_DIRECT_PROMPT, {"QUESTION": question, "chunk_text": chunk_text})
    
    async with semaphore:
        model_to_use = getattr(client, 'default_text_model', DEFAULT_TEXT_MODEL)
        messages = [{"role": "user", "content": prompt}]
        
        # Adjust parameters based on quality mode
        if quality_mode == "fast":
            temperature = 0.2
            max_tokens = 256
        elif quality_mode == "high_quality":
            temperature = 0.0
            max_tokens = 1024
        else:  # balanced
            temperature = 0.1
            max_tokens = 512
        
        resp = await client.chat_completion(model_to_use, messages, temperature=temperature, max_tokens=max_tokens)
        
        if not resp:
            return None
        
        try:
            content = resp["choices"][0]["message"]["content"]
            return parse_codebook_json(content)
        except Exception as e:
            print(f"      ⚠️ Adaptive processing failed: {e}")
            return None

async def generate_codebook_batch_optimized(client: AsyncVLLMClient, prompt_prefix: str, chunk_texts: List[str], 
                                          semaphore: asyncio.Semaphore, batch_size: int = 10) -> List[Optional[Dict[str, List[str]]]]:
    """
    Optimized batch codebook generation with vLLM-specific optimizations.
    
    Args:
        client: AsyncVLLMClient instance
        prompt_prefix: Common prompt prefix (can be cached)
        chunk_texts: List of chunk texts to process
        semaphore: Semaphore for concurrency control
        batch_size: Number of chunks to process in each batch
    
    Returns:
        List of parsed codebooks (None for failed chunks)
    """
    from .prompts import STRICT_JSON_SUFFIX
    
    all_results = []
    
    # Process chunks in batches for optimal vLLM utilization
    for i in range(0, len(chunk_texts), batch_size):
        batch_chunks = chunk_texts[i:i + batch_size]
        batch_start = time.time()
        
        print(f"   📦 Processing batch {i//batch_size + 1}/{(len(chunk_texts) + batch_size - 1)//batch_size} ({len(batch_chunks)} chunks)")
        
        # Create tasks for this batch - send all requests simultaneously for vLLM batching
        batch_tasks = []
        for chunk_text in batch_chunks:
            # Use the common prefix + chunk text for KV cache optimization
            full_prompt = prompt_prefix + chunk_text + STRICT_JSON_SUFFIX
            messages = [{"role": "user", "content": full_prompt}]
            
            # Create individual task for this chunk with semaphore control
            task = generate_single_codebook_with_semaphore(client, messages, semaphore)
            batch_tasks.append(task)
        
        # Process all tasks in this batch simultaneously
        # This allows vLLM's dynamic batching to work optimally
        # NO semaphore here - let all requests go concurrently!
        batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
        
        # Parse responses with better error handling for vLLM
        parsed_results = []
        for i, result in enumerate(batch_results):
            if isinstance(result, Exception):
                print(f"      ⚠️ Request {i} failed: {result}")
                parsed_results.append(None)
            elif result and isinstance(result, dict):
                try:
                    content = result["choices"][0]["message"]["content"]
                    parsed = parse_codebook_json(content)
                    parsed_results.append(parsed)
                except Exception as e:
                    print(f"      ⚠️ Parse failed for request {i}: {e}")
                    parsed_results.append(None)
            else:
                print(f"      ⚠️ Request {i} returned None or invalid response: {type(result)}")
                parsed_results.append(None)
        
        all_results.extend(parsed_results)
        
        batch_time = time.time() - batch_start
        successful = sum(1 for r in parsed_results if r is not None)
        throughput = successful / batch_time if batch_time > 0 else 0
        
        print(f"      ✅ Batch completed in {batch_time:.2f}s: {successful}/{len(batch_chunks)} successful")
        print(f"      ⚡ Throughput: {throughput:.2f} chunks/second")
    
    return all_results

async def generate_single_codebook_with_semaphore(client: AsyncVLLMClient, messages: List[Dict[str, str]], semaphore: asyncio.Semaphore) -> Optional[Dict]:
    """Generate a single codebook with proper semaphore control for concurrency"""
    async with semaphore:  # This controls concurrency
        return await generate_single_codebook(client, messages)

async def generate_single_codebook(client: AsyncVLLMClient, messages: List[Dict[str, str]]) -> Optional[Dict]:
    """Generate a single codebook with vLLM-optimized parameters"""
    import asyncio
    model_to_use = getattr(client, 'default_text_model', DEFAULT_TEXT_MODEL)
    
    # Try up to 3 attempts with exponential backoff for connection issues
    for attempt in range(3):
        try:
            resp = await client.chat_completion(model_to_use, messages, temperature=0.0, max_tokens=1024)
            if resp:
                return resp
        except Exception as e:
            error_msg = str(e).lower()
            if attempt < 2:  # Only log and retry for first 2 attempts
                print(f"      ⚠️ Attempt {attempt + 1} failed: {e}")
                
                # Add exponential backoff for connection issues
                if any(term in error_msg for term in ['disconnect', 'connection', 'timeout', 'server']):
                    wait_time = (2 ** attempt) + 1  # 1s, 3s, 7s
                    print(f"      ⏱️ Connection issue detected, waiting {wait_time}s before retry...")
                    await asyncio.sleep(wait_time)
                else:
                    # Short wait for other errors
                    await asyncio.sleep(0.5)
            continue
    
    print(f"      ❌ All 3 attempts failed for chunk")
    raise RuntimeError("All 3 attempts failed for chunk")



# --------------------------- Main ------------------------------------------

async def async_main(args) -> None:
    start_time = time.time()
    
    # Import required modules
    import sys
    import os
    sys.path.append(os.path.dirname(__file__))
    from .prompts import (
        fill_template, STRATEGY_1_DIRECT_PROMPT, 
        CODEBOOK_GENERATION_PROMPT
    )
    
    # Resolve input
    input_path = args.input or find_first_valid_data_file()
    if not input_path:
        raise FileNotFoundError(f"No valid data files found under {DATA_DIR}")

    print(f"📥 Input: {input_path.split('/')[-1]}")
    print(f"❓ Question: {args.question}")
    print(f"🤖 Model: {getattr(args, 'model', '32B')}")

    # Configure model-specific environment variables
    model = getattr(args, 'model', '32B')
    if model == "32B":
        vllm_text_url = os.getenv("VLLM_QWEN_32B_URL")
        vllm_text_model = os.getenv("VLLM_QWEN_32B_MODEL")
    elif model == "30B-A3B":
        vllm_text_url = os.getenv("VLLM_QWEN_A3B_URL")
        vllm_text_model = os.getenv("VLLM_QWEN_A3B_MODEL")
    else:
        raise ValueError(f"Unknown model: {model}")
    
    # Prepare client
    if not vllm_text_url:
        raise ValueError(f"VLLM URL not set for model {model}. Set VLLM_QWEN_{model.replace('-', '_')}_URL.")

    # Create client with model-specific configuration
    async with AsyncVLLMClient(vllm_text_url, timeout=REQUEST_TIMEOUT) as client:
        # Override the default model for this client
        client.default_text_model = vllm_text_model
        
        # Initialize variables
        best_concurrency = MAX_CONCURRENCY
        best_chunk_size = args.chunk_size
        
        # If testing chunk sizes, run the test first
        if args.test_chunk_sizes:
            print(f"\n🚀 Starting chunk size testing (sizes: {args.chunk_sizes})")
            chunk_size_results = await test_chunk_sizes(client, input_path, args.question, args.chunk_sizes, args.overlap_ratio)
            best_chunk_size = print_chunk_size_analysis(chunk_size_results)
            
            # Use the best chunk size for the rest of the processing
            chunk_size = best_chunk_size
            overlap = int(chunk_size * args.overlap_ratio)
            print(f"\n🎯 Using optimal chunk size: {chunk_size} (overlap: {overlap})")
        else:
            chunk_size = args.chunk_size
            overlap = args.overlap

        # Build chunks with the selected size
        chunks = build_text_chunks(input_path, chunk_size, overlap, args.seed)
        print(f"✅ Built {len(chunks)} chunks (size: {chunk_size}, overlap: {overlap})")

        # Filter chunks by similarity with question
        if hasattr(args, 'similarity_threshold') and args.similarity_threshold > 0:
            chunks = await filter_chunks_by_similarity(
                client, args.question, chunks, args.similarity_threshold
            )
        
        if not chunks:
            raise RuntimeError("No chunks produced from input")

        # Load prompts
        initial_tmpl = STRATEGY_1_DIRECT_PROMPT
        initial_prompt = fill_template(initial_tmpl, {"QUESTION": args.question}) if "{QUESTION}" in initial_tmpl else (f"Question: {args.question}\n\n" + initial_tmpl)

        # Generate codebooks for all chunks (no initial seeding)
        print(f"🧪 Generating codebooks for all {len(chunks)} chunks with concurrency={MAX_CONCURRENCY}")
        
        # Create prompt for direct codebook generation (30 codes distributed across 3 levels)
        direct_prompt = fill_template(STRATEGY_1_DIRECT_PROMPT, {"QUESTION": args.question})
        direct_prompt = direct_prompt.replace("Read the supplied document in full.", "Read the supplied document in full.\n\nText chunk:\n{chunk_text}")

        # If testing concurrency, run the test first
        if args.test_concurrency:
            print(f"\n🚀 Starting concurrency testing (max: {args.max_test_concurrency})")
            concurrency_results = await test_concurrency_levels(client, chunks, direct_prompt, args.max_test_concurrency)
            best_concurrency = print_concurrency_analysis(concurrency_results)
            
            # Use the best concurrency for the rest of the processing
            semaphore = asyncio.Semaphore(best_concurrency)
            print(f"\n🎯 Using optimal concurrency: {best_concurrency}")
        else:
            # Use moderate concurrency to avoid server disconnections
            semaphore = asyncio.Semaphore(128)  # Increased from 64 to 128 for better throughput

        # Generate codebooks for all chunks concurrently
        all_start = time.time()
        
        # Handle Strategy 2 specially (hierarchical approach)
        if args.strategy == "strategy_2":
            print("🔄 Strategy 2: Using hierarchical code generation approach...")
            
            # Extract chunk texts for hierarchical processing
            chunk_texts = [ctext for (_, _, ctext) in chunks]
            
            # Call the hierarchical strategy function directly
            all_codebooks = await generate_codebook_hierarchical_strategy_2(
                client, chunk_texts, args.question, semaphore
            )
            
            # Ensure we have the same number of codebooks as chunks
            if len(all_codebooks) != len(chunks):
                print(f"   ⚠️ Warning: Generated {len(all_codebooks)} codebooks for {len(chunks)} chunks")
                # Pad with empty codebooks if needed
                while len(all_codebooks) < len(chunks):
                    all_codebooks.append({"low-level": [], "mid-level": [], "high-level": []})
                # Truncate if too many
                all_codebooks = all_codebooks[:len(chunks)]
        else:
            # Optimized approach with KV cache and server batching
            print(f"🚀 Processing {len(chunks)} chunks with optimized batching and KV cache...")
            
            # Extract chunk texts
            chunk_texts = [ctext for (_, _, ctext) in chunks]
            
            # Use simple approach - full prompt for each chunk (no KV cache optimization)
            from .prompts import STRATEGY_1_DIRECT_PROMPT
            direct_prompt = fill_template(STRATEGY_1_DIRECT_PROMPT, {"QUESTION": args.question})
            
            # Simple batch processing - each chunk gets the full prompt
            all_codebooks = []
            batch_size = 100  # Increased from 50 to 100 for better throughput
            
            for i in range(0, len(chunks), batch_size):
                batch_chunks = chunks[i:i + batch_size]
                batch_start = time.time()
                
                print(f"   📦 Processing batch {i//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size} ({len(batch_chunks)} chunks)")
                
                # Create tasks for this batch
                batch_tasks = []
                for (_, _, ctext) in batch_chunks:
                    # Use single question processing
                    full_prompt = direct_prompt + "\n\nText chunk:\n" + ctext
                    messages = [{"role": "user", "content": full_prompt}]
                    task = generate_single_codebook_with_semaphore(client, messages, semaphore)
                    batch_tasks.append(task)
                
                # Process batch
                batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
                
                # Parse results with retry logic for failed chunks
                batch_codebooks = []
                failed_indices = []  # Track which chunks failed for retry
                
                for j, result in enumerate(batch_results):
                    if isinstance(result, Exception):
                        print(f"      ⚠️ Request {j} failed: {result}")
                        batch_codebooks.append(None)
                        failed_indices.append(j)
                    elif result and isinstance(result, dict):
                        # Handle single question result
                        try:
                            content = result["choices"][0]["message"]["content"]
                            parsed = parse_codebook_json(content)
                            batch_codebooks.append(parsed)
                        except Exception as e:
                            print(f"      ⚠️ Parse failed for request {j}: {e}")
                            batch_codebooks.append(None)
                            failed_indices.append(j)
                    else:
                        print(f"      ⚠️ Request {j} returned None (failed after retries)")
                        batch_codebooks.append(None)
                        failed_indices.append(j)
                
                # Retry failed chunks (up to 2 attempts)
                print(f"      📊 Failed indices: {failed_indices}")  # Debug output
                # Derive failed indices from None entries if list is empty but there are failures
                if not failed_indices:
                    derived_failed = [idx for idx, cb in enumerate(batch_codebooks) if cb is None]
                    if derived_failed:
                        print(f"      📊 Derived failed indices from None entries: {derived_failed}")
                        failed_indices = derived_failed
                if failed_indices:
                    print(f"      🔄 Retrying {len(failed_indices)} failed chunks...")
                    max_chunk_retries = 2
                    retry_success_count = 0  # Track retry successes
                    
                    for retry_attempt in range(max_chunk_retries):
                        if not failed_indices:
                            break
                            
                        print(f"         🔄 Chunk retry attempt {retry_attempt + 1}/{max_chunk_retries} for {len(failed_indices)} chunks...")
                        current_failed = failed_indices.copy()
                        failed_indices = []
                        
                        # Wait before retry to avoid rate limiting
                        await asyncio.sleep(3 + (retry_attempt * 2))  # 3s, 5s
                        
                        # Create async tasks for all failed chunks in this retry attempt
                        retry_tasks = []
                        for failed_idx in current_failed:
                            # Get the original chunk
                            (_, _, ctext) = batch_chunks[failed_idx]
                            
                            # Retry the specific chunk
                            full_prompt = direct_prompt + "\n\nText chunk:\n" + ctext
                            messages = [{"role": "user", "content": full_prompt}]
                            
                            # Use the configured model for this client
                            retry_model = getattr(client, 'default_text_model', DEFAULT_TEXT_MODEL)
                            task = client.chat_completion(retry_model, messages, temperature=0.2, max_tokens=1024)
                            retry_tasks.append((failed_idx, task))
                        
                        # Execute all retry tasks concurrently
                        retry_results = await asyncio.gather(*[task for _, task in retry_tasks], return_exceptions=True)
                        
                        # Process retry results
                        for failed_idx in current_failed:
                            # Find the corresponding result for this failed_idx
                            result_idx = next(i for i, (idx, _) in enumerate(retry_tasks) if idx == failed_idx)
                            result = retry_results[result_idx]
                            
                            if isinstance(result, Exception):
                                print(f"            ❌ Chunk {failed_idx} retry failed: {result}")
                                failed_indices.append(failed_idx)
                            elif result and isinstance(result, dict):
                                try:
                                    content = result["choices"][0]["message"]["content"]
                                    parsed = parse_codebook_json(content)
                                    if parsed is None:
                                        print(f"            ❌ Chunk {failed_idx} retry parse returned None")
                                        print(f"            🔍 LLM response preview: {content[:200]}...")
                                        # Try one more time with explicit JSON instruction
                                        print(f"            🔄 Retrying chunk {failed_idx} with explicit JSON instruction...")
                                        explicit_json_prompt = direct_prompt + "\n\nText chunk:\n" + ctext + "\n\nCRITICAL: Respond with ONLY valid JSON. NO thinking, NO reasoning, NO explanations. ONLY the JSON object."
                                        explicit_messages = [{"role": "user", "content": explicit_json_prompt}]
                                        try:
                                            explicit_result = await client.chat_completion(retry_model, explicit_messages, temperature=0.0, max_tokens=1024)
                                            if explicit_result and isinstance(explicit_result, dict):
                                                explicit_content = explicit_result["choices"][0]["message"]["content"]
                                                explicit_parsed = parse_codebook_json(explicit_content)
                                                if explicit_parsed is not None:
                                                    batch_codebooks[failed_idx] = explicit_parsed
                                                    retry_success_count += 1
                                                    print(f"            ✅ Chunk {failed_idx} explicit JSON retry successful (total retry successes: {retry_success_count})")
                                                else:
                                                    print(f"            ❌ Chunk {failed_idx} explicit JSON retry also failed")
                                                    failed_indices.append(failed_idx)
                                            else:
                                                failed_indices.append(failed_idx)
                                        except Exception as e:
                                            print(f"            ❌ Chunk {failed_idx} explicit JSON retry failed: {e}")
                                            failed_indices.append(failed_idx)
                                    else:
                                        batch_codebooks[failed_idx] = parsed
                                        retry_success_count += 1
                                        print(f"            ✅ Chunk {failed_idx} retry successful (total retry successes: {retry_success_count})")
                                except Exception as e:
                                    print(f"            ❌ Chunk {failed_idx} retry parse failed: {e}")
                                    failed_indices.append(failed_idx)
                            else:
                                print(f"            ❌ Chunk {failed_idx} retry failed: no response")
                                failed_indices.append(failed_idx)
                        
                        if failed_indices:
                            print(f"         ⚠️ {len(failed_indices)} chunks still failed after retry attempt {retry_attempt + 1}")
                        else:
                            print(f"         🎉 All chunks succeeded after retry attempt {retry_attempt + 1}")
                            break
                    
                    if failed_indices:
                        print(f"         ⚠️ {len(failed_indices)} chunks failed after all retry attempts")
                    else:
                        print(f"         ✅ All retries completed successfully")
                
                # VERIFICATION: Check if retry successes were properly added
                final_successful = sum(1 for cb in batch_codebooks if cb is not None)
                print(f"      🔍 FINAL VERIFICATION: {final_successful}/{len(batch_chunks)} successful after all retries")
                
                # DEBUG: Show which chunks are missing
                missing_chunks = [i for i, cb in enumerate(batch_codebooks) if cb is None]
                successful_chunks = [i for i, cb in enumerate(batch_codebooks) if cb is not None]
                print(f"      🔍 MISSING CHUNKS: {missing_chunks}")
                print(f"      🔍 SUCCESSFUL CHUNKS: {successful_chunks}")
                
                # VERIFICATION: Confirm retry codes are being added to overall results
                successful_codes = sum(1 for cb in batch_codebooks if cb is not None and isinstance(cb, dict))
                print(f"      📊 VERIFICATION: {successful_codes} successful codebooks ready to add to overall results")
                
                all_codebooks.extend(batch_codebooks)
                
                # Report timing
                batch_time = time.time() - batch_start
                successful = sum(1 for cb in batch_codebooks if cb is not None)
                failed = sum(1 for cb in batch_codebooks if cb is None)
                throughput = successful / batch_time if batch_time > 0 else 0
                
                print(f"      ✅ Batch completed in {batch_time:.2f}s: {successful}/{len(batch_chunks)} successful ({failed} failed)")
                print(f"      🔍 Batch codebooks status: {[i for i, cb in enumerate(batch_codebooks) if cb is not None]} successful, {[i for i, cb in enumerate(batch_codebooks) if cb is None]} failed")
                
                # VERIFICATION: Confirm retry successes are counted
                if successful > len(batch_chunks) - len(failed_indices) if 'failed_indices' in locals() else 0:
                    print(f"      ✅ VERIFIED: Retry successes are being counted correctly")
                else:
                    print(f"      ⚠️ WARNING: Retry successes may not be counted correctly")
                print(f"      ⚡ Throughput: {throughput:.2f} chunks/second")
        
        all_time = time.time() - all_start
        print(f"⏱️ All generation took {all_time:.2f}s ({len(all_codebooks)}/{len(chunks)} successful)")
        print(f"🎯 Using strategy: {args.strategy}")
        
        # FINAL VERIFICATION: Confirm retry successes are in final results
        total_successful = sum(1 for cb in all_codebooks if cb is not None)
        print(f"🔍 FINAL TOTAL VERIFICATION: {total_successful}/{len(chunks)} total successful chunks in final results")
        print(f"📊 FINAL CODEBOOKS: {len(all_codebooks)} total codebooks (including retry successes)")

        if not all_codebooks:
            raise RuntimeError("Failed to generate any codebooks")

    # Collect all records
    all_records: List[Tuple[str, int, str, str, str]] = []
    # include flattened codebooks for all chunks
    for (spath, sidx, ctext), cb in zip(chunks, all_codebooks):
        if cb is not None:  # Only process non-None codebooks
            all_records.extend(flatten_codebook(spath, sidx, cb, ctext))

    # Evaluate corpus quality
    corpus_quality = evaluate_corpus_quality(all_records)
    print(f"\n📊 Corpus Quality Analysis:")
    print(f"  🎯 Overall Quality Score: {corpus_quality['overall_quality']:.3f}")
    print(f"  📝 Total Tags: {corpus_quality['total_tags']}")
    print(f"  🎨 Unique Tags: {corpus_quality['unique_tags']}")
    print(f"  🌈 Tag Diversity: {corpus_quality['tag_diversity']:.1%}")
    print(f"  📊 Level Distribution: {corpus_quality['level_distribution']}")

    # Write final corpus
    write_parquet(all_records, args.out_corpus)
    
    # Create DataFrame for return
    corpus_df = pd.DataFrame(all_records, columns=['source_path', 'chunk_index', 'level', 'tag', 'chunk_text'])
    
    total_time = time.time() - start_time
    total_chunks = len(chunks)
    successful_chunks = len(all_codebooks)
    current_concurrency = best_concurrency if args.test_concurrency else MAX_CONCURRENCY
    
    print(f"📝 Wrote corpus with {len(all_records)} tags to {args.out_corpus}")
    print(f"⏱️ Total time: {total_time:.2f}s")
    print(f"🚀 Performance: {successful_chunks}/{total_chunks} chunks processed ({successful_chunks/total_chunks*100:.1f}% success rate)")
    print(f"⚡ Throughput: {successful_chunks/total_time:.2f} chunks/second")
    print(f"🎯 Concurrency: {current_concurrency} concurrent requests")
    print(f"📏 Chunk Size: {chunk_size} tokens")
    print("🚀 Done.")
    
    return corpus_df


def main():
    parser = argparse.ArgumentParser(description="Build codebook corpus using few-shot prompting from sampled chunks (async with embeddings pre-flight per chunk).")
    parser.add_argument("--input", default=None, help="Input file or directory (default: first valid file in data)")
    parser.add_argument("--question", required=True, help="Question to condition initial code generation on")
    parser.add_argument("--chunk-size", type=int, default=2048, help="Chunk size in tokens (default 2048)")
    parser.add_argument("--overlap", type=int, default=200, help="Token overlap between chunks (default 200)")
    parser.add_argument("--similarity-threshold", type=float, default=0.3, help="Minimum cosine similarity threshold to keep chunks (default 0.3)")
    parser.add_argument("--seed-low", type=int, default=5, help="Lower bound for number of seed chunks (default 5)")
    parser.add_argument("--seed-high", type=int, default=10, help="Upper bound for number of seed chunks (default 10)")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
    parser.add_argument("--out-corpus", default=DEFAULT_CORPUS, help=f"Output corpus Parquet (default {DEFAULT_CORPUS})")
    parser.add_argument("--test-concurrency", action="store_true", help="Test different concurrency levels and find optimal performance")
    parser.add_argument("--max-test-concurrency", type=int, default=128, help="Maximum concurrency to test (default 128)")
    parser.add_argument("--test-chunk-sizes", action="store_true", help="Test different chunk sizes and evaluate quality")
    parser.add_argument("--chunk-sizes", nargs='+', type=int, default=[256], help="Chunk sizes to test (default: 256)")
    parser.add_argument("--overlap-ratio", type=float, default=0.2, help="Overlap ratio as fraction of chunk size (default 0.2)")
    parser.add_argument("--strategy", type=str, default="strategy_1", help="Codebook generation strategy to use (default: strategy_1)")
    parser.add_argument("--model", type=str, default="32B", help="Model to use for codebook generation (default: 32B)")

    args = parser.parse_args()
    asyncio.run(async_main(args))


if __name__ == "__main__":
    main() 