"""
ColBERT retriever implementation.
"""

import os
import pickle
import numpy as np
from typing import List, Dict, Any, Optional
import logging
from tqdm import tqdm

from .base_retriever import BaseRetriever, RetrievalResult

logger = logging.getLogger(__name__)


class ColBERTRetriever(BaseRetriever):
    """ColBERT retriever implementation."""
    
    def __init__(self, config: Dict[str, Any], device: str = "cuda"):
        super().__init__(config, device)
        self.model_name = config.get('retrieval', {}).get('colbert', {}).get('model', 'colbert-ir/colbertv2.0')
        self.model = None
        self.index = None
        self.doc_ids = []
        self.documents = []
        
        # Try to load ColBERT
        self._load_colbert()
    
    def _load_colbert(self):
        """Load ColBERT model."""
        # First try actual colbert-ai library
        try:
            import torch
            # Try to use the colbert-ai library if available
            logger.info("Attempting to load ColBERT from colbert-ai library")
            from colbert.modeling.checkpoint import Checkpoint
            from colbert.infra import ColBERTConfig
            
            # Create config for ColBERT
            self.colbert_config = ColBERTConfig(
                root="./indexes/colbert",
                experiment="peerqa",
                index_root=None,
                nranks=1,
                checkpoint=self.model_name
            )
            
            # Load checkpoint
            logger.info(f"Loading ColBERT checkpoint: {self.model_name}")
            self.checkpoint = Checkpoint(self.model_name, colbert_config=self.colbert_config)
            self.model = self.checkpoint
            self.use_colbert_lib = True
            
            # Also setup tokenizer for compatibility
            from transformers import AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            
            logger.info("ColBERT loaded successfully using colbert-ai library")
            
        except ImportError:
            logger.info("colbert-ai not available, falling back to transformers")
            self.use_colbert_lib = False
            # Fallback to transformers
            try:
                import torch
                from transformers import AutoTokenizer, AutoModel
                
                logger.info(f"Loading ColBERT model via transformers: {self.model_name}")
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                self.model = AutoModel.from_pretrained(self.model_name)
                
                if self.device == "cuda" and torch.cuda.is_available():
                    self.model = self.model.cuda()
                
                self.model.eval()
                
            except ImportError:
                logger.warning("Transformers not available, using fallback")
                self.model = None
            except Exception as e:
                logger.warning(f"Could not load ColBERT model: {e}")
                self.model = None
        except Exception as e:
            logger.warning(f"Could not load ColBERT via colbert-ai: {e}")
            self.model = None
            self.use_colbert_lib = False
    
    def build_index(self, documents: List[str], metadata: Optional[List[Dict]] = None):
        """Build ColBERT index from documents (compatibility method).
        
        Args:
            documents: List of document texts
            metadata: Optional metadata for each document
        """
        # Convert to expected format and call index_documents
        doc_objects = []
        for i, doc_text in enumerate(documents):
            # Create a simple document object
            class SimpleDoc:
                def __init__(self, doc_id, content):
                    self.doc_id = doc_id
                    self.content = content
            
            doc_objects.append(SimpleDoc(str(i), doc_text))
        
        self.index_documents(doc_objects, "default")
    
    def batch_retrieve(self, queries: List[str], k: int = 10) -> Dict[int, List[tuple]]:
        """Retrieve documents for multiple queries.
        
        Args:
            queries: List of query strings
            k: Number of documents to retrieve per query
            
        Returns:
            Dictionary mapping query index to retrieved documents
        """
        results = {}
        for i, query in enumerate(queries):
            search_results = self.search(query, top_k=k)
            results[i] = [(int(r.doc_id), r.score) for r in search_results]
        return results
    
    def evaluate(self, retrieved: Dict[int, List[tuple]],
                 ground_truth: Dict[int, List[int]],
                 k_values: List[int]) -> Dict[str, Any]:
        """Evaluate retrieval performance.
        
        Args:
            retrieved: Retrieved documents per query
            ground_truth: Relevant documents per query
            k_values: K values to evaluate at
            
        Returns:
            Dictionary of metrics
        """
        metrics = {}
        
        # Calculate Recall@K
        recall_at_k = {}
        for k in k_values:
            recalls = []
            for query_id, retrieved_docs in retrieved.items():
                if query_id not in ground_truth:
                    continue
                
                relevant_docs = set(ground_truth[query_id])
                retrieved_at_k = set([doc_id for doc_id, _ in retrieved_docs[:k]])
                
                if relevant_docs:
                    recall = len(retrieved_at_k & relevant_docs) / len(relevant_docs)
                    recalls.append(recall)
            
            recall_at_k[k] = np.mean(recalls) if recalls else 0.0
        
        metrics['recall'] = recall_at_k
        
        # Calculate MRR
        mrrs = []
        for query_id, retrieved_docs in retrieved.items():
            if query_id not in ground_truth:
                continue
            
            relevant_docs = set(ground_truth[query_id])
            
            for i, (doc_id, _) in enumerate(retrieved_docs):
                if doc_id in relevant_docs:
                    mrrs.append(1.0 / (i + 1))
                    break
            else:
                mrrs.append(0.0)
        
        metrics['mrr'] = np.mean(mrrs) if mrrs else 0.0
        
        return metrics
    
    def index_documents(self, documents: List[Any], index_name: str) -> None:
        """Index documents using ColBERT."""
        if self.model is None:
            logger.error("ColBERT model not loaded, using fallback indexing")
            self._fallback_index(documents, index_name)
            return
        
        # Check if we're using colbert-ai library
        if hasattr(self, 'use_colbert_lib') and self.use_colbert_lib:
            # Use fallback for now - full colbert-ai integration requires more setup
            logger.info("Using fallback indexing for ColBERT (full integration pending)")
            self._fallback_index(documents, index_name)
            return
        
        index_dir = f"indexes/colbert/{index_name}"
        os.makedirs(index_dir, exist_ok=True)
        
        index_path = os.path.join(index_dir, 'colbert_index.pkl')
        
        # Check if index exists
        if os.path.exists(index_path):
            logger.info(f"Loading existing ColBERT index from {index_dir}")
            with open(index_path, 'rb') as f:
                index_data = pickle.load(f)
                self.index = index_data['index']
                self.doc_ids = index_data['doc_ids']
                self.documents = index_data['documents']
            return
        
        logger.info(f"Creating ColBERT index for {len(documents)} documents")
        
        # Store documents
        self.doc_ids = [doc.doc_id for doc in documents]
        self.documents = documents
        
        # Create document embeddings
        doc_embeddings = []
        
        import torch
        with torch.no_grad():
            for doc in tqdm(documents, desc="Encoding documents with ColBERT"):
                # Tokenize document
                inputs = self.tokenizer(
                    doc.content,
                    return_tensors="pt",
                    max_length=512,
                    truncation=True,
                    padding=True
                )
                
                if self.device == "cuda":
                    inputs = {k: v.cuda() for k, v in inputs.items()}
                
                # Get embeddings
                outputs = self.model(**inputs)
                embeddings = outputs.last_hidden_state  # Token-level embeddings
                
                # Store embeddings
                doc_embeddings.append(embeddings.cpu().numpy())
        
        self.index = {
            'embeddings': doc_embeddings,
            'doc_ids': self.doc_ids
        }
        
        # Save index
        logger.info(f"Saving ColBERT index to {index_path}")
        with open(index_path, 'wb') as f:
            pickle.dump({
                'index': self.index,
                'doc_ids': self.doc_ids,
                'documents': self.documents
            }, f)
    
    def _fallback_index(self, documents: List[Any], index_name: str):
        """Fallback indexing using simple embeddings."""
        logger.info("Using fallback indexing (TF-IDF style)")
        
        from sklearn.feature_extraction.text import TfidfVectorizer
        
        # Store documents
        self.doc_ids = [doc.doc_id for doc in documents]
        self.documents = documents
        texts = [doc.content for doc in documents]
        
        # Create TF-IDF vectors as fallback
        vectorizer = TfidfVectorizer(max_features=10000, stop_words='english')
        doc_vectors = vectorizer.fit_transform(texts)
        
        self.index = {
            'vectorizer': vectorizer,
            'doc_vectors': doc_vectors,
            'doc_ids': self.doc_ids
        }
    
    def search(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
        """Search using ColBERT."""
        if self.model is None or self.index is None:
            return self._fallback_search(query, top_k)
        
        # If we're using colbert-ai library, use fallback for now
        if hasattr(self, 'use_colbert_lib') and self.use_colbert_lib:
            return self._fallback_search(query, top_k)
        
        # Check if we have embeddings in the index (for non-fallback search)
        if 'embeddings' not in self.index:
            return self._fallback_search(query, top_k)
        
        import torch
        
        # Encode query
        with torch.no_grad():
            inputs = self.tokenizer(
                query,
                return_tensors="pt",
                max_length=512,
                truncation=True,
                padding=True
            )
            
            # Only use CUDA if available
            if self.device == "cuda" and torch.cuda.is_available():
                try:
                    inputs = {k: v.cuda() for k, v in inputs.items()}
                except:
                    # Fallback to CPU if CUDA fails
                    pass
            
            outputs = self.model(**inputs)
            query_embeddings = outputs.last_hidden_state.cpu().numpy()
        
        # Calculate MaxSim scores
        scores = []
        doc_embeddings = self.index['embeddings']
        
        for doc_emb in doc_embeddings:
            # MaxSim: max similarity between query and document tokens
            similarity_matrix = np.dot(query_embeddings[0], doc_emb[0].T)
            max_similarities = np.max(similarity_matrix, axis=1)
            score = np.sum(max_similarities)
            scores.append(score)
        
        # Get top-k documents
        top_indices = np.argsort(scores)[::-1][:top_k]
        
        results = []
        for i, idx in enumerate(top_indices):
            result = RetrievalResult(
                doc_id=self.doc_ids[idx],
                score=float(scores[idx]),
                rank=i + 1,
                content=self.documents[idx].content if self.documents else None
            )
            results.append(result)
        
        return results
    
    def _fallback_search(self, query: str, top_k: int) -> List[RetrievalResult]:
        """Fallback search using TF-IDF."""
        if self.index is None:
            logger.error("No index available for search")
            return []
        
        # Check if we have the required components
        if 'vectorizer' not in self.index or 'doc_vectors' not in self.index:
            logger.error("Required components not in index")
            return []
            
        vectorizer = self.index['vectorizer']
        doc_vectors = self.index['doc_vectors']
        
        # Vectorize query
        query_vector = vectorizer.transform([query])
        
        # Calculate similarities
        from sklearn.metrics.pairwise import cosine_similarity
        similarities = cosine_similarity(query_vector, doc_vectors)[0]
        
        # Get top-k
        top_indices = np.argsort(similarities)[::-1][:top_k]
        
        results = []
        for i, idx in enumerate(top_indices):
            result = RetrievalResult(
                doc_id=self.doc_ids[idx],
                score=float(similarities[idx]),
                rank=i + 1,
                content=self.documents[idx].content if self.documents else None
            )
            results.append(result)
        
        return results


class ColBERTv2Retriever(ColBERTRetriever):
    """ColBERTv2 specific implementation."""
    
    def __init__(self, config: Dict[str, Any], device: str = "cuda"):
        super().__init__(config, device)
        self.nbits = config.get('retrieval', {}).get('colbert', {}).get('nbits', 2)
        self.kmeans_niters = config.get('retrieval', {}).get('colbert', {}).get('kmeans_niters', 4)
    
    def _compress_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
        """Compress embeddings using quantization."""
        # Simple quantization for demonstration
        # In practice, ColBERTv2 uses more sophisticated compression
        
        # Normalize embeddings
        norms = np.linalg.norm(embeddings, axis=-1, keepdims=True)
        normalized = embeddings / (norms + 1e-10)
        
        # Quantize to nbits
        if self.nbits == 2:
            # 2-bit quantization
            quantized = np.sign(normalized)
        elif self.nbits == 4:
            # 4-bit quantization (simplified)
            bins = np.array([-1, -0.5, 0, 0.5, 1])
            indices = np.digitize(normalized, bins)
            quantized = bins[np.clip(indices, 0, len(bins) - 1)]
        else:
            # No quantization
            quantized = normalized
        
        return quantized