"""
Evaluation metrics and early stopping utilities.
"""

import logging
from typing import Dict, List, Optional
import numpy as np
import faiss

logger = logging.getLogger(__name__)


def compute_retrieval_metrics(
    query_embeddings: np.ndarray,
    doc_embeddings: np.ndarray,
    query_ids: List[str],
    doc_ids: List[str],
    qrels: Dict[str, Dict[str, int]],
    k_values: List[int] = [1, 5, 10, 20, 50, 100],
) -> Dict[str, float]:
    """
    Compute retrieval metrics using FAISS for efficient search.

    Computes:
    - NDCG@k
    - Recall@k
    - MRR@k

    Args:
        query_embeddings: Query embeddings (num_queries, embedding_dim)
        doc_embeddings: Document embeddings (num_docs, embedding_dim)
        query_ids: List of query IDs
        doc_ids: List of document IDs
        qrels: Ground truth qrels
        k_values: k values for metrics

    Returns:
        Dict of metric_name -> value
    """
    logger.info("Building FAISS index for retrieval")

    # Build FAISS index
    embedding_dim = doc_embeddings.shape[1]
    index = faiss.IndexFlatIP(
        embedding_dim
    )  # Inner product (cosine similarity for normalized embeddings)
    index.add(doc_embeddings.astype(np.float32))

    # Filter queries that have qrels and relevant documents
    valid_query_indices = []
    valid_query_ids = []
    for i, query_id in enumerate(query_ids):
        if query_id in qrels:
            relevant_docs = {
                doc_id: rel for doc_id, rel in qrels[query_id].items() if rel > 0
            }
            if len(relevant_docs) > 0:
                valid_query_indices.append(i)
                valid_query_ids.append(query_id)

    if len(valid_query_indices) == 0:
        logger.warning("No valid queries with qrels found!")
        return {}

    logger.info(
        f"Filtered {len(valid_query_ids)} queries with qrels out of {len(query_ids)} total queries"
    )

    # Get embeddings for valid queries only
    valid_query_embeddings = query_embeddings[valid_query_indices]

    # Search
    max_k = max(k_values)
    logger.info(
        f"Searching for top-{max_k} documents for {len(valid_query_ids)} queries"
    )
    scores, indices = index.search(valid_query_embeddings.astype(np.float32), max_k)

    logger.info("Indices for first few queries: {indices[:3]}")
    logger.info(f"Scores for first few queries: {scores[:3]}")

    # Compute metrics
    metrics = {}

    for k in k_values:
        ndcg_scores = []
        recall_scores = []
        mrr_scores = []

        for i, query_id in enumerate(valid_query_ids):
            # Get relevant documents for this query
            relevant_docs = {
                doc_id: rel for doc_id, rel in qrels[query_id].items() if rel > 0
            }

            # Get top-k retrieved documents
            top_k_indices = indices[i, :k]
            top_k_doc_ids = [doc_ids[idx] for idx in top_k_indices]
            top_k_scores = scores[i, :k]

            # NDCG@k
            dcg = 0.0
            for rank, doc_id in enumerate(top_k_doc_ids):
                if doc_id in relevant_docs:
                    rel = relevant_docs[doc_id]
                    dcg += rel / np.log2(rank + 2)  # rank+2 because rank is 0-indexed

            # Ideal DCG
            ideal_rels = sorted(relevant_docs.values(), reverse=True)[:k]
            idcg = sum(rel / np.log2(rank + 2) for rank, rel in enumerate(ideal_rels))

            ndcg = dcg / idcg if idcg > 0 else 0.0
            ndcg_scores.append(ndcg)

            # Recall@k
            num_relevant_retrieved = sum(
                1 for doc_id in top_k_doc_ids if doc_id in relevant_docs
            )
            recall = num_relevant_retrieved / len(relevant_docs)
            recall_scores.append(recall)

            # MRR@k
            mrr = 0.0
            for rank, doc_id in enumerate(top_k_doc_ids):
                if doc_id in relevant_docs:
                    mrr = 1.0 / (rank + 1)
                    break
            mrr_scores.append(mrr)

        # Average metrics
        if len(ndcg_scores) > 0:
            metrics[f"ndcg@{k}"] = np.mean(ndcg_scores)
            metrics[f"recall@{k}"] = np.mean(recall_scores)
            metrics[f"mrr@{k}"] = np.mean(mrr_scores)

    logger.info("Computed retrieval metrics:")
    for metric_name, value in metrics.items():
        logger.info(f"  {metric_name}: {value:.4f}")

    return metrics


class EarlyStoppingMonitor:
    """
    Monitor for early stopping based on a metric.
    """

    def __init__(
        self,
        patience: int = 5,
        metric_name: str = "ndcg@10",
        mode: str = "max",
    ):
        """
        Initialize early stopping monitor.

        Args:
            patience: Number of epochs to wait before stopping
            metric_name: Name of metric to monitor
            mode: "max" (higher is better) or "min" (lower is better)
        """
        self.patience = patience
        self.metric_name = metric_name
        self.mode = mode

        self.best_value = float("-inf") if mode == "max" else float("inf")
        self.best_epoch = 0
        self.counter = 0
        self.should_stop_training = False

        logger.info(
            f"Early stopping: monitoring {metric_name} with patience {patience}"
        )

    def update(self, metrics: Dict[str, float], epoch: int) -> bool:
        """
        Update the monitor with new metrics.

        Args:
            metrics: Dict of metric values
            epoch: Current epoch number

        Returns:
            True if training should stop, False otherwise
        """
        if self.metric_name not in metrics:
            logger.warning(f"Metric {self.metric_name} not found in metrics")
            return False

        current_value = metrics[self.metric_name]

        # Check if improved
        if self.mode == "max":
            improved = current_value > self.best_value
        else:
            improved = current_value < self.best_value

        if improved:
            self.best_value = current_value
            self.best_epoch = epoch
            self.counter = 0
            logger.info(
                f"New best {self.metric_name}: {current_value:.4f} at epoch {epoch}"
            )
        else:
            self.counter += 1
            logger.info(
                f"No improvement in {self.metric_name} for {self.counter} epochs "
                f"(best: {self.best_value:.4f} at epoch {self.best_epoch})"
            )

            if self.counter >= self.patience:
                logger.info(
                    f"Early stopping triggered after {self.counter} epochs without improvement"
                )
                self.should_stop_training = True
                return True

        return False

    def get_best_epoch(self) -> int:
        """Get the epoch with the best metric value."""
        return self.best_epoch

    def get_best_value(self) -> float:
        """Get the best metric value."""
        return self.best_value

    def is_best_epoch(self, epoch: int) -> bool:
        """Check if the given epoch is the best so far."""
        return epoch == self.best_epoch
