#!/usr/bin/env python3
"""
Optimized Cosine Similarity Module

This module optimizes cosine similarity computation by:
1. Processing data directly without file I/O
2. Implementing top-k filtering per cluster
3. Supporting ratio-based filtering for smaller clusters
4. Providing direct data input/output
"""

import os
import numpy as np
import pandas as pd
from typing import List, Dict, Any
from tqdm import tqdm

class CosineSimilarity:
    """Optimized cosine similarity computation with top-k filtering"""
    
    def __init__(self, top_k: int = 100, similarity_threshold: float = 0.50, 
                 max_similarity_threshold: float = 0.90, use_ratio_filtering: bool = True):
        self.top_k = top_k
        self.similarity_threshold = similarity_threshold  # Lowered to 50% for more lenient filtering
        self.max_similarity_threshold = max_similarity_threshold  # Raised to 90% for more lenient filtering
        self.use_ratio_filtering = use_ratio_filtering
    
    def normalize_code(self, code: str) -> str:
        """Normalize a code string to catch obvious duplicates."""
        import re
        
        # Convert to lowercase and strip
        code = code.lower().strip()
        
        # Remove extra whitespace and punctuation
        code = re.sub(r'[^\w\s]', '', code)
        code = re.sub(r'\s+', ' ', code).strip()
        
        return code
    
    def are_duplicate_codes(self, code_a: str, code_b: str) -> bool:
        """Check if two codes are essentially duplicates."""
        # Skip if identical
        if code_a == code_b:
            return True
        
        # Normalize both codes
        norm_a = self.normalize_code(code_a)
        norm_b = self.normalize_code(code_b)
        
        # Check if normalized versions are identical
        if norm_a == norm_b:
            return True
        
        # Check for plural/singular variations
        if norm_a.endswith('s') and norm_a[:-1] == norm_b:
            return True
        if norm_b.endswith('s') and norm_b[:-1] == norm_a:
            return True
        
        # Check for significant word overlap (aggressive duplicate detection)
        words_a = set(norm_a.split())
        words_b = set(norm_b.split())
        
        if len(words_a) > 0 and len(words_b) > 0:
            intersection = words_a.intersection(words_b)
            union = words_a.union(words_b)
            if len(intersection) / len(union) >= 0.85:  # 85% word overlap = duplicate
                return True
        
        return False
    
    def calculate_cluster_top_k(self, cluster_size: int) -> int:
        """Calculate top-k for a cluster based on min(pairs_in_range, 0.3 * cluster_size)"""
        # Calculate 30% of cluster size
        thirty_percent = cluster_size * 0.3
        
        # For now, we'll use 30% as the default since we don't know pairs_in_range yet
        # The actual filtering will happen in compute_cluster_similarity where we have the similarity values
        top_k = max(1, int(thirty_percent))
        
        return top_k
    
    def compute_cluster_similarity(self, cluster_embeddings: np.ndarray, 
                                 cluster_codes: List[str], cluster_id: int) -> List[Dict]:
        """Compute cosine similarity for a single cluster with top-k filtering"""
        cluster_size = len(cluster_codes)
        
        if cluster_size < 2:
            return []
        
        # Pre-process codes to lowercase for better matching
        cluster_codes_lower = [code.lower().strip() for code in cluster_codes]
        
        # Compute cosine similarity matrix (normalized for efficiency)
        cluster_embeddings_norm = cluster_embeddings / np.linalg.norm(cluster_embeddings, axis=1, keepdims=True)
        similarity_matrix = np.dot(cluster_embeddings_norm, cluster_embeddings_norm.T)
        
        # Collect all pairs (excluding self-similarity) with similarity in range [0.5, 0.95]
        pairs_in_range = []
        seen_pairs = set()  # Track unique pairs to avoid duplicates
        
        for i in range(cluster_size):
            for j in range(i + 1, cluster_size):
                similarity = similarity_matrix[i, j]
                
                # Only keep pairs with similarity in range [0.5, 0.95]
                if similarity >= self.similarity_threshold and similarity <= self.max_similarity_threshold:
                    # Create a unique pair identifier (sorted to avoid duplicates)
                    code_a_lower = cluster_codes_lower[i]
                    code_b_lower = cluster_codes_lower[j]
                    pair_key = tuple(sorted([code_a_lower, code_b_lower]))
                    
                    # Skip if we've already seen this pair
                    if pair_key in seen_pairs:
                        continue
                    
                    seen_pairs.add(pair_key)
                    
                    pairs_in_range.append({
                        'cluster_id': cluster_id,
                        'code_a': code_a_lower,  # Use lowercase for consistency
                        'code_b': code_b_lower,  # Use lowercase for consistency
                        'code_a_original': cluster_codes[i],  # Store original case for reference
                        'code_b_original': cluster_codes[j],  # Store original case for reference
                        'similarity': float(similarity),
                        'index_a': i,
                        'index_b': j
                    })
        
        # Calculate the limit: min(pairs_in_range, 0.3 * cluster_size)
        thirty_percent_limit = max(1, int(0.7 * cluster_size))
        # actual_limit = min(len(pairs_in_range), thirty_percent_limit)
        actual_limit = int(0.1 * len(pairs_in_range))
        
        # Sort by similarity and keep the limited number
        pairs_in_range.sort(key=lambda x: x['similarity'], reverse=True)
        top_pairs = pairs_in_range[:actual_limit]
        
        return top_pairs
    
    def compute_similarities_optimized(self, cluster_results: Dict[str, Any], 
                                    codes: List[str], embeddings: np.ndarray, corpus_df: pd.DataFrame = None) -> List[Dict]:
        """Compute cosine similarities for all clusters with optimization and datapoint tracking"""
        print("🔍 Computing optimized cosine similarities...")
        
        labels = cluster_results['labels']
        unique_labels = np.unique(labels)
        
        all_similarities = []
        
        for cluster_id in tqdm(unique_labels, desc="Processing clusters"):
            # Get cluster indices
            cluster_indices = np.where(labels == cluster_id)[0]
            
            if len(cluster_indices) < 2:
                continue
            
            # Get cluster embeddings and codes
            cluster_embeddings = embeddings[cluster_indices]
            cluster_codes = [codes[i] for i in cluster_indices]
            
            # Get cluster datapoints if corpus is provided
            cluster_datapoints = {}
            if corpus_df is not None:
                for i, code in enumerate(cluster_codes):
                    code_data = corpus_df[corpus_df['tag'] == code]
                    if not code_data.empty:
                        cluster_datapoints[code] = code_data['chunk_text'].tolist()
            
            # Compute similarities for this cluster
            cluster_similarities = self.compute_cluster_similarity(
                cluster_embeddings, cluster_codes, int(cluster_id)
            )
            
            # Add datapoint information to similarities
            for similarity in cluster_similarities:
                code_a = similarity['code_a']
                code_b = similarity['code_b']
                
                # Add datapoint information
                if corpus_df is not None:
                    similarity['datapoints_a'] = cluster_datapoints.get(code_a, [])
                    similarity['datapoints_b'] = cluster_datapoints.get(code_b, [])
                else:
                    similarity['datapoints_a'] = []
                    similarity['datapoints_b'] = []
            
            all_similarities.extend(cluster_similarities)
        
        print(f"✅ Computed {len(all_similarities)} similarity pairs across {len(unique_labels)} clusters")
        
        # Print statistics
        if all_similarities:
            similarities = [pair['similarity'] for pair in all_similarities]
            print(f"   Similarity range: {min(similarities):.3f} - {max(similarities):.3f}")
            print(f"   Average similarity: {np.mean(similarities):.3f}")
            
            # Check for potential duplicates (should be 0 now)
            unique_pairs = set()
            for pair in all_similarities:
                pair_key = tuple(sorted([pair['code_a'], pair['code_b']]))
                unique_pairs.add(pair_key)
            
            print(f"   Unique pairs: {len(unique_pairs)} (should equal total pairs: {len(all_similarities)})")
            if len(unique_pairs) != len(all_similarities):
                print(f"   ⚠️ Warning: Found {len(all_similarities) - len(unique_pairs)} duplicate pairs!")
                # Show some examples of duplicates
                seen_pairs = set()
                for pair in all_similarities:
                    pair_key = tuple(sorted([pair['code_a'], pair['code_b']]))
                    if pair_key in seen_pairs:
                        print(f"      Duplicate example: {pair['code_a']} <-> {pair['code_b']}")
                    seen_pairs.add(pair_key)
            else:
                print(f"   ✅ No duplicate pairs detected")
            
            # Show some examples of the pairs to verify case handling
            print(f"   📝 Sample pairs (first 3):")
            for i, pair in enumerate(all_similarities[:3]):
                print(f"      {i+1}. {pair['code_a']} <-> {pair['code_b']} (sim: {pair['similarity']:.3f})")
        
        return all_similarities
    
    def validate_similarity_pairs(self, similarity_pairs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Validate similarity pairs for duplicates and data quality"""
        if not similarity_pairs:
            return {'total_pairs': 0, 'unique_pairs': 0, 'duplicates': 0, 'issues': []}
        
        issues = []
        unique_pairs = set()
        duplicate_count = 0
        
        for i, pair in enumerate(similarity_pairs):
            # Check required fields
            required_fields = ['code_a', 'code_b', 'similarity', 'cluster_id']
            missing_fields = [field for field in required_fields if field not in pair]
            if missing_fields:
                issues.append(f"Pair {i}: Missing fields {missing_fields}")
            
            # Check for duplicates using lowercase versions
            if 'code_a' in pair and 'code_b' in pair:
                pair_key = tuple(sorted([pair['code_a'], pair['code_b']]))
                if pair_key in unique_pairs:
                    duplicate_count += 1
                    issues.append(f"Pair {i}: Duplicate pair {pair['code_a']} <-> {pair['code_b']}")
                else:
                    unique_pairs.add(pair_key)
            
            # Check similarity range
            similarity = pair.get('similarity', 0)
            if not (0 <= similarity <= 1):
                issues.append(f"Pair {i}: Invalid similarity {similarity}")
        
        return {
            'total_pairs': len(similarity_pairs),
            'unique_pairs': len(unique_pairs),
            'duplicates': duplicate_count,
            'issues': issues,
            'has_issues': len(issues) > 0
        }
    
    def save_similarities(self, similarities: List[Dict], output_path: str):
        """Save similarity results to file"""
        if not similarities:
            print("⚠️ No similarities to save")
            return
        
        # Create output directory
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Convert to DataFrame and save
        df = pd.DataFrame(similarities)
        df.to_parquet(output_path)
        print(f"💾 Saved {len(similarities)} similarity pairs to {output_path}")

def compute_similarities_optimized(cluster_results: Dict[str, Any], 
                                codes: List[str], embeddings: np.ndarray,
                                top_k: int = 100, similarity_threshold: float = 0.50,
                                max_similarity_threshold: float = 0.90,
                                use_ratio_filtering: bool = True, corpus_df: pd.DataFrame = None) -> List[Dict]:
    """
    Optimized cosine similarity computation function with datapoint tracking
    
    Args:
        cluster_results: Results from clustering (must contain 'labels')
        codes: Text codes corresponding to embeddings
        embeddings: Embedding vectors
        top_k: Maximum number of pairs per cluster
        similarity_threshold: Minimum similarity threshold
        max_similarity_threshold: Maximum similarity threshold (filter near-duplicates)
        use_ratio_filtering: Whether to use ratio-based filtering for small clusters
        corpus_df: Corpus DataFrame with datapoint information (optional)
        
    Returns:
        List of similarity pairs with metadata and datapoints
    """
    computer = CosineSimilarity(
        top_k=top_k,
        similarity_threshold=similarity_threshold,
        max_similarity_threshold=max_similarity_threshold,
        use_ratio_filtering=use_ratio_filtering
    )
    
    return computer.compute_similarities_optimized(cluster_results, codes, embeddings, corpus_df)

# Backward compatibility function
def compute_cluster_similarities(max_workers: int = 4):
    """Backward compatibility function that loads from files and saves results"""
    import glob
    
    # Load embeddings
    embeddings_path = os.path.join(os.path.dirname(__file__), "..", "temp_files", "embeddings.parquet")
    embeddings_df = pd.read_parquet(embeddings_path)
    embeddings = np.array([emb for emb in embeddings_df['embedding']])
    codes = embeddings_df['text'].tolist()
    
    # Load clustering results (assuming they exist)
    clusters_dir = os.path.join(os.path.dirname(__file__), "..", "temp_files", "clusters")
    cluster_files = glob.glob(os.path.join(clusters_dir, "cluster_*.parquet"))
    
    if not cluster_files:
        print("❌ No cluster files found. Please run clustering first.")
        return
    
    # Reconstruct cluster labels from files
    labels = np.zeros(len(codes), dtype=int)
    for cluster_file in cluster_files:
        cluster_df = pd.read_parquet(cluster_file)
        cluster_id = int(os.path.basename(cluster_file).split('_')[1].split('.')[0])
        
        # Find indices of codes in this cluster
        for _, row in cluster_df.iterrows():
            try:
                idx = codes.index(row['text'])
                labels[idx] = cluster_id
            except ValueError:
                continue
    
    cluster_results = {'labels': labels}
    
    # Compute similarities
    similarities = compute_similarities_optimized(cluster_results, codes, embeddings)
    
    # Save results
    output_path = os.path.join(os.path.dirname(__file__), "..", "temp_files", "cluster_sim", "cluster_sim_optimized.parquet")
    computer = CosineSimilarity()
    computer.save_similarities(similarities, output_path)
    
    return similarities

if __name__ == "__main__":
    # Test the optimized cosine similarity
    compute_cluster_similarities() 