#!/usr/bin/env python3
"""
Final Data Retrieval Script

This script implements data retrieval that works with the final iteration
from the multi-iteration schema induction pipeline. It uses the topological
graph's final datachunk-to-code mappings for retrieval.

Flow:
1. Each iteration: LLM code selector returns 20 codes per datachunk (refined from previous iteration)
2. Store embeddings for these codes
3. Run schema induction pipeline including topological graph
4. Save topological graph output as reference for next iteration
5. After multi-iteration complete: Use final datachunk-to-code mappings for data retrieval

Key features:
- Automatically detects the latest iteration
- Uses the final topological graph's datachunk-to-code mappings
- Handles case-insensitive matching between codes and embeddings
- Calculates similarity between codes and question
- Uses similarity scores to determine which datachunks are useful for QA
"""

import os
import numpy as np
import pandas as pd
import networkx as nx
from typing import List, Dict, Any, Tuple, Optional
from sklearn.metrics.pairwise import cosine_similarity
import asyncio
import aiohttp
import json
from dotenv import load_dotenv
import glob

load_dotenv()

class GraphBasedDataRetriever:
    """Final data retrieval that works with the latest iteration's topological graph output"""
    
    def __init__(self, base_temp_dir: str, alpha: float = 0.85, max_iterations: int = 5):
        """
        Initialize the final data retriever
        
        Args:
            base_temp_dir: Base directory containing iteration folders
            alpha: Damping factor for relevance propagation (default: 0.85)
            max_iterations: Maximum iterations for relevance propagation (default: 5)
        """
        self.base_temp_dir = base_temp_dir
        self.alpha = alpha
        self.max_iterations = max_iterations
        
        # Data paths (will be set by _detect_latest_iteration)
        self.latest_iteration_dir = None
        self.embeddings_path = None
        self.topological_graph_dir = None
        
        # Loaded data
        self.embeddings_df = None
        self.graph = None
        self.sorted_nodes = None
        self.hierarchy = None
        self.code_to_datachunks = None  # Final datachunk-to-code mapping
        self.code_to_embedding = None   # Case-insensitive code to embedding mapping
        
        # Relationship weights for relevance propagation
        self.weights = {
            'upward': 0.5,    # child to parent
            'downward': 0.5,  # parent to child  
            'sibling': 0.0    # sibling to sibling (ignored)
        }
        
        self._detect_latest_iteration()
        self._load_data()
    
    def _detect_latest_iteration(self):
        """Detect the latest iteration directory and set paths"""
        print("🔍 Detecting latest iteration for final data retrieval...")
        
        # Look for iteration directories
        iteration_pattern = os.path.join(self.base_temp_dir, "iteration_*")
        iteration_dirs = glob.glob(iteration_pattern)
        
        if not iteration_dirs:
            raise FileNotFoundError(f"No iteration directories found in {self.base_temp_dir}")
        
        # Sort by iteration number (extract number from directory name)
        def extract_iteration_number(path):
            dirname = os.path.basename(path)
            # Extract number from "iteration_XX" format
            try:
                return int(dirname.split('_')[1])
            except (IndexError, ValueError):
                return 0
        
        iteration_dirs.sort(key=extract_iteration_number, reverse=True)
        self.latest_iteration_dir = iteration_dirs[0]
        latest_iteration_num = extract_iteration_number(self.latest_iteration_dir)
        
        print(f"   ✅ Latest iteration detected: {latest_iteration_num}")
        print(f"   📁 Using directory: {self.latest_iteration_dir}")
        
        # Set paths for the latest iteration
        self.embeddings_path = os.path.join(self.latest_iteration_dir, "embeddings", "embeddings.parquet")
        self.topological_graph_dir = os.path.join(self.latest_iteration_dir, "topologically_sorted_graph")
        
        # Verify required files exist
        required_files = [self.embeddings_path]
        for file_path in required_files:
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Required file not found: {file_path}")
        
        if not os.path.exists(self.topological_graph_dir):
            raise FileNotFoundError(f"Required directory not found: {self.topological_graph_dir}")
        
        print(f"   ✅ All required files found for iteration {latest_iteration_num}")
    
    def _load_data(self):
        """Load embeddings and final topological graph mappings from the latest iteration"""
        print("📊 Loading data from latest iteration for final data retrieval...")
        
        # Load embeddings
        if os.path.exists(self.embeddings_path):
            self.embeddings_df = pd.read_parquet(self.embeddings_path)
            print(f"   ✅ Loaded embeddings: {len(self.embeddings_df)} records")
            
            # Create case-insensitive code to embedding mapping
            self.embeddings_df['tag_lower'] = self.embeddings_df['tag'].str.lower()
            self.code_to_embedding = {}
            for _, row in self.embeddings_df.iterrows():
                self.code_to_embedding[row['tag_lower']] = row
            print(f"   ✅ Created case-insensitive embedding mapping: {len(self.code_to_embedding)} codes")
        else:
            raise FileNotFoundError(f"Embeddings file not found: {self.embeddings_path}")
        
        # Load final datachunk-to-code mapping from topological graph
        # Use the code_to_datapoints.parquet file which contains the final mappings
        mapping_path = os.path.join(self.topological_graph_dir, "datapoint_code_mapping", "code_to_datapoints.parquet")
        if os.path.exists(mapping_path):
            mapping_df = pd.read_parquet(mapping_path)
            self.code_to_datachunks = {}
            
            # Build mapping from codes to datachunks
            for _, row in mapping_df.iterrows():
                code = row['code']
                datachunk = row['datapoint']  # datapoint is the datachunk content
                
                if code not in self.code_to_datachunks:
                    self.code_to_datachunks[code] = []
                self.code_to_datachunks[code].append(datachunk)
            
            print(f"   ✅ Loaded final datachunk-to-code mapping: {len(self.code_to_datachunks)} codes")
            print(f"       Total datachunk references: {sum(len(chunks) for chunks in self.code_to_datachunks.values())}")
            
            # Use the codes from the mapping as our sorted nodes
            self.sorted_nodes = list(self.code_to_datachunks.keys())
            print(f"   ✅ Using {len(self.sorted_nodes)} codes from final topological graph mapping")
        else:
            print(f"   ⚠️ Topological graph mapping file not found: {mapping_path}")
            self.code_to_datachunks = {}
            self.sorted_nodes = []
        
        # Load hierarchy (final consolidated codes)
        hierarchy_path = os.path.join(self.topological_graph_dir, "hierarchy.parquet")
        if os.path.exists(hierarchy_path):
            hierarchy_df = pd.read_parquet(hierarchy_path)
            self.hierarchy = {}
            for _, row in hierarchy_df.iterrows():
                level = row['level']
                code = row['node']  # Column is called 'node', not 'code'
                if level not in self.hierarchy:
                    self.hierarchy[level] = []
                self.hierarchy[level].append(code)
            print(f"   ✅ Loaded hierarchy with {len(self.hierarchy)} levels")
        
        # Load existing relationship matrix from conflict detection
        conflict_dir = os.path.join(self.latest_iteration_dir, "conflict_detection")
        relationship_summary_path = os.path.join(conflict_dir, "relationship_summary.parquet")
        if os.path.exists(relationship_summary_path):
            self.relationship_summary = pd.read_parquet(relationship_summary_path)
            print(f"   ✅ Loaded relationship summary: {len(self.relationship_summary)} relationships")
        else:
            print("   ⚠️ Relationship summary not found")
            self.relationship_summary = None
    
    async def embed_question(self, question: str) -> np.ndarray:
        """Embed the given question using the same model as codes"""
        print(f"🔍 Embedding question: '{question[:50]}...'")
        
        # Use the same embedding model as the codes
        from .initial_iteration.embeddings import build_embeddings_parquet
        
        # Create a temporary dataframe with the question
        question_df = pd.DataFrame({
            'tag': ['question'],
            'chunk_text': [question]
        })
        
        # Get embeddings
        _, embeddings, _, _ = await build_embeddings_parquet(
            corpus_df=question_df,
            output_parquet=None  # Don't save, just return embeddings
        )
        
        question_embedding = embeddings[0]  # Get the first (and only) embedding
        print(f"   ✅ Question embedded successfully")
        
        return question_embedding
    
    def compute_question_code_similarities(self, question_embedding: np.ndarray) -> Dict[str, float]:
        """Compute similarity scores between question and codes using case-insensitive lookup"""
        print("🔍 Computing question-code similarities with final codes...")
        
        # Compute similarities for codes from final topological graph mapping
        similarities = {}
        question_embedding_reshaped = question_embedding.reshape(1, -1)

        for code in self.sorted_nodes:
            # Look up embedding with case-insensitive matching
            code_lower = code.lower()
            if code_lower in self.code_to_embedding:
                embedding_row = self.code_to_embedding[code_lower]
                embedding = embedding_row['embedding']
                if isinstance(embedding, str):
                    # Parse if stored as string
                    embedding = np.array(json.loads(embedding))
                
                code_embedding_reshaped = embedding.reshape(1, -1)
                similarity = cosine_similarity(question_embedding_reshaped, code_embedding_reshaped)[0][0]
                similarities[code] = similarity
            else:
                print(f"   ⚠️ No embedding found for code: {code[:60]}...")
                similarities[code] = 0.0
        
        # Sort and show top similarities
        sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
        print(f"   📊 Top 5 code similarities:")
        for i, (code, sim) in enumerate(sorted_similarities[:5]):
            print(f"      {i+1}. {code[:60]}... (sim: {sim:.3f})")
        
        return similarities
    
    def construct_graph_matrix(self, question_similarities: Dict[str, float]) -> Tuple[np.ndarray, List[str], Dict[str, float]]:
        """Construct graph matrix and base relevance from existing relationship data"""
        print("🔗 Constructing graph matrix from final codes...")
        
        # Get all unique codes from the final topological graph mapping
        if self.sorted_nodes:
            unique_codes = self.sorted_nodes
        else:
            unique_codes = list(question_similarities.keys())
        
        n_codes = len(unique_codes)
        print(f"   📊 Building {n_codes}x{n_codes} graph matrix")
        
        # Create code to index mapping
        code_to_idx = {code: i for i, code in enumerate(unique_codes)}
        
        # Initialize graph matrix
        graph_matrix = np.zeros((n_codes, n_codes))
        
        # Initialize base relevance scores
        base_relevance = np.zeros(n_codes)
        for i, code in enumerate(unique_codes):
            base_relevance[i] = question_similarities.get(code, 0.0)
        
        # Build graph from hierarchy relationships
        if self.hierarchy:
            for level, codes in self.hierarchy.items():
                for code in codes:
                    if code in code_to_idx:
                        code_idx = code_to_idx[code]
                        
                        # Add relationships to parent level
                        if level > 0:
                            parent_level = level - 1
                            if parent_level in self.hierarchy:
                                for parent_code in self.hierarchy[parent_level]:
                                    if parent_code in code_to_idx:
                                        parent_idx = code_to_idx[parent_code]
                                        # Child to parent (upward)
                                        graph_matrix[code_idx, parent_idx] = self.weights['upward']
                                        # Parent to child (downward)
                                        graph_matrix[parent_idx, code_idx] = self.weights['downward']
        
        # Add relationships from relationship summary if available
        if self.relationship_summary is not None:
            for _, row in self.relationship_summary.iterrows():
                code_a = row.get('code_a', '')
                code_b = row.get('code_b', '')
                relationship = row.get('relationship', '')
                
                if code_a in code_to_idx and code_b in code_to_idx:
                    idx_a = code_to_idx[code_a]
                    idx_b = code_to_idx[code_b]
                    
                    if relationship == 'IMPLIES':
                        # A implies B: A -> B
                        graph_matrix[idx_a, idx_b] = 0.8
                    elif relationship == 'MUTUAL':
                        # A and B are mutual: A <-> B
                        graph_matrix[idx_a, idx_b] = 0.9
                        graph_matrix[idx_b, idx_a] = 0.9
        
        print(f"   ✅ Graph matrix constructed with {np.count_nonzero(graph_matrix)} edges")
        
        return graph_matrix, unique_codes, {code: base_relevance[i] for i, code in enumerate(unique_codes)}
    
    def propagate_relevance(self, graph_matrix: np.ndarray, base_relevance: Dict[str, float], 
                          unique_codes: List[str]) -> Dict[str, float]:
        """Propagate relevance through the graph using iterative updates"""
        print(f"🔄 Propagating relevance through graph (α={self.alpha}, max_iter={self.max_iterations})...")
        
        n_codes = len(unique_codes)
        code_to_idx = {code: i for i, code in enumerate(unique_codes)}
        
        # Initialize relevance vector
        relevance = np.zeros(n_codes)
        for i, code in enumerate(unique_codes):
            relevance[i] = base_relevance.get(code, 0.0)
        
        # Normalize graph matrix (row normalization for proper probability distribution)
        row_sums = graph_matrix.sum(axis=1)
        row_sums[row_sums == 0] = 1  # Avoid division by zero
        normalized_matrix = graph_matrix / row_sums[:, np.newaxis]
        
        # Iterative relevance propagation
        for iteration in range(self.max_iterations):
            old_relevance = relevance.copy()
            
            # Update relevance: R = α * M * R + (1-α) * B
            relevance = self.alpha * (normalized_matrix @ relevance) + (1 - self.alpha) * np.array([base_relevance.get(code, 0.0) for code in unique_codes])
            
            # Check for convergence
            if np.allclose(relevance, old_relevance, atol=1e-6):
                print(f"   ✅ Convergence reached after {iteration + 1} iterations")
                break
        else:
            print(f"   ⚠️ Maximum iterations ({self.max_iterations}) reached")
        
        # Convert back to dictionary
        final_relevance = {code: relevance[i] for i, code in enumerate(unique_codes)}
        
        # Show top relevance scores
        sorted_relevance = sorted(final_relevance.items(), key=lambda x: x[1], reverse=True)
        print(f"   📊 Top 5 relevance scores after propagation:")
        for i, (code, rel) in enumerate(sorted_relevance[:5]):
            print(f"      {i+1}. {code[:60]}... (rel: {rel:.3f})")
        
        return final_relevance
    
    def score_data_chunks(self, relevance_scores: Dict[str, float]) -> List[Tuple[str, float]]:
        """Score datachunks based on code relevance scores"""
        print("📊 Scoring datachunks based on code relevance...")
        
        chunk_scores = {}
        
        for code, relevance in relevance_scores.items():
            if code in self.code_to_datachunks:
                for datachunk in self.code_to_datachunks[code]:
                    if datachunk not in chunk_scores:
                        chunk_scores[datachunk] = 0.0
                    chunk_scores[datachunk] += relevance
        
        # Sort by score
        sorted_chunks = sorted(chunk_scores.items(), key=lambda x: x[1], reverse=True)
        
        print(f"   ✅ Scored {len(sorted_chunks)} datachunks")
        if sorted_chunks:
            print(f"   📊 Top 3 datachunk scores:")
            for i, (chunk, score) in enumerate(sorted_chunks[:3]):
                print(f"      {i+1}. {chunk[:60]}... (score: {score:.3f})")
        
        return sorted_chunks
    
    def _get_datachunk_content(self, datachunk: str) -> str:
        """Get the content of a datachunk (for <256 tokens, datapoint = datachunk)"""
        # For our use case, datachunk is the actual content since datapoints are <256 tokens
        return datachunk
    
    async def retrieve_relevant_chunks(self, question: str, top_k: int = 10) -> List[Tuple[str, float]]:
        """Retrieve top-k most relevant datachunks for a question using final codes"""
        print(f"🔍 Retrieving top-{top_k} relevant datachunks for: '{question[:50]}...'")
        
        # Step 1: Embed the question
        question_embedding = await self.embed_question(question)
        
        # Step 2: Compute question-code similarities with final codes
        question_similarities = self.compute_question_code_similarities(question_embedding)
        
        # Step 3: Construct graph matrix
        graph_matrix, unique_codes, base_relevance = self.construct_graph_matrix(question_similarities)
        
        # Step 4: Propagate relevance through graph
        relevance_scores = self.propagate_relevance(graph_matrix, base_relevance, unique_codes)
        
        # Step 5: Score datachunks based on code relevance
        chunk_scores = self.score_data_chunks(relevance_scores)
        
        # Step 6: Return top-k chunks
        top_chunks = chunk_scores[:top_k]
        
        print(f"   ✅ Retrieved {len(top_chunks)} relevant datachunks")
        
        return top_chunks
    
    def get_iteration_info(self) -> Dict[str, Any]:
        """Get information about the current iteration being used"""
        iteration_num = os.path.basename(self.latest_iteration_dir).split('_')[1]
        return {
            'iteration_number': int(iteration_num),
            'iteration_dir': self.latest_iteration_dir,
            'embeddings_path': self.embeddings_path,
            'topological_graph_dir': self.topological_graph_dir,
            'total_codes': len(self.embeddings_df) if self.embeddings_df is not None else 0,
            'total_datachunks': sum(len(chunks) for chunks in self.code_to_datachunks.values()) if self.code_to_datachunks else 0,
            'final_codes': len(self.sorted_nodes) if self.sorted_nodes else 0
        }


# Example usage
async def main():
    """Example usage of the final data retriever"""
    # Initialize retriever with base temp directory
    retriever = GraphBasedDataRetriever(
        base_temp_dir="temp_files_test_real_refinement",
        alpha=0.85,
        max_iterations=5
    )
    
    # Get iteration info
    info = retriever.get_iteration_info()
    print(f"📊 Using iteration {info['iteration_number']}")
    print(f"   Total codes: {info['total_codes']}")
    print(f"   Total datachunks: {info['total_datachunks']}")
    print(f"   Final codes: {info['final_codes']}")
    
    # Example question
    question = "What are the key themes in the news timeline events?"
    
    # Retrieve relevant chunks
    top_chunks = await retriever.retrieve_relevant_chunks(question, top_k=5)
    
    print(f"\n🎯 Top {len(top_chunks)} relevant datachunks:")
    for i, (chunk, score) in enumerate(top_chunks):
        print(f"{i+1}. Score: {score:.3f}")
        print(f"   Content: {chunk[:100]}...")
        print()


if __name__ == "__main__":
    asyncio.run(main())
