# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib
import os
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional

import torch
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed import get_process_group_ranks


import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.metrics import (
    build_metrics_processor,
    ensure_pp_loss_visible,
)
from torchtitan.config_manager import ConfigManager, JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.protocols.model_converter import build_model_converters
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.profiling import (
    maybe_enable_memory_snapshot,
    maybe_enable_profiling,
)

from torchtitan.tools.utils import debug_memory_on_rank, print_memory_on_rank, debug_on_rank

from torch.profiler import profile, record_function, ProfilerActivity

class Trainer(torch.distributed.checkpoint.stateful.Stateful):
    job_config: JobConfig
    gc_handler: utils.GarbageCollection

    parallel_dims: ParallelDims
    train_spec: train_spec_module.TrainSpec
    world_mesh: torch.distributed.DeviceMesh

    dataloader: train_spec_module.BaseDataLoader
    metrics_processor: train_spec_module.MetricsProcessor
    checkpointer: CheckpointManager
    train_context: Generator[None, None, None]

    model_parts: list[torch.nn.Module]
    loss_fn: train_spec_module.LossFunction
    optimizers: train_spec_module.OptimizersContainer
    lr_schedulers: train_spec_module.LRSchedulersContainer

    pp_has_first_stage: bool
    pp_has_last_stage: bool

    device: torch.device

    # states
    step: int

    # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
    @record
    def __init__(self, job_config: JobConfig):
        self.job_config = job_config

        # ===== COMPREHENSIVE SEEDING FOR REPRODUCIBILITY =====
        # Set all possible seeds early to ensure reproducibility
        import random
        import numpy as np
        
        seed_value = job_config.training.seed
        
        # Python built-in random module
        random.seed(seed_value)
        
        # NumPy random number generator
        try:
            np.random.seed(seed_value)
        except ImportError:
            pass  # NumPy might not be available
        
        # PyTorch CPU random number generator
        torch.manual_seed(seed_value)
        
        # PyTorch CUDA random number generator (if available)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed_value)
            torch.cuda.manual_seed_all(seed_value)
        
        # Python hash seed (for deterministic hashing)
        os.environ["PYTHONHASHSEED"] = str(seed_value)
        
        # Enable deterministic algorithms if specified in config
        if job_config.training.deterministic:
            torch.use_deterministic_algorithms(True, warn_only=True)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            # Set environment variable for deterministic CuBLAS operations
            os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        
        logger.info(f"🌱 Comprehensive seeding enabled with seed={seed_value} for exact reproducibility")
        if job_config.training.deterministic:
            logger.info("🔒 Deterministic algorithms enabled (expect performance impact)")
        # ===== END SEEDING SECTION =====

        logger.info(f"Starting job: {job_config.job.description}")

        if job_config.experimental.custom_import:
            importlib.import_module(job_config.experimental.custom_import)

        if job_config.job.print_args:
            logger.info(f"Running with args: {job_config.to_dict()}")

        # take control of garbage collection to avoid stragglers
        self.gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)

        device_module, device_type = utils.device_module, utils.device_type
        self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
        # Device has to be set before creating TorchFT manager.
        device_module.set_device(self.device)

        # init distributed
        world_size = int(os.environ["WORLD_SIZE"])
        parallelism_config = job_config.parallelism
        self.parallel_dims = parallel_dims = ParallelDims(
            dp_shard=parallelism_config.data_parallel_shard_degree,
            dp_replicate=parallelism_config.data_parallel_replicate_degree,
            cp_ring=parallelism_config.context_parallel_degree,
            cp_ulysses=parallelism_config.context_parallel_ulysses_degree,
            tp=parallelism_config.tensor_parallel_degree,
            pp=parallelism_config.pipeline_parallel_degree,
            world_size=world_size,
            enable_loss_parallel=not parallelism_config.disable_loss_parallel,
        )
        dist_utils.init_distributed(job_config)

        # build meshes
        self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type)
        if parallel_dims.dp_enabled:
            dp_mesh = world_mesh["dp"]
            dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
        else:
            dp_degree, dp_rank = 1, 0
        
        # Debug: Print context parallel topology if enabled
        if parallel_dims.cp_enabled and torch.distributed.get_rank() == 0:
            logger.info("=== Context Parallel Topology Debug ===")
            if parallel_dims.cp_ring > 1:
                cp_ring_group = world_mesh.get_group("cp_ring")
                ring_ranks = get_process_group_ranks(cp_ring_group)
                logger.info(f"CP Ring group for rank 0: {ring_ranks}")
                logger.info(f"Node distribution: Node 0={[r for r in ring_ranks if r < 8]}, Node 1={[r for r in ring_ranks if 8 <= r < 16]}, Node 2={[r for r in ring_ranks if 16 <= r < 24]}, Node 3={[r for r in ring_ranks if r >= 24]}")
            
            if parallel_dims.cp_ulysses > 1:
                cp_ulysses_group = world_mesh.get_group("cp_ulysses")
                ulysses_ranks = get_process_group_ranks(cp_ulysses_group) 
                logger.info(f"CP Ulysses group for rank 0: {ulysses_ranks}")
                logger.info(f"Node distribution: Node 0={[r for r in ulysses_ranks if r < 8]}, Node 1={[r for r in ulysses_ranks if 8 <= r < 16]}, Node 2={[r for r in ulysses_ranks if 16 <= r < 24]}, Node 3={[r for r in ulysses_ranks if r >= 24]}")
            logger.info("=========================================")
        
        self.ft_manager = ft.init_ft_manager(job_config)
        # If TorchFT is enabled, the dp_rank and dp_degree, which are used for
        # dataloader must be changed.
        if self.ft_manager.enabled:
            dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)

        # Set random seed, and maybe enable deterministic mode
        # (mainly for debugging, expect perf loss).
        dist_utils.set_determinism(
            world_mesh,
            self.device,
            job_config.training.seed,
            job_config.training.deterministic,
        )
        self.train_spec = train_spec_module.get_train_spec(job_config.model.name)

        # build dataloader
        tokenizer = (
            self.train_spec.build_tokenizer_fn(job_config)
            if self.train_spec.build_tokenizer_fn is not None
            else None
        )

        self.dataloader = self.train_spec.build_dataloader_fn(
            dp_world_size=dp_degree,
            dp_rank=dp_rank,
            tokenizer=tokenizer,
            job_config=job_config,
        )

        # Final verification: Re-seed all RNGs after dataloader creation to ensure
        # no component has reset the seeds during initialization
        seed_value = job_config.training.seed
        torch.manual_seed(seed_value)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed_value)
        
        # Verify and log RNG states for debugging (only on rank 0)
        if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
            logger.info("🔍 Final RNG State Verification:")
            logger.info(f"  - PyTorch CPU RNG state hash: {hash(str(torch.get_rng_state())[:100])}")
            if torch.cuda.is_available():
                logger.info(f"  - PyTorch CUDA RNG state hash: {hash(str(torch.cuda.get_rng_state())[:100])}")
            logger.info(f"✅ All seeds verified with seed={seed_value} - exact reproducibility should be achieved")
        elif not torch.distributed.is_initialized():
            logger.info(f"🔒 Single process: Final seed verification with seed={seed_value}")

        # build model (using meta init)
        model_cls = self.train_spec.cls
        model_args = self.train_spec.config[job_config.model.flavor]
        # set the model args from training job configs
        model_args.update_from_config(job_config, tokenizer)

        logger.info(
            f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
        )
        debug_memory_on_rank(0, "Debug point 1: Before model creation")
        with torch.device("meta"):
            model = model_cls.from_model_args(model_args)

        debug_memory_on_rank(0, "Debug point 2: After model creation (meta device)")
        
        # Build the collection of model converters. No-op if `model.converters` empty
        model_converters = build_model_converters(job_config, parallel_dims)
        model_converters.convert(model)

        # metrics logging
        build_metrics_processor_fn = (
            build_metrics_processor
            if self.train_spec.build_metrics_processor_fn is None
            else self.train_spec.build_metrics_processor_fn
        )
        self.metrics_processor = build_metrics_processor_fn(
            job_config, parallel_dims, model_args
        )
        color = self.metrics_processor.color

        # calculate model size and flops per token
        (
            model_param_count,
            self.metrics_processor.num_flops_per_token,
        ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)

        logger.info(
            f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
            f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
        )

        # move sharded model to CPU/GPU and initialize weights via DTensor
        if job_config.checkpoint.create_seed_checkpoint:
            init_device = "cpu"
            buffer_device = None
        elif job_config.training.enable_cpu_offload:
            init_device = "cpu"
            buffer_device = device_type
        else:
            init_device = device_type
            buffer_device = None

        self.loss_fn = self.train_spec.build_loss_fn(job_config)

        debug_memory_on_rank(0, "Debug point 3: Before parallelization")
        
        # apply parallelisms and initialization
        if parallel_dims.pp_enabled:
            if not self.train_spec.pipelining_fn:
                raise RuntimeError(
                    f"Pipeline Parallel is enabled but {self.train_spec.name} "
                    f"does not support pipelining"
                )

            # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
            (
                self.pp_schedule,
                self.model_parts,
                self.pp_has_first_stage,
                self.pp_has_last_stage,
            ) = self.train_spec.pipelining_fn(
                model,
                world_mesh,
                parallel_dims,
                job_config,
                self.device,
                model_args,
                self.train_spec.parallelize_fn,
                self.loss_fn,
            )
            # when PP is enabled, `model` obj is no longer used after this point,
            # model_parts is used instead
            del model

            for m in self.model_parts:
                m.to_empty(device=init_device)
                with torch.no_grad():
                    m.init_weights(buffer_device=buffer_device)
                m.train()

            # confirm that user will be able to view loss metrics on the console
            ensure_pp_loss_visible(parallel_dims, job_config, color)
        else:
            # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
            debug_memory_on_rank(0, "Debug point 4: Before parallelize_fn")
            model = self.train_spec.parallelize_fn(
                model, world_mesh, parallel_dims, job_config
            )
            debug_memory_on_rank(0, "Debug point 5: After parallelize_fn")

            debug_memory_on_rank(0, "Debug point 6: Before model.to_empty")
            model.to_empty(device=init_device)
            debug_memory_on_rank(0, "Debug point 7: After model.to_empty")
            
            with torch.no_grad():
                debug_memory_on_rank(0, "Debug point 8: Before init_weights")
                model.init_weights(buffer_device=buffer_device)
                debug_memory_on_rank(0, "Debug point 9: After init_weights")
            model.train()

            self.model_parts = [model]

        if self.ft_manager.enabled:
            self.ft_manager.set_all_reduce_hook(self.model_parts)

        # initialize device memory monitor and get peak flops for MFU calculation
        device_memory_monitor = self.metrics_processor.device_memory_monitor
        gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
        logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
        device_mem_stats = device_memory_monitor.get_peak_stats()
        logger.info(
            f"{device_type.upper()} memory usage for model: "
            f"{device_mem_stats.max_reserved_gib:.2f}GiB"
            f"({device_mem_stats.max_reserved_pct:.2f}%)"
        )

        # build optimizer after applying parallelisms to the model
        self.optimizers = self.train_spec.build_optimizers_fn(
            self.model_parts, job_config, self.ft_manager
        )
        self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
            self.optimizers, job_config
        )
        # Post optimizer step model converters hook.
        # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
        # where it issues a single all-reduce for all parameters at once for better performance
        self.optimizers.register_step_post_hook(
            lambda *args, **kwargs: model_converters.post_optimizer_hook(
                self.model_parts
            )
        )
        self.metrics_processor.optimizers = self.optimizers

        # Initialize trainer states that will be saved in checkpoint.
        # These attributes must be initialized before checkpoint loading.
        self.step = 0

        debug_memory_on_rank(0, "Debug point 10: Before checkpoint initialization")
        self.checkpointer = CheckpointManager(
            dataloader=self.dataloader,
            model_parts=self.model_parts,
            optimizers=self.optimizers,
            lr_schedulers=self.lr_schedulers,
            states={"train_state": self},
            job_config=job_config,
            ft_manager=self.ft_manager,
        )
        debug_memory_on_rank(0, "Debug point 11: After checkpoint initialization")

        self.train_context = dist_utils.get_train_context(
            parallel_dims.loss_parallel_enabled,
            parallelism_config.enable_compiled_autograd,
        )

        logger.info(
            "Trainer is initialized with "
            f"local batch size {job_config.training.batch_size}, "
            f"global batch size {job_config.training.batch_size * dp_degree}, "
            f"sequence length {job_config.training.seq_len}, "
            f"total steps {job_config.training.steps} "
            f"(warmup {job_config.lr_scheduler.warmup_steps})."
        )
        debug_memory_on_rank(0, "Debug point 12: End of __init__")

    def batch_generator(
        self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
    ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]:
        """Returns an iterator that processes batches from the data iterator."""
        device_type = utils.device_type

        for batch in iter(data_iterable):
            data_load_start = time.perf_counter()
            input_dict, labels = batch
            self.metrics_processor.ntokens_since_last_log += labels.numel()
            self.metrics_processor.data_loading_times.append(
                time.perf_counter() - data_load_start
            )

            # Move tensors to the appropriate device
            for k, v in input_dict.items():
                if isinstance(v, torch.Tensor):
                    input_dict[k] = v.to(device_type)
            labels = labels.to(device_type)

            yield input_dict, labels

    def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
        # Add step timing for performance comparison
        if torch.distributed.get_rank() == 0:
            torch.cuda.synchronize()
            step_start_time = time.perf_counter()
        
        # Profile specific steps and layer 7
        if self.step % 10 == 1:  # Profile layer 7 every 10 steps
            if torch.distributed.get_rank() == 0:
                print("\n" + "="*80)
                print("PERFORMANCE PROFILING - CONFIGURATION SUMMARY")
                print("="*80)
                print(f"Context Parallel Ring Degree: {self.parallel_dims.cp_ring}")
                print(f"Context Parallel Ulysses Degree: {self.parallel_dims.cp_ulysses}")
                print(f"Total World Size: {torch.distributed.get_world_size() if torch.distributed.is_available() and torch.distributed.is_initialized() else 1}")
                print(f"Tensor Parallel Degree: {self.parallel_dims.tp}")
                print(f"Data Parallel Degree: {self.parallel_dims.dp_shard * self.parallel_dims.dp_replicate}")
                print(f"Sequence Length: {self.job_config.training.seq_len}")
                print(f"Batch Size: {self.job_config.training.batch_size}")
                
                # Determine attention mode
                cp_ring = self.parallel_dims.cp_ring
                cp_ulysses = self.parallel_dims.cp_ulysses
                
                if cp_ring > 1 and cp_ulysses > 1:
                    print("🔥 ATTENTION MODE: Hybrid Ulysses + Ring Attention")
                elif cp_ulysses > 1:
                    print("🚀 ATTENTION MODE: Ulysses Attention Only")
                elif cp_ring > 1:
                    print("⭕ ATTENTION MODE: Ring Attention Only")
                else:
                    print("💻 ATTENTION MODE: Standard Attention (No CP)")
                
                print("="*80)
                print()
                
        
        self.optimizers.zero_grad()

        # Keep these variables local to shorten the code as these are
        # the major variables that are used in the training loop.
        model_parts = self.model_parts
        world_mesh = self.world_mesh
        parallel_dims = self.parallel_dims

        # apply context parallelism if cp is enabled
        # ensure CP handles the separate freqs_cis buffer for each pp stage
        inputs = input_dict["input"]
        debug_memory_on_rank(0, "Debug point 13: Before context parallel setup")
        optional_context_parallel_ctx = (
            dist_utils.create_context_parallel_ctx(
                cp_mesh=(
                    world_mesh["cp_ulysses", "cp_ring"] 
                    if parallel_dims.cp_ulysses > 1 
                    else world_mesh["cp_ring"]
                ),
                cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
                cp_seq_dims=[1, 1] + [0 for _ in model_parts],
                cp_no_restore_buffers={inputs, labels},
                cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
            )
            if parallel_dims.cp_enabled
            else None
        )
        debug_memory_on_rank(0, "Debug point 14: After context parallel setup")

        if parallel_dims.pp_enabled:
            # Pipeline Parallel forward / backward inside step() call
            with self.train_context(optional_context_parallel_ctx):
                targets, losses = (
                    (labels, []) if self.pp_has_last_stage else (None, None)
                )
                if self.pp_has_first_stage:
                    self.pp_schedule.step(
                        inputs, target=targets, losses=losses, input_batch=inputs
                    )
                else:
                    self.pp_schedule.step(
                        target=targets, losses=losses, input_batch=inputs
                    )

            # accumulate losses across pipeline microbatches
            # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
            loss = (
                torch.mean(torch.stack(losses)).to(self.device)
                if self.pp_has_last_stage
                else torch.tensor([-1.0], device=self.device)
            )
        else:
            # Non-PP forward / backward
            if torch.distributed.get_rank() == 0:
                torch.cuda.synchronize()
                forward_start = time.perf_counter()
                
            with self.train_context(optional_context_parallel_ctx):
                assert len(model_parts) == 1
                # debug_memory_on_rank(0, "Debug point 15: Before model forward pass")
                pred = model_parts[0](inputs)
                
                if torch.distributed.get_rank() == 0:
                    torch.cuda.synchronize()
                    forward_time = time.perf_counter() - forward_start
                    print(f"[STEP-TIMING] Forward pass: {forward_time*1000:.2f}ms")
                
                loss = self.loss_fn(pred, labels)
                
                # need to free to before bwd to avoid peaking memory
                del pred
                
                loss.backward()
                

        dist_utils.clip_grad_norm_(
            [p for m in model_parts for p in m.parameters()],
            self.job_config.training.max_norm,
            foreach=True,
            pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
        )
        self.checkpointer.maybe_wait_for_staging()
        
        if torch.distributed.get_rank() == 0:
            torch.cuda.synchronize()
            optimizer_start = time.perf_counter()
            
        self.optimizers.step()
        self.lr_schedulers.step()
        
        if torch.distributed.get_rank() == 0:
            torch.cuda.synchronize()
            optimizer_time = time.perf_counter() - optimizer_start
            total_step_time = time.perf_counter() - step_start_time
            print(f"[STEP-TIMING] Optimizer step: {optimizer_time*1000:.2f}ms")
            print(f"[STEP-TIMING] === TOTAL STEP TIME: {total_step_time*1000:.2f}ms ===\n")

        # log metrics
        if not self.metrics_processor.should_log(self.step):
            return

        if (
            parallel_dims.dp_replicate_enabled
            or parallel_dims.dp_shard_enabled
            or parallel_dims.cp_enabled
            or self.ft_manager.enabled
        ):
            loss = loss.detach()
            ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None
            global_avg_loss, global_max_loss = (
                dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg),
                dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),
            )
        else:
            global_avg_loss = global_max_loss = loss.detach().item()

        self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)

    @record
    def train(self):
        job_config = self.job_config
        debug_memory_on_rank(0, "Debug point 16: Start of train()")
        
        # Print configuration summary for performance analysis
        if torch.distributed.get_rank() == 0:
            print("\n" + "="*80)
            print("PERFORMANCE PROFILING - CONFIGURATION SUMMARY")
            print("="*80)
            print(f"Context Parallel Ring Degree: {self.parallel_dims.cp_ring}")
            print(f"Context Parallel Ulysses Degree: {self.parallel_dims.cp_ulysses}")
            print(f"Total World Size: {torch.distributed.get_world_size()}")
            print(f"Tensor Parallel Degree: {self.parallel_dims.tp}")
            print(f"Data Parallel Degree: {self.parallel_dims.dp_shard * self.parallel_dims.dp_replicate}")
            print(f"Sequence Length: {job_config.training.seq_len}")
            print(f"Batch Size: {job_config.training.batch_size}")
            
            # Determine which attention mechanism is active
            if self.parallel_dims.cp_ulysses > 1 and self.parallel_dims.cp_ring > 1:
                print("🔥 ATTENTION MODE: Hybrid Ulysses + Ring Attention")
            elif self.parallel_dims.cp_ulysses > 1:
                print("🚀 ATTENTION MODE: Ulysses Attention Only")
            elif self.parallel_dims.cp_ring > 1:
                print("⭕ ATTENTION MODE: Ring Attention Only")
            else:
                print("💻 ATTENTION MODE: Standard Attention (No CP)")
            print("="*80 + "\n")
        
        self.checkpointer.load(step=job_config.checkpoint.load_step)
        logger.info(f"Training starts at step {self.step + 1}.")

        with (
            maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler,
            maybe_enable_memory_snapshot(
                job_config, global_step=self.step
            ) as memory_profiler,
            ft.maybe_semi_sync_training(
                job_config,
                ft_manager=self.ft_manager,
                model=self.model_parts[0],
                optimizer=self.optimizers,
                sync_every=job_config.fault_tolerance.sync_steps,
            ),
        ):
            for inputs, labels in self.batch_generator(self.dataloader):
                if self.step >= job_config.training.steps:
                    break
                self.step += 1
                self.gc_handler.run(self.step)
                self.train_step(inputs, labels)
                self.checkpointer.save(
                    self.step, force=(self.step == job_config.training.steps)
                )

                # signal the profiler that the next profiling step has started
                if torch_profiler:
                    torch_profiler.step()
                if memory_profiler:
                    memory_profiler.step()

                # reduce timeout after first train step for faster signal
                # (assuming lazy init and compilation are finished)
                if self.step == 1:
                    dist_utils.set_pg_timeouts(
                        timeout=timedelta(
                            seconds=job_config.comm.train_timeout_seconds
                        ),
                        world_mesh=self.world_mesh,
                    )

        if torch.distributed.get_rank() == 0:
            logger.info("Sleeping 2 seconds for other ranks to complete")
            time.sleep(2)

        self.metrics_processor.close()
        logger.info("Training completed")

    def state_dict(self) -> dict[str, Any]:
        return {"step": self.step}

    def load_state_dict(self, state_dict: dict[str, Any]):
        self.step = state_dict["step"]

    def close(self) -> None:
        if self.checkpointer:
            self.checkpointer.close()

def debug_on_rank(rank: int = 0, label: str = "") -> None:
    """Helper function to break only on specified rank.
    
    Args:
        rank: The rank to break on. Defaults to 0.
        label: Optional label for the debug point.
    """
    pass
    # if torch.distributed.get_rank() == rank:
    #     print(f"Breaking on rank {rank}")
    #     breakpoint()

if __name__ == "__main__":
    init_logger()
    config_manager = ConfigManager()
    config = config_manager.parse_args()
    trainer: Optional[Trainer] = None

    try:
        trainer = Trainer(config)

        if config.checkpoint.create_seed_checkpoint:
            assert (
                int(os.environ["WORLD_SIZE"]) == 1
            ), "Must create seed checkpoint using a single device, to disable sharding."
            assert (
                config.checkpoint.enable_checkpoint
            ), "Must enable checkpointing when creating a seed checkpoint."
            trainer.checkpointer.save(curr_step=0, force=True)
            logger.info("Created seed checkpoint")
        else:
            trainer.train()
    finally:
        if trainer:
            trainer.close()

        if torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()
            logger.info("Process group destroyed.")
