"""
Main Trainer class for dense retrieval training.
"""

import logging
import os
import json
import time
from typing import Dict, Optional, Tuple, Any, List, Set
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from tqdm import tqdm

from ..config import ExperimentConfig
from ..models import DualEncoder
from ..losses import InfoNCELoss
from ..data import (
    DenseRetrievalDataset,
    DenseRetrievalCollator,
    load_collection,
    load_queries,
    load_qrels,
    save_training_qrels,
)
from ..samplers import create_batch_sampler, create_negative_sampler
from ..utils import (
    setup_logging,
    generate_embeddings,
    generate_embeddings_distributed,
    save_embeddings,
    compute_retrieval_metrics,
    EarlyStoppingMonitor,
)
from .chunked_distributed_sampler import ChunkedDistributedSampler

logger = logging.getLogger(__name__)


class Trainer:
    """
    Main trainer class for dense retrieval.

    Orchestrates:
    - Data loading
    - Model training
    - Embedding generation
    - Batch sampling
    - Negative sampling
    - Evaluation
    - Checkpointing
    """

    def __init__(self, config: ExperimentConfig, local_rank: int = -1):
        """
        Initialize trainer.

        Args:
            config: Experiment configuration
            local_rank: Local rank for distributed training (-1 for single GPU)
        """
        self.config = config
        self.local_rank = local_rank

        # Setup distributed training
        self.is_distributed = local_rank != -1
        if self.is_distributed:
            self.world_size = dist.get_world_size()
            self.rank = dist.get_rank()
        else:
            self.world_size = 1
            self.rank = 0

        # Setup logging
        log_file = os.path.join(config.experiment_dir, config.logging.log_file)
        setup_logging(log_file, config.logging.log_level, self.rank)

        logger.info(f"Initializing Trainer for experiment: {config.experiment_name}")
        logger.info(f"Experiment directory: {config.experiment_dir}")
        logger.info(
            f"Distributed training: {self.is_distributed} (rank {self.rank}/{self.world_size})"
        )

        # Create experiment directories
        self._create_directories()

        # Setup device
        if self.is_distributed:
            self.device = torch.device(f"cuda:{local_rank}")
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")

        # Set random seeds
        self._set_seeds(config.seed)

        # Load data
        self._load_data()

        # Initialize model
        self._initialize_model()

        # Initialize loss
        self._initialize_loss()

        # Initialize optimizer and scheduler
        self._initialize_optimizer()

        # Initialize samplers
        self._initialize_samplers()

        # Initialize tensorboard (only on rank 0)
        if self.rank == 0:
            tensorboard_dir = os.path.join(
                config.experiment_dir, config.logging.tensorboard_dir
            )
            self.writer = SummaryWriter(tensorboard_dir)
            logger.info(f"Tensorboard directory: {tensorboard_dir}")
        else:
            self.writer = None

        # Initialize early stopping
        self.early_stopping = EarlyStoppingMonitor(
            patience=config.training.early_stopping_patience,
            metric_name=config.training.early_stopping_metric,
            mode="max",
        )

        # Training state
        self.global_step = 0
        self.current_epoch = 0
        self.best_checkpoint_saved = False  # Track if we've saved a best checkpoint

        # Cache for embeddings, batch order, and sampled negatives
        self.cached_query_embeddings = None
        self.cached_doc_embeddings = None
        self.cached_query_order = None
        self.cached_sampled_negatives = None

        # Forward-pass embedding caches (epoch i uses cache from epoch i-1)
        # Stored as dicts keyed by id to avoid allocating full (num_docs, dim) arrays.
        self._prev_epoch_query_emb_by_id: Optional[Dict[str, np.ndarray]] = None
        self._prev_epoch_doc_emb_by_id: Optional[Dict[str, np.ndarray]] = None
        self._prev_epoch_cache_epoch: Optional[int] = None

        # Epoch-local caches built during _train_one_epoch
        self._cur_epoch_query_emb_by_id: Optional[Dict[str, np.ndarray]] = None
        self._cur_epoch_doc_emb_by_id: Optional[Dict[str, np.ndarray]] = None

        logger.info("Trainer initialized successfully")

    def _create_directories(self):
        """Create necessary directories for the experiment."""
        if self.rank == 0:
            os.makedirs(self.config.experiment_dir, exist_ok=True)
            os.makedirs(
                os.path.join(self.config.experiment_dir, "checkpoints"), exist_ok=True
            )
            os.makedirs(
                os.path.join(self.config.experiment_dir, "embeddings"), exist_ok=True
            )
            os.makedirs(
                os.path.join(self.config.experiment_dir, "training_data"), exist_ok=True
            )
            os.makedirs(
                os.path.join(self.config.experiment_dir, "results"), exist_ok=True
            )
            os.makedirs(os.path.join(self.config.experiment_dir, "log"), exist_ok=True)

            # Create directories for step-wise checkpoints and results
            if self.config.training.epoch_intermediate_eval:
                os.makedirs(
                    os.path.join(self.config.experiment_dir, "checkpoints_by_steps"),
                    exist_ok=True,
                )
                os.makedirs(
                    os.path.join(self.config.experiment_dir, "results_by_steps"),
                    exist_ok=True,
                )

            # Save config
            from ..config import save_config

            config_path = os.path.join(self.config.experiment_dir, "config.yaml")
            save_config(self.config, config_path)
            logger.info(f"Saved config to {config_path}")

    def _set_seeds(self, seed: int):
        """Set random seeds for reproducibility."""
        import random

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        logger.info(f"Set random seed to {seed}")

    def _load_data(self):
        """Load collection, queries, and qrels."""
        logger.info("Loading data")

        # Load collection
        collection_path = (
            self.config.data.collection_positive_only_path
            if self.config.data.use_positive_only_collection
            else self.config.data.collection_path
        )
        self.doc_ids, self.doc_id_to_text = load_collection(collection_path)

        # Load train queries and qrels
        self.train_query_ids, self.train_query_id_to_text = load_queries(
            self.config.data.queries_train_path
        )
        self.train_qrels = load_qrels(self.config.data.qrels_train_path)

        # Load dev queries and qrels
        self.dev_query_ids, self.dev_query_id_to_text = load_queries(
            self.config.data.queries_dev_path
        )
        self.dev_qrels = load_qrels(self.config.data.qrels_dev_path)

        # Load evaluation collection (use training collection if not specified)
        if self.config.data.eval_collection_path:
            self.eval_doc_ids, self.eval_doc_id_to_text = load_collection(
                self.config.data.eval_collection_path
            )
            logger.info(f"Loaded {len(self.eval_doc_ids)} evaluation documents from separate collection")
        else:
            # Use training collection for evaluation
            if self.config.data.use_positive_only_collection:
                self.eval_doc_ids, self.eval_doc_id_to_text = load_collection(
                    self.config.data.collection_positive_only_path
                )
            else:
                self.eval_doc_ids = self.doc_ids
                self.eval_doc_id_to_text = self.doc_id_to_text
            logger.info(f"Using training collection for evaluation")

        logger.info(f"Loaded {len(self.doc_ids)} training documents")
        logger.info(f"Loaded {len(self.train_query_ids)} training queries")
        logger.info(f"Loaded {len(self.dev_query_ids)} dev queries")

    def _initialize_model(self):
        """Initialize the dual encoder model."""
        logger.info("Initializing model")

        self.model = DualEncoder(
            encoder_name=self.config.model.encoder_name,
            pooling_strategy=self.config.model.pooling_strategy,
            embedding_dim=self.config.model.embedding_dim,
            normalize_embeddings=self.config.model.normalize_embeddings,
        )

        self.model.to(self.device)

        # Wrap with DDP if distributed
        if self.is_distributed:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                find_unused_parameters=True,
            )
            logger.info("Wrapped model with DistributedDataParallel")

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model.encoder_name)
        logger.info(f"Loaded tokenizer: {self.config.model.encoder_name}")

    def _initialize_loss(self):
        """Initialize the loss function."""
        logger.info("Initializing loss function")

        self.loss_fn = InfoNCELoss(
            temperature=self.config.loss.temperature,
            use_mined_negatives=self.config.loss.use_mined_negatives,
            use_sampled_negatives=self.config.loss.use_sampled_negatives,
            use_inbatch_negatives=self.config.loss.use_inbatch_negatives,
            gather_across_gpus=self.config.loss.gather_across_gpus,
        )

    def _initialize_optimizer(self):
        """Initialize optimizer and learning rate scheduler."""
        logger.info("Initializing optimizer and scheduler")

        # Get model parameters
        if self.is_distributed:
            model_params = self.model.module.parameters()
        else:
            model_params = self.model.parameters()

        # Create optimizer
        self.optimizer = torch.optim.AdamW(
            model_params,
            lr=self.config.training.learning_rate,
            weight_decay=self.config.training.weight_decay,
            eps=self.config.training.adam_epsilon,
        )

        # Calculate total training steps
        # Steps per epoch = num_queries / (per_gpu_batch_size * world_size * gradient_accumulation_steps)
        steps_per_epoch = len(self.train_query_ids) // (
            self.config.training.per_gpu_batch_size
            * self.world_size
            * self.config.training.gradient_accumulation_steps
        )
        total_steps = steps_per_epoch * self.config.training.num_epochs

        # Create linear scheduler with warmup
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.training.warmup_steps,
            num_training_steps=total_steps,
        )

        logger.info(f"Optimizer: AdamW with lr={self.config.training.learning_rate}")
        logger.info(
            f"Total training steps: {total_steps} ({steps_per_epoch} per epoch)"
        )
        logger.info(f"Warmup steps: {self.config.training.warmup_steps}")

    def _initialize_samplers(self):
        """Initialize batch and negative samplers."""
        logger.info("Initializing samplers")

        # Batch sampler
        if self.config.batch_sampler.enabled:
            self.batch_sampler = create_batch_sampler(
                name=self.config.batch_sampler.name,
                seed=self.config.seed,
                args=self.config.batch_sampler.args,
            )
        else:
            self.batch_sampler = None
            logger.info("Batch sampler disabled")

        # Negative sampler
        if self.config.negative_sampler.enabled:
            self.negative_sampler = create_negative_sampler(
                name=self.config.negative_sampler.name,
                seed=self.config.seed,
                args=self.config.negative_sampler.args,
            )
        else:
            self.negative_sampler = None
            logger.info("Negative sampler disabled")

    def _maybe_generate_embeddings(self, epoch: int):
        """Generate embeddings if needed for this epoch."""
        if not self.config.embedding_generation.enabled:
            logger.info("Embedding generation disabled")
            return

        # If we are using forward-pass caches for sampler embeddings, avoid full
        # explicit embedding generation for the cached side(s) after epoch 0.
        use_cached_q = bool(self.config.embedding_generation.cache_query_embeddings)
        use_cached_d = bool(self.config.embedding_generation.cache_document_embeddings)

        # Check if we need to generate embeddings this epoch
        if epoch % self.config.embedding_generation.frequency != 0:
            logger.info(
                f"Skipping embedding generation (frequency: {self.config.embedding_generation.frequency})"
            )
            return

        logger.info(f"Generating embeddings for epoch {epoch}")

        # Get the underlying model (unwrap DDP if needed)
        model = self.model.module if self.is_distributed else self.model

        # Generate query embeddings (distributed)
        if self.config.embedding_generation.generate_queries and not (epoch > 0 and use_cached_q):
            logger.info("Generating query embeddings (distributed)")
            self.cached_query_embeddings, _ = generate_embeddings_distributed(
                model=model,
                ids=self.train_query_ids,
                id_to_text=self.train_query_id_to_text,
                tokenizer=self.tokenizer,
                max_length=self.config.model.max_query_length,
                batch_size=self.config.embedding_generation.batch_size,
                device=self.device,
                is_query=True,
                show_progress=(self.rank == 0),
                rank=self.rank,
                world_size=self.world_size,
                prefix=self.config.model.query_prefix,
            )

            # Save embeddings (only rank 0)
            if self.rank == 0:
                save_path = os.path.join(
                    self.config.experiment_dir,
                    "embeddings",
                    f"epoch-{epoch}",
                    "train_queries",
                )
                save_embeddings(
                    self.cached_query_embeddings, self.train_query_ids, save_path
                )

        # Generate document embeddings (distributed)
        if self.config.embedding_generation.generate_documents and not (epoch > 0 and use_cached_d):
            logger.info("Generating document embeddings (distributed)")
            self.cached_doc_embeddings, _ = generate_embeddings_distributed(
                model=model,
                ids=self.doc_ids,
                id_to_text=self.doc_id_to_text,
                tokenizer=self.tokenizer,
                max_length=self.config.model.max_doc_length,
                batch_size=self.config.embedding_generation.batch_size,
                device=self.device,
                is_query=False,
                show_progress=(self.rank == 0),
                rank=self.rank,
                world_size=self.world_size,
                prefix=self.config.model.document_prefix,
            )

            # Save embeddings (only rank 0)
            if self.rank == 0:
                save_path = os.path.join(
                    self.config.experiment_dir,
                    "embeddings",
                    f"epoch-{epoch}",
                    "documents",
                )
                save_embeddings(self.cached_doc_embeddings, self.doc_ids, save_path)

        # Synchronize processes
        if self.is_distributed:
            dist.barrier()

        logger.info("Embedding generation complete")

    def _gather_object_to_rank0(self, obj: Any) -> Any:
        """Gather a Python object from all ranks to rank 0 and return list on rank 0."""
        if not self.is_distributed:
            return [obj]

        gathered: List[Any] = [None for _ in range(self.world_size)] if self.rank == 0 else []
        dist.gather_object(obj, gathered if self.rank == 0 else None, dst=0)
        return gathered

    def _merge_embedding_dicts(self, dicts: List[Optional[Dict[str, np.ndarray]]]) -> Dict[str, np.ndarray]:
        merged: Dict[str, np.ndarray] = {}
        for d in dicts:
            if not d:
                continue
            # Last write wins (should be identical for query ids; docs may repeat)
            merged.update(d)
        return merged

    def _cache_forward_embeddings_from_batch(self, batch: Dict[str, Any], outputs: Dict[str, Any]):
        """Accumulate detached forward embeddings for queries and docs into epoch-local caches."""
        if self._cur_epoch_query_emb_by_id is None:
            self._cur_epoch_query_emb_by_id = {}
        if self._cur_epoch_doc_emb_by_id is None:
            self._cur_epoch_doc_emb_by_id = {}

        # Queries
        query_ids: List[str] = batch.get("query_ids", [])
        if query_ids:
            q_emb = outputs.get("query_embeddings", None)
            if q_emb is not None:
                q_np = q_emb.detach().to(dtype=torch.float32).cpu().numpy()
                for qid, emb in zip(query_ids, q_np):
                    self._cur_epoch_query_emb_by_id[str(qid)] = emb

        # Positives
        pos_ids_2d = batch.get("positive_doc_ids", None)
        pos_emb = outputs.get("positive_embeddings", None)
        pos_mask = outputs.get("positive_mask", None)
        if pos_ids_2d is not None and pos_emb is not None and pos_mask is not None:
            pos_np = pos_emb.detach().to(dtype=torch.float32).cpu().numpy()  # (B, P, D)
            pos_mask_np = pos_mask.detach().cpu().numpy().astype(bool)  # (B, P)
            for i in range(len(pos_ids_2d)):
                for j, did in enumerate(pos_ids_2d[i]):
                    if not did:
                        continue
                    if pos_mask_np[i, j]:
                        self._cur_epoch_doc_emb_by_id[str(did)] = pos_np[i, j]

        # Mined negatives
        mined_ids_2d = batch.get("mined_negative_doc_ids", None)
        mined_emb = outputs.get("mined_negative_embeddings", None)
        mined_mask = outputs.get("mined_negative_mask", None)
        if mined_ids_2d is not None and mined_emb is not None and mined_mask is not None:
            mined_np = mined_emb.detach().to(dtype=torch.float32).cpu().numpy()  # (B, M, D)
            mined_mask_np = mined_mask.detach().cpu().numpy().astype(bool)  # (B, M)
            for i in range(len(mined_ids_2d)):
                for j, did in enumerate(mined_ids_2d[i]):
                    if not did:
                        continue
                    if mined_mask_np[i, j]:
                        self._cur_epoch_doc_emb_by_id[str(did)] = mined_np[i, j]

        # Sampled negatives
        sampled_ids_2d = batch.get("sampled_negative_doc_ids", None)
        sampled_emb = outputs.get("sampled_negative_embeddings", None)
        sampled_mask = outputs.get("sampled_negative_mask", None)
        if sampled_ids_2d is not None and sampled_emb is not None and sampled_mask is not None:
            sampled_np = sampled_emb.detach().to(dtype=torch.float32).cpu().numpy()  # (B, S, D)
            sampled_mask_np = sampled_mask.detach().cpu().numpy().astype(bool)  # (B, S)
            for i in range(len(sampled_ids_2d)):
                for j, did in enumerate(sampled_ids_2d[i]):
                    if not did:
                        continue
                    if sampled_mask_np[i, j]:
                        self._cur_epoch_doc_emb_by_id[str(did)] = sampled_np[i, j]

    def _get_required_positive_doc_ids_for_samplers(self, epoch: int) -> Set[str]:
        """Compute set of positive doc IDs needed by the batch sampler for this epoch (rank 0)."""
        if not self.config.batch_sampler.enabled:
            return set()
        if self.batch_sampler is None:
            return set()

        # Currently only HOBIT exposes a helper; otherwise we conservatively return empty.
        helper = getattr(self.batch_sampler, "get_required_positive_doc_ids", None)
        if callable(helper):
            return set(
                helper(
                    query_ids=self.train_query_ids,
                    qrels=self.train_qrels,
                    epoch=epoch,
                )
            )
        return set()

    def _maybe_fill_missing_embeddings_for_samplers(self, epoch: int):
        """Explicitly generate embeddings for IDs missing from forward-pass caches."""
        if self.rank != 0:
            return

        model = self.model.module if self.is_distributed else self.model

        missing_q_count = 0
        required_q_count = 0
        missing_d_count = 0
        required_d_count = 0

        t_fill_start = time.perf_counter() if "time" in globals() else None

        # Fill missing query embeddings if caching is enabled.
        if self.config.embedding_generation.cache_query_embeddings:
            if self._prev_epoch_query_emb_by_id is None:
                self._prev_epoch_query_emb_by_id = {}

            # Only queries with at least one positive are used by HOBIT and most samplers.
            required_qids = [
                qid
                for qid in self.train_query_ids
                if qid in self.train_qrels and any(rel > 0 for rel in self.train_qrels[qid].values())
            ]
            required_q_count = len(required_qids)
            missing_qids = [qid for qid in required_qids if qid not in self._prev_epoch_query_emb_by_id]
            missing_q_count = len(missing_qids)
            if missing_qids:
                t0 = time.perf_counter()
                logger.info(
                    "[CacheFill] Generating %d missing query embeddings explicitly for epoch %d samplers",
                    len(missing_qids),
                    epoch,
                )
                q_embs, _ = generate_embeddings(
                    model=model,
                    ids=missing_qids,
                    id_to_text=self.train_query_id_to_text,
                    tokenizer=self.tokenizer,
                    max_length=self.config.model.max_query_length,
                    batch_size=min(self.config.embedding_generation.batch_size, max(1, len(missing_qids))),
                    device=self.device,
                    is_query=True,
                    show_progress=True,
                    prefix=self.config.model.query_prefix,
                )
                for qid, emb in zip(missing_qids, q_embs):
                    self._prev_epoch_query_emb_by_id[qid] = emb
                logger.info(
                    "[CacheFill] Query fill done in %.3fs",
                    time.perf_counter() - t0,
                )

        # Fill missing doc embeddings for required positives if caching is enabled.
        if self.config.embedding_generation.cache_document_embeddings:
            if self._prev_epoch_doc_emb_by_id is None:
                self._prev_epoch_doc_emb_by_id = {}

            required_pos_doc_ids = self._get_required_positive_doc_ids_for_samplers(epoch)
            required_d_count = len(required_pos_doc_ids)
            missing_doc_ids = [did for did in required_pos_doc_ids if did not in self._prev_epoch_doc_emb_by_id]
            missing_d_count = len(missing_doc_ids)
            if missing_doc_ids:
                t0 = time.perf_counter()
                logger.info(
                    "[CacheFill] Generating %d missing positive doc embeddings explicitly for epoch %d samplers",
                    len(missing_doc_ids),
                    epoch,
                )
                d_embs, _ = generate_embeddings(
                    model=model,
                    ids=missing_doc_ids,
                    id_to_text=self.doc_id_to_text,
                    tokenizer=self.tokenizer,
                    max_length=self.config.model.max_doc_length,
                    batch_size=min(self.config.embedding_generation.batch_size, max(1, len(missing_doc_ids))),
                    device=self.device,
                    is_query=False,
                    show_progress=True,
                    prefix=self.config.model.document_prefix,
                )
                for did, emb in zip(missing_doc_ids, d_embs):
                    self._prev_epoch_doc_emb_by_id[did] = emb
                logger.info(
                    "[CacheFill] Doc fill done in %.3fs",
                    time.perf_counter() - t0,
                )

        if t_fill_start is not None:
            logger.info(
                "[CacheFill] Done | epoch=%d required_q=%d missing_q=%d required_pos_docs=%d missing_pos_docs=%d | total=%.3fs",
                epoch,
                required_q_count,
                missing_q_count,
                required_d_count,
                missing_d_count,
                time.perf_counter() - t_fill_start,
            )

        if self.writer is not None and epoch is not None:
            self.writer.add_scalar(
                "forward_cache/required_queries_for_samplers",
                float(required_q_count),
                epoch,
            )
            self.writer.add_scalar(
                "forward_cache/missing_queries_filled",
                float(missing_q_count),
                epoch,
            )
            self.writer.add_scalar(
                "forward_cache/required_positive_docs_for_samplers",
                float(required_d_count),
                epoch,
            )
            self.writer.add_scalar(
                "forward_cache/missing_positive_docs_filled",
                float(missing_d_count),
                epoch,
            )

    def _prepare_sampler_embeddings(self, epoch: int):
        """Set sampler-time embedding sources based on config (epoch i uses cache from epoch i-1)."""
        use_cached_q = bool(self.config.embedding_generation.cache_query_embeddings)
        use_cached_d = bool(self.config.embedding_generation.cache_document_embeddings)

        if not (use_cached_q or use_cached_d):
            return

        # Epoch 0: explicitly generated embeddings (already handled by _maybe_generate_embeddings)
        if epoch == 0:
            logger.info(
                "[ForwardCache] Epoch 0: using explicit embeddings for samplers (cache will be built during training)"
            )
            return

        # If we don't have a previous cache, fall back to explicit embeddings.
        if self._prev_epoch_cache_epoch is None or self._prev_epoch_cache_epoch != (epoch - 1):
            logger.info(
                "No epoch-%d forward cache available; falling back to explicit embeddings for samplers.",
                epoch - 1,
            )
            return

        if self.rank == 0:
            qn = len(self._prev_epoch_query_emb_by_id or {}) if use_cached_q else 0
            dn = len(self._prev_epoch_doc_emb_by_id or {}) if use_cached_d else 0
            logger.info(
                "[ForwardCache] Using epoch-%d forward cache for samplers | cache_query=%s (%d) cache_doc=%s (%d)",
                epoch - 1,
                str(use_cached_q),
                qn,
                str(use_cached_d),
                dn,
            )

        # Fill missing ids (rank 0) and then broadcast dicts to other ranks via gather_object+broadcast.
        if self.rank == 0:
            self._maybe_fill_missing_embeddings_for_samplers(epoch)

        # Important: we intentionally do NOT broadcast the (potentially huge) embedding caches
        # to all ranks. Only rank 0 runs batch/negative samplers, so other ranks don't need
        # access to these dicts. This avoids large pickle/broadcast stalls.

        # For cached mode, we keep array caches as None and pass dicts to samplers that support them.
        if use_cached_q:
            self.cached_query_embeddings = None
        if use_cached_d:
            self.cached_doc_embeddings = None

    def _maybe_run_batch_sampler(self, epoch: int):
        """Run batch sampler if needed for this epoch."""
        if not self.config.batch_sampler.enabled:
            logger.info("Batch sampler disabled, using original query order")
            self.cached_query_order = self.train_query_ids
            return

        # Check if we need to run batch sampler this epoch
        if epoch % self.config.batch_sampler.frequency != 0:
            logger.info(
                f"Reusing cached query order (frequency: {self.config.batch_sampler.frequency})"
            )
            return

        logger.info(f"Running batch sampler for epoch {epoch}")

        if self.is_distributed and self.rank != 0:
            logger.info(
                "Waiting for rank 0 to finish batch sampler and broadcast query order..."
            )

        # Only rank 0 runs the sampler
        if self.rank == 0:
            use_cached_q = bool(self.config.embedding_generation.cache_query_embeddings) and epoch > 0
            use_cached_d = bool(self.config.embedding_generation.cache_document_embeddings) and epoch > 0
            t0 = time.perf_counter()
            self.cached_query_order = self.batch_sampler.sample(
                query_ids=self.train_query_ids,
                query_embeddings=None if use_cached_q else self.cached_query_embeddings,
                doc_embeddings=None if use_cached_d else self.cached_doc_embeddings,
                doc_ids=self.doc_ids,
                qrels=self.train_qrels,
                epoch=epoch,
                writer=self.writer,
                dataset_name=self.config.data.dataset_name,
                query_embeddings_by_id=self._prev_epoch_query_emb_by_id if use_cached_q else None,
                doc_embeddings_by_id=self._prev_epoch_doc_emb_by_id if use_cached_d else None,
            )
            logger.info(
                f"Batch sampler generated order for {len(self.cached_query_order)} queries"
            )
            logger.info(
                "Batch sampler runtime: %.3fs",
                time.perf_counter() - t0,
            )

        # Broadcast query order to all ranks
        if self.is_distributed:
            # Convert to indices for broadcasting
            if self.rank == 0:
                query_id_to_idx = {
                    qid: idx for idx, qid in enumerate(self.train_query_ids)
                }
                order_indices = [
                    query_id_to_idx[qid] for qid in self.cached_query_order
                ]
                order_tensor = torch.tensor(
                    order_indices, dtype=torch.long, device=self.device
                )
            else:
                order_tensor = torch.zeros(
                    len(self.train_query_ids), dtype=torch.long, device=self.device
                )

            dist.broadcast(order_tensor, src=0)

            # Convert back to query IDs
            order_indices = order_tensor.cpu().numpy()
            self.cached_query_order = [
                self.train_query_ids[idx] for idx in order_indices
            ]

        logger.info("Batch sampling complete")

    def _maybe_run_negative_sampler(self, epoch: int):
        """Run negative sampler if needed for this epoch."""
        if not self.config.negative_sampler.enabled:
            logger.info("Negative sampler disabled")
            self.cached_sampled_negatives = {}
            return

        # Check if we need to run negative sampler this epoch
        if epoch % self.config.negative_sampler.frequency != 0:
            logger.info(
                f"Reusing cached sampled negatives (frequency: {self.config.negative_sampler.frequency})"
            )
            return

        logger.info(f"Running negative sampler for epoch {epoch}")

        # Only rank 0 runs the sampler
        if self.rank == 0:
            use_cached_q = bool(self.config.embedding_generation.cache_query_embeddings) and epoch > 0
            use_cached_d = bool(self.config.embedding_generation.cache_document_embeddings) and epoch > 0
            self.cached_sampled_negatives = self.negative_sampler.sample(
                query_ids=self.train_query_ids,
                doc_ids=self.doc_ids,
                query_embeddings=None if use_cached_q else self.cached_query_embeddings,
                doc_embeddings=None if use_cached_d else self.cached_doc_embeddings,
                qrels=self.train_qrels,
                num_samples=self.config.negative_sampler.num_samples,
                epoch=epoch,
                writer=self.writer,
                dataset_name=self.config.data.dataset_name,
                query_embeddings_by_id=self._prev_epoch_query_emb_by_id if use_cached_q else None,
                doc_embeddings_by_id=self._prev_epoch_doc_emb_by_id if use_cached_d else None,
            )
            logger.info(
                f"Sampled negatives for {len(self.cached_sampled_negatives)} queries"
            )

        # Broadcast sampled negatives to all ranks (convert to indices)
        if self.is_distributed:
            if self.rank == 0:
                doc_id_to_idx = {did: idx for idx, did in enumerate(self.doc_ids)}
                # Flatten to a single tensor with lengths
                all_neg_indices = []
                lengths = []
                for qid in self.train_query_ids:
                    neg_doc_ids = self.cached_sampled_negatives.get(qid, [])
                    neg_indices = [doc_id_to_idx[did] for did in neg_doc_ids]
                    all_neg_indices.extend(neg_indices)
                    lengths.append(len(neg_indices))

                neg_tensor = torch.tensor(
                    all_neg_indices, dtype=torch.long, device=self.device
                )
                lengths_tensor = torch.tensor(
                    lengths, dtype=torch.long, device=self.device
                )
            else:
                max_total = (
                    len(self.train_query_ids) * self.config.negative_sampler.num_samples
                )
                neg_tensor = torch.zeros(
                    max_total, dtype=torch.long, device=self.device
                )
                lengths_tensor = torch.zeros(
                    len(self.train_query_ids), dtype=torch.long, device=self.device
                )

            # Broadcast
            dist.broadcast(lengths_tensor, src=0)
            total_negs = lengths_tensor.sum().item()
            if self.rank != 0:
                neg_tensor = torch.zeros(
                    total_negs, dtype=torch.long, device=self.device
                )
            dist.broadcast(neg_tensor, src=0)

            # Reconstruct dictionary
            self.cached_sampled_negatives = {}
            lengths = lengths_tensor.cpu().numpy()
            neg_indices = neg_tensor.cpu().numpy()
            offset = 0
            for qid, length in zip(self.train_query_ids, lengths):
                neg_doc_indices = neg_indices[offset : offset + length]
                neg_doc_ids = [self.doc_ids[idx] for idx in neg_doc_indices]
                self.cached_sampled_negatives[qid] = neg_doc_ids
                offset += length

        logger.info("Negative sampling complete")

    def train(self):
        """Main training loop."""
        logger.info("=" * 80)
        logger.info("Starting training")
        logger.info("=" * 80)

        for epoch in range(self.config.training.num_epochs):
            self.current_epoch = epoch
            logger.info(f"\n{'=' * 80}")
            logger.info(f"Epoch {epoch + 1}/{self.config.training.num_epochs}")
            logger.info(f"{'=' * 80}")

            # Step 1: Generate embeddings (if needed)
            self._maybe_generate_embeddings(epoch)

            # Step 1b: If caching enabled, prepare sampler-time embeddings from epoch-1 forward cache
            self._prepare_sampler_embeddings(epoch)

            # Step 2: Run batch sampler (if needed)
            self._maybe_run_batch_sampler(epoch)

            # Step 3: Run negative sampler (if needed)
            self._maybe_run_negative_sampler(epoch)

            # Step 4: Save training qrels for this epoch
            self._save_epoch_training_qrels(epoch)

            # Step 5: Train one epoch
            self._train_one_epoch(epoch)

            # Step 6: Evaluate (if enabled)
            should_stop = False
            if self.config.training.enable_evaluation:
                if self.rank == 0:
                    metrics = self._evaluate(epoch)

                    # Save metrics
                    self._save_metrics(epoch, metrics)

                    # Check early stopping
                    if self.early_stopping.update(metrics, epoch):
                        logger.info("Early stopping triggered")
                        should_stop = True

                    # Save best checkpoint if this is the best epoch so far
                    if self.early_stopping.is_best_epoch(epoch):
                        logger.info(
                            f"New best model at epoch {epoch}, saving to 'best' folder"
                        )
                        self._save_best_checkpoint(epoch)
                        self.best_checkpoint_saved = True

                # Broadcast early stopping decision to all ranks
                if self.is_distributed:
                    should_stop_tensor = torch.tensor(
                        [1 if should_stop else 0], dtype=torch.long, device=self.device
                    )
                    dist.broadcast(should_stop_tensor, src=0)
                    should_stop = bool(should_stop_tensor.item())

                    if should_stop and self.rank != 0:
                        logger.info(
                            f"Rank {self.rank}: Early stopping triggered (broadcast from rank 0)"
                        )

                # All ranks break together if early stopping triggered
                if should_stop:
                    break
            else:
                logger.info("Evaluation disabled, skipping evaluation for this epoch")

            # Step 7: Save checkpoint (if needed)
            if epoch % self.config.training.checkpoint_frequency == 0:
                self._save_checkpoint(epoch)

            # Synchronize processes
            if self.is_distributed:
                dist.barrier()

        # Save final checkpoint
        if self.rank == 0:
            self._save_checkpoint("final")

            # Ensure best checkpoint is saved if we haven't already (only if evaluation was enabled)
            if (
                self.config.training.enable_evaluation
                and not self.best_checkpoint_saved
                and self.early_stopping.best_epoch is not None
            ):
                logger.info(
                    f"Saving best model from epoch {self.early_stopping.best_epoch} to 'best' folder"
                )
                self._save_best_checkpoint(self.early_stopping.best_epoch)

        logger.info("\n" + "=" * 80)
        logger.info("Training completed")
        if self.config.training.enable_evaluation:
            logger.info(
                f"Best {self.early_stopping.metric_name}: {self.early_stopping.get_best_value():.4f} "
                f"at epoch {self.early_stopping.get_best_epoch()}"
            )
        logger.info("=" * 80)

        if self.writer is not None:
            self.writer.close()

    def _save_epoch_training_qrels(self, epoch: int):
        """Save training qrels for this epoch (only rank 0)."""
        if self.rank != 0:
            return

        # Prepare data structures for save_training_qrels
        # Convert train_qrels (query_id -> {doc_id: relevance}) to (query_id -> [doc_id])
        query_id_to_positives = {}
        query_id_to_mined_negatives = {}

        for qid in self.cached_query_order:
            if qid in self.train_qrels:
                # Get all positive documents for this query (relevance > 0)
                query_id_to_positives[qid] = [
                    doc_id for doc_id, rel in self.train_qrels[qid].items() if rel > 0
                ]
                # Get all mined negative documents (relevance == 0)
                query_id_to_mined_negatives[qid] = [
                    doc_id for doc_id, rel in self.train_qrels[qid].items() if rel == 0
                ]

        # Get sampled negatives from cache (already a dict of query_id -> [doc_ids])
        query_id_to_sampled_negatives = self.cached_sampled_negatives

        # Save to file
        save_path = os.path.join(
            self.config.experiment_dir, "training_data", f"epoch_{epoch}_qrels.tsv"
        )
        save_training_qrels(
            output_path=save_path,
            query_ids=self.cached_query_order,
            query_id_to_positives=query_id_to_positives,
            query_id_to_mined_negatives=query_id_to_mined_negatives,
            query_id_to_sampled_negatives=query_id_to_sampled_negatives,
        )
        logger.info(f"Saved training qrels for epoch {epoch} to {save_path}")

    def _train_one_epoch(self, epoch: int):
        """Train for one epoch."""
        logger.info(f"Training epoch {epoch}")

        # Create dataset with current query order and sampled negatives
        train_dataset = DenseRetrievalDataset(
            query_ids=self.train_query_ids,
            query_id_to_text=self.train_query_id_to_text,
            doc_id_to_text=self.doc_id_to_text,
            qrels=self.train_qrels,
            max_positives=self.config.training.max_positives_per_query,
            max_mined_negatives=self.config.training.max_mined_negatives_per_query,
            query_order=self.cached_query_order,
            sampled_negatives=self.cached_sampled_negatives,
            seed=self.config.seed + epoch,
        )

        # Create collator
        collator = DenseRetrievalCollator(
            tokenizer=self.tokenizer,
            max_query_length=self.config.model.max_query_length,
            max_doc_length=self.config.model.max_doc_length,
            use_mined_negatives=self.config.loss.use_mined_negatives,
            use_sampled_negatives=self.config.loss.use_sampled_negatives,
            query_prefix=self.config.model.query_prefix,
            document_prefix=self.config.model.document_prefix,
        )

        # Create dataloader
        if self.is_distributed:
            # Use ChunkedDistributedSampler to preserve sequential ordering
            # from batch sampler (important for curriculum learning)
            sampler = ChunkedDistributedSampler(
                train_dataset,
                num_replicas=self.world_size,
                rank=self.rank,
                batch_size=self.config.training.per_gpu_batch_size,
                shuffle=False,  # Don't shuffle - preserve batch sampler order
                drop_last=True,  # Drop incomplete last batch to avoid NCCL hangs
            )
        else:
            sampler = None

        dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.training.per_gpu_batch_size,
            sampler=sampler,
            collate_fn=collator,
            num_workers=0,
            pin_memory=True,
            drop_last=True,  # Drop incomplete last batch to avoid size mismatches
        )

        # Calculate intermediate evaluation steps if enabled for this epoch
        total_steps_in_epoch = len(dataloader)
        intermediate_eval_steps = []
        if (
            self.config.training.epoch_intermediate_eval
            and self.config.training.enable_evaluation
            and epoch < self.config.training.num_epoch_intermediate_eval
        ):
            for percentage in self.config.training.epoch_eval_percentages:
                step_at_percentage = int(total_steps_in_epoch * percentage / 100)
                if step_at_percentage > 0 and step_at_percentage < total_steps_in_epoch:
                    intermediate_eval_steps.append(step_at_percentage)

            if intermediate_eval_steps:
                logger.info(
                    f"Intermediate evaluation enabled for epoch {epoch} at steps: {intermediate_eval_steps}"
                )

        # Training loop
        self.model.train()
        total_loss = 0
        num_steps = 0

        # Enable mixed precision if requested
        scaler = torch.cuda.amp.GradScaler() if self.config.training.fp16 else None

        # Progress bar (only on rank 0)
        if self.rank == 0:
            pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
        else:
            pbar = dataloader

        self.optimizer.zero_grad()

        # Initialize epoch-local forward caches
        self._cur_epoch_query_emb_by_id = {}
        self._cur_epoch_doc_emb_by_id = {}

        for step, batch in enumerate(pbar):
            # Move batch to device
            batch = {
                k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                for k, v in batch.items()
            }

            # Forward pass with mixed precision
            if self.config.training.fp16:
                with torch.cuda.amp.autocast():
                    outputs = self.model(
                        **{
                            k: v
                            for k, v in batch.items()
                            if k
                            not in {
                                "query_ids",
                                "positive_doc_ids",
                                "mined_negative_doc_ids",
                                "sampled_negative_doc_ids",
                            }
                        }
                    )
                    loss = self.loss_fn(**outputs)
                    loss = loss / self.config.training.gradient_accumulation_steps

                # Backward pass
                scaler.scale(loss).backward()
            else:
                outputs = self.model(
                    **{
                        k: v
                        for k, v in batch.items()
                        if k
                        not in {
                            "query_ids",
                            "positive_doc_ids",
                            "mined_negative_doc_ids",
                            "sampled_negative_doc_ids",
                        }
                    }
                )
                loss = self.loss_fn(**outputs)
                loss = loss / self.config.training.gradient_accumulation_steps
                loss.backward()

            # Cache forward embeddings (no grad) for use in next epoch's samplers
            if (
                self.config.embedding_generation.cache_query_embeddings
                or self.config.embedding_generation.cache_document_embeddings
            ):
                self._cache_forward_embeddings_from_batch(batch, outputs)

                # Periodic progress log for cache construction
                if (
                    self.rank == 0
                    and self.global_step % max(1, self.config.training.log_steps) == 0
                ):
                    logger.info(
                        "[ForwardCache] Progress | epoch=%d step=%d cached_q=%d cached_docs=%d",
                        epoch,
                        step,
                        len(self._cur_epoch_query_emb_by_id or {}),
                        len(self._cur_epoch_doc_emb_by_id or {}),
                    )

            # Optimizer step (with gradient accumulation)
            if (step + 1) % self.config.training.gradient_accumulation_steps == 0:
                # Gradient clipping
                if self.config.training.fp16:
                    scaler.unscale_(self.optimizer)

                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.config.training.max_grad_norm
                )

                # Optimizer step
                if self.config.training.fp16:
                    scaler.step(self.optimizer)
                    scaler.update()
                else:
                    self.optimizer.step()

                self.scheduler.step()
                self.optimizer.zero_grad()

                self.global_step += 1

                # Logging
                if (
                    self.global_step % self.config.training.log_steps == 0
                    and self.rank == 0
                ):
                    current_lr = self.scheduler.get_last_lr()[0]
                    logger.info(
                        f"Step {self.global_step}: loss={loss.item() * self.config.training.gradient_accumulation_steps:.4f}, "
                        f"lr={current_lr:.2e}"
                    )
                    if self.writer is not None:
                        self.writer.add_scalar(
                            "train/loss",
                            loss.item()
                            * self.config.training.gradient_accumulation_steps,
                            self.global_step,
                        )
                        self.writer.add_scalar(
                            "train/learning_rate", current_lr, self.global_step
                        )

            total_loss += loss.item()
            num_steps += 1

            # Update progress bar
            if self.rank == 0:
                pbar.set_postfix(
                    {
                        "loss": f"{loss.item() * self.config.training.gradient_accumulation_steps:.4f}"
                    }
                )

            # Check if we should perform intermediate evaluation at this step
            if step + 1 in intermediate_eval_steps:
                logger.info(
                    f"\n{'='*80}\nIntermediate evaluation at epoch {epoch}, step {step + 1}\n{'='*80}"
                )

                # Set model back to eval mode for evaluation
                self.model.eval()

                # Only rank 0 performs evaluation
                if self.rank == 0:
                    step_metrics = self._evaluate_at_step(epoch, step + 1)
                    self._save_step_metrics(epoch, step + 1, step_metrics)
                    self._save_step_checkpoint(epoch, step + 1)

                # Synchronize all processes before resuming training
                if self.is_distributed:
                    dist.barrier()

                # Set model back to train mode
                self.model.train()

                logger.info(
                    f"Resuming training after intermediate evaluation\n{'='*80}\n"
                )

        avg_loss = (
            total_loss / num_steps * self.config.training.gradient_accumulation_steps
        )
        logger.info(f"Epoch {epoch} average loss: {avg_loss:.4f}")

        if self.rank == 0 and self.writer is not None:
            self.writer.add_scalar("train/epoch_loss", avg_loss, epoch)

        # Finalize forward caches for next epoch
        if (
            self.config.embedding_generation.cache_query_embeddings
            or self.config.embedding_generation.cache_document_embeddings
        ):
            if self.is_distributed:
                if self.rank == 0:
                    logger.info("[ForwardCache] Gathering per-rank caches to rank 0...")
                gathered_q = self._gather_object_to_rank0(self._cur_epoch_query_emb_by_id)
                gathered_d = self._gather_object_to_rank0(self._cur_epoch_doc_emb_by_id)
                if self.rank == 0:
                    per_rank_q = [len(x or {}) for x in gathered_q]
                    per_rank_d = [len(x or {}) for x in gathered_d]
                    logger.info(
                        "[ForwardCache] Per-rank cache sizes | queries=%s docs=%s",
                        str(per_rank_q),
                        str(per_rank_d),
                    )
                    self._prev_epoch_query_emb_by_id = self._merge_embedding_dicts(gathered_q)
                    self._prev_epoch_doc_emb_by_id = self._merge_embedding_dicts(gathered_d)
                    self._prev_epoch_cache_epoch = epoch
                    logger.info(
                        "[ForwardCache] Epoch %d cached: queries=%d docs=%d",
                        epoch,
                        len(self._prev_epoch_query_emb_by_id),
                        len(self._prev_epoch_doc_emb_by_id),
                    )
                    if self.writer is not None:
                        self.writer.add_scalar(
                            "forward_cache/cached_queries",
                            float(len(self._prev_epoch_query_emb_by_id)),
                            epoch,
                        )
                        self.writer.add_scalar(
                            "forward_cache/cached_docs",
                            float(len(self._prev_epoch_doc_emb_by_id)),
                            epoch,
                        )
                # Ensure all ranks see cache epoch marker via broadcast
                cache_epoch_list = [self._prev_epoch_cache_epoch] if self.rank == 0 else [None]
                dist.broadcast_object_list(cache_epoch_list, src=0)
                self._prev_epoch_cache_epoch = cache_epoch_list[0]
            else:
                self._prev_epoch_query_emb_by_id = self._cur_epoch_query_emb_by_id
                self._prev_epoch_doc_emb_by_id = self._cur_epoch_doc_emb_by_id
                self._prev_epoch_cache_epoch = epoch
                logger.info(
                    "[ForwardCache] Epoch %d cached: queries=%d docs=%d",
                    epoch,
                    len(self._prev_epoch_query_emb_by_id),
                    len(self._prev_epoch_doc_emb_by_id),
                )
                if self.rank == 0 and self.writer is not None:
                    self.writer.add_scalar(
                        "forward_cache/cached_queries",
                        float(len(self._prev_epoch_query_emb_by_id)),
                        epoch,
                    )
                    self.writer.add_scalar(
                        "forward_cache/cached_docs",
                        float(len(self._prev_epoch_doc_emb_by_id)),
                        epoch,
                    )

        # Free epoch-local references
        self._cur_epoch_query_emb_by_id = None
        self._cur_epoch_doc_emb_by_id = None

    def _evaluate_at_step(self, epoch: int, step: int) -> Dict[str, float]:
        """
        Evaluate on dev set at a specific step within an epoch.

        Args:
            epoch: Current epoch number
            step: Current step number within the epoch

        Returns:
            Dictionary of evaluation metrics
        """
        logger.info(f"Evaluating at epoch {epoch}, step {step}")

        # Get the underlying model
        model = self.model.module if self.is_distributed else self.model

        # Generate dev query embeddings
        dev_query_embeddings, _ = generate_embeddings(
            model=model,
            ids=self.dev_query_ids,
            id_to_text=self.dev_query_id_to_text,
            tokenizer=self.tokenizer,
            max_length=self.config.model.max_query_length,
            batch_size=self.config.embedding_generation.batch_size,
            device=self.device,
            is_query=True,
            show_progress=True,
            prefix=self.config.model.query_prefix,
        )

        # Always generate fresh document embeddings for evaluation
        doc_embeddings, _ = generate_embeddings(
            model=model,
            ids=self.eval_doc_ids,
            id_to_text=self.eval_doc_id_to_text,
            tokenizer=self.tokenizer,
            max_length=self.config.model.max_doc_length,
            batch_size=self.config.embedding_generation.batch_size,
            device=self.device,
            is_query=False,
            show_progress=True,
            prefix=self.config.model.document_prefix,
        )

        # Compute metrics
        metrics = compute_retrieval_metrics(
            query_embeddings=dev_query_embeddings,
            doc_embeddings=doc_embeddings,
            query_ids=self.dev_query_ids,
            doc_ids=self.eval_doc_ids,
            qrels=self.dev_qrels,
            k_values=[1, 5, 10, 20, 50, 100],
        )

        # Log metrics
        logger.info(f"Step {step} metrics: {metrics}")

        # Log to tensorboard with step-based tag
        if self.writer is not None:
            for metric_name, value in metrics.items():
                self.writer.add_scalar(
                    f"eval_by_step/{metric_name}", value, self.global_step
                )

        return metrics

    def _save_step_metrics(self, epoch: int, step: int, metrics: Dict[str, float]):
        """Save metrics from step-wise evaluation to JSON file."""
        metrics_path = os.path.join(
            self.config.experiment_dir,
            "results_by_steps",
            f"epoch-{epoch}-step-{step}_metrics.json",
        )
        with open(metrics_path, "w") as f:
            json.dump(metrics, f, indent=2)
        logger.info(f"Saved step metrics to {metrics_path}")

    def _save_step_checkpoint(self, epoch: int, step: int):
        """Save model checkpoint at a specific step within an epoch."""
        logger.info(f"Saving checkpoint for epoch {epoch}, step {step}")

        checkpoint_dir = os.path.join(
            self.config.experiment_dir,
            "checkpoints_by_steps",
            f"epoch-{epoch}-step-{step}",
        )
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Get the underlying model
        model_to_save = self.model.module if self.is_distributed else self.model

        # Save model
        model_to_save.save_pretrained(checkpoint_dir)

        # Save training state
        training_state = {
            "epoch": epoch,
            "step_in_epoch": step,
            "global_step": self.global_step,
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "config": (
                self.config.__dict__
                if hasattr(self.config, "__dict__")
                else str(self.config)
            ),
        }
        torch.save(training_state, os.path.join(checkpoint_dir, "training_state.pt"))

        logger.info(f"Step checkpoint saved to {checkpoint_dir}")

    def _evaluate(self, epoch: int) -> Dict[str, float]:
        """Evaluate on dev set."""
        logger.info(f"Evaluating epoch {epoch}")

        # Get the underlying model
        model = self.model.module if self.is_distributed else self.model

        # Generate dev query embeddings
        dev_query_embeddings, _ = generate_embeddings(
            model=model,
            ids=self.dev_query_ids,
            id_to_text=self.dev_query_id_to_text,
            tokenizer=self.tokenizer,
            max_length=self.config.model.max_query_length,
            batch_size=self.config.embedding_generation.batch_size,
            device=self.device,
            is_query=True,
            show_progress=True,
            prefix=self.config.model.query_prefix,
        )

        # Save dev query embeddings
        save_path = os.path.join(
            self.config.experiment_dir, "embeddings", f"epoch-{epoch}", "dev_queries"
        )
        save_embeddings(dev_query_embeddings, self.dev_query_ids, save_path)

        # Always generate fresh document embeddings for evaluation
        doc_embeddings, _ = generate_embeddings(
            model=model,
            ids=self.eval_doc_ids,
            id_to_text=self.eval_doc_id_to_text,
            tokenizer=self.tokenizer,
            max_length=self.config.model.max_doc_length,
            batch_size=self.config.embedding_generation.batch_size,
            device=self.device,
            is_query=False,
            show_progress=True,
            prefix=self.config.model.document_prefix,
        )

        # Compute metrics
        metrics = compute_retrieval_metrics(
            query_embeddings=dev_query_embeddings,
            doc_embeddings=doc_embeddings,
            query_ids=self.dev_query_ids,
            doc_ids=self.eval_doc_ids,
            qrels=self.dev_qrels,
            k_values=[1, 5, 10, 20, 50, 100],
        )

        # Log to tensorboard
        if self.writer is not None:
            for metric_name, value in metrics.items():
                self.writer.add_scalar(f"eval/{metric_name}", value, epoch)

        return metrics

    def _save_metrics(self, epoch: int, metrics: Dict[str, float]):
        """Save metrics to JSON file."""
        metrics_path = os.path.join(
            self.config.experiment_dir, "results", f"epoch-{epoch}_metrics.json"
        )
        with open(metrics_path, "w") as f:
            json.dump(metrics, f, indent=2)
        logger.info(f"Saved metrics to {metrics_path}")

    def _save_checkpoint(self, epoch):
        """Save model checkpoint."""
        if self.rank != 0:
            return

        logger.info(f"Saving checkpoint for epoch {epoch}")

        checkpoint_dir = os.path.join(
            self.config.experiment_dir,
            "checkpoints",
            f"epoch-{epoch}" if isinstance(epoch, int) else epoch,
        )
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Get the underlying model
        model_to_save = self.model.module if self.is_distributed else self.model

        # Save model
        model_to_save.save_pretrained(checkpoint_dir)

        # Save training state
        training_state = {
            "epoch": epoch if isinstance(epoch, int) else self.current_epoch,
            "global_step": self.global_step,
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "config": (
                self.config.__dict__
                if hasattr(self.config, "__dict__")
                else str(self.config)
            ),
        }
        torch.save(training_state, os.path.join(checkpoint_dir, "training_state.pt"))

        logger.info(f"Checkpoint saved to {checkpoint_dir}")

    def _save_best_checkpoint(self, epoch):
        """Save best model checkpoint to the 'best' folder."""
        if self.rank != 0:
            return

        logger.info(f"Saving best checkpoint from epoch {epoch}")

        checkpoint_dir = os.path.join(
            self.config.experiment_dir,
            "checkpoints",
            "best",
        )
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Get the underlying model
        model_to_save = self.model.module if self.is_distributed else self.model

        # Save model
        model_to_save.save_pretrained(checkpoint_dir)

        # Save training state with best epoch info
        training_state = {
            "epoch": epoch,
            "global_step": self.global_step,
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "best_metric_value": self.early_stopping.get_best_value(),
            "best_metric_name": self.early_stopping.metric_name,
            "config": (
                self.config.__dict__
                if hasattr(self.config, "__dict__")
                else str(self.config)
            ),
        }
        torch.save(training_state, os.path.join(checkpoint_dir, "training_state.pt"))

        logger.info(f"Best checkpoint saved to {checkpoint_dir}")
