"""
Cross-encoder reranker implementation.
"""

import os
import numpy as np
from typing import List, Dict, Any, Optional
import logging
from tqdm import tqdm

from .base_retriever import BaseRetriever, RetrievalResult

logger = logging.getLogger(__name__)


class CrossEncoderRetriever(BaseRetriever):
    """Cross-encoder reranker for two-stage retrieval."""
    
    def __init__(self, config: Dict[str, Any], model_name: str = None, device: str = "cuda"):
        super().__init__(config, device)
        
        # Use provided model or default from config
        if model_name is None:
            models = config.get('retrieval', {}).get('cross_encoder', {}).get('models', [])
            model_name = models[0] if models else "BAAI/bge-reranker-base"
        
        self.model_name = model_name
        self.rerank_top_k = config.get('retrieval', {}).get('cross_encoder', {}).get('rerank_top_k', 100)
        self.batch_size = config.get('retrieval', {}).get('cross_encoder', {}).get('batch_size', 16)
        
        # Base retriever for initial retrieval
        self.base_retriever = None
        self.model = self._load_model()
        
        # Document storage
        self.documents = {}
    
    def _load_model(self):
        """Load cross-encoder model."""
        try:
            from sentence_transformers import CrossEncoder
            logger.info(f"Loading cross-encoder model: {self.model_name}")
            return CrossEncoder(self.model_name, device=self.device)
        except ImportError:
            logger.warning("sentence-transformers not available for CrossEncoder")
            return None
        except Exception as e:
            logger.warning(f"Could not load cross-encoder model: {e}")
            return None
    
    def set_base_retriever(self, base_retriever: BaseRetriever):
        """Set the base retriever for initial retrieval."""
        self.base_retriever = base_retriever
        logger.info(f"Base retriever set: {type(base_retriever).__name__}")
    
    def index_documents(self, documents: List[Any], index_name: str) -> None:
        """Store documents for reranking."""
        logger.info(f"Storing {len(documents)} documents for cross-encoder reranking")
        
        # Store documents in memory
        self.documents = {doc.doc_id: doc for doc in documents}
        
        # If we have a base retriever, index documents there too
        if self.base_retriever:
            self.base_retriever.index_documents(documents, index_name)
    
    def search(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
        """Two-stage retrieval: base retrieval + cross-encoder reranking."""
        
        # Stage 1: Get initial candidates
        if self.base_retriever:
            initial_results = self.base_retriever.search(query, top_k=self.rerank_top_k)
        else:
            # If no base retriever, use all documents (not recommended for large corpus)
            initial_results = [
                RetrievalResult(doc_id=doc_id, score=0, rank=i+1, content=doc.content)
                for i, (doc_id, doc) in enumerate(list(self.documents.items())[:self.rerank_top_k])
            ]
        
        if not initial_results:
            return []
        
        # Stage 2: Rerank with cross-encoder
        if self.model is None:
            logger.warning("Cross-encoder model not available, returning initial results")
            return initial_results[:top_k]
        
        # Prepare pairs for cross-encoder
        pairs = []
        valid_results = []
        
        for result in initial_results:
            doc = self.documents.get(result.doc_id)
            if doc:
                pairs.append([query, doc.content])
                valid_results.append(result)
        
        if not pairs:
            return initial_results[:top_k]
        
        # Score with cross-encoder in batches
        scores = []
        for i in range(0, len(pairs), self.batch_size):
            batch = pairs[i:i + self.batch_size]
            batch_scores = self.model.predict(batch, show_progress_bar=False)
            scores.extend(batch_scores)
        
        # Rerank based on cross-encoder scores
        reranked_indices = np.argsort(scores)[::-1]
        
        results = []
        for i, idx in enumerate(reranked_indices[:top_k]):
            original_result = valid_results[idx]
            result = RetrievalResult(
                doc_id=original_result.doc_id,
                score=float(scores[idx]),
                rank=i + 1,
                content=self.documents[original_result.doc_id].content
            )
            results.append(result)
        
        return results
    
    def batch_search(self, queries: List[tuple], top_k: int = 10) -> Dict[str, List[RetrievalResult]]:
        """Batch search with reranking."""
        results = {}
        
        for query_id, query_text in tqdm(queries, desc="Cross-encoder reranking"):
            results[query_id] = self.search(query_text, top_k)
        
        return results
    
    def build_index(self, documents: List[str], metadata: Optional[List[Dict]] = None):
        """Build index from documents (compatibility method).
        
        Args:
            documents: List of document texts
            metadata: Optional metadata for each document
        """
        # Convert to expected format
        class SimpleDoc:
            def __init__(self, doc_id, content):
                self.doc_id = doc_id
                self.content = content
        
        doc_objects = []
        for i, doc_text in enumerate(documents):
            doc_objects.append(SimpleDoc(str(i), doc_text))
        
        self.index_documents(doc_objects, "default")
        
        # Also need to setup base retriever
        if not self.base_retriever:
            # Use BM25 as default base retriever
            from src.retrieval.bm25_retriever import BM25Retriever
            self.base_retriever = BM25Retriever({})
            self.base_retriever.build_index(documents, metadata)
    
    def retrieve(self, query: str, k: int = 10) -> List[tuple]:
        """Retrieve top-k documents (backward compatibility)."""
        results = self.search(query, top_k=k)
        return [(int(r.doc_id), r.score) for r in results]
    
    def batch_retrieve(self, queries: List[str], k: int = 10) -> Dict[int, List[tuple]]:
        """Retrieve documents for multiple queries."""
        results = {}
        for i, query in enumerate(queries):
            results[i] = self.retrieve(query, k)
        return results
    
    def evaluate(self, retrieved: Dict[int, List[tuple]],
                 ground_truth: Dict[int, List[int]],
                 k_values: List[int]) -> Dict[str, Any]:
        """Evaluate retrieval performance."""
        metrics = {}
        
        # Calculate Recall@K
        recall_at_k = {}
        for k in k_values:
            recalls = []
            for query_id, retrieved_docs in retrieved.items():
                if query_id not in ground_truth:
                    continue
                
                relevant_docs = set(ground_truth[query_id])
                retrieved_at_k = set([doc_id for doc_id, _ in retrieved_docs[:k]])
                
                if relevant_docs:
                    recall = len(retrieved_at_k & relevant_docs) / len(relevant_docs)
                    recalls.append(recall)
            
            recall_at_k[k] = np.mean(recalls) if recalls else 0.0
        
        metrics['recall'] = recall_at_k
        
        # Calculate MRR
        mrrs = []
        for query_id, retrieved_docs in retrieved.items():
            if query_id not in ground_truth:
                continue
            
            relevant_docs = set(ground_truth[query_id])
            
            for i, (doc_id, _) in enumerate(retrieved_docs):
                if doc_id in relevant_docs:
                    mrrs.append(1.0 / (i + 1))
                    break
            else:
                mrrs.append(0.0)
        
        metrics['mrr'] = np.mean(mrrs) if mrrs else 0.0
        
        return metrics


class MonoT5Reranker(CrossEncoderRetriever):
    """MonoT5 reranker implementation."""
    
    def __init__(self, config: Dict[str, Any], device: str = "cuda"):
        super().__init__(config, "castorini/monot5-base-msmarco", device)
    
    def _load_model(self):
        """Load MonoT5 model."""
        try:
            from transformers import T5ForConditionalGeneration, T5Tokenizer
            import torch
            
            logger.info("Loading MonoT5 reranker model")
            self.tokenizer = T5Tokenizer.from_pretrained("castorini/monot5-base-msmarco")
            model = T5ForConditionalGeneration.from_pretrained("castorini/monot5-base-msmarco")
            
            if self.device == "cuda" and torch.cuda.is_available():
                model = model.cuda()
            
            model.eval()
            return model
            
        except Exception as e:
            logger.warning(f"Could not load MonoT5 model: {e}")
            return super()._load_model()
    
    def _score_pair(self, query: str, document: str) -> float:
        """Score a query-document pair using MonoT5."""
        if self.model is None or not hasattr(self, 'tokenizer'):
            return 0.0
        
        import torch
        
        # Format input for MonoT5
        input_text = f"Query: {query} Document: {document} Relevant:"
        
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True,
            padding=True
        )
        
        if self.device == "cuda":
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model.generate(**inputs, max_length=2, return_dict_in_generate=True, output_scores=True)
            
            # Get the scores for "true" and "false" tokens
            scores = outputs.scores[0][0]  # First token scores
            
            # Token IDs for "true" and "false" in T5
            true_id = self.tokenizer.encode("true")[0]
            false_id = self.tokenizer.encode("false")[0]
            
            # Get probabilities
            true_score = scores[true_id].item()
            false_score = scores[false_id].item()
            
            # Return relevance score
            relevance_score = true_score - false_score
        
        return relevance_score


