"""Run this script with 'torchrun'."""

import gzip
import logging
import os
import sys
from datetime import timedelta
from pathlib import Path
from typing import Optional, TextIO
import swanlab

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import wandb
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.nn.parallel import DistributedDataParallel as DDP

from olmo.config import (
    CheckpointType,
    DDPGradSyncMode,
    DistributedStrategy,
    TrainConfig,
)
from olmo.data import build_train_dataloader
from olmo.eval import build_evaluators
from olmo.exceptions import OLMoCliError, OLMoConfigurationError
from olmo.model import OLMo
from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler
from olmo.torch_util import (
    SingleAccelerator,
    barrier,
    get_default_device,
    get_global_rank,
    get_local_rank,
    get_local_world_size,
    get_world_size,
    peak_gpu_memory,
    seed_all,
)
from olmo.train import Trainer
from olmo.util import (
    add_cached_path_clients,
    clean_opt,
    find_latest_checkpoint,
    log_extra_field,
    prepare_cli_environment,
)
from olmo.config import EvaluatorConfig, EvaluatorType, DataConfig

log = logging.getLogger("train")


def setup_file_logging(cfg: TrainConfig) -> None:
    """Setup file logging for training with distributed support."""
    # Create logs directory
    log_dir = Path(cfg.save_folder) / "logs"
    log_dir.mkdir(parents=True, exist_ok=True)

    # Get rank information
    global_rank = get_global_rank()
    local_rank = get_local_rank()

    # Setup file handlers
    log_file = log_dir / f"train_rank_{global_rank}.log"

    # Create file handler with detailed formatting
    file_handler = logging.FileHandler(log_file, mode="a")
    file_formatter = logging.Formatter(
        "%(asctime)s | Rank %(global_rank)s | %(name)s:%(lineno)s | %(levelname)s | %(message)s"
    )
    file_formatter.default_time_format = "%Y-%m-%d %H:%M:%S"
    file_handler.setFormatter(file_formatter)
    file_handler.setLevel(logging.DEBUG)

    # Add file handler to root logger to capture all logs
    root_logger = logging.getLogger()
    root_logger.addHandler(file_handler)

    # Also create a summary log file for rank 0 only
    summary_log_file: Optional[Path] = None
    if global_rank == 0:
        summary_log_file = log_dir / "train_summary.log"
        summary_handler = logging.FileHandler(summary_log_file, mode="a")
        summary_handler.setFormatter(file_formatter)
        summary_handler.setLevel(logging.INFO)

        # Filter to only get important logs
        def summary_filter(record):
            # Only log from main training components
            return record.name in ["train", "olmo.train", "olmo.model", "olmo.data", "__main__"]

        summary_handler.addFilter(summary_filter)
        root_logger.addHandler(summary_handler)

    # Force console logging for all ranks with key information
    console_handler = logging.StreamHandler(sys.stdout)
    console_formatter = logging.Formatter("%(asctime)s [Rank %(global_rank)s] %(levelname)s: %(message)s")
    console_formatter.default_time_format = "%H:%M:%S"
    console_handler.setFormatter(console_formatter)
    console_handler.setLevel(logging.INFO)

    # Remove existing handlers to avoid duplicates
    for handler in root_logger.handlers[:]:
        if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stdout:
            root_logger.removeHandler(handler)

    root_logger.addHandler(console_handler)
    root_logger.setLevel(logging.DEBUG)

    log.info(f"File logging setup complete. Logs will be saved to: {log_dir}")
    log.info(f"Main log file: {log_file}")
    if global_rank == 0:
        log.info(f"Summary log file: {summary_log_file}")


