#!/usr/bin/env python3
"""
Context Retrievers Module

This module provides different strategies for retrieving context for question answering.
Each retriever implements a common interface for easy swapping and extension.
"""

import os
import json
import pandas as pd
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
import asyncio

class BaseContextRetriever(ABC):
    """Base class for context retrieval strategies"""
    
    @abstractmethod
    async def retrieve_context(self, question: str, **kwargs) -> Dict[str, Any]:
        """
        Retrieve context for a given question
        
        Args:
            question: The question to retrieve context for
            **kwargs: Additional parameters specific to the retriever
            
        Returns:
            Dictionary containing context information
        """
        pass
    
    @abstractmethod
    def get_context_summary(self, context: Dict[str, Any]) -> str:
        """
        Get a summary of the retrieved context for logging/debugging
        
        Args:
            context: The context dictionary returned by retrieve_context
            
        Returns:
            String summary of the context
        """
        pass


class DataRetrievalContextRetriever(BaseContextRetriever):
    """Context retriever using graph-based data retrieval"""
    
    def __init__(self, embeddings_path: str, topological_graph_dir: str, 
                 alpha: float = 0.85, max_iterations: int = 5, top_k: int = 10):
        """
        Initialize the data retrieval context retriever
        
        Args:
            embeddings_path: Path to embeddings parquet file
            topological_graph_dir: Directory containing topological graph files
            alpha: Damping factor for relevance propagation
            max_iterations: Maximum iterations for relevance propagation
            top_k: Number of top chunks to retrieve
        """
        self.embeddings_path = embeddings_path
        self.topological_graph_dir = topological_graph_dir
        self.alpha = alpha
        self.max_iterations = max_iterations
        self.top_k = top_k
        
        from .data_retrieval import GraphBasedDataRetriever
        self.retriever = GraphBasedDataRetriever(
            base_temp_dir=os.path.dirname(os.path.dirname(os.path.dirname(embeddings_path))),
            alpha=alpha,
            max_iterations=max_iterations
        )
    async def retrieve_context(self, question: str, **kwargs) -> Dict[str, Any]:
        """Retrieve context using graph-based data retrieval"""
        print(f"🔍 Retrieving context using data retrieval for: {question[:50]}...")
        
        # Override top_k if provided in kwargs
        top_k = kwargs.get('top_k', self.top_k)
        
        try:
            # Get relevant chunks
            top_chunks = await self.retriever.retrieve_relevant_chunks(question, top_k)
            
            # Extract content from chunks (preserve chunk_id)
            context_chunks = []
            for chunk_id, score in top_chunks:
                # Get the actual chunk content
                chunk_content = self.retriever._get_datachunk_content(chunk_id)
                context_chunks.append({
                    'content': chunk_content,
                    'score': score,
                    'chunk_id': chunk_id
                })
            
            return {
                'method': 'data_retrieval',
                'question': question,
                'chunks': context_chunks,
                'total_chunks': len(context_chunks),
                'retrieval_scores': [chunk['score'] for chunk in context_chunks]
            }
            
        except Exception as e:
            print(f"❌ Data retrieval failed: {e}")
            return {
                'method': 'data_retrieval',
                'question': question,
                'chunks': [],
                'total_chunks': 0,
                'error': str(e)
            }
    
    def get_context_summary(self, context: Dict[str, Any]) -> str:
        """Get summary of retrieved context"""
        if 'error' in context:
            return f"Data retrieval failed: {context['error']}"
        
        chunks = context.get('chunks', [])
        if not chunks:
            return "No chunks retrieved"
        
        avg_score = sum(chunk['score'] for chunk in chunks) / len(chunks)
        return f"Retrieved {len(chunks)} chunks (avg score: {avg_score:.3f})"


