#!/usr/bin/env python3
"""
Codebook Generation Pipeline

This pipeline follows the batch_qa_pipeline pattern:
1. Clusters questions from eval_questions.csv into groups (default: 5 questions per cluster)
2. Picks the question closest to cluster centroid as representative
3. Runs code generation for ALL chunks for each representative question
4. Saves codebooks by cluster number with question-to-cluster mapping

This processes all chunks for representative questions while tracking cluster assignments.

Usage:
  python codebook_generation_pipeline.py --output codebooks_by_cluster.json
  python codebook_generation_pipeline.py --chunk-size 256 --overlap 50 --cluster-size 5
"""

import os
import sys
import json
import asyncio
import argparse
import time
import numpy as np
import pandas as pd
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from pathlib import Path
from datetime import datetime
import shutil
import gc
import psutil
import atexit
from datetime import datetime

# Add current directory to path for imports
sys.path.append(os.path.dirname(__file__))

from utils.initial_iteration.embeddings import build_embeddings_parquet
from utils.initial_iteration.build_corpus import find_first_valid_data_file, build_text_chunks, async_main

# Environment variables
VLLM_EMBEDDING_URL = os.getenv("VLLM_EMBEDDING_URL")
DEFAULT_EMBEDDING_MODEL = os.getenv("DEFAULT_EMBEDDING_MODEL")
TEMP_FILES_DIR = os.path.join(os.path.dirname(__file__), "temp_files")

@dataclass
class QuestionCluster:
    """Represents a cluster of questions"""
    cluster_id: int
    representative_question: str
    representative_index: int
    all_questions: List[str]
    all_indices: List[int]
    embeddings: np.ndarray
    cluster_center: np.ndarray

@dataclass
class ClusterCodebook:
    """Represents a cluster with its generated codebooks"""
    cluster_id: int
    cluster_size: int
    representative_question: str
    representative_codebooks: List[str]
    all_questions: List[str]

@dataclass
class CodebookResult:
    """Results from codebook generation"""
    total_questions: int
    total_clusters: int
    processing_time: float
    cluster_codebooks: Dict[int, ClusterCodebook]
    summary: Dict[str, Any]