def main(cfg: TrainConfig) -> None:
    log.info("=== STARTING MAIN TRAINING FUNCTION ===")
    log.info(f"Run name: {cfg.run_name}")
    log.info(f"Save folder: {cfg.save_folder}")
    log.info(f"Global rank: {get_global_rank()}, World size: {get_world_size()}")

    # Ensure run name set.
    if cfg.run_name is None:
        raise OLMoConfigurationError("--run_name is required")
    log_extra_field("run_name", cfg.run_name)

    # Sanity check
    if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None:
        log.warning(
            "You want to reset the optimizer or trainer state, but we're not loading from the checkpoint. The"
            "setting has no effect."
        )

    barrier()

    # Set CUDA device.
    if torch.cuda.is_available():
        torch.cuda.set_device(f"cuda:{get_local_rank()}")
        torch.cuda.empty_cache()
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    # Fill some configuration options.
    cfg.model.precision = cfg.precision
    cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
    assert cfg.device_train_batch_size is not None  # for mypy
    cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size
    if cfg.optimizer.no_decay_norm_and_bias is not None:
        log.warning(
            "You set the deprecated config option `no_decay_norm_and_bias`. For compatibility, this"
            "setting will take precedence over all other weight decay configurations. Please change"
            "your config to use `decay_norm_and_bias` and `decay_embeddings` instead."
        )
        cfg.optimizer.decay_norm_and_bias = not cfg.optimizer.no_decay_norm_and_bias
        cfg.optimizer.decay_embeddings = not cfg.optimizer.no_decay_norm_and_bias
        cfg.optimizer.no_decay_norm_and_bias = None  # So nobody uses this by accident.

    # Display and save configuration.
    if get_global_rank() == 0:
        if cfg.data.paths is not None and len(cfg.data.paths) < 50:
            log.info("Configuration:")
            log.info(cfg)
        if not cfg.dry_run and (cfg.load_path is None or Path(cfg.load_path).parent != Path(cfg.save_folder)):
            # Save config.
            save_path = Path(cfg.save_folder) / "config.yaml"
            if save_path.is_file() and not cfg.save_overwrite:
                raise OLMoConfigurationError(f"{save_path} already exists, use --save_overwrite to overwrite")
            else:
                log.info(f"Saving config to {save_path}")
                save_path.parent.mkdir(exist_ok=True, parents=True)
                cfg.save(save_path)
            del save_path

    barrier()

    # Maybe start W&B run.
    if cfg.wandb is not None and (get_global_rank() == 0 or not cfg.wandb.rank_zero_only):
        wandb_dir = Path(cfg.save_folder) / "wandb"
        wandb_dir.mkdir(parents=True, exist_ok=True)
        wandb.init(
            dir=str(wandb_dir),
            project=cfg.wandb.project,
            entity=cfg.wandb.entity,
            group=cfg.wandb.group,
            name=cfg.wandb.name,
            tags=cfg.wandb.tags,
            mode=cfg.wandb.mode,
            config=cfg.asdict(exclude=["wandb"]),
        )
    elif cfg.swanlab is not None and (get_global_rank() == 0 or not cfg.swanlab.rank_zero_only):
        wandb_dir = Path(cfg.save_folder) / "wandb"
        wandb_dir.mkdir(parents=True, exist_ok=True)
        swanlab.init(
            logdir=str(wandb_dir),
            project=cfg.swanlab.project,
            experiment_name=cfg.swanlab.name,
            tags=cfg.swanlab.tags,
            config=cfg.asdict(exclude=["swanlab"]),
        )

    barrier()

    # Set seed.
    seed_all(cfg.seed)

    train_loader = build_train_dataloader(cfg)

    # Construct evaluators.
    evaluators = build_evaluators(cfg, device)
    barrier()

    # Initialize the model.
    log.info("Building model...")
    olmo_model = OLMo(cfg.model)
    log.info(f"Total number of parameters: {olmo_model.num_params():,d}")
    log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}")
    log.info(f"Peak GPU Memory (MB) before {cfg.distributed_strategy}: {int(peak_gpu_memory() or 0)}")

    # Compile one block at a time.
    if cfg.compile is not None:
        if cfg.model.block_group_size != 1:
            raise OLMoConfigurationError("Compile is only supported with block_group_size 1.")
        for block in olmo_model.transformer.blocks:
            block.compile(**cfg.compile.asdict())

    olmo_model.set_activation_checkpointing(cfg.activation_checkpointing)

    if cfg.distributed_strategy == DistributedStrategy.ddp:
        log.info("Wrapping model with DDP...")
        assert cfg.ddp is not None, "DistributedStrategy ddp needs cfg.ddp to be set!"

        if cfg.model.init_device != "cuda":
            raise OLMoConfigurationError("DDP does not work with init_device set to anything other than `cuda`.")

        if cfg.ddp.find_unused_params is True and cfg.ddp.grad_sync_mode != DDPGradSyncMode.micro_batch:
            raise OLMoConfigurationError(
                "`find_unused_params` is set to True. DDP needs to synchronize gradients for every micro-batch to avoid errors. Set `grad_sync_mode` to `micro_batch`."
            )

        param_init_fn = None

        # move to cuda before calling ddp
        dist_model = DDP(olmo_model.to(device), find_unused_parameters=cfg.ddp.find_unused_params)
    elif cfg.distributed_strategy == DistributedStrategy.fsdp:
        # Wrap the model in FSDP.
        log.info("Wrapping model with FSDP...")
        assert cfg.fsdp is not None, "DistributedStrategy fsdp needs cfg.fsdp to be set!"
        wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)

        if version.parse(torch.__version__) >= version.parse("2.1.0"):
            # This prevents any parameters from being initialized twice
            def dummy_init_fn(module: torch.nn.Module) -> None:
                module.to_empty(device=get_default_device())

            param_init_fn = dummy_init_fn
        else:
            param_init_fn = None

        # Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica
        device_mesh = None
        hybrid_sharding_fsdp_kwargs = {}
        if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2):
            if version.parse(torch.__version__) < version.parse("2.2.0"):
                # Device mesh was not added to PyTorch until v2.2.0
                raise OLMoConfigurationError(
                    "OLMo training does not correctly support hybrid sharding before torch 2.2.0"
                )

            from torch.distributed.device_mesh import init_device_mesh

            num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or (
                get_world_size() // get_local_world_size()
            )

            if num_model_replicas <= 0:
                raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer")

            if get_world_size() % num_model_replicas != 0:
                raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide world size")

            device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas))
            hybrid_sharding_fsdp_kwargs["device_mesh"] = device_mesh

        dist_model = FSDP(
            olmo_model,
            sharding_strategy=cfg.fsdp.sharding_strategy,
            mixed_precision=cfg.fsdp_precision,
            auto_wrap_policy=wrap_policy,
            use_orig_params=cfg.fsdp.use_orig_params,  # needed for compile and some of our optimizer/parameter metrics
            limit_all_gathers=True,
            device_id=get_local_rank(),
            param_init_fn=param_init_fn,
            **hybrid_sharding_fsdp_kwargs,
        )
    elif cfg.distributed_strategy == DistributedStrategy.single:
        param_init_fn = None
        if olmo_model is None:
            raise OLMoConfigurationError("Model initialization failed.")
        olmo_model = olmo_model.to(device)
        dist_model = SingleAccelerator(olmo_model)
    else:
        raise OLMoConfigurationError(f"Unknown distributed strategy: {cfg.distributed_strategy}")

    # when param_init_fn is None, FSDP will call reset_parameters() automatically
    if param_init_fn is not None or cfg.distributed_strategy == DistributedStrategy.ddp:
        olmo_model.reset_parameters()

    log.info(f"Peak GPU Memory (MB) after {cfg.distributed_strategy}: {int(peak_gpu_memory() or 0)}")
    log.info("Model:")
    log.info(dist_model)

    # Construct optimizer and learning rate scheduler.
    optim = build_optimizer(cfg, dist_model)
    scheduler = build_scheduler(cfg)

    # Data indices file.
    indices_file: Optional[TextIO] = None
    if cfg.save_data_indices:
        indices_file_path = Path(cfg.save_folder) / f"data-indices/rank{get_global_rank()}.tsv.gz"
        if indices_file_path.exists() and not cfg.save_overwrite:
            raise OLMoConfigurationError(f"{indices_file_path} already exists, use --save_overwrite to overwrite")
        indices_file_path.parent.mkdir(exist_ok=True, parents=True)
        indices_file = gzip.open(indices_file_path, "wt")

    # Consolidate components into `Trainer` object.
    with Trainer(
        cfg=cfg,
        epoch=cfg.epoch,
        model=olmo_model,
        dist_model=dist_model,
        optim=optim,
        scheduler=scheduler,
        train_loader=train_loader,
        device=device,
        evaluators=evaluators,
        indices_file=indices_file,
    ) as trainer:
        if cfg.try_load_latest_save:
            checkpoint_dir = None
            if (
                cfg.save_folder is not None
                and (checkpoint_dir := find_latest_checkpoint(cfg.save_folder)) is not None
            ):
                log.info("Setting load path to local checkpoint %s", checkpoint_dir)
                cfg.load_path = str(checkpoint_dir)
            elif (
                cfg.remote_save_folder is not None
                and (checkpoint_dir := find_latest_checkpoint(cfg.remote_save_folder)) is not None
            ):
                log.info("Setting load path to remote checkpoint %s", checkpoint_dir)
                cfg.load_path = str(checkpoint_dir)
            if checkpoint_dir is not None and not cfg.restore_dataloader:
                log.info(
                    "You set restore_dataloader=False, but try_load_latest_save=True. If we were to run like "
                    "this, it would overwrite your previous checkpoints. I will assume you didn't mean that, "
                    "and set restore_dataloader=True."
                )
                cfg.restore_dataloader = True
            if checkpoint_dir is not None and cfg.reset_trainer_state:
                log.info(
                    "You set both reset_trainer_state=True, and try_load_latest_save=True. If we were to "
                    "run like this, it would reset your trainer state right now even though we're in the "
                    "middle of a run. I will assume you didn't mean that, and set "
                    "reset_trainer_state=False."
                )
                cfg.reset_trainer_state = False
            if checkpoint_dir is not None and cfg.reset_optimizer_state:
                log.info(
                    "You set both reset_optimizer_state=True, and try_load_latest_save=True. If we were to "
                    "run like this, it would reset your optimizer state right now even though we're in the "
                    "middle of a run. I will assume you didn't mean that, and set "
                    "reset_optimizer_state=False."
                )
                cfg.reset_optimizer_state = False

        if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None:
            if cfg.distributed_strategy == DistributedStrategy.ddp:
                checkpoint_type = CheckpointType.unsharded

                if cfg.save_interval_unsharded is None:
                    log.warning(
                        "DDP requires setting `save_interval_unsharded`. Using the value set for `save_interval`."
                    )
                    cfg.save_interval_unsharded = cfg.save_interval

                if cfg.save_num_unsharded_checkpoints_to_keep == 0:
                    log.warning(
                        "DDP requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`."
                    )
                    cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep
            elif cfg.distributed_strategy == DistributedStrategy.fsdp:
                checkpoint_type = (
                    CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded
                )
            elif cfg.distributed_strategy == DistributedStrategy.single:
                checkpoint_type = CheckpointType.unsharded

                if cfg.save_interval_unsharded is None:
                    log.warning(
                        "single accelerator training requires setting `save_interval_unsharded`. Using the value set for `save_interval`."
                    )
                    cfg.save_interval_unsharded = cfg.save_interval

                if cfg.save_num_unsharded_checkpoints_to_keep == 0:
                    log.warning(
                        "single accelerator training requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`."
                    )
                    cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep
            else:
                raise OLMoConfigurationError(f"Unknown distributed strategy: {cfg.distributed_strategy}")

            # We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever).
            log.info("Saving pre-train checkpoint...")
            checkpoint_path, local_checkpoint_cache = trainer.save_checkpoint(checkpoint_type=checkpoint_type)
            log.info(f"Checkpoint saved to {checkpoint_path}")

            # And they we verify that we can load it.
            log.info("Attempting to load pre-train checkpoint...")
            trainer.restore_checkpoint(
                checkpoint_path, checkpoint_type=checkpoint_type, local_cache=local_checkpoint_cache
            )
            log.info("Checkpoint successfully loaded")

            # NOTE: https://github.com/allenai/LLM/issues/233
            #  log.info("Removing pre-train checkpoint...")
            #  trainer.remove_checkpoint(checkpoint_type=checkpoint_type)
            #  log.info("Successfully removed checkpoint")

        if cfg.load_path is not None:
            log.info(f"Loading checkpoint from {cfg.load_path}...")
            trainer.restore_checkpoint(
                cfg.load_path,
                load_optimizer_state=not cfg.reset_optimizer_state,
                load_trainer_state=not cfg.reset_trainer_state,
                sharded_checkpointer=cfg.load_path_sharded_checkpointer,
            )
            log.info("Checkpoint successfully loaded")

            # If we have to, set a new scheduler:
            if cfg.reset_optimizer_state and not cfg.reset_trainer_state:
                trainer.scheduler = BoltOnWarmupScheduler.wrap(
                    trainer.scheduler,
                    trainer.global_step,
                    int(trainer.global_step + cfg.scheduler.t_warmup),
                )

        if cfg.force_save_unsharded and cfg.distributed_strategy != DistributedStrategy.ddp:
            log.info("Saving unsharded checkpoint...")
            checkpoint_path, _ = trainer.save_checkpoint(checkpoint_type=CheckpointType.unsharded)
            log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

        if not cfg.dry_run:
            log.info("Starting training...")
            trainer.fit()
            log.info("Training complete")
        else:
            log.info("Dry run complete")


