#!/usr/bin/env python3
"""
Schema Induction Pipeline with CORRECT optimizations and missing cosine similarity step

Key optimizations:
1. Reuse existing embeddings instead of recomputing them
2. Only embed new high-level codes
3. Efficient clustering with existing embeddings
4. FIXED: Handle case where embeddings don't exist yet (first iteration)
5. FIXED: Use correct aiohttp approach for chat completions
6. FIXED: Enable concurrency for high-level code generation
7. FIXED: Add missing cosine similarity step
8. FIXED: Correct NLI classification method call
"""

import os
import sys
import asyncio
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from dotenv import load_dotenv
import argparse

# Load environment variables
load_dotenv()

# Add current directory to path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

# Import the CORRECT modules that actually exist
from .embeddings import build_embeddings_parquet, build_embeddings_parquet_async
from .high_level_code_gen import OptimizedHighLevelCodeGenerator  # OPTIMIZED VERSION
from .cluster import cluster_fast
from .cosine_sim import CosineSimilarity
from .nli_classify_with_load_balancer import classify_similarities_optimized
from .topological_graph import build_enhanced_topological_graph
from .flip_label_processing import process_nli_results_for_conflict_detection
from ..conflict_relationship_detection.conflict_detection_resolver import detect_and_resolve_conflicts_advanced

# Constants
TEMP_FILES_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "temp_files")

