"""
Dense retriever implementation using sentence transformers.
"""

import os
import numpy as np
import pickle
from typing import List, Dict, Any, Optional
import logging
from tqdm import tqdm

from .base_retriever import BaseRetriever, RetrievalResult

logger = logging.getLogger(__name__)


class DenseRetriever(BaseRetriever):
    """Dense retriever using sentence transformers and FAISS/numpy."""
    
    def __init__(self, config: Dict[str, Any], model_name: str, device: str = "cuda"):
        super().__init__(config, device)
        self.model_name = model_name
        self.batch_size = config.get('retrieval', {}).get('dense', {}).get('batch_size', 32)
        self.max_length = config.get('retrieval', {}).get('dense', {}).get('max_length', 512)
        
        # Try to load model
        self.model = self._load_model()
        
        # Storage
        self.index = None
        self.doc_ids = []
        self.documents = []
        self.embeddings = None
        
    def _load_model(self):
        """Load the dense retrieval model."""
        try:
            from sentence_transformers import SentenceTransformer
            logger.info(f"Loading dense retriever model: {self.model_name}")
            return SentenceTransformer(self.model_name, device=self.device)
        except ImportError:
            logger.warning("sentence-transformers not available")
            return None
        except Exception as e:
            logger.warning(f"Could not load model {self.model_name}: {e}")
            return None
    
    def index_documents(self, documents: List[Any], index_name: str) -> None:
        """Index documents using dense embeddings."""
        if self.model is None:
            logger.error("Model not loaded, cannot index documents")
            return
            
        index_dir = f"indexes/dense/{self.model_name.replace('/', '_')}/{index_name}"
        os.makedirs(index_dir, exist_ok=True)
        
        index_path = os.path.join(index_dir, 'index.pkl')
        
        # Check if index exists
        if os.path.exists(index_path):
            logger.info(f"Loading existing index from {index_dir}")
            try:
                with open(index_path, 'rb') as f:
                    index_data = pickle.load(f)
                    self.doc_ids = index_data['doc_ids']
                    self.doc_texts = index_data['doc_texts']  # Store texts instead of objects
                    self.embeddings = index_data['embeddings']
                # Recreate document objects if needed
                self.documents = [{'doc_id': doc_id, 'content': text}
                                 for doc_id, text in zip(self.doc_ids, self.doc_texts)]
                self._build_faiss_index()
                return
            except (EOFError, pickle.UnpicklingError, KeyError) as e:
                logger.warning(f"Corrupted index file, rebuilding: {e}")
                # Remove corrupted file
                os.remove(index_path)
        
        logger.info(f"Creating dense index for {len(documents)} documents")
        
        # Extract texts and IDs
        if isinstance(documents[0], str):
            # Documents are already strings
            texts = documents
            self.doc_ids = [str(i) for i in range(len(documents))]
        else:
            # Documents are objects with content attribute
            texts = [doc.content if hasattr(doc, 'content') else str(doc) for doc in documents]
            self.doc_ids = [doc.doc_id if hasattr(doc, 'doc_id') else str(i)
                           for i, doc in enumerate(documents)]
        
        # Store texts instead of document objects
        self.doc_texts = texts
        self.documents = [{'doc_id': doc_id, 'content': text}
                         for doc_id, text in zip(self.doc_ids, self.doc_texts)]
        
        # Encode documents in batches
        embeddings = []
        for i in tqdm(range(0, len(texts), self.batch_size), desc="Encoding documents"):
            batch = texts[i:i + self.batch_size]
            batch_embeddings = self.model.encode(
                batch,
                convert_to_tensor=True,
                show_progress_bar=False,
                device=self.device,
                normalize_embeddings=True
            )
            # Convert to numpy
            if hasattr(batch_embeddings, 'cpu'):
                batch_embeddings = batch_embeddings.cpu().numpy()
            embeddings.append(batch_embeddings)
        
        self.embeddings = np.vstack(embeddings).astype('float32')
        
        # Build FAISS index if available
        self._build_faiss_index()
        
        # Save index
        logger.info(f"Saving index to {index_dir}")
        with open(index_path, 'wb') as f:
            pickle.dump({
                'doc_ids': self.doc_ids,
                'doc_texts': self.doc_texts,  # Save texts instead of objects
                'embeddings': self.embeddings
            }, f)
    
    def _build_faiss_index(self):
        """Build FAISS index if available."""
        try:
            import faiss
            
            dimension = self.embeddings.shape[1]
            self.index = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity
            
            # Add GPU support if available
            if self.device == "cuda":
                try:
                    res = faiss.StandardGpuResources()
                    self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
                except:
                    logger.warning("GPU not available for FAISS, using CPU")
            
            self.index.add(self.embeddings)
            logger.info("FAISS index built successfully")
            
        except ImportError:
            logger.warning("FAISS not available, will use numpy for search")
            self.index = None
    
    def search(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
        """Search for relevant documents using dense retrieval."""
        if self.model is None or self.embeddings is None:
            logger.error("Model or index not initialized")
            return []
        
        # Encode query
        query_embedding = self.model.encode(
            query,
            convert_to_tensor=True,
            device=self.device,
            normalize_embeddings=True
        )
        
        # Convert to numpy
        if hasattr(query_embedding, 'cpu'):
            query_embedding = query_embedding.cpu().numpy()
        query_embedding = query_embedding.reshape(1, -1).astype('float32')
        
        # Search using FAISS or numpy
        if self.index is not None:
            scores, indices = self.index.search(query_embedding, top_k)
            scores = scores[0]
            indices = indices[0]
        else:
            # Fallback to numpy cosine similarity
            similarities = np.dot(self.embeddings, query_embedding.T).flatten()
            top_indices = np.argsort(similarities)[::-1][:top_k]
            indices = top_indices
            scores = similarities[top_indices]
        
        # Create results
        results = []
        for i, (idx, score) in enumerate(zip(indices, scores)):
            if idx < len(self.doc_ids):
                result = RetrievalResult(
                    doc_id=self.doc_ids[idx],
                    score=float(score),
                    rank=i + 1,
                    content=self.doc_texts[idx] if hasattr(self, 'doc_texts') and idx < len(self.doc_texts) else None
                )
                results.append(result)
        
        return results
    
    def build_index(self, documents: List[str], metadata: Optional[List[Dict]] = None):
        """Build dense index from documents (compatibility method).
        
        Args:
            documents: List of document texts
            metadata: Optional metadata for each document
        """
        # Convert to expected format - use a namedtuple which is picklable
        from collections import namedtuple
        SimpleDoc = namedtuple('SimpleDoc', ['doc_id', 'content'])
        
        doc_objects = []
        for i, doc_text in enumerate(documents):
            doc_objects.append(SimpleDoc(doc_id=str(i), content=doc_text))
        
        self.index_documents(doc_objects, "default")
    
    def retrieve(self, query: str, k: int = 10) -> List[tuple]:
        """Retrieve top-k documents (backward compatibility)."""
        results = self.search(query, top_k=k)
        return [(int(r.doc_id), r.score) for r in results]
    
    def batch_retrieve(self, queries: List[str], k: int = 10) -> Dict[int, List[tuple]]:
        """Retrieve documents for multiple queries."""
        results = {}
        for i, query in enumerate(queries):
            results[i] = self.retrieve(query, k)
        return results
    
    def evaluate(self, retrieved: Dict[int, List[tuple]],
                 ground_truth: Dict[int, List[int]],
                 k_values: List[int]) -> Dict[str, Any]:
        """Evaluate retrieval performance."""
        metrics = {}
        
        # Calculate Recall@K
        recall_at_k = {}
        for k in k_values:
            recalls = []
            for query_id, retrieved_docs in retrieved.items():
                if query_id not in ground_truth:
                    continue
                
                relevant_docs = set(ground_truth[query_id])
                retrieved_at_k = set([doc_id for doc_id, _ in retrieved_docs[:k]])
                
                if relevant_docs:
                    recall = len(retrieved_at_k & relevant_docs) / len(relevant_docs)
                    recalls.append(recall)
            
            recall_at_k[k] = np.mean(recalls) if recalls else 0.0
        
        metrics['recall'] = recall_at_k
        
        # Calculate MRR
        mrrs = []
        for query_id, retrieved_docs in retrieved.items():
            if query_id not in ground_truth:
                continue
            
            relevant_docs = set(ground_truth[query_id])
            
            for i, (doc_id, _) in enumerate(retrieved_docs):
                if doc_id in relevant_docs:
                    mrrs.append(1.0 / (i + 1))
                    break
            else:
                mrrs.append(0.0)
        
        metrics['mrr'] = np.mean(mrrs) if mrrs else 0.0
        
        # Calculate nDCG@K
        ndcg_at_k = {}
        for k in k_values:
            ndcgs = []
            for query_id, retrieved_docs in retrieved.items():
                if query_id not in ground_truth:
                    continue
                
                relevant_docs = set(ground_truth[query_id])
                
                # Calculate DCG
                dcg = 0.0
                for i, (doc_id, _) in enumerate(retrieved_docs[:k]):
                    if doc_id in relevant_docs:
                        dcg += 1.0 / np.log2(i + 2)
                
                # Calculate IDCG
                idcg = sum(1.0 / np.log2(i + 2) for i in range(min(k, len(relevant_docs))))
                
                if idcg > 0:
                    ndcgs.append(dcg / idcg)
            
            ndcg_at_k[k] = np.mean(ndcgs) if ndcgs else 0.0
        
        metrics['ndcg'] = ndcg_at_k
        
        return metrics


class ContrieverRetriever(DenseRetriever):
    """Contriever-specific dense retriever."""
    
    def __init__(self, config: Dict[str, Any], device: str = "cuda"):
        super().__init__(config, "facebook/contriever", device)
        
    def _load_model(self):
        """Load Contriever model."""
        try:
            from transformers import AutoTokenizer, AutoModel
            import torch
            
            logger.info(f"Loading Contriever model")
            self.tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
            model = AutoModel.from_pretrained("facebook/contriever")
            
            if self.device == "cuda" and torch.cuda.is_available():
                model = model.cuda()
            
            model.eval()
            return model
            
        except Exception as e:
            logger.warning(f"Could not load Contriever: {e}")
            # Fall back to sentence-transformers
            return super()._load_model()
    
    def _encode_texts(self, texts: List[str]) -> np.ndarray:
        """Encode texts using Contriever."""
        if hasattr(self.model, 'encode'):
            # Using sentence-transformers
            return self.model.encode(texts, normalize_embeddings=True)
        else:
            # Using transformers directly
            import torch
            
            embeddings = []
            with torch.no_grad():
                for text in texts:
                    inputs = self.tokenizer(
                        text,
                        padding=True,
                        truncation=True,
                        max_length=self.max_length,
                        return_tensors="pt"
                    )
                    
                    if self.device == "cuda":
                        inputs = {k: v.cuda() for k, v in inputs.items()}
                    
                    outputs = self.model(**inputs)
                    # Mean pooling
                    embeddings.append(
                        outputs.last_hidden_state.mean(dim=1).cpu().numpy()
                    )
            
            embeddings = np.vstack(embeddings)
            # Normalize
            embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
            return embeddings