"""Run this script with 'torchrun'."""

import gzip
import logging
import os
import sys
from datetime import timedelta
from pathlib import Path
from typing import Optional, TextIO

import swanlab
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import wandb
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.nn.parallel import DistributedDataParallel as DDP

from olmo.config import (
    CheckpointType,
    DataConfig,
    DDPGradSyncMode,
    DistributedStrategy,
    EvaluatorConfig,
    EvaluatorType,
    TrainConfig,
)
from olmo.data import build_train_dataloader
from olmo.eval import build_evaluators
from olmo.exceptions import OLMoCliError, OLMoConfigurationError
from olmo.model import OLMo
from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler
from olmo.torch_util import (
    SingleAccelerator,
    barrier,
    get_default_device,
    get_global_rank,
    get_local_rank,
    get_local_world_size,
    get_world_size,
    peak_gpu_memory,
    seed_all,
)
from olmo.train import Trainer
from olmo.util import (
    add_cached_path_clients,
    clean_opt,
    find_latest_checkpoint,
    log_extra_field,
    prepare_cli_environment,
)

log = logging.getLogger("train")


def setup_enhanced_logging(cfg: TrainConfig) -> None:
    """
    Enhanced logging setup for cluster jobs with file logging and better formatting.
    """
    # Create logs directory
    log_dir = Path(cfg.save_folder) / "logs"
    log_dir.mkdir(parents=True, exist_ok=True)

    # Get rank information for distributed logging
    global_rank = get_global_rank()
    local_rank = get_local_rank()
    world_size = get_world_size()

    # Configure root logger
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.DEBUG)

    # Clear existing handlers to avoid duplicates
    root_logger.handlers.clear()

    # 1. File handler for individual rank logs (detailed)
    rank_log_file = log_dir / f"train_rank_{global_rank:04d}.log"
    file_handler = logging.FileHandler(rank_log_file, mode="a")
    file_formatter = logging.Formatter(
        fmt="%(asctime)s | Rank %(global_rank)04d | %(name)s:%(lineno)d | %(levelname)-8s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    file_handler.setFormatter(file_formatter)
    file_handler.setLevel(logging.DEBUG)

    # Add rank info to log records
    def add_rank_info(record):
        record.global_rank = global_rank
        record.local_rank = local_rank
        record.world_size = world_size
        return True

    file_handler.addFilter(add_rank_info)
    root_logger.addHandler(file_handler)

    # 2. Console handler with rank filtering (only rank 0 for INFO, all ranks for WARNING+)
    console_handler = logging.StreamHandler(sys.stdout)
    console_formatter = logging.Formatter(
        fmt="%(asctime)s [Rank %(global_rank)04d/%(world_size)04d] %(levelname)-8s: %(message)s",
        datefmt="%H:%M:%S",
    )
    console_handler.setFormatter(console_formatter)
    console_handler.setLevel(logging.INFO)

    def console_filter(record):
        record.global_rank = global_rank
        record.local_rank = local_rank
        record.world_size = world_size
        # Only show INFO+ from rank 0, but WARNING+ from all ranks
        if record.levelno >= logging.WARNING:
            return True
        return global_rank == 0

    console_handler.addFilter(console_filter)
    root_logger.addHandler(console_handler)

    # 3. Master log file (rank 0 only) - aggregated important logs
    if global_rank == 0:
        master_log_file = log_dir / "train_master.log"
        master_handler = logging.FileHandler(master_log_file, mode="a")
        master_formatter = logging.Formatter(
            fmt="%(asctime)s | %(name)s:%(lineno)d | %(levelname)-8s | %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )
        master_handler.setFormatter(master_formatter)
        master_handler.setLevel(logging.INFO)

        # Filter for important logs only
        def master_filter(record):
            # Include logs from main training components
            important_loggers = [
                "train",
                "__main__",
                "olmo.train",
                "olmo.model",
                "olmo.data",
                "olmo.optim",
                "olmo.checkpoint",
            ]
            return any(record.name.startswith(logger) for logger in important_loggers)

        master_handler.addFilter(master_filter)
        root_logger.addHandler(master_handler)

        # 4. Error log file (rank 0 only) - errors and warnings
        error_log_file = log_dir / "train_errors.log"
        error_handler = logging.FileHandler(error_log_file, mode="a")
        error_handler.setFormatter(master_formatter)
        error_handler.setLevel(logging.WARNING)
        root_logger.addHandler(error_handler)

    # 5. Progress log file (rank 0 only) - training progress
    if global_rank == 0:
        progress_log_file = log_dir / "train_progress.log"
        progress_handler = logging.FileHandler(progress_log_file, mode="a")
        progress_formatter = logging.Formatter(
            fmt="%(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
        )
        progress_handler.setFormatter(progress_formatter)
        progress_handler.setLevel(logging.INFO)

        # Filter for progress-related logs
        def progress_filter(record):
            progress_keywords = [
                "step=",
                "epoch=",
                "loss=",
                "lr=",
                "tokens/sec",
                "Training complete",
                "Checkpoint saved",
                "Starting training",
            ]
            return any(keyword in record.getMessage() for keyword in progress_keywords)

        progress_handler.addFilter(progress_filter)
        root_logger.addHandler(progress_handler)

    # Log the logging setup
    log.info(f"Enhanced logging setup complete for rank {global_rank}/{world_size}")
    log.info(f"Log directory: {log_dir}")
    log.info(f"Rank log file: {rank_log_file}")
    if global_rank == 0:
        log.info(f"Master log file: {master_log_file}")
        log.info(f"Error log file: {error_log_file}")
        log.info(f"Progress log file: {progress_log_file}")


def main(cfg: TrainConfig) -> None:
    # Ensure run name set.
    if cfg.run_name is None:
        raise OLMoConfigurationError("--run_name is required")
    log_extra_field("run_name", cfg.run_name)

    # Setup enhanced logging early
    setup_enhanced_logging(cfg)

    # Log system information
    log.info(f"=== TRAINING SESSION START ===")
    log.info(f"Run name: {cfg.run_name}")
    log.info(f"Global rank: {get_global_rank()}/{get_world_size()}")
    log.info(f"Local rank: {get_local_rank()}/{get_local_world_size()}")
    log.info(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        log.info(f"CUDA device: {torch.cuda.get_device_name()}")
        log.info(
            f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
        )

    # Sanity check
    if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None:
        log.warning(
            "You want to reset the optimizer or trainer state, but we're not loading from the checkpoint. The"
            "setting has no effect."
        )

    barrier()

    # Set CUDA device.
    if torch.cuda.is_available():
        torch.cuda.set_device(f"cuda:{get_local_rank()}")
        torch.cuda.empty_cache()
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    # Fill some configuration options.
    cfg.model.precision = cfg.precision
    cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
    assert cfg.device_train_batch_size is not None  # for mypy
    cfg.device_train_grad_accum = (
        cfg.device_train_batch_size // cfg.device_train_microbatch_size
    )
    if cfg.optimizer.no_decay_norm_and_bias is not None:
        log.warning(
            "You set the deprecated config option `no_decay_norm_and_bias`. For compatibility, this"
            "setting will take precedence over all other weight decay configurations. Please change"
            "your config to use `decay_norm_and_bias` and `decay_embeddings` instead."
        )
        cfg.optimizer.decay_norm_and_bias = not cfg.optimizer.no_decay_norm_and_bias
        cfg.optimizer.decay_embeddings = not cfg.optimizer.no_decay_norm_and_bias
        cfg.optimizer.no_decay_norm_and_bias = None  # So nobody uses this by accident.

    # Display and save configuration.
    if get_global_rank() == 0:
        if cfg.data.paths is not None and len(cfg.data.paths) < 50:
            log.info("Configuration:")
            log.info(cfg)
        if not cfg.dry_run and (
            cfg.load_path is None or Path(cfg.load_path).parent != Path(cfg.save_folder)
        ):
            # Save config.
            save_path = Path(cfg.save_folder) / "config.yaml"
            if save_path.is_file() and not cfg.save_overwrite:
                raise OLMoConfigurationError(
                    f"{save_path} already exists, use --save_overwrite to overwrite"
                )
            else:
                log.info(f"Saving config to {save_path}")
                save_path.parent.mkdir(exist_ok=True, parents=True)
                cfg.save(save_path)
            del save_path

    barrier()

    # Maybe start W&B run.
    if cfg.wandb is not None and (
        get_global_rank() == 0 or not cfg.wandb.rank_zero_only
    ):
        wandb_dir = Path(cfg.save_folder) / "wandb"
        wandb_dir.mkdir(parents=True, exist_ok=True)
        wandb.init(
            dir=str(wandb_dir),
            project=cfg.wandb.project,
            entity=cfg.wandb.entity,
            group=cfg.wandb.group,
            name=cfg.wandb.name,
            tags=cfg.wandb.tags,
            config=cfg.asdict(exclude=["wandb"]),
        )
    elif cfg.swanlab is not None and (
        get_global_rank() == 0 or not cfg.swanlab.rank_zero_only
    ):
        wandb_dir = Path(cfg.save_folder) / "wandb"
        wandb_dir.mkdir(parents=True, exist_ok=True)
        swanlab.init(
            logdir=str(wandb_dir),
            project=cfg.swanlab.project,
            experiment_name=cfg.swanlab.name,
            tags=cfg.swanlab.tags,
            config=cfg.asdict(exclude=["swanlab"]),
        )

    barrier()

    # Set seed.
    seed_all(cfg.seed)

    train_loader = build_train_dataloader(cfg)

    # Construct evaluators.
    evaluators = build_evaluators(cfg, device)
    barrier()

    # Initialize the model.
    log.info("Building model...")
    olmo_model = OLMo(cfg.model)
    log.info(f"Total number of parameters: {olmo_model.num_params():,d}")
    log.info(
        f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}"
    )
    if olmo_model.is_moe:
        log.info(
            f"Number of activated parameters: {olmo_model.num_activated_params():,d}"
        )
        log.info(
            f"Number of activated non-embedding parameters: {olmo_model.num_activated_params(include_embedding=False):,d}"
        )
    log.info(
        f"Peak GPU Memory (MB) before {cfg.distributed_strategy}: {int(peak_gpu_memory() or 0)}"
    )

    # Compile one block at a time.
    if cfg.compile is not None:
        if cfg.model.block_group_size != 1:
            raise OLMoConfigurationError(
                "Compile is only supported with block_group_size 1."
            )
        for block in olmo_model.transformer.blocks:
            block.compile(**cfg.compile.asdict())

    olmo_model.set_activation_checkpointing(cfg.activation_checkpointing)

    if cfg.distributed_strategy == DistributedStrategy.ddp:
        log.info("Wrapping model with DDP...")
        assert cfg.ddp is not None, "DistributedStrategy ddp needs cfg.ddp to be set!"

        if cfg.model.init_device != "cuda":
            raise OLMoConfigurationError(
                "DDP does not work with init_device set to anything other than `cuda`."
            )

        if (
            cfg.ddp.find_unused_params is True
            and cfg.ddp.grad_sync_mode != DDPGradSyncMode.micro_batch
        ):
            raise OLMoConfigurationError(
                "`find_unused_params` is set to True. DDP needs to synchronize gradients for every micro-batch to avoid errors. Set `grad_sync_mode` to `micro_batch`."
            )

        param_init_fn = None

        # move to cuda before calling ddp
        dist_model = DDP(
            olmo_model.to(device), find_unused_parameters=cfg.ddp.find_unused_params
        )
    elif cfg.distributed_strategy == DistributedStrategy.fsdp:
        # Wrap the model in FSDP.
        log.info("Wrapping model with FSDP...")
        assert (
            cfg.fsdp is not None
        ), "DistributedStrategy fsdp needs cfg.fsdp to be set!"
        wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)

        if version.parse(torch.__version__) >= version.parse("2.1.0"):
            # This prevents any parameters from being initialized twice
            def dummy_init_fn(module: torch.nn.Module) -> None:
                module.to_empty(device=get_default_device())

            param_init_fn = dummy_init_fn
        else:
            param_init_fn = None

        # Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica
        device_mesh = None
        hybrid_sharding_fsdp_kwargs = {}
        if cfg.fsdp.sharding_strategy in (
            ShardingStrategy.HYBRID_SHARD,
            ShardingStrategy._HYBRID_SHARD_ZERO2,
        ):
            if version.parse(torch.__version__) < version.parse("2.2.0"):
                # Device mesh was not added to PyTorch until v2.2.0
                raise OLMoConfigurationError(
                    "OLMo training does not correctly support hybrid sharding before torch 2.2.0"
                )

            from torch.distributed.device_mesh import init_device_mesh

            num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or (
                get_world_size() // get_local_world_size()
            )

            if num_model_replicas <= 0:
                raise OLMoConfigurationError(
                    "fsdp.hybrid_sharding_num_model_replicas must be a positive integer"
                )

            if get_world_size() % num_model_replicas != 0:
                raise OLMoConfigurationError(
                    "fsdp.hybrid_sharding_num_model_replicas must divide world size"
                )

            device_mesh = init_device_mesh(
                "cuda", (num_model_replicas, get_world_size() // num_model_replicas)
            )
            hybrid_sharding_fsdp_kwargs["device_mesh"] = device_mesh

        dist_model = FSDP(
            olmo_model,
            sharding_strategy=cfg.fsdp.sharding_strategy,
            mixed_precision=cfg.fsdp_precision,
            auto_wrap_policy=wrap_policy,
            use_orig_params=cfg.fsdp.use_orig_params,  # needed for compile and some of our optimizer/parameter metrics
            limit_all_gathers=True,
            device_id=get_local_rank(),
            param_init_fn=param_init_fn,
            **hybrid_sharding_fsdp_kwargs,
        )
    elif cfg.distributed_strategy == DistributedStrategy.single:
        param_init_fn = None
        if olmo_model is None:
            raise OLMoConfigurationError("Model initialization failed.")
        olmo_model = olmo_model.to(device)
        dist_model = SingleAccelerator(olmo_model)

    # when param_init_fn is None, FSDP will call reset_parameters() automatically
    if param_init_fn is not None or cfg.distributed_strategy == DistributedStrategy.ddp:
        olmo_model.reset_parameters()

    log.info(
        f"Peak GPU Memory (MB) after {cfg.distributed_strategy}: {int(peak_gpu_memory() or 0)}"
    )
    log.info("Model:")
    log.info(dist_model)

    # Construct optimizer and learning rate scheduler.
    optim = build_optimizer(cfg, dist_model)
    scheduler = build_scheduler(cfg)

    # Data indices file.
    indices_file: Optional[TextIO] = None
    if cfg.save_data_indices:
        indices_file_path = (
            Path(cfg.save_folder) / f"data-indices/rank{get_global_rank()}.tsv.gz"
        )
        if indices_file_path.exists() and not cfg.save_overwrite:
            raise OLMoConfigurationError(
                f"{indices_file_path} already exists, use --save_overwrite to overwrite"
            )
        indices_file_path.parent.mkdir(exist_ok=True, parents=True)
        indices_file = gzip.open(indices_file_path, "wt")

    # Consolidate components into `Trainer` object.
    with Trainer(
        cfg=cfg,
        epoch=cfg.epoch,
        model=olmo_model,
        dist_model=dist_model,
        optim=optim,
        scheduler=scheduler,
        train_loader=train_loader,
        device=device,
        evaluators=evaluators,
        indices_file=indices_file,
    ) as trainer:
        if cfg.try_load_latest_save:
            checkpoint_dir = None
            if (
                cfg.save_folder is not None
                and (checkpoint_dir := find_latest_checkpoint(cfg.save_folder))
                is not None
            ):
                log.info("Setting load path to local checkpoint %s", checkpoint_dir)
                cfg.load_path = str(checkpoint_dir)
            elif (
                cfg.remote_save_folder is not None
                and (checkpoint_dir := find_latest_checkpoint(cfg.remote_save_folder))
                is not None
            ):
                log.info("Setting load path to remote checkpoint %s", checkpoint_dir)
                cfg.load_path = str(checkpoint_dir)
            if checkpoint_dir is not None and not cfg.restore_dataloader:
                log.info(
                    "You set restore_dataloader=False, but try_load_latest_save=True. If we were to run like "
                    "this, it would overwrite your previous checkpoints. I will assume you didn't mean that, "
                    "and set restore_dataloader=True."
                )
                cfg.restore_dataloader = True
            if checkpoint_dir is not None and cfg.reset_trainer_state:
                log.info(
                    "You set both reset_trainer_state=True, and try_load_latest_save=True. If we were to "
                    "run like this, it would reset your trainer state right now even though we're in the "
                    "middle of a run. I will assume you didn't mean that, and set "
                    "reset_trainer_state=False."
                )
                cfg.reset_trainer_state = False
            if checkpoint_dir is not None and cfg.reset_optimizer_state:
                log.info(
                    "You set both reset_optimizer_state=True, and try_load_latest_save=True. If we were to "
                    "run like this, it would reset your optimizer state right now even though we're in the "
                    "middle of a run. I will assume you didn't mean that, and set "
                    "reset_optimizer_state=False."
                )
                cfg.reset_optimizer_state = False

        if (
            not cfg.dry_run
            and not cfg.no_pre_train_checkpoint
            and cfg.load_path is None
        ):
            if cfg.distributed_strategy == DistributedStrategy.ddp:
                checkpoint_type = CheckpointType.unsharded

                if cfg.save_interval_unsharded is None:
                    log.warning(
                        "DDP requires setting `save_interval_unsharded`. Using the value set for `save_interval`."
                    )
                    cfg.save_interval_unsharded = cfg.save_interval

                if cfg.save_num_unsharded_checkpoints_to_keep == 0:
                    log.warning(
                        "DDP requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`."
                    )
                    cfg.save_num_unsharded_checkpoints_to_keep = (
                        cfg.save_num_checkpoints_to_keep
                    )
            elif cfg.distributed_strategy == DistributedStrategy.fsdp:
                checkpoint_type = (
                    CheckpointType.sharded
                    if cfg.save_num_checkpoints_to_keep != 0
                    else CheckpointType.unsharded
                )
            elif cfg.distributed_strategy == DistributedStrategy.single:
                checkpoint_type = CheckpointType.unsharded

                if cfg.save_interval_unsharded is None:
                    log.warning(
                        "single accelerator training requires setting `save_interval_unsharded`. Using the value set for `save_interval`."
                    )
                    cfg.save_interval_unsharded = cfg.save_interval

                if cfg.save_num_unsharded_checkpoints_to_keep == 0:
                    log.warning(
                        "single accelerator training requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`."
                    )
                    cfg.save_num_unsharded_checkpoints_to_keep = (
                        cfg.save_num_checkpoints_to_keep
                    )

            # We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever).
            log.info("Saving pre-train checkpoint...")
            checkpoint_path, local_checkpoint_cache = trainer.save_checkpoint(
                checkpoint_type=checkpoint_type
            )
            log.info(f"Checkpoint saved to {checkpoint_path}")

            # And they we verify that we can load it.
            log.info("Attempting to load pre-train checkpoint...")
            trainer.restore_checkpoint(
                checkpoint_path,
                checkpoint_type=checkpoint_type,
                local_cache=local_checkpoint_cache,
            )
            log.info("Checkpoint successfully loaded")

            # NOTE: https://github.com/allenai/LLM/issues/233
            #  log.info("Removing pre-train checkpoint...")
            #  trainer.remove_checkpoint(checkpoint_type=checkpoint_type)
            #  log.info("Successfully removed checkpoint")

        if cfg.load_path is not None:
            log.info(f"Loading checkpoint from {cfg.load_path}...")
            trainer.restore_checkpoint(
                cfg.load_path,
                load_optimizer_state=not cfg.reset_optimizer_state,
                load_trainer_state=not cfg.reset_trainer_state,
                sharded_checkpointer=cfg.load_path_sharded_checkpointer,
            )
            log.info("Checkpoint successfully loaded")

            # If we have to, set a new scheduler:
            if cfg.reset_optimizer_state and not cfg.reset_trainer_state:
                trainer.scheduler = BoltOnWarmupScheduler.wrap(
                    trainer.scheduler,
                    trainer.global_step,
                    int(trainer.global_step + cfg.scheduler.t_warmup),
                )

        if (
            cfg.force_save_unsharded
            and cfg.distributed_strategy != DistributedStrategy.ddp
        ):
            log.info("Saving unsharded checkpoint...")
            checkpoint_path, _ = trainer.save_checkpoint(
                checkpoint_type=CheckpointType.unsharded
            )
            log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

        if not cfg.dry_run:
            log.info("Starting training...")
            trainer.fit()
            log.info("Training complete")
        else:
            log.info("Dry run complete")


if __name__ == "__main__":
    try:
        mp.set_start_method("spawn", force=True)
    except RuntimeError as e:
        print(f"failed to set multiprocessing start method: {e}")
    log.info(f"Multiprocessing start method set to '{mp.get_start_method()}'")
    if torch.cuda.is_available():
        # Set CUDA device.
        torch.cuda.set_device(f"cuda:{get_local_rank()}")

        # Initialize process group.
        device_as_string = f"cuda:{get_local_rank()}"
        torch.cuda.set_device(
            device_as_string
        )  # Set this early to prevent GPU 0 from picking up a bunch of tensors it shouldn't have.
        dist.init_process_group(
            backend="nccl",
            timeout=timedelta(minutes=30),
            device_id=torch.device(device_as_string),
        )
    elif torch.backends.mps.is_available():
        if not os.getenv("RANK"):
            os.environ["RANK"] = "0"
        if not os.getenv("WORLD_SIZE"):
            os.environ["WORLD_SIZE"] = "1"
        if not os.getenv("MASTER_ADDR"):
            os.environ["MASTER_ADDR"] = "0.0.0.0"
        if not os.getenv("MASTER_PORT"):
            os.environ["MASTER_PORT"] = "24501"
        dist.init_process_group(backend="gloo", timeout=timedelta(minutes=30))

    else:
        dist.init_process_group(backend="gloo", timeout=timedelta(minutes=30))

    log.info("Process group initialized")

    prepare_cli_environment()
    log.info("CLI environment prepared")

    add_cached_path_clients()

    try:
        yaml_path, args_list = sys.argv[1], sys.argv[2:]
    except IndexError:
        raise OLMoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")

    cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])

    if cfg.data.dir is not None:
        all_npy_data_paths = list(Path(cfg.data.dir).glob("*.npy"))
        train_data_paths = all_npy_data_paths[:-1]
        train_data_paths = [str(s) for s in train_data_paths]
        val_data_paths = all_npy_data_paths[-1:]
        val_data_paths = [str(s) for s in val_data_paths]
        cfg.data.paths = train_data_paths
        cfg.evaluators.append(
            EvaluatorConfig(
                label="train-set-validation",
                data=DataConfig(
                    paths=val_data_paths,
                    num_workers=8,
                    pin_memory=True,
                    drop_last=True,
                ),
            )
        )

    if torch.backends.mps.is_available():
        log.info("Device is MPS. Updating config...")
        cfg.model.init_device = "mps"
        cfg.distributed_strategy = "single"  # type: ignore

    if not torch.cuda.is_available() and not torch.backends.mps.is_available():
        log.info("Device is CPU. Updating config...")
        cfg.model.init_device = "cpu"
        cfg.distributed_strategy = "single"  # type: ignore
    main(cfg)