class SchemaInductionPipeline:
    """
    Schema Induction Pipeline with OPTIMIZATIONS
    
    KEY OPTIMIZATION: Reuses embeddings for original codes, but generates new embeddings
    for the newly created high-level codes from high-level code generation.
    """
    
    def __init__(self, 
                 iteration_number: int = 1,
                 temp_files_dir: str = None,
                 iteration_dir: str = None,
                 min_frequency: int = 2,
                 min_frequency_ratio: float = 0.1,
                 model_url: str = None,
                 model_name: str = None,
                 max_concurrency: int = 32):
        """
        Initialize the schema induction pipeline
        
        Args:
            iteration_number: Iteration number for this run
            temp_files_dir: Base directory for temporary files
            iteration_dir: Specific iteration directory
            min_frequency: Minimum frequency for relationships
            min_frequency_ratio: Minimum frequency ratio for relationships
            model_url: URL for the vLLM server
            model_name: Name of the model to use
            max_concurrency: Maximum number of concurrent requests
        """
        self.iteration_number = iteration_number
        self.min_frequency = min_frequency
        self.min_frequency_ratio = min_frequency_ratio
        self.temp_files_dir = temp_files_dir or TEMP_FILES_DIR
        self.iteration_dir = iteration_dir
        self.model_url = model_url or os.getenv("VLLM_QWEN_32B_URL")
        self.model_name = model_name or os.getenv("VLLM_QWEN_32B_MODEL")
        self.max_concurrency = max_concurrency
        
        # Setup iteration-specific directories
        if self.iteration_dir:
            self.embeddings_dir = os.path.join(self.iteration_dir, "embeddings")
            self.high_level_dir = os.path.join(self.iteration_dir, "high_level_codes")
            self.clustering_dir = os.path.join(self.iteration_dir, "cluster_sim")
            self.nli_dir = os.path.join(self.iteration_dir, "nli_classify")
            self.conflict_dir = os.path.join(self.iteration_dir, "conflict_detection")
            self.topological_dir = os.path.join(self.iteration_dir, "topologically_sorted_graph")
        else:
            self.embeddings_dir = os.path.join(self.temp_files_dir, "embeddings")
            self.high_level_dir = os.path.join(self.temp_files_dir, "high_level_codes")
            self.clustering_dir = os.path.join(self.temp_files_dir, "cluster_sim")
            self.nli_dir = os.path.join(self.temp_files_dir, "nli_classify")
            self.conflict_dir = os.path.join(self.temp_files_dir, "conflict_detection")
            self.topological_dir = os.path.join(self.temp_files_dir, "topologically_sorted_graph")
        
        # Create all directories
        for dir_path in [self.embeddings_dir, self.high_level_dir, 
                        self.clustering_dir, self.nli_dir, 
                        self.conflict_dir, self.topological_dir]:
            os.makedirs(dir_path, exist_ok=True)
        
        print(f"🚀 Initialized SchemaInductionPipeline")
        print(f"   Iteration: {self.iteration_number}")
        print(f"   Temp Files Dir: {self.temp_files_dir}")
        print(f"   Model URL: {self.model_url}")
        print(f"   Model: {self.model_name}")
        print(f"   Max Concurrency: {self.max_concurrency}")
    
    async def run_high_level_code_generation_with_embeddings(self, corpus_df: pd.DataFrame) -> Tuple[pd.DataFrame, str, np.ndarray, List[str]]:
        """
        OPTIMIZED: Run high-level code generation with embedding reuse
        
        Args:
            corpus_df: Input corpus DataFrame
            
        Returns:
            Tuple of (enhanced_corpus_df, embeddings_path, embeddings_array, codes_list)
        """
        print("🧠 Running High-Level Code Generation with Embeddings...")
        
        # Save original corpus
        os.makedirs(self.high_level_dir, exist_ok=True)
        corpus_path = os.path.join(self.high_level_dir, "original_corpus.parquet")
        corpus_df.to_parquet(corpus_path, index=False)
        
        # Check for existing embeddings from previous steps
        existing_embeddings_path = os.path.join(self.embeddings_dir, "embeddings.parquet")
        
        if os.path.exists(existing_embeddings_path):
            print("✅ Found existing embeddings - will reuse them!")
            embeddings_path = existing_embeddings_path
        else:
            print("⚠️ No existing embeddings found - will generate new ones")
            # Fall back to generating embeddings if none exist
            embeddings_path = os.path.join(self.high_level_dir, "embeddings.parquet")
        
        # Initialize OPTIMIZED high-level code generator
        generator = OptimizedHighLevelCodeGenerator(
            model_url=self.model_url,
            model_name=self.model_name,
            max_concurrency=self.max_concurrency
        )
        
        
        # OPTIMIZATION: Use the optimized generation process
        enhanced_corpus_df = await generator.run_optimized_high_level_generation(
            corpus_df=corpus_df,
            existing_embeddings_path=embeddings_path,
            output_dir=self.high_level_dir
        )
        
        # OPTIMIZATION: Load the embeddings that were used/generated for reuse
        if os.path.exists(embeddings_path):
            embeddings_df = pd.read_parquet(embeddings_path)
            embeddings_array = np.array(embeddings_df['embedding'].tolist())
            codes_list = embeddings_df['code'].tolist() if 'code' in embeddings_df.columns else embeddings_df['tag'].tolist()
        else:
            # This shouldn't happen with the optimized version, but just in case
            raise FileNotFoundError(f"Embeddings not found at {embeddings_path}")
        
        print(f"✅ High-level code generation completed with {len(enhanced_corpus_df)} records")
        
        # OPTIMIZATION: Copy embeddings to embeddings directory for subsequent iterations
        target_embeddings_path = os.path.join(self.embeddings_dir, "embeddings.parquet")
        if os.path.exists(embeddings_path) and not os.path.exists(target_embeddings_path):
            import shutil
            os.makedirs(self.embeddings_dir, exist_ok=True)
            shutil.copy2(embeddings_path, target_embeddings_path)
            print(f"✅ Copied embeddings to {target_embeddings_path} for subsequent iterations")
        
        return enhanced_corpus_df, embeddings_path, embeddings_array, codes_list
    
    async def run_embeddings_for_new_codes(self, corpus_df: pd.DataFrame, existing_embeddings_path: str, existing_embeddings_array: np.ndarray, existing_codes_list: List[str]) -> Tuple[str, np.ndarray, List[str]]:
        """
        OPTIMIZATION: Generate embeddings only for new high-level codes
        
        Args:
            corpus_df: Enhanced corpus with high-level codes
            existing_embeddings_path: Path to existing embeddings
            existing_embeddings_array: Existing embeddings array
            existing_codes_list: List of existing codes
            
        Returns:
            Tuple of (combined_embeddings_path, combined_embeddings_array, all_codes_list)
        """
        print("🔄 Generating embeddings for new high-level codes...")
        
        # Get all codes from enhanced corpus
        all_codes_from_corpus = corpus_df["tag"].unique().tolist()
        existing_codes_set = set(existing_codes_list)
        
        new_codes = list(set(all_codes_from_corpus) - existing_codes_set)
        
        # Create all_codes in the correct order: existing codes first, then new codes
        all_codes = existing_codes_list + new_codes
        
        if not new_codes:
            print("✅ No new codes to embed - reusing existing embeddings")
            return existing_embeddings_path, existing_embeddings_array, all_codes
        
        print(f"🔄 Embedding {len(new_codes)} new high-level codes...")
        
        # Generate embeddings for new codes only
        target_embeddings_path = os.path.join(self.embeddings_dir, "embeddings.parquet")
        
        # Create DataFrame with new codes
        new_codes_df = pd.DataFrame({'tag': new_codes})
        new_embeddings_path = os.path.join(self.embeddings_dir, "new_embeddings.parquet")
        
        # Generate embeddings for new codes
        _, _, _, new_embeddings_df = await build_embeddings_parquet_async(
            corpus_df=new_codes_df,
            output_parquet=new_embeddings_path
        )
        
        # Combine existing and new embeddings
        new_embeddings = np.array(new_embeddings_df['embedding'].tolist())
        combined_embeddings_array = np.vstack([existing_embeddings_array, new_embeddings])
        
        # Create combined embeddings DataFrame
        combined_embeddings_df = pd.DataFrame({
            'tag': all_codes,
            'embedding': combined_embeddings_array.tolist()
        })
        
        # Save combined embeddings
        target_embeddings_path = os.path.join(self.embeddings_dir, "embeddings.parquet")
        combined_embeddings_df.to_parquet(target_embeddings_path, index=False)
        
        print(f"✅ Combined embeddings: {existing_embeddings_array.shape[0]} existing + {len(new_codes)} new = {combined_embeddings_array.shape[0]} total")
        return target_embeddings_path, combined_embeddings_array, all_codes
    
    async def run_clustering_with_combined_embeddings(self, embeddings_array: np.ndarray, codes_list: List[str]) -> Tuple[pd.DataFrame, Dict[str, Any]]:
        """
        OPTIMIZED: Run clustering using combined embeddings (existing + new)
        
        Args:
            embeddings_array: Combined embeddings array
            codes_list: List of all codes
            
        Returns:
            Tuple of (cluster_df, cluster_results)
        """
        print("🔍 Running Clustering with Combined Embeddings...")
        
        # Use the existing cluster_fast function
        cluster_results = cluster_fast(embeddings_array)
        
        # Extract cluster assignments
        cluster_assignments = cluster_results['labels']
        optimal_k = cluster_results['optimal_k']
        
        print(f"✅ Clustering completed with k={optimal_k}")
        
        # Create cluster assignments DataFrame
        cluster_df = pd.DataFrame({
            'code_index': range(len(cluster_assignments)),
            'cluster_id': cluster_assignments
        })
        
        # Save cluster assignments
        cluster_path = os.path.join(self.clustering_dir, "cluster_assignments.parquet")
        cluster_df.to_parquet(cluster_path, index=False)
        print(f"✅ Saved cluster assignments to {cluster_path}")
        
        return cluster_df, cluster_results
    
    async def run_cosine_similarity(self, cluster_df: pd.DataFrame, embeddings_array: np.ndarray, codes_list: List[str], corpus_df: pd.DataFrame, cluster_results: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        FIXED: Run cosine similarity computation between codes in the same cluster
        
        Args:
            cluster_df: DataFrame with cluster assignments
            embeddings_array: Embeddings array
            codes_list: List of codes
            corpus_df: Corpus DataFrame
            cluster_results: Cluster results from clustering step
            
        Returns:
            List of similarity pairs
        """
        print("🔍 Running Cosine Similarity Computation...")
        
        # Initialize cosine similarity computer
        similarity_computer = CosineSimilarity()
        
        # Compute cosine similarities
        similarity_pairs = similarity_computer.compute_similarities_optimized(
            cluster_results, codes_list, embeddings_array, corpus_df
        )
        
        print(f"✅ Computed {len(similarity_pairs)} similarity pairs")
        
        # Save similarity pairs
        similarity_path = os.path.join(self.clustering_dir, "similarity_pairs.parquet")
        if similarity_pairs:
            similarity_df = pd.DataFrame(similarity_pairs)
            similarity_df.to_parquet(similarity_path, index=False)
            print(f"✅ Saved similarity pairs to {similarity_path}")
        
        return similarity_pairs
    
    async def run_nli_classification(self, similarity_pairs: List[Dict[str, Any]]) -> pd.DataFrame:
        """
        FIXED: Run NLI classification on similarity pairs
        
        Args:
            similarity_pairs: List of similarity pairs from cosine similarity
            
        Returns:
            NLI classified results DataFrame
        """
        print("🧠 Running NLI Classification...")
        
        if not similarity_pairs:
            print("⚠️ No similarity pairs to classify")
            return pd.DataFrame()
        
        # Use the correct function for NLI classification
        nli_results, relationship_matrix, unique_codes = await classify_similarities_optimized(
            similarity_pairs=similarity_pairs,
            output_dir=self.nli_dir
        )
        
        print(f"✅ NLI classification completed: {len(nli_results)} results")
        return pd.DataFrame(nli_results)
    
    async def run_conflict_detection(self, corpus_df: pd.DataFrame, nli_results_df: pd.DataFrame) -> pd.DataFrame:
        """
        Run conflict detection step using the existing conflict detection
        
        Args:
            corpus_df: Corpus DataFrame
            nli_results_df: NLI results DataFrame
            
        Returns:
            Conflict-resolved corpus DataFrame
        """
        print("⚔️ Running Conflict Detection...")
        
        # Process NLI results for advanced conflict detection (like in backup)
        processed_nli_results = process_nli_results_for_conflict_detection(nli_results_df.to_dict("records"))
        
        # Use the existing conflict detection
        conflict_results = detect_and_resolve_conflicts_advanced(
            nli_results=processed_nli_results,
            output_dir=self.conflict_dir,
            corpus_df=corpus_df,
        )
        
        print(f"✅ Conflict detection completed: {len(conflict_results)} records")
        
        # Extract the tuple components
        final_matrix, final_codes, code_to_datapoints = conflict_results
        
        # Save the results to files for topological graph construction
        matrix_path = os.path.join(self.conflict_dir, "final_relationship_matrix.npy")
        codes_path = os.path.join(self.conflict_dir, "unique_codes.parquet")
        
        # Save matrix and codes
        np.save(matrix_path, final_matrix)
        codes_df = pd.DataFrame({"code": final_codes})
        codes_df.to_parquet(codes_path, index=False)
        
        print(f"✅ Conflict detection completed: {len(final_codes)} unique codes")
        return corpus_df  # Return the original corpus_df since topological graph reads from files
    
    async def run_topological_graph_construction(self, corpus_df: pd.DataFrame) -> pd.DataFrame:
        """
        Run topological graph construction step using the existing function
        
        Args:
            corpus_df: Corpus DataFrame
            nli_results_df: NLI results DataFrame
            
        Returns:
            Final corpus DataFrame with topological relationships
        """
        print("🕸️ Running Topological Graph Construction...")
        
        # Use the existing topological graph builder with correct parameters
        corpus_path = os.path.join(self.temp_files_dir, "build_corpus", "corpus.parquet")
        matrix_path = os.path.join(self.conflict_dir, "final_relationship_matrix.npy")
        codes_path = os.path.join(self.conflict_dir, "unique_codes.parquet")
        
        # Save corpus if not already saved
        if not os.path.exists(corpus_path):
            corpus_df.to_parquet(corpus_path, index=False)
        
        # Build enhanced topological graph
        builder = build_enhanced_topological_graph(
            corpus_path=corpus_path,
            matrix_path=matrix_path,
            codes_path=codes_path,
            output_dir=self.topological_dir,
            min_frequency=5
        )
        
        # Extract the final corpus from the builder
        final_corpus_df = corpus_df  # The builder processes the files, we return the original corpus
        
        print(f"✅ Topological graph construction completed: {len(final_corpus_df)} records")
        return final_corpus_df
    
    async def run_pipeline(self, corpus_df: pd.DataFrame) -> pd.DataFrame:
        """
        Run the complete schema induction pipeline with CORRECT optimizations
        
        KEY OPTIMIZATION: Reuses embeddings for original codes, but generates new embeddings
        for the newly created high-level codes from high-level code generation.
        
        FIXED: Added missing cosine similarity step
        FIXED: Corrected NLI classification method call
        
        Args:
            corpus_df: Input corpus DataFrame
            
        Returns:
            Final processed corpus DataFrame
        """
        print("🚀 Starting Schema Induction Pipeline...")
        
        # Step 1: High-level code generation with embedding reuse
        enhanced_corpus_df, embeddings_path, embeddings_array, codes_list = await self.run_high_level_code_generation_with_embeddings(corpus_df)
        
        # Step 2: Generate embeddings for new high-level codes only
        combined_embeddings_path, combined_embeddings_array, all_codes_list = await self.run_embeddings_for_new_codes(
            enhanced_corpus_df, embeddings_path, embeddings_array, codes_list
        )
        
        # Step 3: Clustering with combined embeddings
        cluster_df, cluster_results = await self.run_clustering_with_combined_embeddings(combined_embeddings_array, all_codes_list)
        
        # Step 4: FIXED - Cosine similarity computation
        similarity_pairs = await self.run_cosine_similarity(cluster_df, combined_embeddings_array, all_codes_list, enhanced_corpus_df, cluster_results)
        
        # Step 5: FIXED - NLI classification on similarity pairs
        nli_results_df = await self.run_nli_classification(similarity_pairs)
        
        # Step 6: Conflict detection
        conflict_results_df = await self.run_conflict_detection(enhanced_corpus_df, nli_results_df)
        
        # Step 7: Topological graph construction
        final_corpus_df = await self.run_topological_graph_construction(conflict_results_df)
        
        print("✅ Schema Induction Pipeline completed successfully!")
        return final_corpus_df


async def main():
    """Main function for testing the optimized pipeline"""
    parser = argparse.ArgumentParser(description='Optimized Schema Induction Pipeline')
    parser.add_argument('--iteration', type=int, default=1, help='Iteration number')
    parser.add_argument('--input', type=str, required=True, help='Input corpus parquet file')
    parser.add_argument('--output', type=str, required=True, help='Output directory')
    parser.add_argument('--model_url', type=str, help='vLLM server URL')
    parser.add_argument('--model_name', type=str, help='Model name')
    
    args = parser.parse_args()
    
    # Load input corpus
    print(f"📂 Loading input corpus from {args.input}")
    corpus_df = pd.read_parquet(args.input)
    print(f"✅ Loaded corpus with {len(corpus_df)} records")
    
    # Initialize optimized pipeline
    pipeline = SchemaInductionPipeline(
        iteration_number=args.iteration,
        temp_files_dir=args.output,
        model_url=args.model_url,
        model_name=args.model_name
    )
    
    # Run optimized pipeline
    final_corpus_df = await pipeline.run_pipeline(corpus_df)
    
    # Save final results
    final_output_path = os.path.join(args.output, "final_corpus_optimized.parquet")
    final_corpus_df.to_parquet(final_output_path, index=False)
    print(f"✅ Saved final corpus to {final_output_path}")


if __name__ == "__main__":
    asyncio.run(main())