class CodebookGenerationPipeline:
    """Main pipeline for codebook generation following batch_qa pattern with questions"""
    
    def __init__(self, chunk_size: int = 256, overlap: int = 50, 
                 cluster_size: int = 5, strategy: str = "strategy_1", model: str = "32B",
                 temp_dir: str = "temp_files", output_file: str = "temp_files/generated_codebooks.json"):
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.cluster_size = cluster_size
        self.strategy = strategy
        self.model = model
        self.temp_dir = temp_dir
        self.output_file = output_file
        os.makedirs(self.temp_dir, exist_ok=True)
        
        # Register cleanup function
        atexit.register(self._cleanup)
    
    def _cleanup(self):
        """Clean up resources to prevent memory leaks"""
        gc.collect()
        print("🧹 CodebookGenerationPipeline cleanup completed")
    
    def _get_memory_usage(self):
        """Get current memory usage"""
        process = psutil.Process()
        memory_info = process.memory_info()
        return {
            'rss': memory_info.rss / 1024 / 1024,  # MB
            'vms': memory_info.vms / 1024 / 1024,  # MB
            'percent': process.memory_percent()
        }
    
    def _log_memory_usage(self, stage: str):
        """Log memory usage at different stages"""
        memory = self._get_memory_usage()
        print(f"💾 Memory usage at {stage}: {memory['rss']:.1f}MB RSS, {memory['percent']:.1f}%")
    
    def load_questions_from_csv(self, csv_path: str) -> List[str]:
        """Load questions from CSV file"""
        print(f"📖 Loading questions from {csv_path}...")
        
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"Questions file not found: {csv_path}")
        
        questions = []
        with open(csv_path, 'r', encoding='utf-8') as f:
            import csv
            reader = csv.reader(f)
            for row in reader:
                if row:  # Skip empty rows
                    # Remove quotes if present and strip whitespace
                    question = row[0].strip().strip('"')
                    if question:  # Skip empty questions
                        questions.append(question)
        
        print(f"   ✅ Loaded {len(questions)} questions")
        return questions
    
    async def embed_questions(self, questions: List[str]) -> np.ndarray:
        """Embed all questions using the embedding model"""
        self._log_memory_usage("before embedding")
        print(f"🧠 Embedding {len(questions)} questions...")
        
        # Create DataFrame for embedding
        df = pd.DataFrame({
            'text': questions,
            'id': range(len(questions))
        })
        
        # Build embeddings
        embeddings_path = os.path.join(self.temp_dir, "question_embeddings.parquet")
        embeddings_path, embeddings, _, _ = await build_embeddings_parquet(
            corpus_df=df,
            output_parquet=embeddings_path
        )
        
        # Clean up DataFrame
        del df
        gc.collect()
        
        self._log_memory_usage("after embedding")
        print(f"✅ Embedded {len(questions)} questions")
        return embeddings
    
    def cluster_questions(self, questions: List[str], embeddings: np.ndarray) -> List[QuestionCluster]:
        """Cluster questions using K-means and select representatives (following batch_qa pattern)"""
        from sklearn.cluster import KMeans
        
        n_questions = len(questions)
        n_clusters = max(1, n_questions // self.cluster_size)
        
        print(f"🎯 Clustering {n_questions} questions into {n_clusters} clusters (size ~{self.cluster_size})...")
        
        # Perform K-means clustering
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(embeddings)
        
        # Group questions by cluster
        clusters = []
        for cluster_id in range(n_clusters):
            # Find questions in this cluster
            cluster_indices = [i for i, label in enumerate(cluster_labels) if label == cluster_id]
            cluster_questions = [questions[i] for i in cluster_indices]
            cluster_embeddings = embeddings[cluster_indices]
            
            # Find representative (closest to cluster center)
            cluster_center = kmeans.cluster_centers_[cluster_id]
            distances = np.linalg.norm(cluster_embeddings - cluster_center, axis=1)
            representative_idx = np.argmin(distances)
            representative_question = cluster_questions[representative_idx]
            representative_global_idx = cluster_indices[representative_idx]
            
            cluster = QuestionCluster(
                cluster_id=cluster_id,
                representative_question=representative_question,
                representative_index=representative_global_idx,
                all_questions=cluster_questions,
                all_indices=cluster_indices,
                embeddings=cluster_embeddings,
                cluster_center=cluster_center
            )
            clusters.append(cluster)
            
            print(f"   📦 Cluster {cluster_id}: {len(cluster_questions)} questions")
            print(f"      Representative: {representative_question[:50]}...")
        
        return clusters
    
    async def _generate_codebooks_for_question(self, question: str) -> Dict[str, Any]:
        """Generate codebooks for a single representative question (processes ALL chunks)"""
        # Create temporary directory for this question
        question_temp_dir = os.path.join(self.temp_dir, f"question_{hash(question) % 10000}")
        os.makedirs(question_temp_dir, exist_ok=True)
        
        try:
            # Create mock arguments for build_corpus with the representative question
            # Capture pipeline attributes to avoid scope issues
            chunk_size = self.chunk_size
            overlap = self.overlap
            strategy = self.strategy
            # Use "32B" for build_corpus which will use VLLM_QWEN_32B_URL and VLLM_QWEN_32B_MODEL
            llm_model = "32B"  # This tells build_corpus to use the 32B LLM, not the embedding model
            
            class MockArgs:
                def __init__(self, q):
                    self.question = q
                    self.input = None  # Use default data
                    self.chunk_size = chunk_size
                    self.overlap = overlap
                    self.concurrency = 64  # Increased to 64 for higher throughput
                    self.strategy = strategy
                    self.model = llm_model  # Use LLM model for codebook generation
                    self.test_chunk_sizes = False
                    self.chunk_sizes = [chunk_size]
                    self.overlap_ratio = 0.2
                    self.seed_low = 5
                    self.seed_high = 10
                    self.seed = 42
                    self.out_corpus = os.path.join(question_temp_dir, "corpus.parquet")
                    self.test_concurrency = False
                    self.max_test_concurrency = 128
            
            # Build corpus for this representative question (processes ALL chunks)
            print(f"      🔄 Building corpus for question...")
            mock_args = MockArgs(question)
            corpus_df = await async_main(mock_args)
            
            # Extract codebooks from the corpus DataFrame
            # Save both the tags AND the chunk tracking information with chunk indexing
            if 'tag' in corpus_df.columns:
                # Extract relevant columns for chunk tracking
                codebooks = []
                chunks_list = []  # Separate list to store unique chunks
                chunk_to_index = {}  # Map chunk text to its index in chunks_list
                
                for _, row in corpus_df.iterrows():
                    if pd.notna(row['tag']) and row['tag'].strip():  # Skip empty/null tags
                        chunk_text = row.get('chunk_text', '')
                        
                        # Get or create chunk index
                        if chunk_text not in chunk_to_index:
                            chunk_to_index[chunk_text] = len(chunks_list)
                            chunks_list.append(chunk_text)
                        
                        chunk_idx = chunk_to_index[chunk_text]
                        
                        codebook_entry = {
                            'tag': row['tag'],
                            'chunk_index': chunk_idx,  # Reference to chunks_list
                            'data_chunk_index': int(row.get('chunk_index', 0)),  # Original chunk index from data file
                            'level': row.get('level', '')
                        }
                        codebooks.append(codebook_entry)
                
                print(f"      📝 Generated {len(codebooks)} codebooks from {len(chunks_list)} unique chunks (total {len(corpus_df)} processed)")
                return {'codebooks': codebooks, 'chunks': chunks_list}
            else:
                print(f"      ⚠️  No 'tag' column found in corpus, returning empty codebooks")
                return {'codebooks': [], 'chunks': []}
            
        except Exception as e:
            print(f"      ❌ Error in codebook generation: {str(e)}")
            raise e
            
        finally:
            # Clean up question-specific temp files
            if os.path.exists(question_temp_dir):
                shutil.rmtree(question_temp_dir)

    async def generate_codebooks_for_representative_questions(self, clusters: List[QuestionCluster]) -> Dict[str, Any]:
        """Generate codebooks for representative questions from each cluster"""
        print(f"🔧 Generating codebooks for {len(clusters)} representative questions...")
        
        cluster_codebooks = {}
        failed_clusters = []  # Track failed clusters for retry
        global_chunks = []  # Global list of all unique chunks across clusters
        global_chunk_to_index = {}  # Global mapping of chunk text to index
        
        for i, cluster in enumerate(clusters):
            cluster_id = cluster.cluster_id
            representative_question = cluster.representative_question
            
            print(f"   🔧 Processing cluster {cluster_id} ({i+1}/{len(clusters)}): {representative_question[:50]}...")
            
            try:
                # Generate codebooks for this representative question
                # This will process ALL chunks from the data directory
                codebooks_result = await self._generate_codebooks_for_question(representative_question)
                
                # Update global chunks list and remap indices
                cluster_codebooks_remapped = []
                for codebook in codebooks_result['codebooks']:
                    chunk_text = codebooks_result['chunks'][codebook['chunk_index']]
                    
                    # Add to global chunks if not already present
                    if chunk_text not in global_chunk_to_index:
                        global_chunk_to_index[chunk_text] = len(global_chunks)
                        global_chunks.append(chunk_text)
                    
                    # Remap to global index
                    codebook_remapped = codebook.copy()
                    codebook_remapped['chunk_index'] = global_chunk_to_index[chunk_text]
                    cluster_codebooks_remapped.append(codebook_remapped)
                
                cluster_codebooks[cluster_id] = cluster_codebooks_remapped
                print(f"   ✅ Cluster {cluster_id}: Generated {len(cluster_codebooks_remapped)} codebooks from {len(codebooks_result['chunks'])} chunks")
                
                # On first successful cluster, populate global_chunks since all clusters process the same chunks
                if len(global_chunks) == 0:  # First successful cluster
                    global_chunks.extend(codebooks_result['chunks'])
                    print(f"   📦 Populated global chunks list with {len(global_chunks)} unique chunks")
                    # Update the JSON file with global chunks
                    self._update_global_chunks_in_file(global_chunks)
                
                # Save this cluster incrementally
                self.save_cluster_incrementally(cluster, cluster_codebooks_remapped, status="success")
                
            except Exception as e:
                print(f"   ⚠️  Cluster {cluster_id} failed: {str(e)}")
                print(f"   💾 Saving cluster {cluster_id} with empty codebooks for retry later...")
                # Save failed cluster with empty codebooks
                cluster_codebooks[cluster_id] = []
                failed_clusters.append(cluster)
                
                # Save failed cluster incrementally
                self.save_cluster_incrementally(cluster, [], status="failed")
                continue
        
        # Retry failed clusters (up to 3 attempts)
        if failed_clusters:
            print(f"\n🔄 Retrying {len(failed_clusters)} failed clusters (up to 3 attempts)...")
            retry_successful = []
            max_retries = 3
            
            for attempt in range(max_retries):
                if not failed_clusters:
                    break
                    
                print(f"   🔄 Retry attempt {attempt + 1}/{max_retries} for {len(failed_clusters)} clusters...")
                current_failed = failed_clusters.copy()
                failed_clusters = []  # Reset for next attempt
                
                for i, cluster in enumerate(current_failed):
                    cluster_id = cluster.cluster_id
                    representative_question = cluster.representative_question
                    
                    print(f"      🔄 Retry {attempt + 1}.{i+1}: Cluster {cluster_id} - {representative_question[:50]}...")
                    
                    try:
                        # Wait a bit before retry to avoid rate limiting (longer wait for later attempts)
                        wait_time = 2 + (attempt * 3)  # 2s, 5s, 8s
                        await asyncio.sleep(wait_time)
                        
                        # Generate codebooks for this representative question
                        codebooks_result = await self._generate_codebooks_for_question(representative_question)
                        
                        # Update global chunks list and remap indices
                        cluster_codebooks_remapped = []
                        for codebook in codebooks_result['codebooks']:
                            chunk_text = codebooks_result['chunks'][codebook['chunk_index']]
                            
                            # Add to global chunks if not already present
                            if chunk_text not in global_chunk_to_index:
                                global_chunk_to_index[chunk_text] = len(global_chunks)
                                global_chunks.append(chunk_text)
                            
                            # Remap to global index
                            codebook_remapped = codebook.copy()
                            codebook_remapped['chunk_index'] = global_chunk_to_index[chunk_text]
                            cluster_codebooks_remapped.append(codebook_remapped)
                        
                        # Update only the representative_codebooks for this specific cluster
                        # All other cluster info (representative_question, all_questions, etc.) is preserved
                        cluster_codebooks[cluster_id] = cluster_codebooks_remapped
                        retry_successful.append(cluster_id)
                        print(f"      ✅ Retry {attempt + 1} successful for cluster {cluster_id}: Generated {len(cluster_codebooks_remapped)} codebooks from {len(codebooks_result['chunks'])} chunks")
                        
                        # Save successful retry incrementally
                        self.save_cluster_incrementally(cluster, cluster_codebooks_remapped, status=f"success_retry_{attempt + 1}")
                        
                    except Exception as e:
                        print(f"      ❌ Retry {attempt + 1} failed for cluster {cluster_id}: {str(e)}")
                        # Add back to failed list for next attempt (unless this was the last attempt)
                        if attempt < max_retries - 1:
                            failed_clusters.append(cluster)
                        else:
                            print(f"      💾 Final failure for cluster {cluster_id} - keeping empty codebooks")
                            # Keep empty codebooks for this cluster (already set to [] in first pass)
                            # Failed status was already saved in the initial pass
                
                if failed_clusters:
                    print(f"   ⚠️  {len(failed_clusters)} clusters still failed after attempt {attempt + 1}")
                else:
                    print(f"   🎉 All clusters succeeded after attempt {attempt + 1}")
                    break
            
            if retry_successful:
                print(f"   🎉 Retry summary: {len(retry_successful)} clusters recovered across {max_retries} attempts")
            if failed_clusters:
                print(f"   ⚠️  {len(failed_clusters)} clusters failed after all {max_retries} attempts - will have empty codebooks")
        
        return {
            'cluster_codebooks': cluster_codebooks,
            'global_chunks': global_chunks
        }
    
    def organize_cluster_results(self, clusters: List[QuestionCluster], 
                               cluster_codebooks: Dict[int, List[Dict[str, Any]]]) -> Dict[int, ClusterCodebook]:
        """Organize results by clusters"""
        print(f"📦 Organizing cluster results...")
        
        cluster_results = {}
        
        for cluster in clusters:
            cluster_id = cluster.cluster_id
            codebooks = cluster_codebooks.get(cluster_id, [])
            
            cluster_result = ClusterCodebook(
                cluster_id=cluster_id,
                cluster_size=len(cluster.all_questions),
                representative_question=cluster.representative_question,
                representative_codebooks=codebooks,
                all_questions=cluster.all_questions
            )
            
            cluster_results[cluster_id] = cluster_result
            
            print(f"   📦 Cluster {cluster_id}: {len(cluster.all_questions)} questions, {len(codebooks)} codebooks")
        
        return cluster_results
    
    async def run_codebook_pipeline(self) -> Dict[str, Any]:
        """Run the complete codebook generation pipeline"""
        print(f"🚀 Starting Codebook Generation Pipeline")
        print(f"   📊 Strategy: {self.strategy}")
        print(f"   🤖 Model: {self.model}")
        print(f"   📝 Chunk size: {self.chunk_size}")
        print(f"   🔗 Overlap: {self.overlap}")
        print(f"   📁 Temp directory: {self.temp_dir}")
        
        # Initialize variables
        questions = []
        question_embeddings = None
        clusters = []
        cluster_codebooks = {}
        cluster_results = {}
        
        try:
            # Step 1: Load questions from CSV
            print(f"\n📋 Step 1: Loading questions from eval_questions.csv...")
            try:
                questions = self.load_questions_from_csv("eval_questions.csv")
                print(f"   ✅ Loaded {len(questions)} questions")
            except Exception as e:
                print(f"   ❌ Failed to load questions: {str(e)}")
                return {"error": f"Question loading failed: {str(e)}"}
            
            # Step 2: Generate embeddings for questions
            print(f"\n🔍 Step 2: Generating embeddings for questions...")
            try:
                question_embeddings = await self.embed_questions(questions)
                print(f"   ✅ Generated embeddings for {len(question_embeddings)} questions")
            except Exception as e:
                print(f"   ❌ Failed to generate embeddings: {str(e)}")
                print(f"   ⚠️  Continuing with empty embeddings - all clusters will fail")
                question_embeddings = np.zeros((len(questions), 768))  # Default embedding size
            
            # Step 3: Cluster questions
            print(f"\n🎯 Step 3: Clustering questions...")
            try:
                clusters = self.cluster_questions(questions, question_embeddings)
                print(f"   ✅ Created {len(clusters)} clusters")
            except Exception as e:
                print(f"   ❌ Failed to cluster questions: {str(e)}")
                print(f"   ⚠️  Creating single cluster with all questions")
                # Create a single cluster with all questions as fallback
                # Create dummy embeddings and cluster center for fallback
                dummy_embeddings = np.zeros((len(questions), 768))  # Default embedding size
                dummy_cluster_center = np.zeros(768)
                
                clusters = [QuestionCluster(
                    cluster_id=0,
                    representative_question=questions[0] if questions else "No question",
                    representative_index=0,
                    all_questions=questions,
                    all_indices=list(range(len(questions))),
                    embeddings=dummy_embeddings,
                    cluster_center=dummy_cluster_center
                )]
            
            # Step 4: Generate codebooks for representative questions
            print(f"\n🔧 Step 4: Generating codebooks for representative questions...")
            try:
                result = await self.generate_codebooks_for_representative_questions(clusters)
                cluster_codebooks = result['cluster_codebooks']
                global_chunks = result['global_chunks']
                print(f"   ✅ Generated codebooks for {len(cluster_codebooks)} clusters")
            except Exception as e:
                print(f"   ❌ Failed to generate codebooks: {str(e)}")
                print(f"   ⚠️  Creating empty codebooks for all clusters")
                # Create empty codebooks for all clusters
                cluster_codebooks = {cluster.cluster_id: [] for cluster in clusters}
                global_chunks = [] # Ensure global_chunks is empty if generation fails
            
            # Step 5: Organize results
            print(f"\n📦 Step 5: Organizing results...")
            try:
                cluster_results = self.organize_cluster_results(clusters, cluster_codebooks)
                print(f"   ✅ Organized results for {len(cluster_results)} clusters")
            except Exception as e:
                print(f"   ❌ Failed to organize results: {str(e)}")
                print(f"   ⚠️  Creating basic cluster results")
                # Create basic cluster results
                cluster_results = {}
                for cluster in clusters:
                    cluster_results[cluster.cluster_id] = ClusterCodebook(
                        cluster_id=cluster.cluster_id,
                        cluster_size=len(cluster.all_questions),
                        representative_question=cluster.representative_question,
                        representative_codebooks=cluster_codebooks.get(cluster.cluster_id, []),
                        all_questions=cluster.all_questions
                    )
            
            # Step 6: Save results
            print(f"\n💾 Step 6: Saving results...")
            try:
                self.save_codebooks_by_cluster(cluster_results, global_chunks)
                print(f"   ✅ Results saved to {self.output_file}")
            except Exception as e:
                print(f"   ❌ Failed to save results: {str(e)}")
                print(f"   ⚠️  Pipeline completed but results could not be saved")
            
            # Calculate final statistics
            total_codebooks = sum(len(codebooks) for codebooks in cluster_codebooks.values())
            total_questions = sum(len(cluster.all_questions) for cluster in clusters)
            failed_clusters_count = sum(1 for codebooks in cluster_codebooks.values() if len(codebooks) == 0)
            successful_clusters_count = len(clusters) - failed_clusters_count
            
            print(f"\n🎉 Pipeline completed successfully!")
            print(f"   📊 Total clusters: {len(clusters)}")
            print(f"   ✅ Successful clusters: {successful_clusters_count}")
            print(f"   ❌ Failed clusters: {failed_clusters_count}")
            print(f"   📝 Total codebooks generated: {total_codebooks}")
            print(f"   ❓ Total questions processed: {total_questions}")
            print(f"   📁 Results saved to: {self.output_file}")
            
            if failed_clusters_count > 0:
                print(f"   ⚠️  Note: {failed_clusters_count} clusters have empty codebooks due to processing errors")
            
            return {
                "success": True,
                "total_clusters": len(clusters),
                "successful_clusters": successful_clusters_count,
                "failed_clusters": failed_clusters_count,
                "total_codebooks": total_codebooks,
                "total_questions": total_questions,
                "output_file": self.output_file
            }
            
        except Exception as e:
            print(f"\n💥 Pipeline failed with critical error: {str(e)}")
            print(f"   ⚠️  This should not happen - all steps have error handling")
            return {"error": f"Critical pipeline failure: {str(e)}"}

    def save_cluster_incrementally(self, cluster: QuestionCluster, codebooks: List[Dict[str, Any]], status: str = "success"):
        """
        Save a single cluster's results to separate file in generated_code_books folder
        Each cluster gets its own JSON file: generated_code_books/cluster_X.json
        """
        # Create cluster result
        cluster_result = ClusterCodebook(
            cluster_id=cluster.cluster_id,
            cluster_size=len(cluster.all_questions),
            representative_question=cluster.representative_question,
            representative_codebooks=codebooks,
            all_questions=cluster.all_questions
        )
        
        # Create generated_code_books directory
        codebooks_dir = os.path.join(self.temp_dir, "generated_code_books")
        os.makedirs(codebooks_dir, exist_ok=True)
        
        # Individual cluster file
        cluster_file = os.path.join(codebooks_dir, f"cluster_{cluster.cluster_id}.json")
        
        # Create cluster data structure
        cluster_data = {
            "metadata": {
                "cluster_id": cluster_result.cluster_id,
                "cluster_size": cluster_result.cluster_size,
    
                "codebooks_count": len(cluster_result.representative_codebooks),
                "status": status,
                "created": datetime.now().isoformat(),
                "last_updated": datetime.now().isoformat()
            },
            "cluster": {
                "cluster_id": cluster_result.cluster_id,
                "cluster_size": cluster_result.cluster_size,
                "representative_question": cluster_result.representative_question,
                "representative_codebooks": cluster_result.representative_codebooks,
                "all_questions": cluster_result.all_questions,
                "status": status
            }
        }
        
        # Save individual cluster file
        with open(cluster_file, 'w') as f:
            json.dump(cluster_data, f, indent=2, ensure_ascii=False)
        
        # Also maintain the combined file for compatibility
        self._update_combined_file(cluster, codebooks, status)
        
        print(f"   💾 Saved cluster {cluster.cluster_id} to {cluster_file} ({status})")

    def _update_combined_file(self, cluster: QuestionCluster, codebooks: List[Dict[str, Any]], status: str = "success"):
        """Update the combined file for compatibility"""
        # Load existing results if file exists
        try:
            if os.path.exists(self.output_file):
                with open(self.output_file, 'r') as f:
                    existing_data = json.load(f)
            else:
                existing_data = {
                    "metadata": {
                        "total_clusters": 0,
                        "successful_clusters": 0,
                        "failed_clusters": 0,
                        "total_codebooks_generated": 0,
                        "total_chunks": 0,
                        "last_updated": datetime.now().isoformat(),
                        "format": "codebooks_by_cluster_with_chunk_index"
                    },
                    "global_chunks": [],
                    "clusters": {}
                }
        except (json.JSONDecodeError, FileNotFoundError):
            existing_data = {
                "metadata": {
                    "total_clusters": 0,
                    "successful_clusters": 0,
                    "failed_clusters": 0,
                    "total_codebooks_generated": 0,
                    "total_chunks": 0,
                    "last_updated": datetime.now().isoformat(),
                    "format": "codebooks_by_cluster_with_chunk_index"
                },
                "global_chunks": [],
                "clusters": {}
            }
        
        # Update cluster data
        existing_data["clusters"][str(cluster.cluster_id)] = {
            "cluster_id": cluster.cluster_id,
            "cluster_size": len(cluster.all_questions),
            "representative_question": cluster.representative_question,
            "representative_codebooks": codebooks,
            "all_questions": cluster.all_questions,
            "status": status,
            "last_updated": datetime.now().isoformat()
        }
        
        # Update metadata
        existing_data["metadata"]["total_clusters"] = len(existing_data["clusters"])
        existing_data["metadata"]["successful_clusters"] = sum(1 for c in existing_data["clusters"].values() if c.get("status") == "success")
        existing_data["metadata"]["failed_clusters"] = sum(1 for c in existing_data["clusters"].values() if c.get("status") == "failed")
        existing_data["metadata"]["total_codebooks_generated"] = sum(len(c.get("representative_codebooks", [])) for c in existing_data["clusters"].values())
        existing_data["metadata"]["last_updated"] = datetime.now().isoformat()
        
        # Save combined file
        os.makedirs(os.path.dirname(self.output_file), exist_ok=True)
        with open(self.output_file, 'w') as f:
            json.dump(existing_data, f, indent=2, ensure_ascii=False)

    def save_codebooks_by_cluster(self, cluster_codebooks: Dict[int, ClusterCodebook], global_chunks: List[str]):
        """Save codebooks organized by cluster number"""
        print(f"💾 Saving codebooks by cluster to: {self.output_file}")
        
        # Calculate statistics
        total_clusters = len(cluster_codebooks)
        failed_clusters = sum(1 for cluster_data in cluster_codebooks.values() if len(cluster_data.representative_codebooks) == 0)
        successful_clusters = total_clusters - failed_clusters
        total_codebooks = sum(len(cluster_data.representative_codebooks) for cluster_data in cluster_codebooks.values())
        
        # Convert to serializable format
        output_data = {
            'metadata': {
                'total_clusters': total_clusters,
                'successful_clusters': successful_clusters,
                'failed_clusters': failed_clusters,
                'total_codebooks_generated': total_codebooks,
                'total_chunks': len(global_chunks),
                'timestamp': datetime.now().isoformat(),
                'format': 'codebooks_by_cluster_with_chunk_index'
            },
            'global_chunks': global_chunks,  # Separate list of all unique chunks
            'clusters': {}
        }
        
        for cluster_id, cluster_data in cluster_codebooks.items():
            output_data['clusters'][str(cluster_id)] = {
                'cluster_id': cluster_data.cluster_id,
                'cluster_size': cluster_data.cluster_size,
                'representative_question': cluster_data.representative_question,
                'representative_codebooks': cluster_data.representative_codebooks,
                'all_questions': cluster_data.all_questions,
                'num_codebooks': len(cluster_data.representative_codebooks),
                'status': 'success' if len(cluster_data.representative_codebooks) > 0 else 'failed'
            }
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(self.output_file), exist_ok=True)
        
        # Save to file
        with open(self.output_file, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        
        print(f"✅ Saved {total_clusters} clusters to {self.output_file}")
        print(f"   ✅ Successful: {successful_clusters}, ❌ Failed: {failed_clusters}, 📝 Total codebooks: {total_codebooks}, 📦 Total chunks: {len(global_chunks)}")

    def _update_global_chunks_in_file(self, global_chunks: List[str]):
        """Updates the global_chunks list in the output JSON file."""
        try:
            with open(self.output_file, 'r') as f:
                existing_data = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError):
            print(f"⚠️ Could not load existing data from {self.output_file} to update global_chunks.")
            return

        existing_data["global_chunks"] = global_chunks
        existing_data["metadata"]["total_chunks"] = len(global_chunks)
        existing_data["metadata"]["last_updated"] = datetime.now().isoformat()

        os.makedirs(os.path.dirname(self.output_file), exist_ok=True)
        with open(self.output_file, 'w') as f:
            json.dump(existing_data, f, indent=2, ensure_ascii=False)
        print(f"   📦 Updated global_chunks in {self.output_file} to {len(global_chunks)} unique chunks.")

