#!/usr/bin/env python3
"""
Batch Question Answering Pipeline

This pipeline optimizes processing multiple questions by:
1. Clustering questions using embeddings and K-means (cluster size: 4)
2. Selecting representative questions (closest to cluster centers)
3. Running full qa_main pipeline on representatives only
4. Reusing context for remaining questions in each cluster

Benefits of cluster size 4:
- Preserves question intent better than larger clusters
- Still provides significant efficiency gains
- Balances quality and performance

Usage:
  python batch_qa_pipeline.py --questions questions.txt --output results.json
  python batch_qa_pipeline.py --questions "Q1,Q2,Q3" --cluster-size 4
"""

import os
import sys
import json
import asyncio
import argparse
import time
import numpy as np
import pandas as pd
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from pathlib import Path
import shutil
import gc
import psutil
import atexit

# Add current directory to path for imports
sys.path.append(os.path.dirname(__file__))

from utils.initial_iteration.embeddings import build_embeddings_parquet
from qa_main import run_qa
from utils.context_retrievers import DataRetrievalContextRetriever

# Environment variables
VLLM_EMBEDDING_URL = os.getenv("VLLM_EMBEDDING_URL")
DEFAULT_EMBEDDING_MODEL = os.getenv("DEFAULT_EMBEDDING_MODEL")
TEMP_FILES_DIR = os.path.join(os.path.dirname(__file__), "temp_files")

@dataclass
class QuestionCluster:
    """Represents a cluster of questions"""
    cluster_id: int
    representative_question: str
    representative_index: int
    all_questions: List[str]
    all_indices: List[int]
    embeddings: np.ndarray
    cluster_center: np.ndarray

@dataclass
class BatchResult:
    """Results from batch processing"""
    total_questions: int
    total_clusters: int
    successful_answers: int
    failed_answers: int
    total_time: float
    results: List[Dict[str, Any]]