class TopologicalGraphContextRetriever(BaseContextRetriever):
    """Context retriever using topological graph summaries"""
    
    def __init__(self, graph_dir: str):
        """
        Initialize the topological graph context retriever
        
        Args:
            graph_dir: Directory containing topological graph files
        """
        self.graph_dir = graph_dir
    
    async def retrieve_context(self, question: str, **kwargs) -> Dict[str, Any]:
        """Retrieve context using topological graph summaries"""
        print(f"🔍 Retrieving context using topological graph for: {question[:50]}...")
        
        try:
            # Load hierarchy structure only (not the massive content)
            hierarchy_path = os.path.join(self.graph_dir, "hierarchy.parquet")
            hierarchy = {}
            if os.path.exists(hierarchy_path):
                hierarchy_df = pd.read_parquet(hierarchy_path)
                for _, row in hierarchy_df.iterrows():
                    level = row['level']
                    code = row['node']
                    if level not in hierarchy:
                        hierarchy[level] = []
                    hierarchy[level].append(code)
            
            # Load datapoint tracking
            datapoints = {}
            datapoint_path = os.path.join(self.graph_dir, "code_datapoints.parquet")
            if os.path.exists(datapoint_path):
                datapoint_df = pd.read_parquet(datapoint_path)
                for _, row in datapoint_df.iterrows():
                    code = row['code']
                    datapoint = row['datapoint']
                    if code not in datapoints:
                        datapoints[code] = []
                    datapoints[code].append(datapoint)
            
            # Create a concise summary of the graph structure
            graph_summary = self._create_graph_summary(hierarchy, datapoints)
            
            return {
                'method': 'topological_graph',
                'question': question,
                'hierarchy': hierarchy,
                'graph_summary': graph_summary,
                'datapoints': datapoints,
                'total_codes': len(datapoints)
            }
            
        except Exception as e:
            print(f"❌ Topological graph retrieval failed: {e}")
            return {
                'method': 'topological_graph',
                'question': question,
                'hierarchy': {},
                'graph_summary': "Failed to load graph structure",
                'datapoints': {},
                'error': str(e)
            }
    
    def _create_graph_summary(self, hierarchy: Dict[int, List[str]], datapoints: Dict[str, List[str]]) -> str:
        """Create a concise summary of the graph structure"""
        summary_parts = []
        
        # Add hierarchy levels
        for level in sorted(hierarchy.keys()):
            codes = hierarchy[level]
            summary_parts.append(f"Level {level}: {len(codes)} codes")
            # Add a few example codes from each level
            if codes:
                examples = codes[:3]  # Show first 3 codes
                summary_parts.append(f"  Examples: {', '.join(examples)}")
        
        # Add overall statistics
        total_codes = len(datapoints)
        total_datapoints = sum(len(dps) for dps in datapoints.values())
        summary_parts.append(f"\nTotal: {total_codes} codes, {total_datapoints} datapoints")
        
        return "\n".join(summary_parts)
    
    def get_context_summary(self, context: Dict[str, Any]) -> str:
        """Get summary of retrieved context"""
        if 'error' in context:
            return f"Topological graph retrieval failed: {context['error']}"
        
        hierarchy = context.get('hierarchy', {})
        datapoints = context.get('datapoints', {})
        total_codes = len(datapoints)
        total_datapoints = sum(len(dps) for dps in datapoints.values())
        
        return f"Graph with {len(hierarchy)} levels, {total_codes} codes, {total_datapoints} datapoints"


class HybridContextRetriever(BaseContextRetriever):
    """Context retriever that combines multiple strategies"""
    
    def __init__(self, retrievers: List[BaseContextRetriever], weights: Optional[List[float]] = None):
        """
        Initialize hybrid context retriever
        
        Args:
            retrievers: List of context retrievers to combine
            weights: Optional weights for each retriever (default: equal weights)
        """
        self.retrievers = retrievers
        self.weights = weights or [1.0 / len(retrievers)] * len(retrievers)
        
        if len(self.weights) != len(self.retrievers):
            raise ValueError("Number of weights must match number of retrievers")
    
    async def retrieve_context(self, question: str, **kwargs) -> Dict[str, Any]:
        """Retrieve context using multiple strategies"""
        print(f"🔍 Retrieving context using hybrid approach for: {question[:50]}...")
        
        results = []
        for i, retriever in enumerate(self.retrievers):
            try:
                result = await retriever.retrieve_context(question, **kwargs)
                result['weight'] = self.weights[i]
                results.append(result)
            except Exception as e:
                print(f"❌ Retriever {i} failed: {e}")
                results.append({
                    'method': f'retriever_{i}',
                    'question': question,
                    'error': str(e),
                    'weight': self.weights[i]
                })
        
        return {
            'method': 'hybrid',
            'question': question,
            'retriever_results': results,
            'total_retrievers': len(self.retrievers)
        }
    
    def get_context_summary(self, context: Dict[str, Any]) -> str:
        """Get summary of retrieved context"""
        results = context.get('retriever_results', [])
        summaries = []
        
        for result in results:
            if 'error' in result:
                summaries.append(f"{result['method']}: Failed")
            else:
                summaries.append(f"{result['method']}: Success")
        
        return f"Hybrid retrieval with {len(results)} retrievers - {'; '.join(summaries)}"


class NoContextRetriever(BaseContextRetriever):
    """Context retriever that returns no context (for standard mode)"""
    
    async def retrieve_context(self, question: str, **kwargs) -> Dict[str, Any]:
        """Return empty context"""
        return {
            'method': 'no_context',
            'question': question,
            'chunks': [],
            'total_chunks': 0
        }
    
    def get_context_summary(self, context: Dict[str, Any]) -> str:
        """Get summary of retrieved context"""
        return "No context retrieved (standard mode)" 