"""
Base retriever interface and common utilities.
"""

from abc import ABC, abstractmethod
from typing import List, Dict, Tuple, Any, Optional
import numpy as np
from dataclasses import dataclass
import logging
import time

logger = logging.getLogger(__name__)


@dataclass
class RetrievalResult:
    """Represents a retrieval result."""
    doc_id: str
    score: float
    rank: int
    content: Optional[str] = None
    metadata: Optional[Dict[str, Any]] = None


class BaseRetriever(ABC):
    """Abstract base class for all retrievers."""
    
    def __init__(self, config: Dict[str, Any], device: str = "cpu"):
        self.config = config
        self.device = device
        self.index = None
        self.performance_metrics = {}
        
    @abstractmethod
    def index_documents(self, documents: List[Any], index_name: str) -> None:
        """Index documents for retrieval."""
        pass
    
    @abstractmethod
    def search(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
        """Search for relevant documents."""
        pass
    
    def batch_search(self, queries: List[str], top_k: int = 10) -> Dict[str, List[RetrievalResult]]:
        """Batch search for multiple queries."""
        results = {}
        total_time = 0
        
        for query_id, query_text in queries:
            start_time = time.time()
            results[query_id] = self.search(query_text, top_k)
            total_time += time.time() - start_time
        
        # Log average latency
        avg_latency = total_time / len(queries) if queries else 0
        self.performance_metrics['avg_query_latency'] = avg_latency
        
        return results
    
    def evaluate(self, queries: Dict[str, str], qrels: Dict[str, Dict[str, int]], 
                 k_values: List[int]) -> Dict[str, float]:
        """Evaluate retrieval performance."""
        metrics = {}
        
        # Calculate Recall@K
        for k in k_values:
            recalls = []
            for q_id, query in queries.items():
                if q_id not in qrels:
                    continue
                    
                relevant_docs = set(qrels[q_id].keys())
                if not relevant_docs:
                    continue
                    
                results = self.search(query, top_k=k)
                retrieved_docs = set([r.doc_id for r in results])
                
                recall = len(retrieved_docs & relevant_docs) / len(relevant_docs)
                recalls.append(recall)
            
            metrics[f'recall@{k}'] = np.mean(recalls) if recalls else 0.0
        
        # Calculate nDCG@K
        ndcg_k_values = [k for k in [10, 20] if k in k_values]
        for k in ndcg_k_values:
            ndcgs = []
            for q_id, query in queries.items():
                if q_id not in qrels:
                    continue
                
                results = self.search(query, top_k=k)
                ndcg = self._calculate_ndcg(results, qrels[q_id], k)
                ndcgs.append(ndcg)
            
            metrics[f'ndcg@{k}'] = np.mean(ndcgs) if ndcgs else 0.0
        
        # Calculate MRR (Mean Reciprocal Rank)
        mrrs = []
        for q_id, query in queries.items():
            if q_id not in qrels:
                continue
                
            relevant_docs = set([doc for doc, rel in qrels[q_id].items() if rel > 0])
            if not relevant_docs:
                continue
                
            results = self.search(query, top_k=max(k_values))
            for i, result in enumerate(results):
                if result.doc_id in relevant_docs:
                    mrrs.append(1.0 / (i + 1))
                    break
            else:
                mrrs.append(0.0)
        
        metrics['mrr'] = np.mean(mrrs) if mrrs else 0.0
        
        # Add performance metrics
        metrics.update(self.performance_metrics)
        
        return metrics
    
    def _calculate_ndcg(self, results: List[RetrievalResult], 
                       relevance: Dict[str, int], k: int) -> float:
        """Calculate nDCG@k for a single query."""
        dcg = 0.0
        for i, result in enumerate(results[:k]):
            rel = relevance.get(result.doc_id, 0)
            dcg += (2 ** rel - 1) / np.log2(i + 2)
        
        # Calculate ideal DCG
        ideal_relevances = sorted(relevance.values(), reverse=True)[:k]
        idcg = sum((2 ** rel - 1) / np.log2(i + 2) 
                   for i, rel in enumerate(ideal_relevances))
        
        return dcg / idcg if idcg > 0 else 0.0
    
    def evaluate_with_stratification(self, queries_df, qrels: Dict[str, Dict[str, int]], 
                                    k_values: List[int], stratify_by: str = 'domain') -> Dict[str, Dict[str, float]]:
        """Evaluate retrieval performance with stratification."""
        stratified_metrics = {}
        
        # Get unique values for stratification
        if stratify_by in queries_df.columns:
            unique_values = queries_df[stratify_by].unique()
            
            for value in unique_values:
                # Filter queries by stratification value
                subset_df = queries_df[queries_df[stratify_by] == value]
                subset_queries = dict(zip(subset_df['question_id'], subset_df['question']))
                
                # Evaluate on subset
                metrics = self.evaluate(subset_queries, qrels, k_values)
                stratified_metrics[str(value)] = metrics
        
        # Also compute overall metrics
        all_queries = dict(zip(queries_df['question_id'], queries_df['question']))
        stratified_metrics['overall'] = self.evaluate(all_queries, qrels, k_values)
        
        return stratified_metrics


class MemoryTracker:
    """Utility class to track memory usage."""
    
    @staticmethod
    def get_memory_usage():
        """Get current memory usage in MB."""
        import psutil
        process = psutil.Process()
        return process.memory_info().rss / 1024 / 1024  # Convert to MB
    
    @staticmethod
    def track_memory(func):
        """Decorator to track memory usage of a function."""
        def wrapper(*args, **kwargs):
            import psutil
            process = psutil.Process()
            
            mem_before = process.memory_info().rss / 1024 / 1024
            result = func(*args, **kwargs)
            mem_after = process.memory_info().rss / 1024 / 1024
            
            logger.info(f"Memory usage for {func.__name__}: {mem_after - mem_before:.2f} MB")
            
            return result
        return wrapper