class BatchQAPipeline:
    """Main pipeline for batch question answering"""
    
    def __init__(self, cluster_size: int = 4, strategy: str = "dynamic", 
                 corpus_strategy: str = "strategy_1", use_thinking: bool = False, model: str = "32B",
                 use_codebooks: bool = False, codebooks_file: str = None, cluster_ids: Optional[List[int]] = None,
                 max_codebooks_per_cluster: Optional[int] = None, batch_size: int = 10000,
                 save_detailed_components: bool = False):
        self.cluster_size = cluster_size
        self.strategy = strategy
        self.corpus_strategy = corpus_strategy
        self.use_thinking = use_thinking
        self.model = model
        
        # Codebook loading options
        self.use_codebooks = use_codebooks
        self.codebooks_file = codebooks_file
        self.cluster_ids = cluster_ids
        self.max_codebooks_per_cluster = max_codebooks_per_cluster
        self.batch_size = batch_size
        
        # Component analysis options
        self.save_detailed_components = save_detailed_components
        
        self.temp_dir = TEMP_FILES_DIR
        os.makedirs(self.temp_dir, exist_ok=True)
        
        # Register cleanup function
        atexit.register(self._cleanup)
    
    def _cleanup(self):
        """Clean up resources to prevent memory leaks"""
        gc.collect()
        print("🧹 BatchQAPipeline cleanup completed")
    
    def _save_cluster_graph_statistics(self, cluster_id: int, schema_result: Dict[str, Any]):
        """Save detailed graph statistics for a specific cluster
        
        FIXED: Now correctly reads node count from hierarchy.parquet and component count 
        from component analysis, instead of using potentially incorrect values from schema_result.
        """
        try:
            # Create cluster-specific directory
            cluster_stats_dir = os.path.join(self.temp_dir, "cluster_graph_stats")
            os.makedirs(cluster_stats_dir, exist_ok=True)
            
            # Initialize graph statistics
            graph_stats = {
                'cluster_id': cluster_id,
                'timestamp': time.time(),
                'num_graph_nodes': 0,
                'num_graph_edges': 0,
                'graph_density': 0.0,
                'max_degree': 0,
                'avg_degree': 0.0,
                'num_components': 0,
                'largest_component_size': 0,
                'is_acyclic': False,
                'is_directed': True
            }
            
            # Try to load actual graph statistics from the topological graph directory
            topological_graph_dir = os.path.join(self.temp_dir, "topological_graph")
            graph_analysis_file = os.path.join(topological_graph_dir, "graph_analysis.parquet")
            
            if os.path.exists(graph_analysis_file):
                try:
                    import pandas as pd
                    graph_df = pd.read_parquet(graph_analysis_file)
                    
                    if not graph_df.empty:
                        # Extract basic graph statistics from graph_analysis.parquet
                        graph_stats.update({
                            'num_graph_nodes': int(graph_df.iloc[0]['num_nodes']),
                            'num_graph_edges': int(graph_df.iloc[0]['num_edges']),
                            'graph_density': float(graph_df.iloc[0]['density']),
                            'max_degree': int(graph_df.iloc[0]['max_degree']),
                            'avg_degree': float(graph_df.iloc[0]['avg_degree']),
                            'is_acyclic': bool(graph_df.iloc[0]['is_acyclic']),
                            'is_directed': bool(graph_df.iloc[0]['is_directed'])
                        })
                        
                        print(f"   📊 Loaded basic graph statistics: {graph_stats['num_graph_nodes']} nodes, {graph_stats['num_graph_edges']} edges")
                    
                except Exception as e:
                    print(f"   ⚠️ Could not load graph analysis file: {e}")
            
            # Load correct component statistics from existing component reports
            component_reports_dir = os.path.join(topological_graph_dir, "component_reports")
            components_summary_file = os.path.join(component_reports_dir, "components_summary.csv")
            
            # Use hierarchy file for final graph statistics (after merging)
            hierarchy_file = os.path.join(topological_graph_dir, "hierarchy.parquet")
            if os.path.exists(hierarchy_file):
                try:
                    hierarchy_df = pd.read_parquet(hierarchy_file)
                    final_node_count = len(hierarchy_df)
                    
                    # Update with final graph statistics
                    graph_stats.update({
                        'num_graph_nodes': final_node_count,
                        'num_components': 0,  # Will be calculated from component analysis
                        'largest_component_size': 0  # Will be calculated properly in component analysis
                    })
                    
                    print(f"   📊 Using final graph statistics: {final_node_count} nodes (after merging)")
                    
                except Exception as e:
                    print(f"   ⚠️ Could not load hierarchy file: {e}")
            
            # Note: Component analysis from component_reports is from original graph (before merging)
            # We'll use the graph_analysis.parquet component count for the final graph
            if os.path.exists(components_summary_file):
                try:
                    comp_df = pd.read_csv(components_summary_file)
                    print(f"   📊 Note: Component analysis shows {len(comp_df)} components from original graph ({comp_df['size'].sum()} nodes)")
                    print(f"   📊 This is from before merging - final graph has {graph_stats['num_graph_nodes']} nodes")
                    
                except Exception as e:
                    print(f"   ⚠️ Could not load component statistics: {e}")
            else:
                print(f"   ⚠️ Component analysis not found - will be generated in next run")
            
            # Save cluster graph statistics
            stats_file = os.path.join(cluster_stats_dir, f"cluster_{cluster_id}_graph_stats.json")
            with open(stats_file, 'w') as f:
                json.dump(graph_stats, f, indent=2)
            
            # Try to extract and save components if available
            if os.path.exists(topological_graph_dir):
                try:
                    import pandas as pd
                    
                    # Use existing component analysis from topological graph pipeline
                    component_reports_dir = os.path.join(topological_graph_dir, "component_reports")
                    components_summary_file = os.path.join(component_reports_dir, "components_summary.csv")
                    
                    if os.path.exists(components_summary_file):
                        # Load existing component analysis
                        comp_df = pd.read_csv(components_summary_file)
                        
                        # Extract component information
                        total_components = len(comp_df)
                        total_nodes = comp_df['size'].sum()
                        largest_component_size = comp_df['size'].max()
                        isolated_nodes = len(comp_df[comp_df['size'] == 1])
                        
                        # Update graph stats with correct component information
                        graph_stats.update({
                            'num_components': total_components,
                            'largest_component_size': largest_component_size
                        })
                        
                        # Save component details
                        components_data = {
                            'cluster_id': cluster_id,
                            'weakly_connected_components': total_components,
                            'strongly_connected_components': total_components,  # Same for undirected analysis
                            'largest_wcc_size': largest_component_size,
                            'largest_scc_size': largest_component_size,
                            'component_details': {
                                'wcc_sizes': comp_df['size'].tolist(),
                                'scc_sizes': comp_df['size'].tolist(),
                                'isolated_nodes': isolated_nodes
                            }
                        }
                        
                        # Save components analysis
                        components_file = os.path.join(cluster_stats_dir, f"cluster_{cluster_id}_components.json")
                        with open(components_file, 'w') as f:
                            json.dump(components_data, f, indent=2)
                        
                        # Always save detailed component reports for this cluster
                        self._save_detailed_component_reports_from_existing(cluster_id, component_reports_dir)
                        
                        print(f"   📊 Saved detailed graph statistics for cluster {cluster_id}")
                        print(f"      📁 Graph stats: {stats_file}")
                        print(f"      📁 Components: {components_file}")
                        print(f"      📊 Actual components: {total_components}, nodes: {total_nodes}, largest: {largest_component_size}")
                        
                    else:
                        print(f"   ⚠️ No existing component analysis found for cluster {cluster_id}")
                        
                except Exception as e:
                    print(f"   ⚠️ Could not extract components for cluster {cluster_id}: {e}")
            
        except Exception as e:
            print(f"   ⚠️ Error saving graph statistics for cluster {cluster_id}: {e}")
    
    def _save_detailed_component_reports(self, cluster_id: int, wccs: list, sccs: list):
        """Save detailed component reports for a specific cluster"""
        try:
            # Create cluster-specific component reports directory
            cluster_component_dir = os.path.join(self.temp_dir, f"cluster_{cluster_id}_component_reports")
            
            # Remove existing directory if it exists (overwrite for each cluster)
            import shutil
            if os.path.exists(cluster_component_dir):
                shutil.rmtree(cluster_component_dir)
            
            os.makedirs(cluster_component_dir, exist_ok=True)
            
            # Save component summary
            component_summary = []
            for i, wcc in enumerate(wccs):
                component_summary.append({
                    'component_id': i,
                    'size': len(wcc),
                    'type': 'wcc'  # weakly connected component
                })
            
            # Sort by size (largest first)
            component_summary.sort(key=lambda x: x['size'], reverse=True)
            
            summary_df = pd.DataFrame(component_summary)
            summary_file = os.path.join(cluster_component_dir, "components_summary.csv")
            summary_df.to_csv(summary_file, index=False)
            
            # Save detailed component files
            for i, wcc in enumerate(wccs):
                component_file = os.path.join(cluster_component_dir, f"component_{i}.txt")
                with open(component_file, 'w') as f:
                    f.write(f"Component {i} (Size: {len(wcc)})\n")
                    f.write("=" * 50 + "\n\n")
                    
                    # Sort nodes alphabetically for readability
                    sorted_nodes = sorted(list(wcc))
                    for j, node in enumerate(sorted_nodes, 1):
                        f.write(f"{j:3d}. {node}\n")
                    
                    f.write(f"\nTotal nodes in component: {len(wcc)}\n")
            
            # Save component statistics
            stats = {
                'cluster_id': cluster_id,
                'total_components': len(wccs),
                'total_nodes': sum(len(comp) for comp in wccs),
                'largest_component_size': max(len(comp) for comp in wccs) if wccs else 0,
                'isolated_nodes': len([comp for comp in wccs if len(comp) == 1]),
                'component_size_distribution': {
                    'size_1': len([comp for comp in wccs if len(comp) == 1]),
                    'size_2': len([comp for comp in wccs if len(comp) == 2]),
                    'size_3_10': len([comp for comp in wccs if 3 <= len(comp) <= 10]),
                    'size_11_50': len([comp for comp in wccs if 11 <= len(comp) <= 50]),
                    'size_51_100': len([comp for comp in wccs if 51 <= len(comp) <= 100]),
                    'size_100_plus': len([comp for comp in wccs if len(comp) > 100])
                }
            }
            
            stats_file = os.path.join(cluster_component_dir, "component_statistics.json")
            with open(stats_file, 'w') as f:
                json.dump(stats, f, indent=2)
            
            print(f"      📁 Detailed component reports: {cluster_component_dir}")
            print(f"         - {len(wccs)} components saved")
            print(f"         - Largest component: {stats['largest_component_size']} nodes")
            print(f"         - Isolated nodes: {stats['isolated_nodes']}")
            
        except Exception as e:
            print(f"   ⚠️ Could not save detailed component reports for cluster {cluster_id}: {e}")
    
    def _save_detailed_component_reports_from_existing(self, cluster_id: int, component_reports_dir: str):
        """Save detailed component reports from an existing component_reports directory"""
        try:
            # Create cluster-specific component reports directory
            cluster_component_dir = os.path.join(self.temp_dir, f"cluster_{cluster_id}_component_reports")
            
            # Remove existing directory if it exists (overwrite for each cluster)
            if os.path.exists(cluster_component_dir):
                shutil.rmtree(cluster_component_dir)
            
            os.makedirs(cluster_component_dir, exist_ok=True)
            
            # Save component summary
            component_summary = []
            components_summary_file = os.path.join(component_reports_dir, "components_summary.csv")
            if os.path.exists(components_summary_file):
                comp_df = pd.read_csv(components_summary_file)
                for i, row in comp_df.iterrows():
                    component_summary.append({
                        'component_id': i,
                        'size': int(row['size']),
                        'type': 'wcc' # Assuming all are weakly connected components for this pipeline
                    })
            
            # Sort by size (largest first)
            component_summary.sort(key=lambda x: x['size'], reverse=True)
            
            summary_df = pd.DataFrame(component_summary)
            summary_file = os.path.join(cluster_component_dir, "components_summary.csv")
            summary_df.to_csv(summary_file, index=False)
            
            # Save detailed component files
            for i, row in summary_df.iterrows():
                component_file = os.path.join(cluster_component_dir, f"component_{i}.txt")
                
                # Try to copy from existing component file
                existing_component_file = os.path.join(component_reports_dir, f"component_{int(row['component_id'])}.txt")
                if os.path.exists(existing_component_file):
                    # Copy the existing file
                    import shutil
                    shutil.copy2(existing_component_file, component_file)
                else:
                    # Fallback: create a basic file
                    with open(component_file, 'w') as f:
                        f.write(f"Component {i} (Size: {int(row['size'])})\n")
                        f.write("=" * 50 + "\n\n")
                        f.write(f"Component content not available in existing reports.\n")
                        f.write(f"Total nodes in component: {int(row['size'])}\n")
            
            # Save component statistics
            stats = {
                'cluster_id': cluster_id,
                'total_components': len(component_summary),
                'total_nodes': sum(int(row['size']) for row in summary_df.iterrows()),
                'largest_component_size': max(int(row['size']) for row in summary_df.iterrows()) if component_summary else 0,
                'isolated_nodes': len([row for row in summary_df.iterrows() if int(row['size']) == 1]),
                'component_size_distribution': {
                    'size_1': len([row for row in summary_df.iterrows() if int(row['size']) == 1]),
                    'size_2': len([row for row in summary_df.iterrows() if int(row['size']) == 2]),
                    'size_3_10': len([row for row in summary_df.iterrows() if 3 <= int(row['size']) <= 10]),
                    'size_11_50': len([row for row in summary_df.iterrows() if 11 <= int(row['size']) <= 50]),
                    'size_51_100': len([row for row in summary_df.iterrows() if 51 <= int(row['size']) <= 100]),
                    'size_100_plus': len([row for row in summary_df.iterrows() if int(row['size']) > 100])
                }
            }
            
            stats_file = os.path.join(cluster_component_dir, "component_statistics.json")
            with open(stats_file, 'w') as f:
                json.dump(stats, f, indent=2)
            
            print(f"      📁 Detailed component reports: {cluster_component_dir}")
            print(f"         - {len(component_summary)} components saved")
            print(f"         - Largest component: {stats['largest_component_size']} nodes")
            print(f"         - Isolated nodes: {stats['isolated_nodes']}")
            
        except Exception as e:
            print(f"   ⚠️ Could not save detailed component reports for cluster {cluster_id} from existing: {e}")
    
    def _save_combined_graph_statistics(self):
        """Save combined graph statistics across all clusters"""
        try:
            cluster_stats_dir = os.path.join(self.temp_dir, "cluster_graph_stats")
            if not os.path.exists(cluster_stats_dir):
                return
            
            # Collect all cluster statistics
            all_stats = []
            all_components = []
            
            for stats_file in os.listdir(cluster_stats_dir):
                if stats_file.endswith('_graph_stats.json'):
                    cluster_id = int(stats_file.split('_')[1])
                    with open(os.path.join(cluster_stats_dir, stats_file), 'r') as f:
                        stats = json.load(f)
                        all_stats.append(stats)
                
                if stats_file.endswith('_components.json'):
                    cluster_id = int(stats_file.split('_')[1])
                    with open(os.path.join(cluster_stats_dir, stats_file), 'r') as f:
                        components = json.load(f)
                        all_components.append(components)
            
            # Calculate summary statistics
            if all_stats:
                summary = {
                    'total_clusters': len(all_stats),
                    'avg_nodes_per_cluster': sum(s['num_graph_nodes'] for s in all_stats) / len(all_stats),
                    'avg_edges_per_cluster': sum(s['num_graph_edges'] for s in all_stats) / len(all_stats),
                    'avg_density': sum(s['graph_density'] for s in all_stats) / len(all_stats),
                    'avg_components': sum(s['num_components'] for s in all_stats) / len(all_stats),
                    'total_nodes': sum(s['num_graph_nodes'] for s in all_stats),
                    'total_edges': sum(s['num_graph_edges'] for s in all_stats),
                    'cluster_details': all_stats
                }
                
                # Save summary
                summary_file = os.path.join(cluster_stats_dir, "combined_graph_summary.json")
                with open(summary_file, 'w') as f:
                    json.dump(summary, f, indent=2)
                
                print(f"📊 Combined graph statistics saved to: {summary_file}")
                print(f"   Total clusters: {summary['total_clusters']}")
                print(f"   Average nodes per cluster: {summary['avg_nodes_per_cluster']:.1f}")
                print(f"   Average components per cluster: {summary['avg_components']:.1f}")
            
        except Exception as e:
            print(f"⚠️ Error saving combined graph statistics: {e}")
    
    def _extract_questions_from_codebooks(self) -> List[str]:
        """Extract questions from codebooks - handles both combined file and individual cluster files"""
        all_questions = []
        
        # Check if codebooks_file is a directory (for multiple cluster files) or single file
        if self.codebooks_file and os.path.isdir(self.codebooks_file):
            # Handle directory with multiple cluster files
            cluster_dir = self.codebooks_file
            target_clusters = self.cluster_ids if self.cluster_ids else []
            
            # If no specific clusters, find all cluster files
            if not target_clusters:
                import glob
                cluster_files = glob.glob(os.path.join(cluster_dir, "cluster_*.json"))
                target_clusters = [int(os.path.basename(f).split('_')[1].split('.')[0]) for f in cluster_files]
            
            for cluster_id in target_clusters:
                cluster_file = os.path.join(cluster_dir, f"cluster_{cluster_id}.json")
                if os.path.exists(cluster_file):
                    questions = self._extract_questions_from_single_file(cluster_file, cluster_id)
                    all_questions.extend(questions)
                else:
                    print(f"   ⚠️ Cluster file not found: cluster_{cluster_id}.json")
                    
        elif self.codebooks_file and os.path.isfile(self.codebooks_file):
            # Handle single file (either individual cluster or combined format)
            if "cluster_" in os.path.basename(self.codebooks_file):
                # Individual cluster file
                cluster_id = int(os.path.basename(self.codebooks_file).split('_')[1].split('.')[0])
                questions = self._extract_questions_from_single_file(self.codebooks_file, cluster_id)
                all_questions.extend(questions)
            else:
                # Combined format
                all_questions = self._extract_questions_from_combined_file()
        else:
            print(f"❌ Codebooks path not found: {self.codebooks_file}")
            return []
        
        print(f"📖 Extracted {len(all_questions)} total questions from codebooks")
        return all_questions
    
    def _extract_questions_from_single_file(self, file_path: str, cluster_id: int) -> List[str]:
        """Extract questions from a single cluster file"""
        try:
            with open(file_path, 'r') as f:
                data = json.load(f)
            
            # Handle individual cluster file format
            cluster_data = data.get('cluster', {})
            questions = cluster_data.get('all_questions', [])
            print(f"   📋 Cluster {cluster_id}: {len(questions)} questions")
            return questions
            
        except Exception as e:
            print(f"❌ Error reading cluster file {file_path}: {e}")
            return []
    
    def _extract_questions_from_combined_file(self) -> List[str]:
        """Extract questions from combined codebooks file"""
        try:
            with open(self.codebooks_file, 'r') as f:
                data = json.load(f)
            
            all_questions = []
            clusters = data.get('clusters', {})
            
            # Filter by cluster IDs if specified
            target_clusters = self.cluster_ids if self.cluster_ids else clusters.keys()
            
            for cluster_id in target_clusters:
                cluster_id_str = str(cluster_id)
                if cluster_id_str in clusters:
                    cluster = clusters[cluster_id_str]
                    questions = cluster.get('all_questions', [])
                    all_questions.extend(questions)
                    print(f"   📋 Cluster {cluster_id}: {len(questions)} questions")
                else:
                    print(f"   ⚠️ Cluster {cluster_id} not found in codebooks file")
            
            return all_questions
            
        except Exception as e:
            print(f"❌ Error reading combined codebooks file: {e}")
            return []
    
    def _get_memory_usage(self):
        """Get current memory usage"""
        process = psutil.Process()
        memory_info = process.memory_info()
        return {
            'rss': memory_info.rss / 1024 / 1024,  # MB
            'vms': memory_info.vms / 1024 / 1024,  # MB
            'percent': process.memory_percent()
        }
    
    def _log_memory_usage(self, stage: str):
        """Log memory usage at different stages"""
        memory = self._get_memory_usage()
        print(f"💾 Memory usage at {stage}: {memory['rss']:.1f}MB RSS, {memory['percent']:.1f}%")
    
    def _load_or_create_chunks_cache(self, chunk_size: int = 256, overlap: int = 50) -> Dict[int, List[str]]:
        """Load pre-chunked data from cache or create it if not exists"""
        cache_path = os.path.join(self.temp_dir, "aliabdaal_chunks_cache.json")
        original_data_path = "data/aliabdaal_500.csv"
        
        # Check if cache exists and is newer than original data
        if os.path.exists(cache_path) and os.path.exists(original_data_path):
            cache_mtime = os.path.getmtime(cache_path)
            data_mtime = os.path.getmtime(original_data_path)
            
            if cache_mtime > data_mtime:
                print(f"   📁 Loading from cache: {cache_path}")
                try:
                    with open(cache_path, 'r') as f:
                        cached_data = json.load(f)
                    
                    # Convert string keys back to int
                    chunked_data = {int(k): v for k, v in cached_data.items()}
                    print(f"   ✅ Cache loaded successfully")
                    return chunked_data
                except Exception as e:
                    print(f"   ⚠️ Error loading cache: {e}")
                    print(f"   🔄 Will recreate cache...")
        
        # Create new cache
        print(f"   🔄 Creating chunks cache (chunk_size={chunk_size}, overlap={overlap})...")
        chunked_data = self._prechunk_aliabdaal_data(chunk_size, overlap)
        
        # Save to cache
        try:
            # Convert int keys to string for JSON serialization
            cache_data = {str(k): v for k, v in chunked_data.items()}
            with open(cache_path, 'w') as f:
                json.dump(cache_data, f, indent=2)
            print(f"   💾 Cache saved to: {cache_path}")
        except Exception as e:
            print(f"   ⚠️ Error saving cache: {e}")
        
        return chunked_data
    
    def _prechunk_aliabdaal_data(self, chunk_size: int = 256, overlap: int = 50) -> Dict[int, List[str]]:
        """Pre-chunk all Ali Abdaal transcripts for reuse across clusters (word-based like build_corpus)"""
        try:
            import pandas as pd
            original_df = pd.read_csv('data/aliabdaal_500.csv')
            
            chunked_data = {}
            total_chunks = 0
            
            for idx, row in original_df.iterrows():
                text = str(row['text'])
                chunks = []
                
                # Use word-based chunking like build_corpus
                if not text:
                    chunked_data[idx] = []
                    continue
                    
                tokens = text.split()
                if chunk_size <= 0:
                    chunks = [text]
                else:
                    step = max(1, chunk_size - max(0, overlap))
                    for start in range(0, len(tokens), step):
                        end = min(start + chunk_size, len(tokens))
                        if start >= end:
                            break
                        chunk = " ".join(tokens[start:end]).strip()
                        if chunk:
                            chunks.append(chunk)
                        if end == len(tokens):
                            break
                
                chunked_data[idx] = chunks
                total_chunks += len(chunks)
            
            print(f"   📊 Created {total_chunks} chunks from {len(original_df)} transcripts (word-based)")
            return chunked_data
            
        except Exception as e:
            print(f"   ❌ Error pre-chunking data: {e}")
            return {}
        
    async def embed_questions(self, questions: List[str]) -> np.ndarray:
        """Embed all questions using the embedding model"""
        self._log_memory_usage("before embedding")
        print(f"🧠 Embedding {len(questions)} questions...")
        
        # Create DataFrame for embedding
        df = pd.DataFrame({
            'text': questions,
            'id': range(len(questions))
        })
        
        # Build embeddings
        embeddings_path = os.path.join(self.temp_dir, "question_embeddings.parquet")
        embeddings_path, embeddings, _, _ = await build_embeddings_parquet(
            corpus_df=df,
            output_parquet=embeddings_path
        )
        
        # Clean up DataFrame
        del df
        gc.collect()
        
        self._log_memory_usage("after embedding")
        print(f"✅ Embedded {len(questions)} questions")
        return embeddings
    
    def cluster_questions(self, questions: List[str], embeddings: np.ndarray) -> List[QuestionCluster]:
        """Cluster questions using K-means and select representatives"""
        from sklearn.cluster import KMeans
        
        n_questions = len(questions)
        n_clusters = max(1, n_questions // self.cluster_size)
        
        print(f"🎯 Clustering {n_questions} questions into {n_clusters} clusters...")
        
        # Perform K-means clustering
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(embeddings)
        
        # Group questions by cluster
        clusters = []
        for cluster_id in range(n_clusters):
            # Find questions in this cluster
            cluster_indices = [i for i, label in enumerate(cluster_labels) if label == cluster_id]
            cluster_questions = [questions[i] for i in cluster_indices]
            cluster_embeddings = embeddings[cluster_indices]
            
            # Find representative (closest to cluster center)
            cluster_center = kmeans.cluster_centers_[cluster_id]
            distances = np.linalg.norm(cluster_embeddings - cluster_center, axis=1)
            representative_idx = np.argmin(distances)
            representative_question = cluster_questions[representative_idx]
            representative_global_idx = cluster_indices[representative_idx]
            
            cluster = QuestionCluster(
                cluster_id=cluster_id,
                representative_question=representative_question,
                representative_index=representative_global_idx,
                all_questions=cluster_questions,
                all_indices=cluster_indices,
                embeddings=cluster_embeddings,
                cluster_center=cluster_center
            )
            clusters.append(cluster)
            
            print(f"   📦 Cluster {cluster_id}: {len(cluster_questions)} questions")
            print(f"      Representative: {representative_question[:50]}...")
        
        return clusters
    
    async def process_single_cluster_with_hrp(self, cluster: QuestionCluster, cluster_codebook_file: str, chunked_data_cache: Dict[int, List[str]] = None) -> List[Dict[str, Any]]:
        """
        New method: Process a cluster by running schema induction once, then individual HRP for each question
        
        Flow:
        1. Run schema induction pipeline once on cluster codebooks
        2. For each question (including representative):
           - Run HRP to get question-specific context
           - Answer the question with that context
           - Complete before moving to next question
        """
        print(f"\n📋 Processing Cluster {cluster.cluster_id}: {len(cluster.all_questions)} questions")
        print(f"   🧠 Step 1: Running schema induction on cluster codebooks...")
        
        results = []
        
        try:
            # Step 1: Run schema induction pipeline once for this cluster
            from utils.initial_iteration.schema_induction_pipeline import LLMSchemaInductionPipeline
            
            # Use the representative question for schema induction with this cluster's specific codebook file
            schema_pipeline = LLMSchemaInductionPipeline(
                question=cluster.representative_question,  # Use representative for initial schema
                chunk_size=256,
                overlap=50,
                max_iterations=1,
                min_frequency=1,  # Set to 1 to disable frequency-based merging (like before)
                min_frequency_ratio=0.5,
                strategy=self.corpus_strategy,
                model=self.model,
                use_codebooks=self.use_codebooks,
                codebooks_file=cluster_codebook_file,  # Use specific cluster file
                cluster_ids=[cluster.cluster_id],  # Only this cluster's codebooks
                max_codebooks_per_cluster=self.max_codebooks_per_cluster,
                batch_size=self.batch_size,
                chunked_data_cache=chunked_data_cache  # Pass the pre-chunked data cache
            )
            
            # Run the schema induction pipeline (embeddings, clustering, NLI, conflict detection, topological graph)
            schema_result = await schema_pipeline.run_pipeline()
            
            # Save detailed graph statistics for this cluster
            self._save_cluster_graph_statistics(cluster.cluster_id, schema_result)
            
            # Handle schema result - it returns a dictionary with summary information
            final_codes = schema_result.get('num_unique_codes', 0)
            nodes = schema_result.get('num_graph_nodes', 0)
            total_time = schema_result.get('total_time', 0)
            
            print(f"   ✅ Schema induction completed: {final_codes} codes, {nodes} nodes (Time: {total_time:.1f}s)")
            
            # Step 2: For each question, run individual HRP and answer
            print(f"   🔍 Step 2: Processing {len(cluster.all_questions)} questions individually...")
            
            for i, question in enumerate(cluster.all_questions):
                question_idx = cluster.all_indices[i]
                is_representative = (i == cluster.representative_index)
                
                print(f"   Question {i+1}/{len(cluster.all_questions)}: {question[:80]}...")
                
                try:
                    # Step 2a: Run HRP for this specific question
                    from utils.context_retrievers import DataRetrievalContextRetriever
                    
                    # Create context retriever using the schema induction results
                    context_retriever = DataRetrievalContextRetriever(
                        os.path.join(self.temp_dir, "embeddings.parquet"),
                        os.path.join(self.temp_dir, "topological_graph"),
                        top_k=30
                                        )
                    
                    # Get question-specific context via HRP
                    hrp_result = await context_retriever.retrieve_context(
                        question=question,
                        strategy=self.strategy
                    )
                    
                    # Extract chunks from the HRP result
                    hrp_context = hrp_result.get('chunks', [])
                    
                    print(f"     📊 HRP context: {len(hrp_context)} chunks (avg score: {sum(c.get('score', 0) for c in hrp_context)/len(hrp_context):.3f})" if hrp_context else "     📊 HRP context: 0 chunks")
                    
                    # Save top 20 chunks to data_retrieved folder (accumulate for cluster)
                    self.save_retrieved_chunks(hrp_context, cluster.cluster_id, question_idx)
                    
                    # Step 2b: Answer this specific question with its tailored context
                    from utils.question_answerer import QuestionAnswerer
                    
                    # Configure model settings
                    if self.model == "32B":
                        model_config = {
                            'base_url': os.getenv("VLLM_QWEN_32B_URL"),
                            'model_name': os.getenv("VLLM_QWEN_32B_MODEL"),
                            'max_tokens': 2048,
                            'temperature': 0.7
                        }
                    elif self.model == "30B-A3B":
                        model_config = {
                            'base_url': os.getenv("VLLM_QWEN_A3B_URL"),
                            'model_name': os.getenv("VLLM_QWEN_A3B_MODEL"),
                            'max_tokens': 2048,
                            'temperature': 0.7
                        }
                    else:
                        raise ValueError(f"Unknown model: {self.model}")
                    
                    # Create question answerer
                    async with QuestionAnswerer(model_config=model_config) as answerer:
                        # Generate answer using question-specific context
                        result_dict = await answerer.answer_question_with_chunks(
                            question=question,
                            chunks=hrp_context,
                            strategy=self.strategy,
                            max_chars=25000,
                            top_k=30
                        )
                        answer = result_dict.get('answer', 'Failed to generate answer')
                    
                    # Store result
                    result = {
                        'question_index': question_idx,
                        'question': question,
                        'cluster_id': cluster.cluster_id,
                        'is_representative': is_representative,
                        'answer': answer,
                        'context_chunks': len(hrp_context),
                        'context_score': sum(c.get('score', 0) for c in hrp_context)/len(hrp_context) if hrp_context else 0,
                        'processing_method': 'individual_hrp'
                    }
                    
                    results.append(result)
                    print(f"     ✅ Question answered (context: {len(hrp_context)} chunks)")
                    
                except Exception as e:
                    print(f"     ❌ Error processing question: {e}")
                    # Still add error result to maintain question tracking
                    result = {
                        'question_index': question_idx,
                        'question': question,
                        'cluster_id': cluster.cluster_id,
                        'is_representative': is_representative,
                        'answer': f"Error: {str(e)}",
                        'context_chunks': 0,
                        'context_score': 0,
                        'processing_method': 'individual_hrp_failed'
                    }
                    results.append(result)
            
            print(f"   ✅ Cluster {cluster.cluster_id} completed: {len(results)} questions processed")
            
            # Save cluster results immediately
            self.save_cluster_results(results, cluster.cluster_id)
            
            return results
            
        except Exception as e:
            print(f"   ❌ Error in schema induction for cluster {cluster.cluster_id}: {e}")
            # Return error results for all questions
            error_results = []
            for i, question in enumerate(cluster.all_questions):
                question_idx = cluster.all_indices[i]
                is_representative = (i == cluster.representative_index)
                
                error_results.append({
                    'question_index': question_idx,
                    'question': question,
                    'cluster_id': cluster.cluster_id,
                    'is_representative': is_representative,
                    'answer': f"Cluster processing error: {str(e)}",
                    'context_chunks': 0,
                    'context_score': 0,
                    'processing_method': 'cluster_failed'
                })
            
            return error_results

    async def run_batch_pipeline(self, questions: List[str]) -> BatchResult:
        start_time = time.time()
        
        # Extract questions from codebooks if needed
        if self.use_codebooks and not questions:
            questions = self._extract_questions_from_codebooks()
        
        print("🚀 Starting Batch QA Pipeline (Individual HRP Mode)")
        print(f"   Questions: {len(questions)}")
        print(f"   Cluster size: {self.cluster_size}")
        print(f"   Strategy: {self.strategy}")
        print(f"   Model: {self.model}")
        print(f"   Use codebooks: {self.use_codebooks}")
        if self.use_codebooks:
            print(f"   Codebooks file: {self.codebooks_file}")
            print(f"   Target clusters: {self.cluster_ids if self.cluster_ids else 'all'}")
        print("=" * 60)
        
        memory_info = self._get_memory_usage()
        print(f"💾 Memory usage at pipeline start: {memory_info['rss']:.1f}MB RSS, {memory_info['percent']:.1f}%")
        
        # Step 0: Load or create pre-chunked Ali Abdaal data cache (when using codebooks)
        chunked_data_cache = None
        if self.use_codebooks:
            print(f"\n📚 Loading Ali Abdaal chunks cache...")
            chunked_data_cache = self._load_or_create_chunks_cache()
            print(f"✅ Loaded {sum(len(chunks) for chunks in chunked_data_cache.values())} total chunks from {len(chunked_data_cache)} transcripts")
        
        # Step 1: Embed questions
        embeddings = await self.embed_questions(questions)
        
        # Step 1.5: Handle clustering based on mode
        if self.use_codebooks and self.cluster_ids:
            # When using pre-generated codebooks with specific cluster IDs, 
            # don't re-cluster questions - use the original cluster structure
            print(f"📋 Using original cluster structure from codebooks (no re-clustering)")
            
            # Create a single cluster with all questions from the original cluster
            from dataclasses import dataclass
            @dataclass
            class OriginalCluster:
                cluster_id: int
                all_questions: List[str]
                all_indices: List[int]
                representative_index: int
                representative_question: str
                embeddings: np.ndarray
                cluster_center: np.ndarray
            
            # Use the first question as representative (or we could pick the most central one)
            representative_idx = 0
            representative_question = questions[0]
            
            # Create a single cluster with all questions, using the original cluster ID
            original_cluster_id = self.cluster_ids[0] if self.cluster_ids else 0
            clusters = [OriginalCluster(
                cluster_id=original_cluster_id,  # Use the original cluster ID from the file
                all_questions=questions,
                all_indices=list(range(len(questions))),
                representative_index=representative_idx,
                representative_question=representative_question,
                embeddings=embeddings,
                cluster_center=embeddings[representative_idx]
            )]
        else:
            # Normal clustering for questions
            clusters = self.cluster_questions(questions, embeddings)
        
        # Step 2: Process each cluster with individual HRP approach
        print(f"\n🚀 Processing {len(clusters)} clusters with individual HRP...")
        
        all_results = []
        successful_clusters = 0
        failed_clusters = 0
        
        # Process each cluster separately with its own schema induction
        for cluster in clusters:
            print(f"\n📋 Processing Cluster {cluster.cluster_id}: {len(cluster.all_questions)} questions")
            
            # Load codebooks for this specific cluster
            cluster_codebook_file = None
            if self.use_codebooks:
                if "cluster_" in os.path.basename(self.codebooks_file):
                    # Single cluster file
                    cluster_codebook_file = self.codebooks_file
                elif os.path.isdir(self.codebooks_file):
                    # Directory with multiple cluster files
                    cluster_codebook_file = os.path.join(self.codebooks_file, f"cluster_{cluster.cluster_id}.json")
                else:
                    # Combined file - use original approach
                    cluster_codebook_file = self.codebooks_file
                
                # Check if cluster file exists
                if not os.path.exists(cluster_codebook_file):
                    print(f"   ⚠️ Codebook file not found: {cluster_codebook_file}")
                    # Create error results for all questions in this cluster
                    cluster_results = []
                    for i, question in enumerate(cluster.all_questions):
                        question_idx = cluster.all_indices[i]
                        is_representative = (i == cluster.representative_index)
                        cluster_results.append({
                            'question_index': question_idx,
                            'question': question,
                            'cluster_id': cluster.cluster_id,
                            'is_representative': is_representative,
                            'answer': f"Error: Codebook file not found for cluster {cluster.cluster_id}",
                            'context_chunks': 0,
                            'context_score': 0,
                            'processing_method': 'cluster_file_missing'
                        })
                    all_results.extend(cluster_results)
                    failed_clusters += 1
                    continue
            
            # Process this cluster with its individual codebook file
            cluster_results = await self.process_single_cluster_with_hrp(cluster, cluster_codebook_file, chunked_data_cache)
            all_results.extend(cluster_results)
            
            # Count success/failure
            if any(r.get('processing_method', '').endswith('_failed') for r in cluster_results):
                failed_clusters += 1
            else:
                successful_clusters += 1
        
        total_time = time.time() - start_time
        
        # Create batch result
        batch_result = BatchResult(
            total_questions=len(questions),
            total_clusters=len(clusters),
            successful_answers=len([r for r in all_results if not r.get('answer', '').startswith('Error')]),
            failed_answers=len([r for r in all_results if r.get('answer', '').startswith('Error')]),
            total_time=total_time,
            results=all_results
        )
        
        # Print summary
        print("\n" + "="*60)
        print("📊 BATCH QA PIPELINE RESULTS (Individual HRP)")
        print("="*60)
        print(f"Total Questions: {batch_result.total_questions}")
        print(f"Total Clusters: {batch_result.total_clusters}")
        print(f"Successful Clusters: {successful_clusters}")
        print(f"Failed Clusters: {failed_clusters}")
        print(f"Successful Answers: {batch_result.successful_answers}")
        print(f"Failed Answers: {batch_result.failed_answers}")
        print(f"Total Processing Time: {batch_result.total_time:.2f}s")
        print(f"Average Time per Cluster: {batch_result.total_time/len(clusters):.2f}s")
        print(f"Average Time per Question: {batch_result.total_time/len(questions):.2f}s")
        
        # Print sample answers
        print("\n📋 SAMPLE ANSWERS:")
        print("="*60)
        
        for i, result in enumerate(all_results[:3]):  # Show first 3 answers
            cluster_marker = "Representative" if result['is_representative'] else "Non-representative"
            print(f"\nCluster {result['cluster_id']} - Question{i+1}: {result['question'][:80]}...")
            print(f"Response: {result['answer'][:200]}...")
            print(f"{cluster_marker}: {result['is_representative']}")
            print(f"Context: {result.get('context_chunks', 0)} chunks (score: {result.get('context_score', 0):.3f})")
        
        print(f"   ✅ Completed in {total_time:.2f}s")
        print(f"   📁 Context files: ['embeddings', 'topological_graph', 'corpus']")
        
        # Save results to file automatically
        results_file = self.save_results_to_file(batch_result)
        
        # Save combined graph statistics across all clusters
        self._save_combined_graph_statistics()
        
        return batch_result

    def save_cluster_results(self, cluster_results: List[Dict[str, Any]], cluster_id: int) -> str:
        """Save individual cluster results to qa_response folder"""
        qa_response_dir = os.path.join(self.temp_dir, "qa_response")
        os.makedirs(qa_response_dir, exist_ok=True)
        
        output_path = os.path.join(qa_response_dir, f"qa_response_{cluster_id}.json")
        
        # Organize results by cluster
        cluster_data = {
            'metadata': {
                'cluster_id': cluster_id,
                'total_questions': len(cluster_results),
                'successful_answers': len([r for r in cluster_results if not r.get('answer', '').startswith('Error')]),
                'failed_answers': len([r for r in cluster_results if r.get('answer', '').startswith('Error')]),
                'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
                'pipeline_config': {
                    'strategy': self.strategy,
                    'model': self.model,
                    'use_codebooks': self.use_codebooks,
                    'max_codebooks_per_cluster': self.max_codebooks_per_cluster,
                    'batch_size': self.batch_size
                }
            },
            'questions_and_answers': cluster_results
        }
        
        # Save cluster file
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(cluster_data, f, indent=2, ensure_ascii=False)
        
        print(f"     💾 Cluster {cluster_id} results saved to: qa_response_{cluster_id}.json")
        return output_path

    def save_retrieved_chunks(self, chunks: List[Dict[str, Any]], cluster_id: int, question_idx: int) -> str:
        """Save top 20 retrieved chunks for a specific question, accumulating for the cluster"""
        data_retrieved_dir = os.path.join(self.temp_dir, "data_retrieved")
        os.makedirs(data_retrieved_dir, exist_ok=True)
        
        # Take top 30 chunks (or all if less than 30)
        top_chunks = chunks[:30] if len(chunks) > 30 else chunks
        
        # File path for this cluster
        output_path = os.path.join(data_retrieved_dir, f"data_cluster_{cluster_id}.json")
        
        # Load existing data if file exists, otherwise create new structure
        if os.path.exists(output_path):
            with open(output_path, 'r', encoding='utf-8') as f:
                cluster_data = json.load(f)
        else:
            cluster_data = {
                'metadata': {
                    'cluster_id': cluster_id,
                    'total_questions': 0,
                    'total_chunks_retrieved': 0,
                    'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
                    'pipeline_config': {
                        'strategy': self.strategy,
                        'model': self.model
                    }
                },
                'questions': {}
            }
        
        # Add this question's chunks to the cluster data
        cluster_data['questions'][f'question_{question_idx}'] = {
            'question_index': question_idx,
            'total_chunks_retrieved': len(chunks),
            'top_chunks_saved': len(top_chunks),
            'chunks': top_chunks  # Each chunk should now have 'content', 'score', and 'chunk_id'
        }
        
        # Update metadata
        cluster_data['metadata']['total_questions'] = len(cluster_data['questions'])
        cluster_data['metadata']['total_chunks_retrieved'] += len(chunks)
        
        # Save updated cluster data
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(cluster_data, f, indent=2, ensure_ascii=False)
        
        print(f"     📁 Top {len(top_chunks)} chunks for question {question_idx} added to: data_cluster_{cluster_id}.json")
        return output_path

    def save_results_to_file(self, batch_result: BatchResult, output_path: str = None) -> str:
        """Save batch results organized by cluster to qa_response folder"""
        qa_response_dir = os.path.join(self.temp_dir, "qa_response")
        os.makedirs(qa_response_dir, exist_ok=True)
        
        if output_path is None:
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            output_path = os.path.join(qa_response_dir, f"batch_qa_summary_{timestamp}.json")
        
        # Organize results by cluster
        clusters_data = {}
        for result in batch_result.results:
            cluster_id = result['cluster_id']
            if cluster_id not in clusters_data:
                clusters_data[cluster_id] = []
            clusters_data[cluster_id].append(result)
        
        # Save individual cluster files
        cluster_files = []
        for cluster_id, cluster_results in clusters_data.items():
            cluster_file = self.save_cluster_results(cluster_results, cluster_id)
            cluster_files.append(f"qa_response_{cluster_id}.json")
        
        # Save summary file
        summary_data = {
            'metadata': {
                'total_questions': batch_result.total_questions,
                'total_clusters': batch_result.total_clusters,
                'successful_answers': batch_result.successful_answers,
                'failed_answers': batch_result.failed_answers,
                'total_time': batch_result.total_time,
                'average_time_per_question': batch_result.total_time / batch_result.total_questions,
                'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
                'pipeline_config': {
                    'cluster_size': self.cluster_size,
                    'strategy': self.strategy,
                    'model': self.model,
                    'use_codebooks': self.use_codebooks,
                    'codebooks_file': self.codebooks_file,
                    'cluster_ids': self.cluster_ids,
                    'max_codebooks_per_cluster': self.max_codebooks_per_cluster,
                    'batch_size': self.batch_size
                }
            },
            'cluster_files': cluster_files,
            'clusters_summary': {
                str(cluster_id): {
                    'total_questions': len(results),
                    'successful_answers': len([r for r in results if not r.get('answer', '').startswith('Error')]),
                    'failed_answers': len([r for r in results if r.get('answer', '').startswith('Error')]),
                    'representative_question': next((r['question'] for r in results if r.get('is_representative')), 'Unknown')
                }
                for cluster_id, results in clusters_data.items()
            }
        }
        
        # Save summary file
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(summary_data, f, indent=2, ensure_ascii=False)
        
        print(f"💾 Batch summary saved to: {os.path.basename(output_path)}")
        print(f"📁 Individual cluster files saved in: qa_response/")
        for cluster_file in cluster_files:
            print(f"   - {cluster_file}")
        
        return output_path

def load_questions_from_file(filepath: str) -> List[str]:
    """Load questions from a text file or CSV file"""
    # Check if it's a CSV file
    if filepath.lower().endswith('.csv'):
        import csv
        questions = []
        with open(filepath, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            for row in reader:
                if row:  # Skip empty rows
                    # Remove quotes if present and strip whitespace
                    question = row[0].strip().strip('"')
                    if question:  # Skip empty questions
                        questions.append(question)
        return questions
    else:
        # Handle as text file (one question per line)
        with open(filepath, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        # Remove empty lines and strip whitespace
        questions = [line.strip() for line in lines if line.strip()]
        return questions

def load_questions_from_string(questions_str: str) -> List[str]:
    """Load questions from a comma-separated string"""
    import csv
    from io import StringIO
    
    # Use CSV reader to properly handle quoted strings with commas
    reader = csv.reader(StringIO(questions_str))
    questions = []
    for row in reader:
        for question in row:
            question = question.strip().strip('"')
            if question:
                questions.append(question)
    return questions

async def main():
    """Main entry point"""
    parser = argparse.ArgumentParser(
        description="Batch Question Answering Pipeline",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Process questions from file
  python batch_qa_pipeline.py --questions questions.txt --output results.json
  
  # Process questions from string
  python batch_qa_pipeline.py --questions "Q1,Q2,Q3" --cluster-size 5
  
  # With thinking mode
  python batch_qa_pipeline.py --questions questions.txt --thinking --output results.json
        """
    )
    
    parser.add_argument("--questions", required=False, 
                       help="Questions file path (CSV or text) or comma-separated questions string (not required when using --use-codebooks)")
    parser.add_argument("--output", help="Output JSON file path")
    parser.add_argument("--cluster-size", type=int, default=5,
                                                help="Target number of questions per cluster (default: 5)")
    parser.add_argument("--strategy", choices=["fixed", "dynamic"], default="dynamic",
                       help="Chunk strategy (default: dynamic)")
    parser.add_argument("--corpus-strategy", choices=["strategy_1", "strategy_2"], default="strategy_1",
                       help="Corpus generation strategy (default: strategy_1)")
    parser.add_argument("--thinking", action="store_true",
                       help="Enable thinking mode")
    parser.add_argument("--model", choices=["32B", "30B-A3B"], default="32B",
                       help="Model to use for processing (default: 32B)")
    
    # Codebook loading options
    parser.add_argument("--use-codebooks", action="store_true",
                       help="Load codebooks from JSON instead of running build_corpus")
    parser.add_argument("--codebooks-file", default="temp_files/generated_codebooks.json",
                       help="Path to codebooks JSON file (default: temp_files/generated_codebooks.json)")
    parser.add_argument("--cluster-ids", type=int, nargs="+", metavar="ID",
                       help="Specific cluster IDs to process (space-separated, e.g., --cluster-ids 0 1 2)")
    
    # Large cluster optimization options
    parser.add_argument("--max-codebooks-per-cluster", type=int, default=None,
                       help="Maximum codebooks to use per cluster (for large clusters, e.g., --max-codebooks-per-cluster 10000)")
    parser.add_argument("--batch-size", type=int, default=10000,
                       help="Batch size for processing large datasets (default: 10000)")
    
    args = parser.parse_args()
    
    # Validate arguments
    if not args.use_codebooks and not args.questions:
        print("❌ Error: --questions is required when not using --use-codebooks")
        parser.print_help()
        return
    
    # Load questions
    questions = []
    if args.questions:
        if os.path.exists(args.questions):
            questions = load_questions_from_file(args.questions)
            print(f"📄 Loaded {len(questions)} questions from file: {args.questions}")
        else:
            questions = load_questions_from_string(args.questions)
            print(f"📝 Loaded {len(questions)} questions from string")
    elif args.use_codebooks:
        # When using codebooks, we'll extract questions from the codebooks file
        print(f"📖 Using codebooks mode - questions will be extracted from {args.codebooks_file}")
        questions = []  # Will be populated from codebooks
    
    if not questions and not args.use_codebooks:
        print("❌ No questions provided")
        return
    
    # Create and run pipeline
    pipeline = BatchQAPipeline(
        cluster_size=args.cluster_size,
        strategy=args.strategy,
        corpus_strategy=args.corpus_strategy,
        use_thinking=args.thinking,
        model=args.model,
        use_codebooks=args.use_codebooks,
        codebooks_file=args.codebooks_file,
        cluster_ids=args.cluster_ids,
        max_codebooks_per_cluster=args.max_codebooks_per_cluster,
        batch_size=args.batch_size
    )
    
    result = await pipeline.run_batch_pipeline(questions)
    
    # Display results
    print("\n" + "=" * 60)
    print("📊 BATCH QA PIPELINE RESULTS")
    print("=" * 60)
    print(f"Total Questions: {result.total_questions}")
    print(f"Total Clusters: {result.total_clusters}")
    print(f"Successful Answers: {result.successful_answers}")
    print(f"Failed Answers: {result.failed_answers}")
    print(f"Total Processing Time: {result.total_time:.2f}s")
    print(f"Average Time per Cluster: {result.total_time/result.total_clusters:.2f}s" if result.total_clusters > 0 else "Average Time per Cluster: N/A")
    
    # Save results if output specified
    if args.output:
        # Ensure output is saved in temp_files directory
        if not args.output.startswith('/') and not args.output.startswith('./'):
            output_path = os.path.join(TEMP_FILES_DIR, args.output)
        else:
            output_path = args.output
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Save in the requested format
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump({
                'total_questions': result.total_questions,
                'total_clusters': result.total_clusters,
                'successful_answers': result.successful_answers,
                'failed_answers': result.failed_answers,
                'total_time': result.total_time,
                'results': result.results
            }, f, indent=2, ensure_ascii=False)
        
        print(f"\n💾 Results saved to: {output_path}")
    
    # Display sample answers from new results format
    print(f"\n📋 SAMPLE ANSWERS:")
    print("=" * 60)
    sample_count = 0
    for question_result in result.results[:3]:  # Show first 3 questions
        print(f"\nCluster {question_result['cluster_id']} - Question{sample_count + 1}: {question_result['question'][:80]}...")
        print(f"Response: {question_result['answer'][:100]}...")
        print(f"Representative: {question_result.get('is_representative', False)}")
        print(f"Context: {question_result.get('context_chunks', 0)} chunks (score: {question_result.get('context_score', 0):.3f})")
        sample_count += 1

if __name__ == "__main__":
    asyncio.run(main()) 