async def main():
    """Main entry point"""
    parser = argparse.ArgumentParser(
        description="Codebook Generation Pipeline (Batch QA Pattern with Questions)",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Generate codebooks with default settings (5 questions per cluster)
  python codebook_generation_pipeline.py --output codebooks_by_cluster.json
  
  # Custom cluster size
  python codebook_generation_pipeline.py --cluster-size 10 --output codebooks.json
  
  # Custom chunk size and overlap
  python codebook_generation_pipeline.py --chunk-size 512 --overlap 100 --output codebooks.json
  
  # Different strategy
  python codebook_generation_pipeline.py --strategy strategy_2 --output codebooks.json
        """
    )
    
    parser.add_argument("--output", required=True, 
                       help="Output JSON file path for codebooks by cluster")
    parser.add_argument("--chunk-size", type=int, default=256,
                       help="Chunk size for text processing (default: 256)")
    parser.add_argument("--overlap", type=int, default=50,
                       help="Overlap between chunks (default: 50)")
    parser.add_argument("--cluster-size", type=int, default=5,
                       help="Number of questions per cluster (default: 5)")
    parser.add_argument("--strategy", choices=["strategy_1", "strategy_2"], default="strategy_1",
                       help="Corpus generation strategy (default: strategy_1)")
    parser.add_argument("--model", choices=["32B", "30B-A3B"], default="32B",
                       help="Model to use for processing (default: 32B)")
    
    args = parser.parse_args()
    
    # Create and run pipeline
    pipeline = CodebookGenerationPipeline(
        chunk_size=args.chunk_size,
        overlap=args.overlap,
        cluster_size=args.cluster_size,
        strategy=args.strategy,
        model=args.model
    )
    
    result = await pipeline.run_codebook_pipeline()
    
    # Save results
    pipeline.save_codebooks_by_cluster(result.cluster_codebooks)
    
    # Display results
    print("\n" + "=" * 60)
    print("📊 CODEBOOK GENERATION RESULTS")
    print("=" * 60)
    print(f"Total Questions: {result.total_questions}")
    print(f"Total Clusters: {result.total_clusters}")
    print(f"Total Codebooks Generated: {result.summary['total_codebooks_generated']}")
    print(f"Total Processing Time: {result.processing_time:.2f}s")
    print(f"Average Time per Cluster: {result.summary['avg_time_per_cluster']:.2f}s")
    print(f"Average Questions per Cluster: {result.summary['avg_questions_per_cluster']:.1f}")
    print(f"Average Codebooks per Cluster: {result.summary['avg_codebooks_per_cluster']:.1f}")
    
    # Display cluster details
    print(f"\n📋 CLUSTER DETAILS:")
    print("=" * 60)
    for cluster_id, cluster_data in result.cluster_codebooks.items():
        print(f"Cluster {cluster_id}: {cluster_data.cluster_size} questions, {len(cluster_data.representative_codebooks)} codebooks")
        print(f"   Representative: {cluster_data.representative_question[:80]}...")
        if len(cluster_data.representative_codebooks) <= 5:
            print(f"   Codebooks: {', '.join(cluster_data.representative_codebooks)}")
        else:
            print(f"   Codebooks: {len(cluster_data.representative_codebooks)} codebooks (showing first 5: {', '.join(cluster_data.representative_codebooks[:5])}...)")
        print()
    
    print(f"\n💾 Results saved to: {args.output}")

if __name__ == "__main__":
    asyncio.run(main()) 