#!/usr/bin/env python3
"""
Multi-Iteration Schema Induction Pipeline - With Real Data Integration

This pipeline supports multiple iterations of schema induction with refinement:
- Iteration 1: Build Corpus + Schema Induction (no refinement)
- Iterations 2 to N: Refine Retrieval + Schema Induction (no build corpus)

Key features:
- Uses actual data processing instead of fake data
- Integrates with build_corpus for real LLM-generated codebooks
- Uses proper refine retrieval from llm_code_selector.py
- Iteration-specific temp folders for clean separation
- Configurable number of iterations (default: 2)
- Simplified workflow with clear separation of concerns
- Modular design for easy extension
"""

import os
import sys
import time
import shutil
import pandas as pd
from typing import List, Dict, Any, Optional
from pathlib import Path
from datetime import datetime

class MultiIterationSchemaInduction:
    """
    Multi-iteration schema induction pipeline with real data processing
    """
    
    def __init__(self, 
                 base_temp_dir: str = "temp_files",
                 max_iterations: int = 2,
                 model_url: str = None,
                 chunk_size: int = 2048,
                 overlap: int = 200,
                 strategy: str = "strategy_1",
                 model: str = "32B",
                 similarity_threshold: float = 0.0,
                 custom_data_path: Optional[str] = None):
        """
        Initialize the multi-iteration pipeline
        
        Args:
            base_temp_dir: Base directory for temporary files
            max_iterations: Maximum number of iterations (default: 2)
            model_url: URL for the VLLM model
            chunk_size: Size of text chunks for processing
            overlap: Overlap between chunks
            strategy: Strategy for code generation
            model: Model to use for LLM processing
            custom_data_path: Optional custom data file path
        """
        self.base_temp_dir = base_temp_dir
        # Set default base_temp_dir to main pipeline temp_files if not specified
        if base_temp_dir is None:
            # Get the directory where this script is located (main_pipeline/utils)
            script_dir = os.path.dirname(os.path.abspath(__file__))
            main_pipeline_dir = os.path.dirname(script_dir)
            base_temp_dir = os.path.join(main_pipeline_dir, "temp_files")
        self.base_temp_dir = base_temp_dir
        self.max_iterations = max_iterations
        self.model_url = model_url
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.strategy = strategy
        self.model = model
        self.similarity_threshold = similarity_threshold
        self.custom_data_path = custom_data_path
        
        # Initialize iteration-specific directories
        self.iteration_dirs = {}
        for i in range(1, max_iterations + 1):
            self.iteration_dirs[i] = os.path.join(base_temp_dir, f"iteration_{i:02d}")
        
        # Initialize modules
        self.refinement_pipeline = None
    
    def cleanup_all_iteration_folders(self):
        """Clean up all existing iteration folders to prevent confusion"""
        if not os.path.exists(self.base_temp_dir):
            return
            
        # Find all iteration folders
        for item in os.listdir(self.base_temp_dir):
            if item.startswith("iteration_"):
                folder_path = os.path.join(self.base_temp_dir, item)
                if os.path.isdir(folder_path):
                    print(f"   🗑️  Removing {item}")
                    shutil.rmtree(folder_path)
        
        print(f"   ✅ Cleaned up all iteration folders in {self.base_temp_dir}")
        
    def setup_iteration_directory(self, iteration: int) -> str:
        """
        Setup directory for a specific iteration
        
        Args:
            iteration: Iteration number (1-based)
            
        Returns:
            Path to iteration directory
        """
        iteration_dir = self.iteration_dirs[iteration]
        
        # Clean and create iteration directory
        if os.path.exists(iteration_dir):
            shutil.rmtree(iteration_dir)
        os.makedirs(iteration_dir, exist_ok=True)
        
        # Create subdirectories
        subdirs = [
            "build_corpus",
            "schema_induction", 
            "refinement",
            "embeddings",
            "clusters",
        ]
        
        for subdir in subdirs:
            os.makedirs(os.path.join(iteration_dir, subdir), exist_ok=True)
        
        print(f"📁 Setup iteration {iteration} directory: {iteration_dir}")
        return iteration_dir
    
    async def run_iteration_1(self, question: str) -> Dict[str, Any]:
        """
        Run iteration 1: Build Corpus + Schema Induction (no refinement)
        
        Args:
            question: Research question
            
        Returns:
            Results from iteration 1
        """
        print(f"\n🔄 Starting Iteration 1: Build Corpus + Schema Induction")
        print("=" * 80)
        
        iteration_dir = self.setup_iteration_directory(1)
        
        # Step 1: Build Corpus (only in iteration 1)
        print("\n📚 Step 1: Building Corpus...")
        corpus_result = await self.build_corpus_iteration_1(iteration_dir, question)
        
        # Step 2: Schema Induction
        print("\n🧠 Step 2: Schema Induction...")
        schema_result = await self.run_schema_induction_iteration_1(iteration_dir, corpus_result["corpus_df"])
        
        # No refinement step for iteration 1
        
        # Combine results
        iteration_1_result = {
            "iteration": 1,
            "corpus_result": corpus_result,
            "schema_result": schema_result,
            "iteration_dir": iteration_dir
        }
        
        print(f"\n✅ Iteration 1 completed successfully!")
        return iteration_1_result
    
    async def build_corpus_iteration_1(self, iteration_dir: str, question: str) -> Dict[str, Any]:
        """
        Build corpus for iteration 1 using actual build_corpus functionality
        
        Args:
            iteration_dir: Directory for this iteration
            question: Research question
            
        Returns:
            Corpus building results
        """
        build_corpus_dir = os.path.join(iteration_dir, "build_corpus")
        corpus_path = os.path.join(build_corpus_dir, "corpus.parquet")
        
        # Create mock arguments for build_corpus
        class MockArgs:
            def __init__(self, question, custom_data_path, chunk_size, overlap, strategy, model, corpus_path, similarity_threshold):
                self.question = question
                self.input = custom_data_path
                self.chunk_size = chunk_size
                self.overlap = overlap
                self.concurrency = 32
                self.strategy = strategy
                self.model = model
                self.similarity_threshold = similarity_threshold
                self.test_chunk_sizes = False
                self.chunk_sizes = [256]
                self.overlap_ratio = 0.2
                self.seed_low = 5
                self.seed_high = 10
                self.seed = 42
                self.out_corpus = corpus_path
                self.similarity_threshold = similarity_threshold
                self.test_concurrency = False
                self.max_test_concurrency = 128
        
        # Create MockArgs instance
        mock_args = MockArgs(question, self.custom_data_path, self.chunk_size, self.overlap, self.strategy, self.model, corpus_path, self.similarity_threshold)
        
        try:
            # Import and use build_corpus to generate actual corpus
            from .initial_iteration.build_corpus import async_main
            
            print(f"   📝 Using build_corpus with question: {question}")
            if self.custom_data_path:
                print(f"   📚 Custom data path: {self.custom_data_path}")
            
            # Run the actual build_corpus pipeline
            corpus_df = await async_main(mock_args)
            
            print(f"   ✅ Built actual corpus: {len(corpus_df)} records")
            
            return {
                "corpus_path": corpus_path,
                "chunks_count": len(corpus_df),
                "build_corpus_dir": build_corpus_dir,
                "corpus_df": corpus_df
            }
            
        except Exception as e:
            print(f"   ❌ Error building corpus: {e}")
            # Fallback to minimal corpus for testing
            corpus_data = [{
                'source_path': "fallback",
                'chunk_index': 0,
                'level': "low-level",
                'tag': "fallback_code",
                'chunk_text': "Fallback text for testing"
            }]
            
            corpus_df = pd.DataFrame(corpus_data)
            corpus_df.to_parquet(corpus_path, index=False)
            
            return {
                "corpus_path": corpus_path,
                "chunks_count": len(corpus_data),
                "build_corpus_dir": build_corpus_dir,
                "corpus_df": corpus_df
            }
    
    async def run_schema_induction_iteration_1(self, iteration_dir: str, corpus_df: pd.DataFrame) -> Dict[str, Any]:
        """
        Run schema induction for iteration 1
        
        Args:
            iteration_dir: Directory for this iteration
            corpus_df: Corpus DataFrame
            
        Returns:
            Schema induction results
        """
        # Import the cleaned up schema induction pipeline
        from .initial_iteration.schema_induction_pipeline import SchemaInductionPipeline
        
        # Create schema induction pipeline with iteration-specific temp directory
        schema_pipeline = SchemaInductionPipeline(
            min_frequency=1,
            min_frequency_ratio=0.6,
            temp_files_dir=iteration_dir,
            iteration_dir=iteration_dir
        )
        
        # Run the schema induction pipeline
        schema_result = await schema_pipeline.run_pipeline(corpus_df)
        
        return {
            "iteration": 1,
            "status": "schema_induction_completed",
            "iteration_dir": iteration_dir,
            "result": schema_result
        }
    
    async def run_subsequent_iterations(self, start_iteration: int, end_iteration: int) -> List[Dict[str, Any]]:
        """
        Run iterations 2 to N: Refine Retrieval + Schema Induction (no build corpus)
        
        Args:
            start_iteration: Starting iteration number
            end_iteration: Ending iteration number (inclusive)
            
        Returns:
            List of results from subsequent iterations
        """
        results = []
        
        for iteration in range(start_iteration, end_iteration + 1):
            print(f"\n🔄 Starting Iteration {iteration}: Refine Retrieval + Schema Induction")
            print("=" * 80)
            
            iteration_dir = self.setup_iteration_directory(iteration)
            
            # Step 1: Refine Retrieval (using results from previous iteration)
            print(f"\n🔍 Step 1: Refine Retrieval...")
            refinement_result = await self.run_refinement_iteration(iteration, iteration_dir)
            
            # Step 2: Schema Induction (using refinement results as input)
            print(f"\n🧠 Step 2: Schema Induction...")
            schema_result = await self.run_schema_induction_from_refinement(
                iteration, iteration_dir, refinement_result
            )
            
            # Combine results
            iteration_result = {
                "iteration": iteration,
                "refinement_result": refinement_result,
                "schema_result": schema_result,
                "iteration_dir": iteration_dir
            }
            
            results.append(iteration_result)
            print(f"\n✅ Iteration {iteration} completed successfully!")
            
            # Save final corpus if this is the last iteration
            if iteration == end_iteration:
                print(f"\n💾 Saving final corpus for iteration {iteration}...")
                await self.save_final_corpus(iteration_dir, schema_result)
                        
        return results
    
    async def run_refinement_iteration(self, iteration: int, iteration_dir: str) -> Dict[str, Any]:
        """
        Run refinement iteration using actual RefinementPipeline from llm_code_selector.py
        
        Args:
            iteration: Iteration number
            iteration_dir: Directory for this iteration
            
        Returns:
            Refinement results
        """
        refinement_dir = os.path.join(iteration_dir, "refinement")
        
        # Get data from previous iteration
        prev_iteration_dir = self.iteration_dirs[iteration - 1]
        
        # Paths to previous iteration data
        embeddings_path = os.path.join(prev_iteration_dir, "embeddings", "embeddings.parquet")
        mapping_dir = os.path.join(prev_iteration_dir, "topologically_sorted_graph", "datapoint_code_mapping")
        cliques_dir = os.path.join(prev_iteration_dir, "topologically_sorted_graph", "cliques")
        corpus_path = os.path.join(prev_iteration_dir, "build_corpus", "corpus.parquet")
        
        # Check if required files exist
        required_files = [embeddings_path, corpus_path]
        for file_path in required_files:
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Required file not found: {file_path}")
        
        if not os.path.exists(mapping_dir):
            raise FileNotFoundError(f"Required directory not found: {mapping_dir}")
        
        if not os.path.exists(cliques_dir):
            raise FileNotFoundError(f"Required directory not found: {cliques_dir}")
        
        print(f"   📊 Using data from previous iteration: {prev_iteration_dir}")
        print(f"   📚 Embeddings: {embeddings_path}")
        print(f"   📋 Mapping dir: {mapping_dir}")
        print(f"   🔗 Cliques dir: {cliques_dir}")
        
        try:
            # Add the refine_iteration directory to Python path for proper imports
            refine_iteration_dir = os.path.join(os.path.dirname(__file__), 'refine_iteration')
            if refine_iteration_dir not in sys.path:
                sys.path.insert(0, refine_iteration_dir)
            
            # Import the actual RefinementPipeline with proper path handling
            from .refine_iteration.llm_code_selector import RefinementPipeline
            
            # Initialize the refinement pipeline
            refinement_pipeline = RefinementPipeline(
                embeddings_path=embeddings_path,
                mapping_dir=mapping_dir,
                cliques_dir=cliques_dir,
                model_url=self.model_url or os.environ.get('VLLM_QWEN_32B_URL'),
                max_concurrency=64  # Increased for speed
            )
            
            # Get unique datapoints from previous iteration
            corpus_df = pd.read_parquet(corpus_path)
            # Use chunk_index as datapoint ID for consistency with mapping
            unique_datapoints = corpus_df["chunk_text"].unique().tolist()
            
            print(f"   📝 Found {len(unique_datapoints)} unique datapoints to refine")
            
            # Get previous codes from the corpus
            previous_codes = corpus_df['tag'].unique().tolist()
            print(f"   📊 Using {len(previous_codes)} previous codes as context")
            
            # Process datapoints through the refinement pipeline
            print(f"   🚀 Running RefinementPipeline...")
            refinement_results = refinement_pipeline.process_datapoints_batch(
                datapoint_ids=unique_datapoints,
                previous_codes=previous_codes,
                corpus_path=corpus_path
            )
            
            # Save refinement results
            output_path = os.path.join(refinement_dir, "refined_codes.parquet")
            refinement_pipeline.save_results(refinement_results, output_path)
            
            # Convert results to corpus format for schema induction
            corpus_data = []
            total_selected_codes = 0
            
            for result in refinement_results:
                if result['success'] and result['selected_codes']:
                    for code in result['selected_codes']:
                        corpus_data.append({
                            "source_path": f"refinement_{iteration}",
                            "chunk_index": len(corpus_data),
                            "level": "low-level",  # Default level for refined codes
                            "tag": code,
                            "chunk_text": result['datapoint_text']
                        })
                        total_selected_codes += 1
            
            # Save corpus file for schema induction
            corpus_df = pd.DataFrame(corpus_data)
            corpus_path_output = os.path.join(iteration_dir, "build_corpus", "corpus.parquet")
            corpus_df.to_parquet(corpus_path_output, index=False)
            print(f"   💾 Saved corpus file: {corpus_path_output}")
            
            # Calculate statistics
            successful_datapoints = sum(1 for r in refinement_results if r['success'])
            avg_codes_per_datapoint = total_selected_codes / len(unique_datapoints) if unique_datapoints else 0
            
            print(f"   📊 Refinement completed:")
            print(f"      • Datapoints processed: {len(unique_datapoints)}")
            print(f"      • Successful datapoints: {successful_datapoints}")
            print(f"      • Total selected codes: {total_selected_codes}")
            print(f"      • Average codes per datapoint: {avg_codes_per_datapoint:.1f}")
            
            return {
                "datapoints_processed": len(unique_datapoints),
                "successful_datapoints": successful_datapoints,
                "output_path": output_path,
                "total_pairs": total_selected_codes,
                "refinement_data": refinement_results,
                "avg_codes_per_datapoint": avg_codes_per_datapoint,
                "corpus_path": corpus_path_output
            }
            
        except Exception as e:
            print(f"   ❌ Error in refinement pipeline: {e}")
            import traceback
            traceback.print_exc()
            
            # Fallback to simple importance-based selection
            print(f"   🔄 Falling back to simple importance-based selection...")
            return self._fallback_refinement_iteration(iteration, iteration_dir)
    
    def _fallback_refinement_iteration(self, iteration: int, iteration_dir: str) -> Dict[str, Any]:
        """
        Fallback refinement using simple importance scoring (original implementation)
        
        Args:
            iteration: Iteration number
            iteration_dir: Directory for this iteration
            
        Returns:
            Refinement results
        """
        refinement_dir = os.path.join(iteration_dir, "refinement")
        
        # Get data from previous iteration
        prev_iteration_dir = self.iteration_dirs[iteration - 1]
        prev_mapping_path = os.path.join(prev_iteration_dir, "topologically_sorted_graph", "code_datapoints_enhanced.parquet")
        
        if not os.path.exists(prev_mapping_path):
            raise FileNotFoundError(f"Previous iteration data not found: {prev_mapping_path}")
        
        print(f"   📊 Loading data from previous iteration: {prev_mapping_path}")
        
        # Load the actual code-datapoint mappings from previous iteration
        prev_mapping_df = pd.read_parquet(prev_mapping_path)
        print(f"   📋 Loaded {len(prev_mapping_df)} code-datapoint mappings")
        
        # Get unique datapoints and their associated codes
        unique_datapoints = prev_mapping_df['datapoint'].unique()
        print(f"   📝 Found {len(unique_datapoints)} unique datapoints")
        
        # Create refinement data by selecting top codes for each datapoint
        # Prioritize codes with higher incoming edges and merge scores
        refinement_data = []
        
        for datapoint in unique_datapoints:
            # Get all codes for this datapoint
            datapoint_codes = prev_mapping_df[prev_mapping_df['datapoint'] == datapoint].copy()
            
            # Sort by importance (incoming_edges + merge_score)
            datapoint_codes['importance_score'] = (
                datapoint_codes['incoming_edges'] * 0.7 + 
                datapoint_codes['merge_score'] * 0.3
            )
            datapoint_codes = datapoint_codes.sort_values('importance_score', ascending=False)
            
            # Select top codes (limit to avoid too many codes)
            top_codes = datapoint_codes.head(15)  # Top 15 codes per datapoint
            
            for _, code_row in top_codes.iterrows():
                refinement_data.append({
                    'datapoint': datapoint,
                    'code': code_row['code'],
                    'datapoint_text': datapoint,  # Use the full datapoint text
                    'score': code_row['importance_score'],
                    'incoming_edges': code_row['incoming_edges'],
                    'merge_score': code_row['merge_score'],
                    'level': code_row['level']
                })
        
        refinement_df = pd.DataFrame(refinement_data)
        output_path = os.path.join(refinement_dir, "refined_codes.parquet")
        refinement_df.to_parquet(output_path, index=False)
        
        print(f"   📊 Created refinement data: {len(refinement_data)} datapoint-code pairs")
        print(f"   📈 Average codes per datapoint: {len(refinement_data) / len(unique_datapoints):.1f}")
        
        # Save corpus file for schema induction
        corpus_data = []
        for _, row in refinement_df.iterrows():
            corpus_data.append({
                "source_path": f"refinement_{iteration}",
                "chunk_index": len(corpus_data),
                "level": row.get("level", "low-level"),
                "tag": row["code"],
                "chunk_text": row["datapoint_text"]
            })
        
        corpus_df = pd.DataFrame(corpus_data)
        corpus_path = os.path.join(iteration_dir, "build_corpus", "corpus.parquet")
        corpus_df.to_parquet(corpus_path, index=False)
        print(f"   💾 Saved corpus file: {corpus_path}")
        
        return {
            "datapoints_processed": len(unique_datapoints),
            "output_path": output_path,
            "total_pairs": len(refinement_data),
            "refinement_data": refinement_df,
            "avg_codes_per_datapoint": len(refinement_data) / len(unique_datapoints),
            "corpus_path": corpus_path
        }
    
    async def run_schema_induction_from_refinement(self, iteration: int, iteration_dir: str, refinement_result: Dict[str, Any]) -> Dict[str, Any]:
        """
        Run schema induction using refinement data as input
        
        Args:
            iteration: Iteration number
            iteration_dir: Directory for this iteration
            refinement_result: Previous refinement results
            
        Returns:
            Schema induction results
        """
        # Import the cleaned up schema induction pipeline
        from .initial_iteration.schema_induction_pipeline import SchemaInductionPipeline
        
        # Load corpus from refinement result
        corpus_path = refinement_result.get('corpus_path')
        if corpus_path and os.path.exists(corpus_path):
            corpus_df = pd.read_parquet(corpus_path)
        else:
            # Fallback: convert refinement data to corpus format
            refinement_df = refinement_result['refinement_data']
            corpus_data = []
            
            for _, row in refinement_df.iterrows():
                corpus_data.append({
                    'source_path': f"refinement_{iteration}",
                    'chunk_index': len(corpus_data),
                    'level': row.get('level', 'low-level'),
                    'tag': row['code'],
                    'chunk_text': row['datapoint_text']
                })
            
            corpus_df = pd.DataFrame(corpus_data)

        print(f"   🔄 Deduplicating refinement data based on 'tag' column...")
        original_count = len(corpus_df)
        corpus_df = corpus_df.drop_duplicates(subset=['tag'], keep='first')
        deduplicated_count = len(corpus_df)
        print(f"   📊 Deduplicated: {original_count} -> {deduplicated_count} rows ({original_count - deduplicated_count} duplicates removed)") 
        
        # Create schema induction pipeline with iteration-specific temp directory
        schema_pipeline = SchemaInductionPipeline(
            min_frequency=1,
            min_frequency_ratio=0.6,
            temp_files_dir=iteration_dir,
            iteration_dir=iteration_dir
        )
        
        # Run the schema induction pipeline
        schema_result = await schema_pipeline.run_pipeline(corpus_df)
        
        return {
            "iteration": iteration,
            "input_data": refinement_result["total_pairs"],
            "status": "schema_induction_completed",
            "result": schema_result
        }
    
    async def save_final_corpus(self, iteration_dir: str, schema_result: Dict[str, Any]):
        """Save the final refined corpus using topological sort results (simplified format with only 'tag' column)"""
        import json
        
        print(f"🔄 Building final corpus from topological sort results (simplified format)...")
        
        # Create final_corpus directory
        final_corpus_dir = os.path.join(iteration_dir, "final_corpus")
        os.makedirs(final_corpus_dir, exist_ok=True)
        
        # Step 1: Try to load topological sort results (preferred method)
        topological_sort_path = os.path.join(iteration_dir, "topologically_sorted_graph", "topological_sort.parquet")
        corpus_df = None
        corpus_source = None
        topological_sort_used = False
        
        if os.path.exists(topological_sort_path):
            print(f"🌳 USING TOPOLOGICAL SORT RESULTS:")
            print(f"✅ Found topological sort: {topological_sort_path}")
            
            try:
                topological_sort_df = pd.read_parquet(topological_sort_path)
                print(f"   - Loaded {len(topological_sort_df)} topological sort records")
                print(f"   - Columns: {list(topological_sort_df.columns)}")
                
                # Extract unique codes from topological sort
                if 'node' in topological_sort_df.columns:
                    unique_codes = topological_sort_df['node'].unique()
                    print(f"   - Found {len(unique_codes)} unique refined codes")
                    
                    # Create simple corpus with only 'tag' column
                    corpus_df = pd.DataFrame({
                        'tag': unique_codes
                    })
                    corpus_source = 'topological_sort_refined_codes'
                    topological_sort_used = True
                    
                    print(f"   🎯 Created final corpus with {len(corpus_df)} refined codes from topological sort")
                    
                else:
                    print(f"   ❌ No 'node' column found in topological sort. Columns: {list(topological_sort_df.columns)}")
                    corpus_df = None
                    
            except Exception as e:
                print(f"   ❌ Error loading topological sort: {e}")
                corpus_df = None
        else:
            print(f"⚠️  Topological sort not found: {topological_sort_path}")
        
        # Step 2: Fallback to original corpus files if topological sort failed
        if corpus_df is None:
            print(f"\n🔄 FALLBACK: USING ORIGINAL CORPUS FILES:")
            
            enhanced_corpus_path = os.path.join(iteration_dir, "high_level_codes", "enhanced_corpus.parquet")
            build_corpus_path = os.path.join(iteration_dir, "build_corpus", "corpus.parquet")
            
            if os.path.exists(enhanced_corpus_path):
                print(f"✅ Enhanced corpus: {enhanced_corpus_path}")
                original_corpus = pd.read_parquet(enhanced_corpus_path)
                corpus_source = 'enhanced_corpus_fallback'
                print(f"   - Rows: {len(original_corpus)}")
                print(f"   - Columns: {list(original_corpus.columns)}")
            elif os.path.exists(build_corpus_path):
                print(f"✅ Build corpus: {build_corpus_path}")
                original_corpus = pd.read_parquet(build_corpus_path)
                corpus_source = 'build_corpus_fallback'
                print(f"   - Rows: {len(original_corpus)}")
                print(f"   - Columns: {list(original_corpus.columns)}")
            else:
                print(f"❌ No corpus files found in {iteration_dir}")
                print("   Expected files:")
                print(f"   - {topological_sort_path}")
                print(f"   - {enhanced_corpus_path}")
                print(f"   - {build_corpus_path}")
                return
            
            # Extract only unique tags for fallback
            if 'tag' in original_corpus.columns:
                unique_tags = original_corpus['tag'].unique()
                corpus_df = pd.DataFrame({
                    'tag': unique_tags
                })
                print(f"   🎯 Extracted {len(corpus_df)} unique tags from fallback corpus")
            else:
                print(f"   ❌ No 'tag' column found in corpus. Columns: {list(original_corpus.columns)}")
                return
        
        # Step 3: Save final corpus (only 'tag' column)
        final_corpus_path = os.path.join(final_corpus_dir, "final_corpus.parquet")
        corpus_df.to_parquet(final_corpus_path, index=False)
        print(f"💾 Saved final corpus: {final_corpus_path}")
        
        # Step 4: Create summary
        summary = {
            'total_codes': len(corpus_df),
            'unique_codes': corpus_df['tag'].nunique(),
            'iteration_dir': iteration_dir,
            'corpus_source': corpus_source,
            'topological_sort_used': topological_sort_used,
            'timestamp': pd.Timestamp.now().isoformat(),
            'columns': list(corpus_df.columns),
            'refinement_info': {
                'used_refined_codes': topological_sort_used,
                'source_file': topological_sort_path if topological_sort_used else (enhanced_corpus_path if 'enhanced' in corpus_source else build_corpus_path)
            }
        }
        
        summary_path = os.path.join(final_corpus_dir, "final_corpus_summary.json")
        with open(summary_path, "w") as f:
            json.dump(summary, f, indent=2)
        
        print(f"📊 Created summary: {summary_path}")
        print()
        
        print("✅ FINAL CORPUS GENERATION COMPLETE!")
        print(f"   📁 Directory: {final_corpus_dir}")
        print(f"   📄 Corpus file: {final_corpus_path}")
        print(f"   📊 Total codes: {summary['total_codes']}")
        print(f"   📊 Unique codes: {summary['unique_codes']}")
        print(f"   📋 Source: {corpus_source}")
        print(f"   🌳 Used topological sort: {topological_sort_used}")
        print(f"   📋 Columns: {summary['columns']}")
        
        # Step 5: Create hierarchical tree for inference testing
        print(f"\n🌳 Creating hierarchical tree for inference testing...")
        try:
            from .hierarchical_tree_creator import HierarchicalTreeCreator
            
            tree_creator = HierarchicalTreeCreator(iteration_dir)
            tree_result = tree_creator.create_hierarchical_tree()
            
            print(f"✅ Hierarchical tree created successfully!")
            print(f"   �� Saved files:")
            for file_path in tree_result["saved_files"]:
                print(f"      - {os.path.basename(file_path)}")
            
            # Add tree info to summary
            summary["hierarchical_tree"] = {
                "created": True,
                "total_nodes": tree_result["hierarchical_tree"]["metadata"]["total_nodes"],
                "levels": tree_result["hierarchical_tree"]["metadata"]["levels"],
                "saved_files": [os.path.basename(f) for f in tree_result["saved_files"]]
            }
            
            # Update summary file with tree info
            with open(summary_path, "w") as f:
                json.dump(summary, f, indent=2)
                
        except Exception as e:
            print(f"⚠️  Warning: Failed to create hierarchical tree: {e}")
            import traceback
            traceback.print_exc()
            
            # Add error info to summary
            summary["hierarchical_tree"] = {
                "created": False,
                "error": str(e)
            }
            
            # Update summary file with error info
            with open(summary_path, "w") as f:
                json.dump(summary, f, indent=2)

    async def run_full_pipeline(self, question: str) -> Dict[str, Any]:
        """
        Run the complete multi-iteration pipeline
        
        Args:
            question: Research question
            
        Returns:
            Complete pipeline results
        """
        # Clean up all existing iteration folders to prevent confusion
        print("🧹 Cleaning up existing iteration folders...")
        self.cleanup_all_iteration_folders()
        print("✅ All existing iteration folders cleaned up")
        print("=" * 80)
        print(f"🚀 Starting Multi-Iteration Schema Induction Pipeline")
        print(f"   Question: {question}")
        print(f"   Max iterations: {self.max_iterations}")
        print(f"   Base temp directory: {self.base_temp_dir}")
        print(f"   Chunk size: {self.chunk_size}")
        print(f"   Strategy: {self.strategy}")
        print(f"   Model: {self.model}")
        if self.custom_data_path:
            print(f"   Custom data: {self.custom_data_path}")
        print("=" * 80)
        
        start_time = time.time()
        all_results = {}
        
        try:
            # Iteration 1: Build Corpus + Schema Induction (no refinement)
            iteration_1_result = await self.run_iteration_1(question)
            all_results[1] = iteration_1_result
            
            # Iterations 2 to N: Refine Retrieval + Schema Induction (no build corpus)
            if self.max_iterations > 1:
                subsequent_results = await self.run_subsequent_iterations(2, self.max_iterations)
                for result in subsequent_results:
                    all_results[result["iteration"]] = result
            
            total_time = time.time() - start_time
            
            # Summary
            print(f"\n🎉 Multi-Iteration Pipeline Completed Successfully!")
            print(f"   Total iterations: {self.max_iterations}")
            print(f"   Total time: {total_time:.2f} seconds")
            print(f"   Results saved in: {self.base_temp_dir}")
            
            return {
                "success": True,
                "total_iterations": self.max_iterations,
                "total_time": total_time,
                "results": all_results,
                "base_temp_dir": self.base_temp_dir
            }
            
        except Exception as e:
            print(f"\n❌ Pipeline failed: {str(e)}")
            import traceback
            traceback.print_exc()
            return {
                "success": False,
                "error": str(e),
                "completed_iterations": len(all_results),
                "results": all_results
            }


async def main():
    """
    Example usage of the multi-iteration schema induction pipeline
    """
    # Example with real question
    question = "How does Ali use signposts to orient the audience?"
    
    # Initialize pipeline with real data processing parameters
    pipeline = MultiIterationSchemaInduction(
        base_temp_dir="temp_files",
        max_iterations=2,
        model_url=os.environ.get('VLLM_QWEN_32B_URL'),
        chunk_size=256,
        overlap=50,
        strategy="strategy_1",
        model="32B",
        custom_data_path=None  # Uses default data
    )
    
    # Run pipeline with actual question
    results = await pipeline.run_full_pipeline(question)
    
    if results["success"]:
        print(f"\n✅ Pipeline completed successfully!")
        print(f"   Results available in: {results['base_temp_dir']}")
    else:
        print(f"\n❌ Pipeline failed: {results['error']}")


if __name__ == "__main__":
    import asyncio
    asyncio.run(main())