def worker(local_rank: int, cfg: TrainConfig, world_size: int, node_rank: int):
    """
    The worker function that will be spawned for each GPU.
    """
    # 1. Set up distributed environment for this worker
    gpus_per_node = torch.cuda.device_count()
    global_rank = node_rank * gpus_per_node + local_rank

    # It's important to set these for functions like `get_global_rank()` to work correctly.
    os.environ["RANK"] = str(global_rank)
    os.environ["LOCAL_RANK"] = str(local_rank)
    os.environ["WORLD_SIZE"] = str(world_size)

    # MASTER_ADDR and MASTER_PORT are inherited from the parent process.
    master_addr = os.environ["MASTER_ADDR"]
    master_port = os.environ["MASTER_PORT"]

    print(
        f"Initializing worker: rank={global_rank}, local_rank={local_rank}, "
        f"world_size={world_size}, master={master_addr}:{master_port}"
    )

    try:
        # Initialize the process group
        dist.init_process_group(
            backend="nccl",
            timeout=timedelta(minutes=60),
            rank=global_rank,
            world_size=world_size,
        )
        torch.cuda.set_device(local_rank)

        # 2. Prepare environment and call the main training function
        prepare_cli_environment()
        add_cached_path_clients()

        # Setup file logging, which depends on the rank.
        setup_file_logging(cfg)

        print(f"Starting main training function on rank {global_rank}...")
        main(cfg)

    finally:
        if dist.is_initialized():
            dist.destroy_process_group()


