"""
Evaluation script for trained models.

Usage:
    python evaluate.py --config path/to/config.yaml --checkpoint path/to/checkpoint
"""

import argparse
import os
import sys
import torch
import logging

# Add src_new to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from src.config import load_config
from src.models import DualEncoder
from src.data import load_collection, load_queries, load_qrels
from src.utils import generate_embeddings, compute_retrieval_metrics, setup_logging
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser(description="Evaluate dense retrieval model")
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Path to configuration YAML file",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Path to model checkpoint directory",
    )
    parser.add_argument(
        "--collection_path",
        type=str,
        default=None,
        help="Path to document collection (overrides config if provided)",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Path to save evaluation results (optional)",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=None,
        help="Batch size for embedding generation (overrides config if provided)",
    )
    args = parser.parse_args()

    # Setup logging
    setup_logging(log_level="INFO", rank=0)

    # Load configuration
    config = load_config(args.config)
    if args.collection_path is not None:
        config.data.collection_path = args.collection_path
        if config.data.use_positive_only_collection:
            logger.warning(
                "Overriding collection_path, setting use_positive_only_collection to False"
            )
            config.data.use_positive_only_collection = False
    logger.info(f"Loaded config from {args.config}")

    # Override batch size if provided
    if args.batch_size is not None:
        config.embedding_generation.batch_size = args.batch_size
        logger.info(
            f"Overriding batch size to {args.batch_size} for embedding generation"
        )

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    # Load model
    logger.info(f"Loading model from {args.checkpoint}")
    model = DualEncoder.from_pretrained(args.checkpoint)
    model.to(device)
    model.eval()

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model.encoder_name)

    # Load data
    logger.info("Loading data")
    collection_path = (
        config.data.collection_positive_only_path
        if config.data.use_positive_only_collection
        else config.data.collection_path
    )
    doc_ids, doc_id_to_text = load_collection(collection_path)

    dev_query_ids, dev_query_id_to_text = load_queries(config.data.queries_dev_path)
    dev_qrels = load_qrels(config.data.qrels_dev_path)

    logger.info(f"Loaded {len(doc_ids)} documents")
    logger.info(f"Loaded {len(dev_query_ids)} dev queries")

    # Generate embeddings
    logger.info("Generating query embeddings")
    query_embeddings, _ = generate_embeddings(
        model=model,
        ids=dev_query_ids,
        id_to_text=dev_query_id_to_text,
        tokenizer=tokenizer,
        max_length=config.model.max_query_length,
        batch_size=config.embedding_generation.batch_size,
        device=device,
        is_query=True,
        show_progress=True,
        prefix=config.model.query_prefix,
    )

    logger.info("Generating document embeddings")
    doc_embeddings, _ = generate_embeddings(
        model=model,
        ids=doc_ids,
        id_to_text=doc_id_to_text,
        tokenizer=tokenizer,
        max_length=config.model.max_doc_length,
        batch_size=config.embedding_generation.batch_size,
        device=device,
        is_query=False,
        show_progress=True,
        prefix=config.model.document_prefix,
    )

    # Compute metrics
    logger.info("Computing retrieval metrics")
    metrics = compute_retrieval_metrics(
        query_embeddings=query_embeddings,
        doc_embeddings=doc_embeddings,
        query_ids=dev_query_ids,
        doc_ids=doc_ids,
        qrels=dev_qrels,
        k_values=[1, 5, 10, 20, 50, 100],
    )

    # Print results
    print("\n" + "=" * 50)
    print("Evaluation Results")
    print("=" * 50)
    for metric_name, value in sorted(metrics.items()):
        print(f"{metric_name:20s}: {value:.4f}")
    print("=" * 50)

    # Save results if output path provided
    if args.output:
        import json

        os.makedirs(os.path.dirname(args.output), exist_ok=True)
        with open(args.output, "w") as f:
            json.dump(metrics, f, indent=2)
        logger.info(f"Saved results to {args.output}")


if __name__ == "__main__":
    main()
