# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib
import os
import time
from contextlib import ExitStack
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional

import numpy as np

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed import distributed_c10d as c10d
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.components.batch_warmup import BatchWarmupStrategy
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.dataloader import DataloaderStopIteration
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
from torchtitan.components.grad_clipping import build_grad_clipper
from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.components.metrics import (
    build_metrics_processor,
    ensure_pp_loss_visible,
)
from torchtitan.config import ConfigManager, JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.distributed.sync_tensors import (
    sync_tensor_across_group,
    transfer_layer_across_pp,
)
from torchtitan.models.llama3.model.model import qk_clip_rescale_
from torchtitan.protocols.model_converter import build_model_converters
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.profiling import (
    maybe_enable_memory_snapshot,
    maybe_enable_profiling,
)


class Trainer(torch.distributed.checkpoint.stateful.Stateful):
    # core configs
    job_config: JobConfig
    parallel_dims: ParallelDims
    train_spec: train_spec_module.TrainSpec

    # swappable training components in TrainSpec
    tokenizer: train_spec_module.BaseTokenizer | None
    dataloader: train_spec_module.BaseDataLoader
    model_parts: list[torch.nn.Module]
    loss_fn: train_spec_module.LossFunction
    optimizers: train_spec_module.OptimizersContainer
    lr_schedulers: train_spec_module.LRSchedulersContainer
    validator: train_spec_module.BaseValidator
    metrics_processor: train_spec_module.MetricsProcessor
    model_args: train_spec_module.BaseModelArgs

    # non-swappable training components
    checkpointer: CheckpointManager
    ft_manager: FTManager

    # runtime utilities
    device: torch.device
    gc_handler: utils.GarbageCollection
    train_context: Generator[None, None, None]
    gradient_accumulation_steps: int
    batch_warmup_strategy: BatchWarmupStrategy
    pp_has_first_stage: bool
    pp_has_last_stage: bool

    # additional training states
    step: int
    ntokens_seen: int

    # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
    @record
    def __init__(self, job_config: JobConfig):
        torch._C._log_api_usage_once("torchtitan.train")

        self.job_config = job_config

        logger.info(f"Starting job: {job_config.job.description}")

        if job_config.experimental.custom_import:
            importlib.import_module(job_config.experimental.custom_import)

        if job_config.job.print_args:
            logger.info(f"Running with args: {job_config.to_dict()}")

        device_module, device_type = utils.device_module, utils.device_type
        self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
        # Device has to be set before creating TorchFT manager.
        device_module.set_device(self.device)

        # init distributed and build meshes
        dist_utils.init_distributed(
            job_config.comm,
            enable_cpu_backend=job_config.training.enable_cpu_offload,
            base_folder=job_config.job.dump_folder,
        )
        world_size = int(os.environ["WORLD_SIZE"])
        parallelism_config = job_config.parallelism
        self.parallel_dims = parallel_dims = ParallelDims(
            dp_shard=parallelism_config.data_parallel_shard_degree,
            dp_replicate=parallelism_config.data_parallel_replicate_degree,
            cp=parallelism_config.context_parallel_degree,
            tp=parallelism_config.tensor_parallel_degree,
            pp=parallelism_config.pipeline_parallel_degree,
            ep=parallelism_config.expert_parallel_degree,
            world_size=world_size,
        )

        world_mesh = parallel_dims.world_mesh
        if parallel_dims.dp_enabled:
            dp_mesh = world_mesh["dp"]
            dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
        else:
            dp_degree, dp_rank = 1, 0

        self.ft_manager = FTManager(job_config.fault_tolerance)
        dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)

        # take control of garbage collection to avoid stragglers
        self.gc_handler = utils.GarbageCollection(
            gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
        )

        # Set random seed, and maybe enable deterministic mode
        # (mainly for debugging, expect perf loss).
        dist_utils.set_determinism(
            world_mesh,
            self.device,
            job_config.training.seed,
            job_config.training.deterministic,
        )
        self.train_spec = train_spec_module.get_train_spec(job_config.model.name)

        # build tokenizer and dataloader
        self.tokenizer = (
            self.train_spec.build_tokenizer_fn(job_config)
            if self.train_spec.build_tokenizer_fn is not None
            else None
        )

      
        self.dp_degree = dp_degree
        self.dp_rank = dp_rank
        self.depth_increase_step = 2020
        self.names = []
        self.dataloader = self.train_spec.build_dataloader_fn(
            dp_world_size=dp_degree,
            dp_rank=dp_rank,
            tokenizer=self.tokenizer,
            job_config=job_config,
        )

        # build model (using meta init)
        model_args = self.train_spec.model_args[job_config.model.flavor]
        # set the model args from training job configs
        model_args.update_from_config(job_config)
        self.model_args = model_args

        logger.info(
            f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
        )
        with torch.device("meta"):
            model = self.train_spec.model_cls(model_args)

        # Build the collection of model converters. No-op if `model.converters` empty
        model_converters = build_model_converters(job_config, parallel_dims)
        model_converters.convert(model)

        # metrics logging
        build_metrics_processor_fn = (
            build_metrics_processor
            if self.train_spec.build_metrics_processor_fn is None
            else self.train_spec.build_metrics_processor_fn
        )
        self.metrics_processor = build_metrics_processor_fn(
            job_config, parallel_dims, model_args
        )
        color = self.metrics_processor.color

        # calculate model size and flops per token
        (
            model_param_count,
            self.metrics_processor.num_flops_per_token,
        ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)

        logger.info(
            f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
            f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
        )

        # move sharded model to CPU/GPU and initialize weights via DTensor
        if job_config.checkpoint.create_seed_checkpoint:
            init_device = "cpu"
            buffer_device = None
        elif job_config.training.enable_cpu_offload:
            init_device = "cpu"
            buffer_device = device_type
        else:
            init_device = device_type
            buffer_device = None

        self.loss_fn = self.train_spec.build_loss_fn(job_config)

        # verify batch sizes
        global_batch_size = job_config.training.global_batch_size
        if global_batch_size < 0:
            # This global batch size results in 1 gradient accumulation
            # step.
            global_batch_size = job_config.training.local_batch_size * dp_degree
        assert global_batch_size > 0
        assert (
            global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0
        ), (
            f"global batch size must be multiple of local batch size times "
            f"data-parallel degree ({global_batch_size} "
            f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)"
        )

        # calculate gradient accumulation steps
        self.gradient_accumulation_steps = global_batch_size // (
            job_config.training.local_batch_size * dp_degree
        )
        assert self.gradient_accumulation_steps > 0
        self.loss_fn = rescale_accumulated_loss(
            self.loss_fn, self.gradient_accumulation_steps
        )

        # Build gradient clipper from config
        self.grad_clipper = build_grad_clipper(job_config.gradient_clipping)
        logger.info(
            f"Gradient clipper initialized: method={job_config.gradient_clipping.method} "
            f"scope={job_config.gradient_clipping.scope}"
        )

        if self.job_config.training.enable_batch_warmup:
            # Initialize batch warmup strategy
            self.batch_warmup_strategy = BatchWarmupStrategy(
                job_config=job_config,
                parallel_dims=parallel_dims,
                target_gradient_accumulation_steps=self.gradient_accumulation_steps,
                dp_degree=dp_degree,
            )
        else:
            self.batch_warmup_strategy = None

        # Store the base loss function for dynamic scaling
        # We'll apply dynamic scaling in train_step where we have access to self.step
        self.base_loss_fn = self.loss_fn

        # apply parallelisms and initialization
        if parallel_dims.pp_enabled:
            if not self.train_spec.pipelining_fn:
                raise RuntimeError(
                    f"Pipeline Parallel is enabled but {self.train_spec.name} "
                    f"does not support pipelining"
                )

            # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
            (
                self.pp_schedule,
                self.model_parts,
                self.pp_has_first_stage,
                self.pp_has_last_stage,
            ) = self.train_spec.pipelining_fn(
                model,
                parallel_dims,
                job_config,
                self.device,
                model_args,
                self.train_spec.parallelize_fn,
                self.loss_fn,
            )
            # when PP is enabled, `model` obj is no longer used after this point,
            # model_parts is used instead
            del model

            for m in self.model_parts:
                m.to_empty(device=init_device)
                with torch.no_grad():
                    m.init_weights(buffer_device=buffer_device)
                m.train()
            # Compute and distribute compression vectors if compression is enabled
            if (
                hasattr(self.model_parts[0], "model_args")
                and hasattr(self.model_parts[0].model_args, "use_compression")
                and self.model_parts[0].model_args.use_compression
            ):

                # Get mesh information
                pp_mesh = world_mesh["pp"]

                # Get DP shard mesh if available
                dp_shard_mesh = None
                if "dp_shard" in world_mesh.mesh_dim_names:
                    dp_shard_mesh = world_mesh["dp_shard"]

                # Get PP group and rank information
                pp_group = pp_mesh.get_group()
                pp_rank = pp_mesh.get_local_rank()
                pp_size = pp_mesh.size()

                # Define source PP rank
                src_pp_rank = pp_size - 1

                logger.info(
                    f"PP Rank {pp_rank}/{pp_size}: Setting up compression vectors"
                )

                # Only compute compression vectors on the last PP rank
                logger.info(f"Computing compression vectors on PP rank {pp_rank}")

                # Get the last model part
                final_model_part = self.model_parts[-1]

                # Compute compression vectors
                try:
                    rcv, _ = final_model_part.get_rcv()
                    rcv = rcv.contiguous()
                except Exception as e:
                    logger.error(f"Error computing RCV: {str(e)}")
                    raise

                # Get fixed embeddings if available
                fixed_embeddings = None
                if hasattr(final_model_part, "fixed_tok_embeddings"):
                    fixed_embeddings = (
                        final_model_part.fixed_tok_embeddings.weight.data.contiguous()
                    )
                else:
                    logger.warning("No fixed embeddings found")

                # Synchronize to ensure computation is complete
                dist.barrier()

                # Distribute compression vectors to all PP ranks
                rcv = transfer_layer_across_pp(
                    rcv, src_pp_rank, pp_group, dp_shard_mesh
                )

                # Distribute fixed embeddings if available
                if fixed_embeddings is not None or pp_rank == src_pp_rank:
                    try:
                        fixed_embeddings = transfer_layer_across_pp(
                            fixed_embeddings, src_pp_rank, pp_group, dp_shard_mesh
                        )
                        if fixed_embeddings is None:
                            logger.warning(
                                f"PP Rank {pp_rank}: Fixed embeddings is None after transfer"
                            )
                    except Exception as e:
                        logger.error(
                            f"PP Rank {pp_rank}: Error transferring fixed embeddings: {str(e)}"
                        )
                        raise

                # Synchronize to ensure computation is complete
                dist.barrier()

                # Apply compression vectors to all model parts
                for m in self.model_parts:
                    with torch.no_grad():
                        if rcv is None:
                            logger.error(
                                f"PP Rank {pp_rank}: RCV is None, cannot apply compression"
                            )
                            continue

                        # If rcv is registered as a parameter, copy into its storage
                        with torch.no_grad():
                            if hasattr(m, "rcv") and isinstance(
                                m.rcv, torch.nn.Parameter
                            ):
                                m.rcv.data.copy_(rcv)
                            else:
                                m.rcv = rcv.clone()

                        # Apply regularization
                        try:
                            m.regularize_weights()
                        except Exception as e:
                            logger.error(
                                f"PP Rank {pp_rank}: Error regularizing weights: {str(e)}"
                            )
                            raise

                        # Apply embedding regularization if needed
                        if hasattr(m, "regularize_embeddings"):
                            try:
                                m.regularize_embeddings()
                            except Exception as e:
                                logger.error(
                                    f"PP Rank {pp_rank}: Error regularizing embeddings: {str(e)}"
                                )
                                raise

                        # Copy fixed embeddings if available
                        if fixed_embeddings is not None and hasattr(
                            m, "copy_embedding_weights"
                        ):
                            try:
                                m.copy_embedding_weights(fixed_embeddings)
                            except Exception as e:
                                logger.error(
                                    f"PP Rank {pp_rank}: Error copying fixed embeddings: {str(e)}"
                                )
                                raise
                    m.train()

                logger.info(
                    f"PP Rank {pp_rank}: Successfully initialized compression vectors"
                )

            # confirm that user will be able to view loss metrics on the console
            ensure_pp_loss_visible(parallel_dims, job_config, color)
        else:
            # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
            model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)

            model.to_empty(device=init_device)
            with torch.no_grad():
                model.init_weights(buffer_device=buffer_device)

            model.train()

            self.model_parts = [model]

            # Compute and apply compression vectors if compression is enabled
            if (
                hasattr(model, "model_args")
                and hasattr(model.model_args, "use_compression")
                and model.model_args.use_compression
            ):
                logger.info("Setting up compression vectors for non-PP model")

                # Get DP shard mesh if available
                dp_shard_mesh = None
                if "dp_shard" in world_mesh.mesh_dim_names:
                    dp_shard_mesh = world_mesh["dp_shard"]

                # Get global rank for source designation
                global_rank = dist.get_rank()
                src_rank = 0  # Use rank 0 as source

                # Compute compression vectors on source rank only
                rcv = None
                fixed_embeddings = None
                if global_rank == src_rank:
                    try:
                        logger.info(
                            f"Rank {global_rank}: Computing compression vectors"
                        )
                        rcv, _ = model.get_rcv()
                        rcv = rcv.contiguous()
                    except Exception as e:
                        logger.error(
                            f"Rank {global_rank}: Error computing RCV: {str(e)}"
                        )
                        raise

                    # Get fixed embeddings if available
                    if hasattr(model, "fixed_tok_embeddings"):
                        fixed_embeddings = (
                            model.fixed_tok_embeddings.weight.data.contiguous()
                        )
                    else:
                        logger.warning(f"Rank {global_rank}: No fixed embeddings found")

                # Synchronize RCV across all ranks using the reusable sync function
                logger.info(f"Rank {global_rank}: Synchronizing RCV across all ranks")
                rcv = sync_tensor_across_group(
                    tensor=rcv,
                    src_rank=src_rank,
                    device=self.device,
                    shard_mesh=dp_shard_mesh,
                )

                # Synchronize fixed embeddings across all ranks if available on source
                has_fixed_embeddings = fixed_embeddings is not None
                if global_rank == src_rank:
                    logger.info(
                        f"Rank {global_rank}: Source has fixed embeddings: {has_fixed_embeddings}"
                    )

                # Broadcast whether fixed embeddings are available
                has_embeddings_tensor = torch.tensor(
                    int(has_fixed_embeddings), device=self.device
                )
                dist.broadcast(has_embeddings_tensor, src=src_rank)

                # Only sync if source rank actually has fixed embeddings
                if has_embeddings_tensor.item():
                    fixed_embeddings = sync_tensor_across_group(
                        tensor=fixed_embeddings,
                        src_rank=src_rank,
                        device=self.device,
                        shard_mesh=dp_shard_mesh,
                    )
                    logger.info(f"Rank {global_rank}: Synchronized fixed embeddings")
                else:
                    fixed_embeddings = None
                    logger.info(
                        f"Rank {global_rank}: No fixed embeddings available from source"
                    )

                # Apply compression vectors to the model
                with torch.no_grad():
                    if rcv is None:
                        logger.error(
                            f"Rank {global_rank}: RCV is None, cannot apply compression"
                        )
                    else:
                        with torch.no_grad():
                            if hasattr(model, "rcv") and isinstance(
                                model.rcv, torch.nn.Parameter
                            ):
                                model.rcv.data.copy_(rcv)
                            else:
                                model.rcv = rcv.clone()

                        # Apply regularization
                        try:
                            model.regularize_weights()
                        except Exception as e:
                            logger.error(
                                f"Rank {global_rank}: Error regularizing weights: {str(e)}"
                            )
                            raise

                        # Apply embedding regularization if needed
                        if hasattr(model, "regularize_embeddings"):
                            try:
                                model.regularize_embeddings()
                            except Exception as e:
                                logger.error(
                                    f"Rank {global_rank}: Error regularizing embeddings: {str(e)}"
                                )
                                raise

                        # Copy fixed embeddings if available
                        if fixed_embeddings is not None and hasattr(
                            model, "copy_embedding_weights"
                        ):
                            try:
                                model.copy_embedding_weights(fixed_embeddings)
                            except Exception as e:
                                logger.error(
                                    f"Rank {global_rank}: Error copying fixed embeddings: {str(e)}"
                                )
                                raise

                model.train()

                logger.info(
                    f"Rank {global_rank}: Successfully initialized compression vectors for non-PP model"
                )

        self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)

        # initialize device memory monitor and get peak flops for MFU calculation
        device_memory_monitor = self.metrics_processor.device_memory_monitor
        gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
        logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
        device_mem_stats = device_memory_monitor.get_peak_stats()
        logger.info(
            f"{device_type.upper()} memory usage for model: "
            f"{device_mem_stats.max_reserved_gib:.2f}GiB"
            f"({device_mem_stats.max_reserved_pct:.2f}%)"
        )

        # build optimizer after applying parallelisms to the model
        self.optimizers = self.train_spec.build_optimizers_fn(
            self.model_parts, job_config.optimizer, parallel_dims, self.ft_manager
        )
        self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
            self.optimizers, job_config.lr_scheduler, job_config.training.steps
        )
        # Post optimizer step model converters hook.
        # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
        # where it issues a single all-reduce for all parameters at once for better performance
        self.optimizers.register_step_post_hook(
            lambda *args, **kwargs: model_converters.post_optimizer_hook(
                self.model_parts
            )
        )
        self.metrics_processor.optimizers = self.optimizers

        # Initialize trainer states that will be saved in checkpoint.
        # These attributes must be initialized before checkpoint loading.
        self.step = 0
        self.ntokens_seen = 0
        self.last_entropy = 0.0  # Track entropy of final model logits

        _powersgd_container = getattr(self.model_parts[0], "_powersgd_container", None)

        self.checkpointer = CheckpointManager(
            dataloader=self.dataloader,
            model_parts=self.model_parts,
            optimizers=self.optimizers,
            lr_schedulers=self.lr_schedulers,
            states={
                k: v
                for k, v in {
                    "train_state": self,
                    "powersgd": _powersgd_container,
                }.items()
                if v is not None
            },
            checkpoint_config=job_config.checkpoint,
            sd_adapter=(
                self.train_spec.state_dict_adapter(model_args)
                if self.train_spec.state_dict_adapter
                else None
            ),
            base_folder=job_config.job.dump_folder,
            ft_manager=self.ft_manager,
        )

        loss_parallel_enabled = (
            parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
        )
        self.train_context = dist_utils.get_train_context(
            loss_parallel_enabled,
            parallelism_config.enable_compiled_autograd,
        )
        self.maybe_enable_amp = dist_utils.maybe_enable_amp(
            parallel_dims,
            job_config.training.mixed_precision_param,
            device_type,
        )

        # Build validator if validation is configured
        if job_config.validation.enabled:
            assert self.train_spec.build_validator_fn is not None

            pp_schedule, pp_has_first_stage, pp_has_last_stage = (
                (
                    self.pp_schedule,
                    self.pp_has_first_stage,
                    self.pp_has_last_stage,
                )
                if parallel_dims.pp_enabled
                else (None, None, None)
            )

            self.validator = self.train_spec.build_validator_fn(
                job_config=job_config,
                dp_world_size=dp_degree,
                dp_rank=dp_rank,
                tokenizer=self.tokenizer,
                parallel_dims=parallel_dims,
                loss_fn=self.train_spec.build_loss_fn(job_config),
                validation_context=self.train_context,
                maybe_enable_amp=self.maybe_enable_amp,
                metrics_processor=self.metrics_processor,
                pp_schedule=pp_schedule,
                pp_has_first_stage=pp_has_first_stage,
                pp_has_last_stage=pp_has_last_stage,
            )

        # Log batch warmup information if enabled
        warmup_info = ""
        if job_config.training.enable_batch_warmup:
            warmup_info = (
                f", batch warmup enabled ({job_config.training.batch_warmup_strategy} "
                f"strategy for {job_config.training.batch_warmup_steps} steps, "
                f"start ratio {job_config.training.batch_warmup_start_ratio})"
            )

        logger.info(
            "Trainer is initialized with "
            f"local batch size {job_config.training.local_batch_size}, "
            f"global batch size {global_batch_size}, "
            f"gradient accumulation steps {self.gradient_accumulation_steps}, "
            f"sequence length {job_config.training.seq_len}, "
            f"total steps {job_config.training.steps} "
            f"(warmup {job_config.lr_scheduler.warmup_steps})"
            f"{warmup_info}"
        )

    def enable_layers_all_weights(self, m, names, layers):
        for m in self.model_parts:
            m.enable_all_layer_weights(name_list = names, layer_ids = layers)

    def _set_dataloader_seq_len(self, new_seq_len: int):
    # Update config so newly built components know the new target
        self.job_config.training.seq_len = new_seq_len

        # If your dataloader exposes a setter, prefer that (no worker respawn):

        self.dataloader = self.train_spec.build_dataloader_fn(
            dp_world_size=self.dp_degree,
            dp_rank=self.dp_rank,
            tokenizer=self.tokenizer,
            job_config=self.job_config,
        )
    def set_sequence_length(self, new_seq_len: int, step: int):
        self._set_dataloader_seq_len(new_seq_len=new_seq_len)
        data_iterator = self.batch_generator(self.dataloader)
        
        current_grad_accum_steps = self.gradient_accumulation_steps

        
        for i in range(step):
            for microbatch in range(current_grad_accum_steps):
                input_dict, labels = next(data_iterator)

        return data_iterator  

    def batch_generator(
        self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
    ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]:
        """Returns an iterator that processes batches from the data iterator."""
        device_type = utils.device_type
        data_iterator = iter(data_iterable)

        while True:
            data_load_start = time.perf_counter()
            try:
                batch = next(data_iterator)
            except StopIteration as ex:
                # If data runs out during gradient accumulation, that
                # entire step will not be executed.
                raise DataloaderStopIteration() from ex
            input_dict, labels = batch
            ntokens_batch = labels.numel()
            self.ntokens_seen += ntokens_batch
            self.metrics_processor.ntokens_since_last_log += ntokens_batch
            self.metrics_processor.data_loading_times.append(
                time.perf_counter() - data_load_start
            )

            # Move tensors to the appropriate device
            for k, v in input_dict.items():
                if isinstance(v, torch.Tensor):
                    input_dict[k] = v.to(device_type)
            labels = labels.to(device_type)

            yield input_dict, labels

    def _maybe_align_dataloader_to_tokens(self) -> None:
        """Optionally fast-forward the dataloader iterator to match target token offset.

        Two ways to enable:
        - Set env TT_ALIGN_DATA_BY_TOKENS=1 to compute skip microbatches from ntokens_seen
        - Or set env TT_RESUME_SKIP_MICROBATCHES=<int> to explicitly skip that many microbatches
        """
        import os

        align_enabled = os.environ.get("TT_ALIGN_DATA_BY_TOKENS", "0") == "1"
        explicit_skip = os.environ.get("TT_RESUME_SKIP_MICROBATCHES")
        if not align_enabled and explicit_skip is None:
            return

        if not hasattr(self, "dataloader") or self.dataloader is None:
            return

        try:
            seq_len = self.job_config.training.seq_len
            world_mesh = self.parallel_dims.world_mesh
            dp_degree = world_mesh["dp"].size() if self.parallel_dims.dp_enabled else 1
            local_bsz = self.job_config.training.local_batch_size
            tokens_per_microbatch = local_bsz * dp_degree * seq_len

            if tokens_per_microbatch <= 0:
                return

            if explicit_skip is not None:
                skip_microbatches = int(explicit_skip)
            else:
                skip_microbatches = int(self.ntokens_seen // tokens_per_microbatch)

            if skip_microbatches <= 0:
                return

            logger.info(
                f"Align dataloader: skipping {skip_microbatches} microbatches "
                f"(tokens_per_microbatch={tokens_per_microbatch}, ntokens_seen={self.ntokens_seen})"
            )

            it = iter(self.dataloader)
            for _ in range(skip_microbatches):
                try:
                    next(it)
                except StopIteration:
                    logger.warning("Ran out of data while aligning; stop skipping.")
                    break
        except Exception as ex:
            logger.warning(f"Failed aligning dataloader to tokens: {ex}")

    def get_dynamic_loss_fn(self):
        """Create a loss function with dynamic scaling based on current gradient accumulation steps."""

        def dynamic_scaled_loss(predictions, targets):
            loss = self.base_loss_fn(predictions, targets)
            current_ga = self.batch_warmup_strategy.get_gradient_accumulation_steps(
                self.step
            )
            # Divide by current accumulation to keep gradients consistent
            # across warmup and steady state
            return loss / current_ga

        return dynamic_scaled_loss

    def forward_backward_step(
        self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
    ) -> torch.Tensor:
        model_parts = self.model_parts
        parallel_dims = self.parallel_dims

        if self.job_config.training.enable_batch_warmup:
            current_loss_fn = self.get_dynamic_loss_fn()
        else:
            current_loss_fn = self.loss_fn

        # apply context parallelism if cp is enabled
        # ensure CP handles the separate freqs_cis buffer for each pp stage
        inputs = input_dict["input"]
        optional_context_parallel_ctx = (
            dist_utils.create_context_parallel_ctx(
                cp_mesh=parallel_dims.world_mesh["cp"],
                cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
                cp_seq_dims=[1, 1] + [0 for _ in model_parts],
                cp_no_restore_buffers={inputs, labels},
                cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
            )
            if parallel_dims.cp_enabled
            else None
        )

        if parallel_dims.pp_enabled:
            # Apply dynamic scaling for PP by modifying the schedule's loss function
            if self.batch_warmup_strategy.enabled:
                current_ga = self.batch_warmup_strategy.get_gradient_accumulation_steps(
                    self.step
                )
                # Update the pipeline schedule's loss function with dynamic scaling
                self.pp_schedule.loss_fn = self.get_dynamic_loss_fn()

            # Pipeline Parallel forward / backward inside step() call
            with self.train_context(optional_context_parallel_ctx):
                targets, losses = (
                    (labels, []) if self.pp_has_last_stage else (None, None)
                )
                if self.pp_has_first_stage:
                    self.pp_schedule.step(
                        inputs, target=targets, losses=losses, input_batch=inputs
                    )
                else:
                    self.pp_schedule.step(
                        target=targets, losses=losses, input_batch=inputs
                    )

            # accumulate losses across pipeline microbatches
            # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
            loss = (
                torch.mean(torch.stack(losses)).to(self.device)
                if self.pp_has_last_stage
                else torch.tensor([-1.0], device=self.device)
            )

            # TODO: Entropy calculation for PP is not yet implemented
            # The forward pass happens inside the PP schedule, making it difficult to access logits
            self.last_entropy = 0.0
        else:
            # Non-PP forward / backward
            with self.train_context(optional_context_parallel_ctx):
                assert len(model_parts) == 1
                with self.maybe_enable_amp:
                    # Get predictions and entropy from the model (if enabled)
                    model = model_parts[0]
                    # Check if the underlying model has entropy logging enabled
                    should_compute_entropy = getattr(
                        model, "module", model
                    ).model_args.log_final_entropy

                    if should_compute_entropy:
                        pred, entropy = model(
                            inputs, eos_id=self.tokenizer.eos_id, return_entropy=True
                        )
                        self.last_entropy = (
                            entropy.item() if entropy is not None else 0.0
                        )
                    else:
                        pred = model(
                            inputs, eos_id=self.tokenizer.eos_id, return_entropy=False
                        )
                        self.last_entropy = 0.0

                    loss = current_loss_fn(pred, labels)
                # need to free to before bwd to avoid peaking memory
                del pred
                loss.backward()

        return loss

    def train_step(
        self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
    ):
        self.optimizers.zero_grad()

        
        if self.step == 1:
            for m in self.model_parts:
                self.names = m.disable_trainable_weights(list(range(8, 16 + 1)))
            logger.info(f"named params {self.names}")
            self.enable_layers_all_weights(self.model_parts, self.names, list(range(0, 8)))
           
        elif self.step == self.depth_increase_step:
            self.enable_layers_all_weights(self.model_parts, self.names, list(range(0, 9)))
        elif self.step == self.depth_increase_step + 250:
            self.enable_layers_all_weights(self.model_parts, self.names, list(range(0, 10)))
        elif self.step == self.depth_increase_step + 500:
            self.enable_layers_all_weights(self.model_parts, self.names, list(range(0, 12)))
        elif self.step == self.depth_increase_step + 750:
            self.enable_layers_all_weights(self.model_parts, self.names, list(range(0, 14)))
        elif self.step == self.depth_increase_step + 1000:
            self.enable_layers_all_weights(self.model_parts, self.names, list(range(0,  16+ 1)))
        # Save the current step learning rate for logging
        lr = self.lr_schedulers.schedulers[0].get_last_lr()[0]

        # Reset diagnostics at the start of each step so they don't accumulate across steps
        for model_part in self.model_parts:
            target_model = model_part
            for attr in ("module", "_ddp_wrapped_module", "_orig_mod", "model"):
                if hasattr(target_model, attr):
                    target_model = getattr(target_model, attr)

            if hasattr(target_model, "reset_compression_energy_stats"):
                target_model.reset_compression_energy_stats()

            if hasattr(target_model, "reset_w2_stable_rank_stats"):
                target_model.reset_w2_stable_rank_stats()

            # Enable attention metrics collection for QK-clip if needed
            if (
                hasattr(target_model, "model_args")
                and hasattr(target_model.model_args, "use_qk_clip")
                and target_model.model_args.use_qk_clip
                and hasattr(target_model, "enable_attention_metrics_collection")
            ):
                target_model.clear_attention_metrics()
                target_model.enable_attention_metrics_collection(True)

        # Save the current step learning rate for logging
        lr = self.lr_schedulers.schedulers[0].get_last_lr()[0]

        # Regularize weights and embeddings
        with torch.no_grad():
            for m in self.model_parts:
                m.regularize_weights()
                if m.tok_embeddings is not None:
                    m.regularize_embeddings()

        dist.barrier()

        # Keep these variables local to shorten the code as these are
        # the major variables that are used in the training loop.
        parallel_dims = self.parallel_dims

        if self.job_config.training.enable_batch_warmup:
            # Get dynamic gradient accumulation steps for warmup
            current_grad_accum_steps = (
                self.batch_warmup_strategy.get_gradient_accumulation_steps(self.step)
            )
        else:
            current_grad_accum_steps = self.gradient_accumulation_steps

        accumulated_losses = []
        # If data runs out during gradient accumulation, that
        # entire step will not be executed.
        for microbatch in range(current_grad_accum_steps):
            input_dict, labels = next(data_iterator)
            is_accum = microbatch < current_grad_accum_steps - 1
            if is_accum:
                with ExitStack() as stack:
                    for model_part in self.model_parts:
                        # Classic DDP provides no_sync context manager
                        if hasattr(model_part, "no_sync"):
                            stack.enter_context(model_part.no_sync())
                        # Composable DDP (replicate) exposes set_requires_gradient_sync
                        elif hasattr(model_part, "set_requires_gradient_sync"):
                            model_part.set_requires_gradient_sync(False)

                    loss = self.forward_backward_step(input_dict, labels)
            else:
                # Ensure composable DDP syncs on the final accumulation microbatch
                for model_part in self.model_parts:
                    if hasattr(model_part, "set_requires_gradient_sync"):
                        model_part.set_requires_gradient_sync(True)
                loss = self.forward_backward_step(input_dict, labels)
            accumulated_losses.append(loss.detach())

        should_log_this_step = self.metrics_processor.should_log(self.step)

        # Log gradient clipping details on first step
        if self.step == 1:
            last_layer_info = (
                f", max_norm_last_layer={self.job_config.gradient_clipping.max_norm_last_layer}"
                if self.job_config.gradient_clipping.max_norm_last_layer is not None
                else ""
            )
            logger.info(
                f"Applying gradient clipping (step {self.step}): "
                f"method={self.job_config.gradient_clipping.method}, "
                f"max_norm={self.job_config.gradient_clipping.max_norm}{last_layer_info}"
            )

        preclip_layer_norms: dict[str, float] | None = None
        postclip_layer_norms: dict[str, float] | None = None
        if (
            hasattr(self.grad_clipper, "compute_per_layer_norms")
            and self.job_config.gradient_clipping.method == "vanilla"
            and self.job_config.gradient_clipping.scope == "global"
            and should_log_this_step
        ):
            try:
                preclip_layer_norms = self.grad_clipper.compute_per_layer_norms(
                    self.model_parts, parallel_dims
                )
            except Exception:
                preclip_layer_norms = None

        grad_norm = self.grad_clipper.step(
            self.model_parts,
            parallel_dims,
            self.job_config.gradient_clipping.max_norm,
            self.job_config.gradient_clipping.max_norm_last_layer,
        )

        if (
            hasattr(self.grad_clipper, "compute_per_layer_norms")
            and self.job_config.gradient_clipping.method == "vanilla"
            and self.job_config.gradient_clipping.scope == "global"
            and should_log_this_step
        ):
            try:
                postclip_layer_norms = self.grad_clipper.compute_per_layer_norms(
                    self.model_parts, parallel_dims
                )
            except Exception:
                postclip_layer_norms = None

        # Calculate per-layer gradient norms for attention and feed-forward weights
        per_layer_grad_norms = {}

        # Get per-layer gradient norms from the gradient clipper (if available)
        if should_log_this_step and hasattr(self.grad_clipper, "get_per_layer_norms"):
            clipper_layer_norms = self.grad_clipper.get_per_layer_norms()
            # Add clipper norms with a prefix to distinguish them
            for layer_key, norm_value in clipper_layer_norms.items():
                value = norm_value
                if isinstance(value, torch.Tensor):
                    value = value.detach()
                    value = value.float().cpu()
                    value = value.item() if value.numel() == 1 else value.mean().item()
                per_layer_grad_norms[f"grad_clip_{layer_key}"] = value

        # Add pre/post per-layer norms for global vanilla clipping if collected
        if should_log_this_step and preclip_layer_norms is not None:
            for layer_key, norm_value in preclip_layer_norms.items():
                value = norm_value
                if isinstance(value, torch.Tensor):
                    value = value.detach()
                    value = value.float().cpu()
                    value = value.item() if value.numel() == 1 else value.mean().item()
                per_layer_grad_norms[f"grad_clip_pre_{layer_key}"] = value
        if should_log_this_step and postclip_layer_norms is not None:
            for layer_key, norm_value in postclip_layer_norms.items():
                value = norm_value
                if isinstance(value, torch.Tensor):
                    value = value.detach()
                    value = value.float().cpu()
                    value = value.item() if value.numel() == 1 else value.mean().item()
                per_layer_grad_norms[f"grad_clip_post_{layer_key}"] = value

        # # Calculate compression energy loss metrics after all microbatches accumulated
        # compression_energy_stats = self._calculate_compression_energy_loss()

        # # Calculate w2 stable rank metrics if tracking is enabled
        # w2_stable_rank_stats = self._calculate_w2_stable_rank_stats()

        self.checkpointer.maybe_wait_for_staging()
        self.optimizers.step()
        self.lr_schedulers.step()

        # Apply QK-clip rescaling if enabled
        qk_clip_stats = self._apply_qk_clip_rescaling()

        # Reduce the data collected over gradient accumulation steps.
        loss = torch.sum(torch.stack(accumulated_losses))

        # log metrics
        if not should_log_this_step:
            return

        psgd_metrics: dict[str, float] = {}
        try:
            model_for_psgd = self.model_parts[0]
            psgd_container = getattr(model_for_psgd, "_powersgd_container", None)
            if (
                psgd_container is not None
                and getattr(psgd_container, "powersgd_state", None) is not None
            ):
                latest = getattr(
                    psgd_container.powersgd_state, "latest_error_norms", None
                )
                if isinstance(latest, dict) and len(latest) > 0:
                    # Separate categories to avoid one crowding out the other
                    err_items = sorted(
                        (k, v)
                        for k, v in latest.items()
                        if str(k).startswith("psgd_errnorm/")
                    )
                    cp_items = sorted(
                        (k, v)
                        for k, v in latest.items()
                        if str(k).startswith("psgd_input_cp_norm/")
                    )
                    other_items = sorted(
                        (k, v)
                        for k, v in latest.items()
                        if not (
                            str(k).startswith("psgd_errnorm/")
                            or str(k).startswith("psgd_input_cp_norm/")
                        )
                    )

                    # Summary stats per category
                    if err_items:
                        vals = [float(v) for _, v in err_items]
                        psgd_metrics["psgd_errnorm_count"] = float(len(vals))
                        psgd_metrics["psgd_errnorm_mean"] = float(np.mean(vals))
                        psgd_metrics["psgd_errnorm_max"] = float(np.max(vals))
                    if cp_items:
                        vals = [float(v) for _, v in cp_items]
                        psgd_metrics["psgd_input_cp_count"] = float(len(vals))
                        psgd_metrics["psgd_input_cp_mean"] = float(np.mean(vals))
                        psgd_metrics["psgd_input_cp_max"] = float(np.max(vals))

                    err_cap = 1024
                    cp_cap = 1024
                    other_cap = 64
                    for k, v in err_items[:err_cap]:
                        if k == "psgd_errnorm/b0_t0":
                            name = "psgd_errnorm/tok_embeddings.weight"
                        elif k == "psgd_errnorm/b12_t3":
                            name = "psgd_errnorm/output.weight"
                        elif k == "psgd_errnorm/b13_t1":
                            name = "psgd_errnorm/norm.weight"
                        elif k.startswith("psgd_errnorm/layers"):
                            name = str(k)
                        else:
                            continue
                        psgd_metrics[name] = float(v)

                    for k, v in cp_items[:cp_cap]:
                        if k == "psgd_input_cp_norm/b0_t0":
                            name = "psgd_input_cp_norm/tok_embeddings.weight"
                        elif k == "psgd_input_cp_norm/b12_t3":
                            name = "psgd_input_cp_norm/output.weight"
                        elif k == "psgd_input_cp_norm/b13_t1":
                            name = "psgd_input_cp_norm/norm.weight"
                        elif k.startswith("psgd_input_cp_norm/layers"):
                            name = str(k)
                        else:
                            continue
                        psgd_metrics[name] = float(v)

                    for k, v in other_items[:other_cap]:
                        psgd_metrics[str(k)] = float(v)
        except Exception:
            psgd_metrics = {}

        if parallel_dims.dp_cp_enabled:
            loss = loss.detach()
            ft_pg = self.ft_manager.loss_sync_pg
            global_avg_loss, global_max_loss, global_ntokens_seen = (
                dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
                dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
                dist_utils.dist_sum(
                    torch.tensor(
                        self.ntokens_seen, dtype=torch.int64, device=self.device
                    ),
                    parallel_dims.world_mesh["dp_cp"],
                    ft_pg,
                ),
            )
        else:
            global_avg_loss = global_max_loss = loss.detach().item()
            global_ntokens_seen = self.ntokens_seen

        # Add batch warmup information to metrics
        if self.job_config.training.enable_batch_warmup:
            warmup_info = self.batch_warmup_strategy.log_warmup_info(self.step)
        else:
            warmup_info = {}

        extra_metrics = {
            "n_tokens_seen": global_ntokens_seen,
            "lr": lr,
            # "final_logits_entropy": self.last_entropy,  # Include entropy of final model logits
            # **warmup_info,  # Include batch warmup metrics
            # **qk_clip_stats,  # Include QK-clip rescaling metrics
            **per_layer_grad_norms,  # Include per-layer gradient norm metrics
            # **compression_energy_stats,  # Include compression energy loss metrics
            # **w2_stable_rank_stats,  # Include w2 stable rank metrics
            **psgd_metrics,
        }
        self.metrics_processor.log(
            self.step,
            global_avg_loss,
            global_max_loss,
            grad_norm.item(),
            extra_metrics=extra_metrics,
        )

    @record
    def train(self):
        job_config = self.job_config

        self.checkpointer.load(step=job_config.checkpoint.load_step)

        # Restore PowerSGD state properly after checkpoint loading
        model = self.model_parts[0]
        _container = getattr(model, "_powersgd_container", None)
        _ddp_pg = getattr(model, "_ddp_pg", None)
        if _container is not None and _container.powersgd_state is not None:
            powersgd_state = _container.powersgd_state

            # Restore the process group
            if _ddp_pg is not None:
                powersgd_state.process_group = _ddp_pg
                logger.info("PowerSGD: Restored process group")

            # Preserve loaded iteration and buffers; our custom hook handles immediate resume safely.

            # Log PowerSGD state restoration details
            compression_matrices_loaded = False
            for attr in ("p_memory_dict", "q_memory_dict", "error_dict"):
                if hasattr(powersgd_state, attr):
                    attr_dict = getattr(powersgd_state, attr)
                    if isinstance(attr_dict, dict) and len(attr_dict) > 0:
                        logger.info(
                            f"PowerSGD: Restored {attr} with {len(attr_dict)} entries"
                        )
                        compression_matrices_loaded = True
                    else:
                        logger.info(
                            f"PowerSGD: {attr} is empty (will be rebuilt during training)"
                        )

            # Log key PowerSGD configuration from checkpoint
            config_attrs = [
                "matrix_approximation_rank",
                "start_powerSGD_iter",
                "min_compression_rate",
                "use_error_feedback",
                "warm_start",
            ]
            config_info = []
            for attr in config_attrs:
                if hasattr(powersgd_state, attr):
                    value = getattr(powersgd_state, attr)
                    config_info.append(f"{attr}={value}")

            if config_info:
                logger.info(
                    f"PowerSGD: Restored configuration - {', '.join(config_info)}"
                )

            # Validate and fix compression matrices from checkpoint
            if compression_matrices_loaded:
                total_p_params = 0
                total_q_params = 0
                total_error_params = 0
                matrices_valid = True

                # Validate compression matrices and ensure correct device placement
                for attr_name in ("p_memory_dict", "q_memory_dict", "error_dict"):
                    if hasattr(powersgd_state, attr_name):
                        attr_dict = getattr(powersgd_state, attr_name)
                        if isinstance(attr_dict, dict):
                            keys_to_remove = []
                            for key, tensor in list(attr_dict.items()):
                                if hasattr(tensor, "numel"):
                                    # Check for empty tensors (common cause of the reshape error)
                                    if tensor.numel() == 0:
                                        logger.warning(
                                            f"PowerSGD: Found empty tensor in {attr_name}[{key}], removing it"
                                        )
                                        keys_to_remove.append(key)
                                        matrices_valid = False
                                        continue

                                    # Move tensor to the correct device if needed
                                    if tensor.device != self.device:
                                        attr_dict[key] = tensor.to(self.device)
                                        logger.info(
                                            f"PowerSGD: Moved {attr_name}[{key}] to {self.device}"
                                        )

                                    # Count parameters
                                    if attr_name == "p_memory_dict":
                                        total_p_params += tensor.numel()
                                    elif attr_name == "q_memory_dict":
                                        total_q_params += tensor.numel()
                                    elif attr_name == "error_dict":
                                        total_error_params += tensor.numel()

                            # Remove invalid tensors
                            for key in keys_to_remove:
                                del attr_dict[key]

                # If we're past the start iteration but matrices are invalid, force complete reinitialization
                start_iter = getattr(powersgd_state, "start_powerSGD_iter", 0)
                if not matrices_valid and self.step >= start_iter:
                    logger.warning(
                        f"PowerSGD: Invalid compression matrices detected at step {self.step} "
                        f"(>= start_iter {start_iter}). Clearing all matrices for complete reinitialization."
                    )
                    # Clear all compression matrices to force fresh initialization
                    for attr_name in ("p_memory_dict", "q_memory_dict", "error_dict"):
                        if hasattr(powersgd_state, attr_name):
                            getattr(powersgd_state, attr_name).clear()

                    # Reset compression statistics since we're starting fresh
                    if hasattr(powersgd_state, "total_numel_before_compression"):
                        powersgd_state.total_numel_before_compression = 0
                    if hasattr(powersgd_state, "total_numel_after_compression"):
                        powersgd_state.total_numel_after_compression = 0

                    logger.info(
                        "PowerSGD: All compression matrices cleared, will reinitialize on first use"
                    )
                elif not matrices_valid:
                    logger.info(
                        f"PowerSGD: Invalid matrices removed, but step {self.step} < start_iter {start_iter}, "
                        f"so matrices will be initialized when compression starts."
                    )

                logger.info(
                    f"PowerSGD: Validated compression matrices - "
                    f"P: {total_p_params:,} params, Q: {total_q_params:,} params, "
                    f"Error: {total_error_params:,} params"
                )

                # Log compression statistics if available
                if hasattr(
                    powersgd_state, "total_numel_before_compression"
                ) and hasattr(powersgd_state, "total_numel_after_compression"):
                    before = powersgd_state.total_numel_before_compression
                    after = powersgd_state.total_numel_after_compression
                    if before > 0 and after > 0:
                        ratio = before / after
                        logger.info(
                            f"PowerSGD: Compression ratio from checkpoint - "
                            f"{ratio:.2f}:1 ({((before - after) / before * 100):.1f}% reduction)"
                        )
            else:
                # No compression matrices in checkpoint
                start_iter = getattr(powersgd_state, "start_powerSGD_iter", 0)
                if self.step >= start_iter:
                    logger.info(
                        f"PowerSGD: No compression matrices in checkpoint, but step {self.step} >= start_iter {start_iter}. "
                        f"PowerSGD will initialize matrices on first use."
                    )
                else:
                    logger.info(
                        "PowerSGD: No compression matrices in checkpoint, will initialize when needed"
                    )

        logger.info(f"Training starts at step {self.step + 1}")

        leaf_folder = (
            ""
            if not self.ft_manager.enabled
            else f"replica_{self.ft_manager.replica_id}"
        )

        # Optionally align dataloader by skipping microbatches to match token offset on resume.
        # Enable by setting env var TT_ALIGN_DATA_BY_TOKENS=1 (or TT_RESUME_SKIP_MICROBATCHES to an integer).
        try:
            self._maybe_align_dataloader_to_tokens()
        except Exception as _ex:
            logger.warning(f"Data alignment by tokens failed: {_ex}")
        with (
            maybe_enable_profiling(
                job_config.profiling,
                global_step=self.step,
                base_folder=job_config.job.dump_folder,
                leaf_folder=leaf_folder,
            ) as torch_profiler,
            maybe_enable_memory_snapshot(
                job_config.profiling,
                global_step=self.step,
                base_folder=job_config.job.dump_folder,
                leaf_folder=leaf_folder,
            ) as memory_profiler,
            maybe_semi_sync_training(
                job_config.fault_tolerance,
                ft_manager=self.ft_manager,
                model=self.model_parts[0],
                n_layers=(
                    self.model_args.n_layers
                    if hasattr(self.model_args, "n_layers")
                    else 0
                ),
                optimizer=self.optimizers,
                fragment_fn=(
                    self.train_spec.fragment_fn
                    if hasattr(self.train_spec, "fragment_fn")
                    else None
                ),
            ),
        ):
            data_iterator = self.batch_generator(self.dataloader)
            while self.step < job_config.training.steps:
                self.step += 1
                if self.step == self.depth_increase_step:
                    data_iterator = self.set_sequence_length(job_config.training.seq_len//2, self.step)
                self.gc_handler.run(self.step)
                try:
                    self.train_step(data_iterator)
                except DataloaderStopIteration:
                    logger.warning("Ran out of data; last step was canceled.")
                    break

                # Run validation if validator is available
                if (
                    self.job_config.validation.enabled
                    and self.validator.should_validate(self.step)
                ):
                    self.validator.validate(self.model_parts, self.step)

                self.checkpointer.save(
                    self.step, last_step=(self.step == job_config.training.steps)
                )

                # signal the profiler that the next profiling step has started
                if torch_profiler:
                    torch_profiler.step()
                if memory_profiler:
                    memory_profiler.step()

                # reduce timeout after first train step for faster signal
                # (assuming lazy init and compilation are finished)
                if self.step == 1:
                    dist_utils.set_pg_timeouts(
                        timeout=timedelta(
                            seconds=job_config.comm.train_timeout_seconds
                        ),
                        world_mesh=self.parallel_dims.world_mesh,
                    )

        if torch.distributed.get_rank() == 0:
            logger.info("Sleeping 2 seconds for other ranks to complete")
            time.sleep(2)

        logger.info("Training completed")

    def _calculate_per_layer_grad_norms(self, parallel_dims) -> dict[str, float]:
        """
        Calculate gradient norms for attention and feed-forward weights per layer.

        Returns:
            Dictionary with keys like "grad_norm_attn_layer_0", "grad_norm_ffn_layer_0", etc.
        """

        per_layer_norms = {}

        # Process each model part (for pipeline parallelism)
        for model_part in self.model_parts:
            # Access layers from the model
            if hasattr(model_part, "layers"):
                layers = model_part.layers
            elif hasattr(model_part, "module") and hasattr(model_part.module, "layers"):
                layers = model_part.module.layers
            else:
                continue

            for layer_id, layer_module in layers.items():
                # Calculate attention weight gradient norm
                if hasattr(layer_module, "attention") and hasattr(
                    layer_module.attention, "wo"
                ):
                    wo_weight = layer_module.attention.wo.weight
                    if wo_weight.grad is not None:
                        attn_params = [wo_weight]
                        attn_grad_norm = dist_utils.clip_grad_norm_(
                            attn_params,
                            float("inf"),  # No clipping, just calculate norm
                            foreach=True,
                            pp_mesh=(
                                parallel_dims.world_mesh["pp"]
                                if parallel_dims.pp_enabled
                                else None
                            ),
                            ep_dense_params_mesh_ndim=(
                                parallel_dims.dense_params_mesh_ndim
                                if parallel_dims.ep_enabled
                                else None
                            ),
                        )
                        per_layer_norms[
                            f"grad_norm_attn_layer_{layer_id}"
                        ] = attn_grad_norm.item()

                # Calculate feed-forward weight gradient norm
                if hasattr(layer_module, "feed_forward") and hasattr(
                    layer_module.feed_forward, "w2"
                ):
                    w2_weight = layer_module.feed_forward.w2.weight
                    if w2_weight.grad is not None:
                        ffn_params = [w2_weight]
                        ffn_grad_norm = dist_utils.clip_grad_norm_(
                            ffn_params,
                            float("inf"),  # No clipping, just calculate norm
                            foreach=True,
                            pp_mesh=(
                                parallel_dims.world_mesh["pp"]
                                if parallel_dims.pp_enabled
                                else None
                            ),
                            ep_dense_params_mesh_ndim=(
                                parallel_dims.dense_params_mesh_ndim
                                if parallel_dims.ep_enabled
                                else None
                            ),
                        )
                        per_layer_norms[
                            f"grad_norm_ffn_layer_{layer_id}"
                        ] = ffn_grad_norm.item()

        return per_layer_norms

    def _calculate_compression_energy_loss(self) -> dict[str, float]:
        """
        Calculate compression energy loss metrics from all model parts.

        Returns:
            Dictionary with compression energy loss statistics
        """
        compression_stats = {}

        # Process each model part (for pipeline parallelism)
        for model_part in self.model_parts:
            # Aggressively unwrap common wrappers (DDP/FSDP/compile wrappers)
            target_model = model_part
            for attr in ("module", "_ddp_wrapped_module", "_orig_mod", "model"):
                if hasattr(target_model, attr):
                    target_model = getattr(target_model, attr)

            if hasattr(target_model, "get_compression_energy_loss_stats"):
                part_stats = target_model.get_compression_energy_loss_stats()
                compression_stats.update(part_stats)

        return compression_stats

    def _calculate_w2_stable_rank_stats(self) -> dict[str, float]:
        """
        Calculate w2 stable rank metrics from all model parts.

        Returns:
            Dictionary with w2 stable rank statistics
        """
        stable_rank_stats = {}

        # Process each model part (for pipeline parallelism)
        for model_part in self.model_parts:
            # Aggressively unwrap common wrappers (DDP/FSDP/compile wrappers)
            target_model = model_part
            for attr in ("module", "_ddp_wrapped_module", "_orig_mod", "model"):
                if hasattr(target_model, attr):
                    target_model = getattr(target_model, attr)
            if hasattr(target_model, "stable_rank_w2"):
                for layer_id, stable_rank in target_model.stable_rank_w2.items():
                    stable_rank_stats[f"w2_stable_rank_layer_{layer_id}"] = stable_rank

        return stable_rank_stats

    def _apply_qk_clip_rescaling(self) -> dict[str, Any]:
        """
        Apply QK-Clip rescaling to model parts if enabled and collect
        per-layer entropy metrics. We perform a distributed MAX reduction
        of S_max^h across the dp_cp mesh to ensure identical rescaling
        on all devices.

        Returns a flat dict of metrics suitable for logging.
        """
        qk_clip_stats: dict[str, Any] = {}

        parallel_dims = self.parallel_dims
        dp_cp_mesh = (
            parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None
        )
        ft_pg = self.ft_manager.loss_sync_pg if parallel_dims.dp_cp_enabled else None

        # Iterate each model part (handles PP)
        for model_part in self.model_parts:
            # Unwrap common wrappers
            target_model = model_part
            for attr in ("module", "_ddp_wrapped_module", "_orig_mod", "model"):
                if hasattr(target_model, attr):
                    target_model = getattr(target_model, attr)

            # Skip if QK-Clip is not enabled
            if not (
                hasattr(target_model, "model_args")
                and getattr(target_model.model_args, "use_qk_clip", False)
            ):
                continue

            # Access attention metrics captured during forward
            if not hasattr(target_model, "_attention_metrics"):
                continue

            threshold = getattr(target_model.model_args, "qk_clip_threshold", 8.0)
            alpha = getattr(target_model.model_args, "qk_clip_alpha", 0.5)

            # Aggregate entropy across layers for logging
            entropy_accum = 0.0
            entropy_count = 0

            # Walk layers for which we have metrics
            for layer_id, metrics in list(target_model._attention_metrics.items()):
                # Entropy logging (mean over B,H,T)
                if "entropy" in metrics:
                    try:
                        # Layer-wide mean entropy (B,H,T -> scalar)
                        entropy_mean = metrics["entropy"].mean().detach().float()
                        qk_clip_stats[f"layer_{layer_id}_entropy_mean"] = float(
                            entropy_mean.item()
                        )
                        entropy_accum += float(entropy_mean.item())
                        entropy_count += 1

                        # Per-head mean entropy (mean over B and T -> H)
                        entropy_per_head = (
                            metrics["entropy"].mean(dim=(0, 2)).detach().float()
                        )  # (H,)
                        if dp_cp_mesh is not None:
                            # First reduce on extra pg (FT optimizer wrapper) if present
                            if ft_pg is not None:
                                entropy_per_head = funcol.all_reduce(
                                    entropy_per_head,
                                    reduceOp=c10d.ReduceOp.AVG.name,
                                    group=ft_pg,
                                )
                            entropy_per_head = funcol.all_reduce(
                                entropy_per_head,
                                reduceOp=c10d.ReduceOp.AVG.name,
                                group=dp_cp_mesh,
                            )
                    except Exception:
                        pass

                # QK-Clip per-head maxima S_max^h if available
                if "qk_row_max" in metrics:
                    qk_row_max = metrics["qk_row_max"]  # (B, H, T)
                    # Local per-head S_max over B and T
                    smax_per_head = qk_row_max.amax(dim=2).amax(dim=0)  # (H,)

                    # All-reduce (MAX) across dp_cp mesh to synchronize S_max globally
                    if dp_cp_mesh is not None:
                        # Reduce on extra_pg first if present, then across the dp_cp device mesh
                        if ft_pg is not None:
                            smax_per_head = funcol.all_reduce(
                                smax_per_head,
                                reduceOp=c10d.ReduceOp.MAX.name,
                                group=ft_pg,
                            )
                        smax_per_head = funcol.all_reduce(
                            smax_per_head,
                            reduceOp=c10d.ReduceOp.MAX.name,
                            group=dp_cp_mesh,
                        )

                    # Log S_max always if computed
                    qk_clip_stats[f"layer_{layer_id}_smax"] = np.mean(
                        smax_per_head.detach().cpu().tolist()
                    )

                    # Compute gamma_h = min(1, tau / S_max)
                    gamma_h = torch.minimum(
                        torch.ones_like(smax_per_head),
                        threshold / (smax_per_head + 1e-12),
                    )

                    needs_rescaling = (smax_per_head > threshold).any().item()
                    if needs_rescaling:
                        # Apply in-place rescaling of projection weights
                        layer = target_model.layers[str(layer_id)]
                        qk_clip_rescale_(layer.attention, gamma_h, alpha)

                        qk_clip_stats[f"layer_{layer_id}_gamma"] = np.mean(
                            gamma_h.detach().cpu().tolist()
                        )
                        qk_clip_stats[f"layer_{layer_id}_heads_clipped"] = int(
                            (gamma_h < 1.0).sum().item()
                        )
                    else:
                        qk_clip_stats[f"layer_{layer_id}_gamma"] = 1.0
                        qk_clip_stats[f"layer_{layer_id}_heads_clipped"] = 0

            # Global entropy (average of per-layer means)
            if entropy_count > 0:
                qk_clip_stats["entropy_mean_over_layers"] = (
                    entropy_accum / entropy_count
                )

        return qk_clip_stats

    def state_dict(self) -> dict[str, Any]:
        return {
            "step": self.step,
            "ntokens_seen": self.ntokens_seen,
            "final_logits_entropy": self.last_entropy,
        }

    def load_state_dict(self, state_dict: dict[str, Any]):
        self.step = state_dict["step"]
        self.ntokens_seen = state_dict["ntokens_seen"]
        self.last_entropy = state_dict.get(
            "last_entropy", 0.0
        )  # Default for older checkpoints

    def close(self) -> None:
        if self.checkpointer:
            self.checkpointer.close()
        if self.metrics_processor:
            self.metrics_processor.close()


if __name__ == "__main__":
    init_logger()
    config_manager = ConfigManager()
    config = config_manager.parse_args()
    trainer: Optional[Trainer] = None

    try:
        trainer = Trainer(config)

        if config.checkpoint.create_seed_checkpoint:
            assert (
                int(os.environ["WORLD_SIZE"]) == 1
            ), "Must create seed checkpoint using a single device, to disable sharding."
            assert (
                config.checkpoint.enable_checkpoint
            ), "Must enable checkpointing when creating a seed checkpoint."
            trainer.checkpointer.save(curr_step=0, last_step=True)
            logger.info("Created seed checkpoint")
        else:
            trainer.train()
    except Exception:
        if trainer:
            trainer.close()
        raise
    else:
        trainer.close()
        torch.distributed.destroy_process_group()
        logger.info("Process group destroyed")
