#!/usr/bin/env python3
"""
Datapoint Retrieval Module

This module implements fast retrieval for single datapoints by:
1. Getting codes linked to a datapoint from mapping
2. Random sampling 20 codes (or all if fewer)
3. For each code, calling both semantic and graph retrieval
4. Using vectorization for speed optimization
5. Combining results and removing duplicates
"""

import re
import os
import sys
import numpy as np
import pandas as pd
from typing import List, Dict, Any, Set, Tuple, Optional
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

# Add parent directories to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..'))

from utils.refine_iteration.retrieval.semantic_retrieval import SemanticRetriever
from utils.refine_iteration.retrieval.graph_retrieval import GraphRetriever

class DatapointRetriever:
    """Fast datapoint retrieval system combining semantic and graph retrieval"""
    
    def __init__(self, embeddings_path: str, mapping_dir: str, cliques_dir: str, 
                 sample_size: int = 20, total_codes_per_original: int = 20,
                 max_workers: int = 4):
        """
        Initialize datapoint retriever
        
        Args:
            embeddings_path: Path to embeddings.parquet file
            mapping_dir: Path to datapoint_code_mapping directory
            cliques_dir: Path to cliques directory
            sample_size: Number of codes to sample from datapoint (default: 20)
            total_codes_per_original: Total codes to return per original code (default: 10)
            max_workers: Number of parallel workers for processing
        """
        self.embeddings_path = embeddings_path
        self.mapping_dir = mapping_dir
        self.cliques_dir = cliques_dir
        self.sample_size = sample_size
        self.total_codes_per_original = total_codes_per_original
        self.max_workers = max_workers
        
        # Initialize retrievers
        self.semantic_retriever = SemanticRetriever(embeddings_path, mapping_dir, top_k=10)
        self.graph_retriever = GraphRetriever(cliques_dir, mapping_dir)
        
        # Load datapoint-code mapping
        self.datapoint_code_mapping = self._load_datapoint_mapping()
        
        # OPTIMIZATION: Add caching for frequently accessed results
        self.semantic_cache = {}  # code -> results
        self.graph_cache = {}     # code -> results
        self.cache_hits = 0
        self.cache_misses = 0
        

    def _extract_datapoint_id_from_chunk_text(self, chunk_text):
        """Extract datapoint ID from chunk text - handles various formats"""
        if not chunk_text:
            return None
        
        # Try to extract the first number from the chunk text
        # This handles cases like '25 === Test write to log file ==='
        match = re.search(r'^(\d+)', chunk_text.strip())
        if match:
            return match.group(1)
        
        # If no number at the start, try to find any number in the text
        match = re.search(r'(\d+)', chunk_text)
        if match:
            return match.group(1)
        
        # If still no number, return the first word
        first_word = chunk_text.split()[0] if chunk_text.split() else None
        return first_word

                # Set random seed for reproducibility
        random.seed(42)
    
    def _load_datapoint_mapping(self) -> Dict[str, List[str]]:
        """Load datapoint to codes mapping"""
        mapping_path = os.path.join(self.mapping_dir, "datapoint_to_codes.parquet")
        if not os.path.exists(mapping_path):
            raise FileNotFoundError(f"Datapoint mapping file not found: {mapping_path}")
        
        df = pd.read_parquet(mapping_path)
        mapping = {}
        for _, row in df.iterrows():
            datapoint_id = row["datapoint"]
            code = row["code"]
            if datapoint_id not in mapping:
                mapping[datapoint_id] = []
            mapping[datapoint_id].append(code)
        
        return mapping
        
        return mapping
    
    def get_codes_for_datapoint(self, datapoint_id) -> List[str]:
        """Get all codes linked to a datapoint - handles both int and string datapoint IDs"""
        # If datapoint_id looks like chunk text, extract the actual datapoint ID
        if isinstance(datapoint_id, str) and len(datapoint_id) > 10:
            # This might be chunk text, try to extract the datapoint ID
            extracted_id = self._extract_datapoint_id_from_chunk_text(datapoint_id)
            if extracted_id and extracted_id != datapoint_id:
                # Try with the extracted ID
                result = self.get_codes_for_datapoint(extracted_id)
                if result:
                    return result
        
        # Try direct lookup first
        if datapoint_id in self.datapoint_code_mapping:
            return self.datapoint_code_mapping[datapoint_id]
        
        # If not found, try converting to string if it's an integer
        try:
            if isinstance(datapoint_id, int):
                str_id = str(datapoint_id)
                if str_id in self.datapoint_code_mapping:
                    return self.datapoint_code_mapping[str_id]
        except (ValueError, TypeError):
            pass
        
        # If not found, try converting to int if it's a string
        try:
            if isinstance(datapoint_id, str) and datapoint_id.isdigit():
                int_id = int(datapoint_id)
                if int_id in self.datapoint_code_mapping:
                    return self.datapoint_code_mapping[int_id]
        except (ValueError, TypeError):
            pass
        
        # If still not found, try all possible type conversions
        for key in self.datapoint_code_mapping.keys():
            if str(key) == str(datapoint_id):
                return self.datapoint_code_mapping[key]
        
        # If still not found, try extracting datapoint ID from mapping keys
        for key in self.datapoint_code_mapping.keys():
            if isinstance(key, str):
                extracted_key_id = self._extract_datapoint_id_from_chunk_text(key)
                if extracted_key_id == str(datapoint_id):
                    return self.datapoint_code_mapping[key]
        
        return []
    
    def sample_codes(self, codes: List[str]) -> List[str]:
        """Random sample codes, or return all if fewer than sample_size"""
        if len(codes) <= self.sample_size:
            return codes
        return random.sample(codes, self.sample_size)
    
    def retrieve_codes_for_single_code(self, code: str) -> List[Dict[str, Any]]:
        """
        Retrieve codes for a single original code using both semantic and graph retrieval
        
        Args:
            code: The original code to find related codes for
            
        Returns:
            List of related codes (up to total_codes_per_original)
        """
        all_related_codes = []
        
        try:
            # Step 1: Get graph retrieval results (top 10)
            # Step 1: Get graph retrieval results (with caching)
            if code in self.graph_cache:
                graph_results = self.graph_cache[code]
                self.cache_hits += 1
            else:
                graph_results = self.graph_retriever.retrieve_graph_codes(code)
                self.graph_cache[code] = graph_results
                self.cache_misses += 1
            graph_codes = [{"code": result["code"], "score": result["global_frequency"], "source": "graph", "relationship": result["relationship"]} for result in graph_results]
            all_related_codes.extend(graph_codes)
            
            # Step 2: Calculate remaining slots for semantic retrieval (10 total)
            remaining_slots = self.total_codes_per_original - len(graph_codes)
            
            if remaining_slots > 0:
                # Step 3: Get semantic retrieval results (fill remaining slots)
                # Step 3: Get semantic retrieval results (with caching)
                if code in self.semantic_cache:
                    semantic_results = self.semantic_cache[code]
                    self.cache_hits += 1
                else:
                    semantic_results = self.semantic_retriever.retrieve_semantic_codes(code)
                    self.semantic_cache[code] = semantic_results
                    self.cache_misses += 1
                semantic_codes = [{"code": result["code"], "score": result["combined_score"], "source": "semantic", "similarity": result["similarity"]} for result in semantic_results[:remaining_slots]]
                all_related_codes.extend(semantic_codes)
            
            # Remove the original code itself and return results
            # Filter out the original code itself
            all_related_codes = [item for item in all_related_codes if item["code"] != code]
            
            return all_related_codes
            
        except Exception as e:
            # If retrieval fails for this code, continue with others
            pass
        
        return all_related_codes
    
    def retrieve_codes_parallel(self, original_codes: List[str]) -> Set[str]:
        """
        Retrieve codes for multiple original codes in parallel
        
        Args:
            original_codes: List of original codes to find related codes for
            
        Returns:
            Set of all related codes (combined from all original codes)
        """
        all_candidate_codes = []
        
        # Use ThreadPoolExecutor for parallel processing
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all tasks
            future_to_code = {
                executor.submit(self.retrieve_codes_for_single_code, code): code 
                for code in original_codes
            }
            
            # Collect results as they complete
            for future in as_completed(future_to_code):
                code = future_to_code[future]
                try:
                    related_codes = future.result()
                    all_candidate_codes.extend(related_codes)
                except Exception as e:
                    # If one code fails, continue with others
                    pass
        
        # Deduplicate, sort by score, and filter top 200
        return self._deduplicate_and_filter(all_candidate_codes)
    
    def retrieve_codes_vectorized(self, original_codes: List[str]) -> Set[str]:
        """
        Vectorized retrieval for multiple codes (optimized version)
        
        Args:
            original_codes: List of original codes to find related codes for
            
        Returns:
            Set of all related codes (combined from all original codes)
        """
        all_candidate_codes = []
        
        # Batch process codes for better performance
        batch_size = min(self.max_workers, len(original_codes))
        
        for i in range(0, len(original_codes), batch_size):
            batch = original_codes[i:i + batch_size]
            
            # Process batch in parallel
            batch_results = self.retrieve_codes_parallel(batch)
            all_candidate_codes.extend(batch_results)
        
        # Deduplicate, sort by score, and filter top 200
        return self._deduplicate_and_filter(all_candidate_codes)
    
    def retrieve_for_datapoint(self, datapoint_id: str, use_parallel: bool = True) -> Dict[str, Any]:
        """
        Main method to retrieve candidate codes for a datapoint
        
        Args:
            datapoint_id: The datapoint ID to retrieve codes for
            use_parallel: Whether to use parallel processing (default: True)
            
        Returns:
            Dictionary with retrieval results and metadata
        """
        start_time = time.time()
        
        # Step 1: Get codes linked to this datapoint
        original_codes = self.get_codes_for_datapoint(datapoint_id)
        
        if not original_codes:
            return {
                "datapoint_id": datapoint_id,
                'original_codes': [],
                'sampled_codes': [],
                'candidate_codes': [],
                'total_candidates': 0,
                'processing_time': time.time() - start_time,
                'method': 'none'
            }
        
        # Step 2: Random sample codes
        sampled_codes = self.sample_codes(original_codes)
        
        # Step 3: Retrieve related codes
        if use_parallel:
            candidate_codes = self.retrieve_codes_vectorized(sampled_codes)
            candidate_codes_list = [item["code"] for item in candidate_codes]
            method = 'parallel'
        else:
            # Sequential processing
            candidate_codes = []
            for code in sampled_codes:
                related_codes = self.retrieve_codes_for_single_code(code)
                candidate_codes.extend(related_codes)
            method = 'sequential'
        
        # Deduplicate, sort, and filter for sequential processing
        if not use_parallel:
            candidate_codes = self._deduplicate_and_filter(candidate_codes)
        
        candidate_codes_list = [item["code"] for item in candidate_codes]
        processing_time = time.time() - start_time
        
        return {
            "datapoint_id": datapoint_id,
            "original_codes": original_codes,
            "sampled_codes": sampled_codes,
            "candidate_codes": candidate_codes_list,
            "total_candidates": len(candidate_codes_list),
            "processing_time": processing_time,
            "method": method,
            "sampling_ratio": len(sampled_codes) / len(original_codes) if original_codes else 0
        }
    
    def batch_retrieve_for_datapoints(self, datapoint_ids: List[str]) -> List[Dict[str, Any]]:
        """
        Batch retrieve codes for multiple datapoints
        
        Args:
            datapoint_ids: List of datapoint IDs to process
            
        Returns:
            List of retrieval results for each datapoint
        """
        results = []
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all datapoint retrieval tasks
            future_to_id = {
                executor.submit(self.retrieve_for_datapoint, datapoint_id): datapoint_id 
                for datapoint_id in datapoint_ids
            }
            
            # Collect results as they complete
            for future in as_completed(future_to_id):
                datapoint_id = future_to_id[future]
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    # If one datapoint fails, continue with others
                    results.append({
                        "datapoint_id": datapoint_id,
                        'error': str(e),
                        'total_candidates': 0
                    })
        
        return results
    
    def _deduplicate_and_filter(self, candidate_codes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Deduplicate codes, sort by score, and filter top 200
        
        Args:
            candidate_codes: List of candidate codes with scores
            
        Returns:
            List of deduplicated, sorted, and filtered codes
        """
        if not candidate_codes:
            return []
        
        # Deduplicate by code, keeping the one with highest score
        unique_codes = {}
        for item in candidate_codes:
            if isinstance(item, dict) and "code" in item:
                code = item["code"]
                if code not in unique_codes or item.get("score", 0) > unique_codes[code].get("score", 0):
                    unique_codes[code] = item
        
        # Convert back to list and sort by score (descending)
        deduplicated = list(unique_codes.values())
        deduplicated.sort(key=lambda x: x.get("score", 0), reverse=True)
        
        # Filter top 200 if more than 200 (should be 20 codes × 10 candidates = 200 max)
        if len(deduplicated) > 200:
            return deduplicated[:200]
        
        return deduplicated
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get statistics about the retrieval system"""
        return {
            'total_datapoints': len(self.datapoint_code_mapping),
            'sample_size': self.sample_size,
            'total_codes_per_original': self.total_codes_per_original,
            'max_workers': self.max_workers,
            'semantic_retriever_loaded': self.semantic_retriever is not None,
            'graph_retriever_loaded': self.graph_retriever is not None
        }
    def _extract_numeric_id_from_chunk_text(self, chunk_text: str) -> Optional[str]:
        """
        Extracts a numeric ID from the beginning of a chunk text string.
        Assumes the ID is the first sequence of digits before any non-digit characters.
        """
        import re
        match = re.match(r'^(\d+)', chunk_text.strip())
        if match:
            return match.group(1)
        return None



# Example usage and testing
if __name__ == "__main__":
    # Test the datapoint retriever
    embeddings_path = "../temp_files/embeddings.parquet"
    mapping_dir = "../temp_files/topologically_sorted_graph/datapoint_code_mapping"
    cliques_dir = "../temp_files/topologically_sorted_graph/cliques"
    
    if all(os.path.exists(path) for path in [embeddings_path, mapping_dir, cliques_dir]):
        retriever = DatapointRetriever(
            embeddings_path=embeddings_path,
            mapping_dir=mapping_dir,
            cliques_dir=cliques_dir,
            sample_size=20,
            total_codes_per_original=10,
            max_workers=4
        )
        
        print("📊 Datapoint Retriever Statistics:")
        stats = retriever.get_statistics()
        for key, value in stats.items():
            print(f"   {key}: {value}")
        
        # Test with a sample datapoint
        datapoint_ids = list(retriever.datapoint_code_mapping.keys())
        if datapoint_ids:
            test_datapoint = datapoint_ids[0]
            print(f"\n🧪 Testing with datapoint: {test_datapoint}")
            
            result = retriever.retrieve_for_datapoint(test_datapoint)
            
            print(f"\n📈 Results:")
            print(f"   Original codes: {len(result['original_codes'])}")
            print(f"   Sampled codes: {len(result['sampled_codes'])}")
            print(f"   Candidate codes: {result['total_candidates']}")
            print(f"   Processing time: {result['processing_time']:.3f}s")
            print(f"   Method: {result['method']}")
        else:
            print("❌ No datapoints found in mapping")
    else:
        print("❌ Required files not found")

    def get_cache_statistics(self) -> Dict[str, Any]:
        """Get cache performance statistics"""
        total_requests = self.cache_hits + self.cache_misses
        hit_rate = (self.cache_hits / total_requests * 100) if total_requests > 0 else 0
        
        return {
            "cache_hits": self.cache_hits,
            "cache_misses": self.cache_misses,
            "hit_rate_percent": hit_rate,
            "semantic_cache_size": len(self.semantic_cache),
            "graph_cache_size": len(self.graph_cache)
        }