class DuoBERTReranker(CrossEncoderRetriever):
    """DuoBERT reranker for pairwise document ranking."""
    
    def __init__(self, config: Dict[str, Any], device: str = "cuda"):
        # DuoBERT typically uses a BERT-based model fine-tuned for pairwise ranking
        super().__init__(config, "bert-base-uncased", device)
        self.pairwise_comparison = True
    
    def search(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
        """Search with pairwise DuoBERT reranking."""
        
        # Get initial candidates
        if self.base_retriever:
            initial_results = self.base_retriever.search(query, top_k=self.rerank_top_k)
        else:
            initial_results = [
                RetrievalResult(doc_id=doc_id, score=0, rank=i+1)
                for i, doc_id in enumerate(list(self.documents.keys())[:self.rerank_top_k])
            ]
        
        if len(initial_results) <= 1:
            return initial_results
        
        # Pairwise comparison scores
        n = len(initial_results)
        pairwise_scores = [[0] * n for _ in range(n)]
        
        # Compare all pairs
        for i in range(n):
            for j in range(i + 1, n):
                doc_i = self.documents.get(initial_results[i].doc_id)
                doc_j = self.documents.get(initial_results[j].doc_id)
                
                if doc_i and doc_j:
                    # Score pair (i preferred over j)
                    score = self._compare_pair(query, doc_i.content, doc_j.content)
                    pairwise_scores[i][j] = score
                    pairwise_scores[j][i] = -score
        
        # Aggregate scores (Bradley-Terry model or simple sum)
        final_scores = [sum(row) for row in pairwise_scores]
        
        # Rerank based on aggregated scores
        reranked_indices = np.argsort(final_scores)[::-1]
        
        results = []
        for i, idx in enumerate(reranked_indices[:top_k]):
            original_result = initial_results[idx]
            result = RetrievalResult(
                doc_id=original_result.doc_id,
                score=float(final_scores[idx]),
                rank=i + 1,
                content=self.documents[original_result.doc_id].content
            )
            results.append(result)
        
        return results
    
    def _compare_pair(self, query: str, doc1: str, doc2: str) -> float:
        """Compare two documents for a query."""
        if self.model is None:
            return 0.0
        
        # Simple comparison: score difference
        # In practice, DuoBERT uses a specialized pairwise model
        score1 = self.model.predict([[query, doc1]], show_progress_bar=False)[0]
        score2 = self.model.predict([[query, doc2]], show_progress_bar=False)[0]
        
        return score1 - score2