#!/usr/bin/env python3
"""
Semantic Retrieval Module for Refine Iteration

This module implements semantic retrieval by:
1. Loading embeddings.parquet (which contains code embeddings)
2. For a given code, compute cosine similarity with all other codes in the global code set
3. Using the stored embeddings from embeddings.parquet
4. Applying similarity thresholds (max_sim_threshold=100, sim_threshold=0.6)
5. Selecting top k most similar codes (configurable, max 10)
6. Including global frequency information from topological graph mapping
7. Calculating combined score: 0.7 * similarity + 0.3 * normalized_global_frequency
8. Sorting and returning codes by TOP K HIGHEST COMBINED SCORE
9. Storing results in memory for later reference
"""

import os
import sys
import numpy as np
import pandas as pd
from typing import List, Dict, Any, Tuple, Optional
from sklearn.metrics.pairwise import cosine_similarity

# Add parent directories to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..'))

class SemanticRetriever:
    """Semantic retrieval system for refine iteration using code embeddings"""
    
    def __init__(self, embeddings_path: str, mapping_dir: str,
                 max_sim_threshold: int = 100, sim_threshold: float = 0.6,
                 top_k: int = 5, similarity_weight: float = 0.7, frequency_weight: float = 0.3):
        """
        Initialize semantic retriever
        
        Args:
            embeddings_path: Path to embeddings.parquet file (contains code embeddings)
            mapping_dir: Path to datapoint_code_mapping directory from topological graph
            max_sim_threshold: Maximum similarity threshold (default: 100)
            sim_threshold: Minimum similarity threshold (default: 0.6)
            top_k: Number of top codes to return (default: 5, maximum: 10)
            similarity_weight: Weight for similarity score (default: 0.7)
            frequency_weight: Weight for normalized frequency score (default: 0.3)
        """
        self.embeddings_path = embeddings_path
        self.mapping_dir = mapping_dir
        self.max_sim_threshold = max_sim_threshold
        self.sim_threshold = sim_threshold
        
        # Validate and set top_k
        if top_k > 10:
            self.top_k = 10
        elif top_k < 1:
            self.top_k = 1
        else:
            self.top_k = top_k
        
        self.similarity_weight = similarity_weight
        self.frequency_weight = frequency_weight
        
        # Load data
        self.embeddings_df = None
        self.code_embeddings = None
        self.unique_codes = None
        self.code_frequencies = None
        self.max_global_frequency = None
        
        # Store results in memory
        self.semantic_results = {}  # code_name -> top_k results
        
        self._load_data()
    
    def _load_data(self):
        """Load embeddings data and frequency mappings"""
        if not os.path.exists(self.embeddings_path):
            raise FileNotFoundError(f"Embeddings file not found: {self.embeddings_path}")
        
        self.embeddings_df = pd.read_parquet(self.embeddings_path)
        
        # Load code frequencies
        freq_path = os.path.join(self.mapping_dir, "code_frequencies.parquet")
        if not os.path.exists(freq_path):
            raise FileNotFoundError(f"Code frequencies file not found: {freq_path}")
        
        freq_df = pd.read_parquet(freq_path)
        self.code_frequencies = {row['code']: {
            'global_frequency': row['global_frequency'],
            'incoming_edges': row['incoming_edges'],
            'merge_score': row['merge_score']
        } for _, row in freq_df.iterrows()}
        
        # Find maximum global frequency for normalization
        self.max_global_frequency = max(freq_df['global_frequency']) if len(freq_df) > 0 else 1
        
        # Extract unique codes and their embeddings
        self._prepare_code_embeddings()
    
    def _prepare_code_embeddings(self):
        """Prepare code embeddings for similarity computation"""
        # Get unique codes and their embeddings
        # Group by code (tag) and take the first embedding for each unique code
        unique_codes_data = self.embeddings_df.groupby('tag').first().reset_index()
        
        self.unique_codes = unique_codes_data['tag'].tolist()
        
        # Extract embeddings for each unique code
        embeddings_list = []
        for _, row in unique_codes_data.iterrows():
            embedding = row['embedding']
            if isinstance(embedding, np.ndarray):
                embeddings_list.append(embedding)
            else:
                # Convert to numpy array if needed
                embeddings_list.append(np.array(embedding))
        
        self.code_embeddings = np.array(embeddings_list)
    
    def _normalize_global_frequency(self, global_frequency: int) -> float:
        """
        Normalize global frequency with respect to max frequency
        
        Args:
            global_frequency: Raw global frequency count
            
        Returns:
            Normalized frequency score (0.0 to 1.0)
        """
        if self.max_global_frequency == 0:
            return 0.0
        
        # Normalize by max frequency to get 0-1 range
        normalized_freq = global_frequency / self.max_global_frequency
        
        # Ensure it's between 0 and 1
        return min(1.0, max(0.0, normalized_freq))
    
    def _calculate_combined_score(self, similarity: float, global_frequency: int) -> float:
        """
        Calculate combined score: similarity_weight * similarity + frequency_weight * normalized_frequency
        
        Args:
            similarity: Cosine similarity score (0.0 to 1.0)
            global_frequency: Global frequency count
            
        Returns:
            Combined score
        """
        normalized_freq = self._normalize_global_frequency(global_frequency)
        combined_score = (self.similarity_weight * similarity) + (self.frequency_weight * normalized_freq)
        return combined_score
    
    def get_code_embedding(self, code_name: str) -> Optional[np.ndarray]:
        """
        Get the embedding for a specific code
        
        Args:
            code_name: The name of the code
            
        Returns:
            numpy array of the embedding, or None if not found
        """
        if code_name not in self.unique_codes:
            return None
        
        code_index = self.unique_codes.index(code_name)
        return self.code_embeddings[code_index]
    
    def compute_similarities(self, source_code: str) -> List[Dict[str, Any]]:
        """
        Compute cosine similarities between source code and all other codes
        
        Args:
            source_code: The name of the source code to compare against
            
        Returns:
            List of similarity results with codes, scores, and frequency info
        """
        # Get the embedding for the source code
        source_embedding = self.get_code_embedding(source_code)
        if source_embedding is None:
            return []
        
        # Normalize embeddings for cosine similarity
        source_norm = source_embedding / np.linalg.norm(source_embedding)
        code_embeddings_norm = self.code_embeddings / np.linalg.norm(self.code_embeddings, axis=1, keepdims=True)
        
        # Compute cosine similarities
        similarities = np.dot(code_embeddings_norm, source_norm)
        
        # Create results list with frequency information and combined scores
        results = []
        for i, code in enumerate(self.unique_codes):
            # Skip the source code itself
            if code == source_code:
                continue
                
            similarity_score = float(similarities[i])
            
            # Apply similarity threshold filtering
            if similarity_score >= self.sim_threshold:
                # Get frequency information
                freq_info = self.code_frequencies.get(code, {
                    'global_frequency': 0,
                    'incoming_edges': 0,
                    'merge_score': 0.0
                })
                
                global_freq = freq_info['global_frequency']
                normalized_freq = self._normalize_global_frequency(global_freq)
                combined_score = self._calculate_combined_score(similarity_score, global_freq)
                
                results.append({
                    'code': code,
                    'similarity': similarity_score,
                    'index': i,
                    'global_frequency': global_freq,
                    'normalized_frequency': normalized_freq,
                    'combined_score': combined_score
                })
        
        # Sort by COMBINED SCORE (descending) - HIGHEST COMBINED SCORE FIRST
        results.sort(key=lambda x: x['combined_score'], reverse=True)
        
        return results
    
    def select_top_codes(self, similarity_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Select top k codes based on COMBINED SCORE
        
        Args:
            similarity_results: List of similarity results (already sorted by combined score)
            
        Returns:
            Top k codes with simplified information
        """
        top_codes = similarity_results[:self.top_k]  # Take top k results, or all if fewer than k
        top_codes = similarity_results[:self.top_k]
        
        # Return simplified results
        simplified_results = []
        for result in top_codes:
            simplified_results.append({
                'code': result['code'],
                'similarity': result['similarity'],
                'global_frequency': result['global_frequency'],
                'normalized_frequency': result['normalized_frequency'],
                'combined_score': result['combined_score']
            })
        
        return simplified_results
    
    def retrieve_semantic_codes(self, source_code: str) -> List[Dict[str, Any]]:
        """
        Main method to retrieve semantic codes for a given code
        
        Args:
            source_code: The code name to find similar codes for
            
        Returns:
            List of top k codes sorted by COMBINED SCORE (highest first)
            Format: code, similarity, global_frequency, normalized_frequency, combined_score
        """
        try:
            # Step 1: Compute similarities with all other codes and calculate combined scores
            similarity_results = self.compute_similarities(source_code)
            
            # Step 2: Select top k codes
            top_codes = self.select_top_codes(similarity_results)
            
            # Step 3: Store in memory for later reference
            self.semantic_results[source_code] = top_codes
            
            return top_codes
            
        except Exception as e:
            return []
    
    def process_all_codes(self) -> Dict[str, List[Dict[str, Any]]]:
        """
        Process all codes and find semantic similarities for each
        
        Returns:
            Dictionary mapping code_name to top semantic codes
        """
        results = {}
        
        for code in self.unique_codes:
            # Retrieve semantic codes for this code
            semantic_codes = self.retrieve_semantic_codes(code)
            results[code] = semantic_codes
        
        return results
    
    def get_stored_results(self, code_name: str = None) -> Dict[str, List[Dict[str, Any]]]:
        """
        Get stored semantic retrieval results
        
        Args:
            code_name: Specific code name to retrieve, or None for all results
            
        Returns:
            Dictionary of stored results
        """
        if code_name:
            return {code_name: self.semantic_results.get(code_name, [])}
        else:
            return self.semantic_results.copy()
    
    def get_data_summary(self) -> Dict[str, Any]:
        """Get summary of loaded data"""
        if self.embeddings_df is None:
            return {}
        
        return {
            'total_records': len(self.embeddings_df),
            'unique_codes': len(self.unique_codes),
            'max_global_frequency': self.max_global_frequency,
            'embedding_dimension': self.code_embeddings.shape[1] if self.code_embeddings is not None else 0,
            'similarity_threshold': self.sim_threshold,
            'max_similarity_threshold': self.max_sim_threshold,
            'top_k': self.top_k,
            'similarity_weight': self.similarity_weight,
            'frequency_weight': self.frequency_weight,
            'sorting_criteria': 'COMBINED_SCORE',
            'stored_results_count': len(self.semantic_results)
        }


# Example usage and testing
def main():
    """Example usage of semantic retrieval"""
    # Example paths (adjust as needed)
    embeddings_path = "../temp_files/embeddings.parquet"
    mapping_dir = "../temp_files/topologically_sorted_graph/datapoint_code_mapping"
    
    # Initialize semantic retriever with custom top_k
    retriever = SemanticRetriever(
        embeddings_path=embeddings_path,
        mapping_dir=mapping_dir,
        max_sim_threshold=100,
        sim_threshold=0.6,
        top_k=5,  # Default to 5, can be 1-10
        similarity_weight=0.7,
        frequency_weight=0.3
    )
    
    # Example code
    test_code = "economic interdependence shapes international conflict dynamics"
    
    # Retrieve semantic codes
    results = retriever.retrieve_semantic_codes(test_code)
    
    # Get stored results
    stored = retriever.get_stored_results(test_code)
    print(f"\n💾 Stored results: {len(stored)} codes")
    
    # Print data summary
    summary = retriever.get_data_summary()
    print(f"\n📈 Data Summary: {summary}")


if __name__ == "__main__":
    main()
