"""
Main training script entry point.

Usage:
    # Single GPU
    python train.py --config path/to/config.yaml

    # Multi-GPU with torchrun
    torchrun --nproc_per_node=4 train.py --config path/to/config.yaml
"""

import argparse
import os
import sys
import torch
import torch.distributed as dist

# Add src_new to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from src.config import load_config
from src.training import Trainer


def setup_distributed():
    """Setup distributed training if using torchrun."""
    from datetime import timedelta

    timeout = timedelta(seconds=240 * 60)

    if "LOCAL_RANK" in os.environ:
        local_rank = int(os.environ["LOCAL_RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        rank = int(os.environ["RANK"])

        # Initialize process group
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            world_size=world_size,
            rank=rank,
            timeout=timeout,
        )

        return local_rank
    else:
        return -1


def main():
    parser = argparse.ArgumentParser(description="Train dense retrieval model")
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Path to configuration YAML file",
    )
    args = parser.parse_args()

    # Load configuration
    config = load_config(args.config)

    # Setup distributed training
    local_rank = setup_distributed()

    # Create trainer
    trainer = Trainer(config=config, local_rank=local_rank)

    # Train
    trainer.train()

    # Cleanup
    if local_rank != -1:
        dist.destroy_process_group()


if __name__ == "__main__":
    main()