if __name__ == "__main__":
    # This script now acts as a launcher. It will spawn a process for each GPU.
    try:
        mp.set_start_method("spawn", force=True)
    except RuntimeError as e:
        print(f"Failed to set multiprocessing start method: {e}")

    # Load configuration from YAML file and CLI arguments.
    try:
        yaml_path, args_list = sys.argv[1], sys.argv[2:]
    except IndexError:
        raise OLMoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")
    cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])

    if cfg.data.dir is not None:
        all_npy_data_paths = list(Path(cfg.data.dir).glob("*.npy"))
        train_data_paths = all_npy_data_paths[:-1]
        train_data_paths = [str(s) for s in train_data_paths]
        val_data_paths = all_npy_data_paths[-1:]
        val_data_paths = [str(s) for s in val_data_paths]
        cfg.data.paths = train_data_paths
        cfg.evaluators.append(
            EvaluatorConfig(
                label="train-set-validation",
                data=DataConfig(
                    paths=val_data_paths,
                    num_workers=8,
                    pin_memory=True,
                    drop_last=True,
                ),
            )
        )

    # Handle non-GPU or single-process backends
    if not torch.cuda.is_available():
        print("Running in single process mode (CPU or MPS).")
        # We need to initialize a dummy process group for the single-process case.
        os.environ.setdefault("RANK", "0")
        os.environ.setdefault("LOCAL_RANK", "0")
        os.environ.setdefault("WORLD_SIZE", "1")
        os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
        os.environ.setdefault("MASTER_PORT", "29500")

        dist.init_process_group(backend="gloo")

        prepare_cli_environment()
        add_cached_path_clients()
        setup_file_logging(cfg)
        main(cfg)
    else:
        # Distributed training setup for Nebula
        print("=== OLMo Distributed Training Launcher for Nebula ===")

        # Get node-level info from Nebula's environment variables.
        # Fallback to defaults for local runs.
        node_rank = int(os.environ.get("RANK", "0"))
        num_nodes = int(os.environ.get("WORLD_SIZE", "1"))
        master_addr = os.environ.get("MASTER_ADDR")
        master_port = os.environ.get("MASTER_PORT")

        if master_addr is None or master_port is None:
            print("MASTER_ADDR or MASTER_PORT not set. Using localhost.")
            os.environ["MASTER_ADDR"] = "127.0.0.1"
            os.environ["MASTER_PORT"] = "29500"

        print(f"Node rank: {node_rank}, Total nodes: {num_nodes}")
        print(f"Master: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}")

        gpus_per_node = torch.cuda.device_count()
        world_size = num_nodes * gpus_per_node
        print(f"Found {gpus_per_node} GPUs per node. Total world size: {world_size}")

        if world_size <= 1:
            print("World size is 1, running in single process mode.")
            os.environ.setdefault("RANK", "0")
            os.environ.setdefault("LOCAL_RANK", "0")
            os.environ.setdefault("WORLD_SIZE", "1")
            torch.cuda.set_device(0)
            dist.init_process_group(backend="nccl")
            prepare_cli_environment()
            add_cached_path_clients()
            setup_file_logging(cfg)
            main(cfg)
        else:
            if gpus_per_node == 0:
                raise RuntimeError("No GPUs found on this node, but CUDA is available. Check CUDA setup.")

            # Spawn a worker process for each GPU.
            mp.spawn(
                worker,
                args=(cfg, world_size, node_rank),
                nprocs=gpus_per_node,
                join=True,
            )
