# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os

DEBUG_MODE = os.environ.get("DEBUG_MODE", "False") == "True"

os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["LOCAL_RANK"]
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]

import importlib
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional

import torch
torch.set_num_threads(12)
# torch.cuda.current_device()

if DEBUG_MODE:
        # if torch.distributed.get_rank() == 0:
    breakpoint()
    torch.distributed.barrier()

# torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    # optionally:
# assert torch.cuda.current_device() == int(os.environ["LOCAL_RANK"])

from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.metrics import (
    build_metrics_processor,
    ensure_pp_loss_visible,
)
from torchtitan.config_manager import ConfigManager, JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.protocols.model_converter import build_model_converters
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.profiling import (
    maybe_enable_memory_snapshot,
    maybe_enable_profiling,
)

import math
from typing import List, Dict
import traceback
from pathlib import Path
import json


class Trainer(torch.distributed.checkpoint.stateful.Stateful):
    job_config: JobConfig
    gc_handler: utils.GarbageCollection

    parallel_dims: ParallelDims
    train_spec: train_spec_module.TrainSpec
    world_mesh: torch.distributed.DeviceMesh

    dataloader: train_spec_module.BaseDataLoader
    metrics_processor: train_spec_module.MetricsProcessor
    checkpointer: CheckpointManager
    train_context: Generator[None, None, None]

    model_parts: list[torch.nn.Module]
    loss_fn: train_spec_module.LossFunction
    optimizers: train_spec_module.OptimizersContainer
    lr_schedulers: train_spec_module.LRSchedulersContainer

    pp_has_first_stage: bool
    pp_has_last_stage: bool

    device: torch.device

    # states
    step: int
    
    # evaluation states
    eval_samples: List[Dict[str, torch.Tensor]]
    tokenizer: Optional[Any]

    # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
    @record
    def __init__(self, job_config: JobConfig):
        self.job_config = job_config

        logger.info(f"Starting job: {job_config.job.description}")

        if job_config.experimental.custom_import:
            importlib.import_module(job_config.experimental.custom_import)

        if job_config.job.print_args:
            logger.info(f"Running with args: {job_config.to_dict()}")

        device_module, device_type = utils.device_module, utils.device_type
        self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
        # Device has to be set before creating TorchFT manager.
        device_module.set_device(self.device)

        # init distributed
        world_size = int(os.environ["WORLD_SIZE"])
        parallelism_config = job_config.parallelism
        self.parallel_dims = parallel_dims = ParallelDims(
            dp_shard=parallelism_config.data_parallel_shard_degree,
            dp_replicate=parallelism_config.data_parallel_replicate_degree,
            cp_ring=parallelism_config.context_parallel_degree,
            cp_ulysses=parallelism_config.context_parallel_ulysses_degree,
            tp=parallelism_config.tensor_parallel_degree,
            pp=parallelism_config.pipeline_parallel_degree,
            world_size=world_size,
            enable_loss_parallel=not parallelism_config.disable_loss_parallel,
        )
        dist_utils.init_distributed(job_config)

        # torch.distributed.barrier()

        # build meshes
        self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type)
        if parallel_dims.dp_enabled:
            dp_mesh = world_mesh["dp"]
            dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
        else:
            dp_degree, dp_rank = 1, 0

        self.ft_manager = ft.init_ft_manager(job_config)
        # If TorchFT is enabled, the dp_rank and dp_degree, which are used for
        # dataloader must be changed.
        if self.ft_manager.enabled:
            dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)

        # take control of garbage collection to avoid stragglers
        self.gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)

        # Set random seed, and maybe enable deterministic mode
        # (mainly for debugging, expect perf loss).
        dist_utils.set_determinism(
            world_mesh,
            self.device,
            job_config.training.seed,
            job_config.training.deterministic,
        )
        self.train_spec = train_spec_module.get_train_spec(job_config.model.name)

        # build dataloader
        tokenizer = (
            self.train_spec.build_tokenizer_fn(job_config)
            if self.train_spec.build_tokenizer_fn is not None
            else None
        )

        self.dataloader = self.train_spec.build_dataloader_fn(
            dp_world_size=dp_degree,
            dp_rank=dp_rank,
            tokenizer=tokenizer,
            job_config=job_config,
        )

        if job_config.evaluation.enable_eval:
            self.valid_dataloader = self.train_spec.build_dataloader_fn(
                dp_world_size=dp_degree,
                dp_rank=dp_rank,
                tokenizer=tokenizer,
                job_config=job_config,
                split="validation",
            )

        # build model (using meta init)
        model_cls = self.train_spec.cls
        model_args = self.train_spec.config[job_config.model.flavor]
        # set the model args from training job configs
        model_args.update_from_config(job_config, tokenizer)

        from yunchang import set_seq_parallel_pg, patch_ulysses_ring_pg
        # set_seq_parallel_pg(
        #     sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim
        # )
        #TODO: uncomment this when using 2d context parallel
        # set_seq_parallel_pg(1, 8, torch.distributed.get_rank(), world_size, False)
        if parallel_dims.cp_ulysses > 1:
            patch_ulysses_ring_pg(world_mesh["cp_ulysses", "cp_ring"].get_group(0), world_mesh["cp_ulysses", "cp_ring"].get_group(1))
            # if "dist" in job_config.model.attn_impl:
            #     initialize_distributed(ring_pg=world_mesh["cp_ulysses", "cp_ring"].get_group(1))
        else:
            # create a dummy pg with 1 device
            dummy_pg = torch.distributed.new_group([torch.distributed.get_rank()])
            patch_ulysses_ring_pg(dummy_pg, world_mesh["cp_ring"].get_group())
            # if "dist" in job_config.model.attn_impl:
            #     initialize_distributed(ring_pg=world_mesh["cp_ring"].get_group())

        

        # set the model args from training job configs
        logger.info(
            f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
        )
        with torch.device("meta"):
            model = model_cls.from_model_args(model_args)

        # Build the collection of model converters. No-op if `model.converters` empty
        model_converters = build_model_converters(job_config, parallel_dims)
        model_converters.convert(model)

        # metrics logging
        build_metrics_processor_fn = (
            build_metrics_processor
            if self.train_spec.build_metrics_processor_fn is None
            else self.train_spec.build_metrics_processor_fn
        )
        self.metrics_processor = build_metrics_processor_fn(
            job_config, parallel_dims, model_args
        )
        color = self.metrics_processor.color

        # calculate model size and flops per token
        (
            model_param_count,
            self.metrics_processor.num_flops_per_token,
        ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)

        logger.info(
            f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
            f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
        )

        # move sharded model to CPU/GPU and initialize weights via DTensor
        if job_config.checkpoint.create_seed_checkpoint:
            init_device = "cpu"
            buffer_device = None
        elif job_config.training.enable_cpu_offload:
            init_device = "cpu"
            buffer_device = device_type
        else:
            init_device = device_type
            buffer_device = None

        self.loss_fn = self.train_spec.build_loss_fn(job_config)

        # apply parallelisms and initialization
        if parallel_dims.pp_enabled:
            if not self.train_spec.pipelining_fn:
                raise RuntimeError(
                    f"Pipeline Parallel is enabled but {self.train_spec.name} "
                    f"does not support pipelining"
                )

            # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
            (
                self.pp_schedule,
                self.model_parts,
                self.pp_has_first_stage,
                self.pp_has_last_stage,
            ) = self.train_spec.pipelining_fn(
                model,
                world_mesh,
                parallel_dims,
                job_config,
                self.device,
                model_args,
                self.train_spec.parallelize_fn,
                self.loss_fn,
            )
            # when PP is enabled, `model` obj is no longer used after this point,
            # model_parts is used instead
            del model

            for m in self.model_parts:
                m.to_empty(device=init_device)
                with torch.no_grad():
                    m.init_weights(buffer_device=buffer_device)
                m.train()
            
            # Load pretrained weights if specified
            if job_config.model.pretrained_checkpoint_path:
                load_pretrained_checkpoint(
                    self.model_parts, 
                    job_config.model.pretrained_checkpoint_path, 
                    self.device,
                    model_args
                )

            # confirm that user will be able to view loss metrics on the console
            ensure_pp_loss_visible(parallel_dims, job_config, color)
        else:
            # Load pretrained weights if specified
            if job_config.model.pretrained_checkpoint_path:
                model.to_empty(device=init_device)
                load_pretrained_checkpoint(
                    [model],
                    job_config.model.pretrained_checkpoint_path, 
                    self.device,
                    model_args
                )
            
            # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
            model = self.train_spec.parallelize_fn(
                model, world_mesh, parallel_dims, job_config
            )

            if not job_config.model.pretrained_checkpoint_path:
                model.to_empty(device=init_device)
                # model.freqs_cis = model._precompute_freqs_cis()
                with torch.no_grad():
                    model.init_weights(buffer_device=buffer_device)
            model.train()

            self.model_parts = [model]
            
            

        if (
            self.ft_manager.enabled
            and job_config.fault_tolerance.semi_sync_method is None
        ):
            self.ft_manager.set_all_reduce_hook(self.model_parts)

        # initialize device memory monitor and get peak flops for MFU calculation
        device_memory_monitor = self.metrics_processor.device_memory_monitor
        gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
        logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
        device_mem_stats = device_memory_monitor.get_peak_stats()
        logger.info(
            f"{device_type.upper()} memory usage for model: "
            f"{device_mem_stats.max_reserved_gib:.2f}GiB"
            f"({device_mem_stats.max_reserved_pct:.2f}%)"
        )

        # build optimizer after applying parallelisms to the model
        self.optimizers = self.train_spec.build_optimizers_fn(
            self.model_parts, job_config, self.ft_manager
        )
        self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
            self.optimizers, job_config
        )
        # Post optimizer step model converters hook.
        # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
        # where it issues a single all-reduce for all parameters at once for better performance
        self.optimizers.register_step_post_hook(
            lambda *args, **kwargs: model_converters.post_optimizer_hook(
                self.model_parts
            )
        )
        self.metrics_processor.optimizers = self.optimizers

        # Initialize trainer states that will be saved in checkpoint.
        # These attributes must be initialized before checkpoint loading.
        self.step = 0

        self.grad_acc_steps = job_config.optimizer.gradient_accumulation_steps
        
        # Initialize evaluation
        # self.tokenizer = tokenizer
        # self.eval_samples = []
        # if job_config.evaluation.enable_eval:
        #     self._prepare_eval_samples()

        self.checkpointer = CheckpointManager(
            dataloader=self.dataloader,
            model_parts=self.model_parts,
            optimizers=self.optimizers,
            lr_schedulers=self.lr_schedulers,
            states={"train_state": self},
            job_config=job_config,
            ft_manager=self.ft_manager,
        )

        self.train_context = dist_utils.get_train_context(
            parallel_dims.loss_parallel_enabled,
            parallelism_config.enable_compiled_autograd,
        )

        # chunk freqs cis buffers for each model part
        from torchtitan.distributed.context_parallel_2d import _context_parallel_buffers_2d
        freqs_cis_buffers = [m.freqs_cis for m in self.model_parts]
        freqs_cis_buffer_seq_dims = [0 for _ in self.model_parts]
        cp_mesh = world_mesh["cp_ulysses", "cp_ring"] if parallel_dims.cp_ulysses > 1 else world_mesh["cp_ring"]
        chunks = _context_parallel_buffers_2d(cp_mesh, freqs_cis_buffers, freqs_cis_buffer_seq_dims)
        for buffer, chunk in zip(freqs_cis_buffers, chunks):
            chunk = chunk.clone()
            buffer.resize_(chunk.shape)
            buffer.copy_(chunk)

        logger.info(
            "Trainer is initialized with "
            f"local batch size {job_config.training.batch_size}, "
            f"global batch size {job_config.training.batch_size * dp_degree}, "
            f"sequence length {job_config.training.seq_len}, "
            f"total steps {job_config.training.steps} "
            f"(warmup {job_config.lr_scheduler.warmup_steps})."
        )

    def batch_generator(
        self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
    ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]:
        """Returns an iterator that processes batches from the data iterator."""
        device_type = utils.device_type

        for batch in iter(data_iterable):
            data_load_start = time.perf_counter()
            input_dict, labels = batch
            self.metrics_processor.ntokens_since_last_log += labels.numel()
            self.metrics_processor.data_loading_times.append(
                time.perf_counter() - data_load_start
            )

            # Move tensors to the appropriate device
            for k, v in input_dict.items():
                if isinstance(v, torch.Tensor):
                    input_dict[k] = v.to(device_type)
            labels = labels.to(device_type)

            yield input_dict, labels

    def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
        self.optimizers.zero_grad()

        # Keep these variables local to shorten the code as these are
        # the major variables that are used in the training loop.
        model_parts = self.model_parts
        world_mesh = self.world_mesh
        parallel_dims = self.parallel_dims

        # apply context parallelism if cp is enabled
        # ensure CP handles the separate freqs_cis buffer for each pp stage
        inputs = input_dict["input"]
        optional_context_parallel_ctx = (
            dist_utils.create_context_parallel_ctx(
                cp_mesh=(
                    world_mesh["cp_ulysses", "cp_ring"] 
                    if parallel_dims.cp_ulysses > 1 
                    else world_mesh["cp_ring"]
                ),
                cp_buffers=[inputs, labels], # + [m.freqs_cis for m in model_parts],
                cp_seq_dims=[1, 1],# + [0 for _ in model_parts],
                # cp_no_restore_buffers={inputs, labels},
                cp_no_restore_buffers={inputs, labels},
                cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
                cp_impl=self.job_config.model.attn_impl,
            )
            if parallel_dims.cp_enabled
            else None
        )

        if parallel_dims.pp_enabled:
            # Pipeline Parallel forward / backward inside step() call
            with self.train_context(optional_context_parallel_ctx):
                targets, losses = (
                    (labels, []) if self.pp_has_last_stage else (None, None)
                )
                if self.pp_has_first_stage:
                    self.pp_schedule.step(
                        inputs, target=targets, losses=losses, input_batch=inputs
                    )
                else:
                    self.pp_schedule.step(
                        target=targets, losses=losses, input_batch=inputs
                    )

            # accumulate losses across pipeline microbatches
            # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
            loss = (
                torch.mean(torch.stack(losses)).to(self.device)
                if self.pp_has_last_stage
                else torch.tensor([-1.0], device=self.device)
            )
        # elif self.job_config.activation_checkpoint.offloading == "TAO":
        #     # Non-PP forward / backward
        #     with self.train_context(optional_context_parallel_ctx), OffloadActivations():
        #         assert len(model_parts) == 1
                
        #         # if torch.all(labels == -100):
        #         #     loss = torch.tensor([0.0], device=self.device, requires_grad=True) #self.loss_fn(pred, pred)
        #         # else:
        #         #     # labels = labels[labels != -100]
        #         if self.job_config.training.chunked_loss:
        #             loss = model_parts[0](inputs, labels=labels)
        #             # loss = self.loss_fn(output_weights, pred, labels.reshape(-1))
        #         else:
        #             pred = model_parts[0](inputs, labels=labels)
        #             loss = self.loss_fn(pred, labels)
                
        #             # need to free to before bwd to avoid peaking memory
        #             del pred
                
        #         non_masked_tokens_current_node = (labels != -100).sum().detach()
        #         non_masked_tokens_total = non_masked_tokens_current_node.clone()
        #         torch.distributed.all_reduce(non_masked_tokens_total, op=torch.distributed.ReduceOp.SUM, group=world_mesh.get_group(mesh_dim="dp_cp"))  # Modifies in-place
        #         non_masked_tokens_total = max(non_masked_tokens_total.item(), 1)  # Ensure no division by zero
        #         loss *= non_masked_tokens_current_node / non_masked_tokens_total

        #         loss.backward()
                
        #         if torch.all(labels == -100):
        #             loss.data.zero_()
        else:
            # Non-PP forward / backward
            with self.train_context(optional_context_parallel_ctx):
                assert len(model_parts) == 1
                
                # if torch.all(labels == -100):
                #     loss = torch.tensor([0.0], device=self.device, requires_grad=True) #self.loss_fn(pred, pred)
                # else:
                #     # labels = labels[labels != -100]
                if self.job_config.training.chunked_loss:
                    loss = model_parts[0](inputs, labels=labels)
                    # loss = self.loss_fn(output_weights, pred, labels.reshape(-1))
                else:
                    pred = model_parts[0](inputs, labels=labels)
                    loss = self.loss_fn(pred, labels)
                
                    # need to free to before bwd to avoid peaking memory
                    del pred
                
                non_masked_tokens_current_node = (labels != -100).sum().detach()
                non_masked_tokens_total = non_masked_tokens_current_node.clone()
                torch.distributed.all_reduce(non_masked_tokens_total, op=torch.distributed.ReduceOp.SUM, group=world_mesh.get_group(mesh_dim="dp_shard_cp"))  # Modifies in-place
                non_masked_tokens_total = max(non_masked_tokens_total.item(), 1)  # Ensure no division by zero
                loss *= non_masked_tokens_current_node / non_masked_tokens_total

                loss.backward()
                
                if torch.all(labels == -100):
                    loss.data.zero_()

        # if "fpdt" not in self.job_config.model.attn_impl:
        dist_utils.clip_grad_norm_(
            [p for m in model_parts for p in m.parameters()],
            self.job_config.training.max_norm,
            foreach=True,
            pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
        )
        self.checkpointer.maybe_wait_for_staging()
        # if not torch.all(labels == -100):
        #     self.optimizers.zero_grad(set_to_none=True) # remove the nan grads from loss due to all -100 labels
        #     loss = torch.tensor([0.0], device=self.device)
        if self.step % self.grad_acc_steps == 0:
            self.optimizers.step()
            self.lr_schedulers.step()

        # log metrics
        if not self.metrics_processor.should_log(self.step):
            return

        if (
            parallel_dims.dp_replicate_enabled
            or parallel_dims.dp_shard_enabled
            or parallel_dims.cp_enabled
            or self.ft_manager.enabled
        ):
            loss = loss.detach()
            # Skip ft manager communication when using semi sync training
            use_ft_pg = (
                self.ft_manager.enabled
                and self.job_config.fault_tolerance.semi_sync_method is None
            )
            ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None
            global_avg_loss, global_max_loss = (
                dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg),
                dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),
            )
        else:
            global_avg_loss = global_max_loss = loss.detach().item()
        
        if os.environ.get("CLEAR_CUDA_CACHE", "0") == "1":
            torch.cuda.empty_cache()

        if self.job_config.evaluation.enable_eval and (self.step % self.job_config.evaluation.eval_freq == 0 or self.step == 1):
            val_loss = self.validate()
            self.metrics_processor.log(self.step, global_avg_loss, global_max_loss, {"loss_metrics/val_loss": val_loss})

        else:
            self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
    
    def validate_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
        model_parts = self.model_parts
        world_mesh = self.world_mesh
        parallel_dims = self.parallel_dims

        inputs = input_dict["input"]
        optional_context_parallel_ctx = (
            dist_utils.create_context_parallel_ctx(
                cp_mesh=(
                    world_mesh["cp_ulysses", "cp_ring"] 
                    if parallel_dims.cp_ulysses > 1 
                    else world_mesh["cp_ring"]
                ),
                cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
                cp_seq_dims=[1, 1] + [0 for _ in model_parts],
                cp_no_restore_buffers={inputs, labels},
                cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
                cp_impl=self.job_config.model.attn_impl,
            )
            if parallel_dims.cp_enabled
            else None
        )

        with self.train_context(optional_context_parallel_ctx), torch.no_grad():
            assert len(model_parts) == 1, "Only one model part is supported for validation"
            model_parts[0].eval()
            pred = model_parts[0](inputs)
            loss = self.loss_fn(pred, labels)
            if torch.all(labels == -100):
                loss.data.zero_()
            
            non_masked_tokens_current_node = (labels != -100).sum().detach()
            non_masked_tokens_total = non_masked_tokens_current_node.clone()
            torch.distributed.all_reduce(non_masked_tokens_total, op=torch.distributed.ReduceOp.SUM, group=world_mesh.get_group(mesh_dim="dp_cp"))  # Modifies in-place
            non_masked_tokens_total = max(non_masked_tokens_total.item(), 1)  # Ensure no division by zero
            loss *= non_masked_tokens_current_node / non_masked_tokens_total
            
            if (
                parallel_dims.dp_replicate_enabled
                or parallel_dims.dp_shard_enabled
                or parallel_dims.cp_enabled
                or self.ft_manager.enabled
            ):
                # loss = loss.detach()
                # Skip ft manager communication when using semi sync training
                use_ft_pg = (
                    self.ft_manager.enabled
                    and self.job_config.fault_tolerance.semi_sync_method is None
                )
                ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None
                global_avg_loss, global_max_loss = (
                    dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg),
                    dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),
                )
            else:
                global_avg_loss = global_max_loss = loss.item()
            return global_avg_loss

    def validate(self):
        val_loss = 0.0
        val_steps = 0
        for inputs, labels in self.batch_generator(self.valid_dataloader):
            if val_steps >= self.job_config.evaluation.num_eval_samples:
                break
            loss = self.validate_step(inputs, labels)
            val_loss += loss
            val_steps += 1
        self.model_parts[0].train()
        return val_loss / val_steps

    @record
    def train(self):
        job_config = self.job_config

        self.checkpointer.load(step=job_config.checkpoint.load_step)
        logger.info(f"Training starts at step {self.step + 1}.")

        with (
            maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler,
            maybe_enable_memory_snapshot(
                job_config, global_step=self.step
            ) as memory_profiler,
            ft.maybe_semi_sync_training(
                job_config,
                ft_manager=self.ft_manager,
                model=self.model_parts[0],
                optimizer=self.optimizers,
                sync_every=job_config.fault_tolerance.sync_steps,
            ),
        ):
            for inputs, labels in self.batch_generator(self.dataloader):
                if self.step >= job_config.training.steps:
                    break
                self.step += 1
                self.gc_handler.run(self.step)
                self.train_step(inputs, labels)
                
                # Run evaluation if enabled and at the right frequency
                # if self.job_config.evaluation.enable_eval and self.step % self.job_config.evaluation.eval_freq == 0:
                #     torch.distributed.barrier()
                #     self._evaluate_model()
                #     torch.distributed.barrier()

                self.checkpointer.save(
                    self.step, force=(self.step == job_config.training.steps)
                )

                # signal the profiler that the next profiling step has started
                if torch_profiler:
                    torch_profiler.step()
                if memory_profiler:
                    memory_profiler.step()

                # reduce timeout after first train step for faster signal
                # (assuming lazy init and compilation are finished)
                if self.step == 1 and job_config.training.backend == "torch":
                    dist_utils.set_pg_timeouts(
                        timeout=timedelta(
                            seconds=job_config.comm.train_timeout_seconds
                        ),
                        world_mesh=self.world_mesh,
                    )

        if torch.distributed.get_rank() == 0:
            logger.info("Sleeping 2 seconds for other ranks to complete")
            time.sleep(2)

        self.metrics_processor.close()
        logger.info("Training completed")
        # if torch.distributed.get_rank() == 0:
        #     logger.info(torch_profiler.key_averages(group_by_stack_n=10).table(sort_by='self_cuda_time_total', row_limit=50))

    def state_dict(self) -> dict[str, Any]:
        return {"step": self.step}

    def load_state_dict(self, state_dict: dict[str, Any]):
        self.step = state_dict["step"]

    def close(self) -> None:
        if self.checkpointer:
            self.checkpointer.close()

def load_pretrained_checkpoint(model_parts: list[torch.nn.Module], checkpoint_path: str, device: torch.device, model_args=None) -> None:
    """Load pretrained weights from a HuggingFace checkpoint into TorchTitan model.
    
    Args:
        model_parts: List of model parts to load weights into
        checkpoint_path: Path to the pretrained checkpoint file or HF model name
        device: Device to load the checkpoint to
        model_args: Model arguments needed for state dict conversion
    """
    if not checkpoint_path:
        logger.warning("No pretrained checkpoint path provided. Skipping pretrained loading.")
        return
        
    logger.info(f"Loading pretrained weights from {checkpoint_path}")
    
    try:
        # Import HuggingFace components
        try:
            from transformers import AutoModelForCausalLM, AutoConfig
        except ImportError:
            logger.error("transformers library not installed. Please install it to load HuggingFace checkpoints")
            return
            
        # Import the state dict adapter
        import sys
        from torchtitan.models.llama3.state_dict_adapter import Llama3StateDictAdapter
        
        # Load HuggingFace model to get the state dict
        logger.info("Loading HuggingFace model...")
        
        # Load the config first to get model architecture info
        hf_config = AutoConfig.from_pretrained(checkpoint_path, trust_remote_code=True)
        
        # Load the model with torch_dtype=torch.float32 to avoid any dtype issues
        # Use device_map="cpu" to load on CPU first, then we'll move weights as needed
        hf_model = AutoModelForCausalLM.from_pretrained(
            checkpoint_path, 
            torch_dtype=torch.float32,
            device_map="cpu",
            trust_remote_code=True
        )
        
        logger.info("HuggingFace model loaded successfully")
        
        # Get the state dict from HF model
        hf_state_dict = hf_model.state_dict()
        
        logger.info(f"HF model has {len(hf_state_dict)} parameters")
        logger.info("Sample HF keys:")
        for i, key in enumerate(list(hf_state_dict.keys())[:10]):
            logger.info(f"  {key}: {hf_state_dict[key].shape}")
        
        # If model_args is not provided, try to infer from HF config
        if model_args is None:
            logger.warning("model_args not provided, attempting to infer from HF config")
            # You might need to adjust this based on your model args structure
            class DummyModelArgs:
                def __init__(self, config):
                    self.n_heads = config.num_attention_heads
                    self.n_kv_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
                    self.dim = config.hidden_size
            
            model_args = DummyModelArgs(hf_config)
        
        # Convert HF state dict to TorchTitan format using the adapter
        logger.info("Converting HF state dict to TorchTitan format...")
        converted_state_dict = Llama3StateDictAdapter.from_hf(hf_state_dict, model_args)
        
        logger.info(f"Converted state dict has {len(converted_state_dict)} parameters")
        logger.info("Sample converted keys:")
        for i, key in enumerate(list(converted_state_dict.keys())[:10]):
            logger.info(f"  {key}: {converted_state_dict[key].shape}")
        
        # Clear the HF model from memory
        del hf_model
        del hf_state_dict
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Load weights into model parts
        if len(model_parts) == 1:
            # Single model part (no pipeline parallelism)
            logger.info("Loading weights into single model part...")
            
            # Get the model's expected state dict keys for comparison
            model_state_dict = model_parts[0].state_dict()
            logger.info(f"Model expects {len(model_state_dict)} parameters")
            logger.info("Sample model keys:")
            for i, key in enumerate(list(model_state_dict.keys())[:10]):
                logger.info(f"  {key}: {model_state_dict[key].shape}")
            
            # Load with strict=False to handle any missing or unexpected keys
            missing_keys, unexpected_keys = model_parts[0].load_state_dict(converted_state_dict, strict=False)
            
            if missing_keys:
                logger.warning(f"Missing keys when loading pretrained weights ({len(missing_keys)} total):")
                for key in missing_keys[:10]:  # Show first 10
                    logger.warning(f"  Missing: {key}")
                if len(missing_keys) > 10:
                    logger.warning(f"  ... and {len(missing_keys) - 10} more")
                    
            if unexpected_keys:
                logger.warning(f"Unexpected keys when loading pretrained weights ({len(unexpected_keys)} total):")
                for key in unexpected_keys[:10]:  # Show first 10
                    logger.warning(f"  Unexpected: {key}")
                if len(unexpected_keys) > 10:
                    logger.warning(f"  ... and {len(unexpected_keys) - 10} more")
                    
            logger.info(f"Successfully loaded {len(converted_state_dict) - len(missing_keys)} parameters")
            
        else:
            # Multiple model parts (pipeline parallelism)
            logger.info(f"Loading weights into {len(model_parts)} model parts...")
            for i, model_part in enumerate(model_parts):
                missing_keys, unexpected_keys = model_part.load_state_dict(converted_state_dict, strict=False)
                logger.info(f"Loaded pretrained weights for model part {i} "
                          f"({len(converted_state_dict) - len(missing_keys)} parameters loaded)")
                
        logger.info("Successfully loaded pretrained weights from HuggingFace checkpoint")
        
    except Exception as e:
        logger.error(f"Failed to load pretrained checkpoint: {e}")
        logger.error(f"Exception details: {traceback.format_exc()}")
        logger.warning("Continuing with random initialization")


def debug_on_rank(rank: int = 0, label: str = "") -> None:
    """Helper function to break only on specified rank.
    
    Args:
        rank: The rank to break on. Defaults to 0.
        label: Optional label for the debug point.
    """
    pass
    # if torch.distributed.get_rank() == rank:
    #     print(f"Breaking on rank {rank}")
    #     breakpoint()

if __name__ == "__main__":

    if DEBUG_MODE:
        # if torch.distributed.get_rank() == 0:
        breakpoint()
        # torch.distributed.barrier()
    
    # os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["LOCAL_RANK"]
    # torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    # optionally:
    # assert torch.cuda.current_device() == int(os.environ["LOCAL_RANK"])

    config_manager = ConfigManager()
    config = config_manager.parse_args()
    
    # Initialize logger with profiling folder info so logs get saved there too
    init_logger(
        save_traces_folder=config.profiling.save_traces_folder if config.profiling.enable_profiling else None,
        dump_folder=config.job.dump_folder
    )
    
    trainer: Optional[Trainer] = None

    try:
        trainer = Trainer(config)

        if config.checkpoint.create_seed_checkpoint:
            assert (
                int(os.environ["WORLD_SIZE"]) == 1
            ), "Must create seed checkpoint using a single device, to disable sharding."
            assert (
                config.checkpoint.enable_checkpoint
            ), "Must enable checkpointing when creating a seed checkpoint."
            trainer.checkpointer.save(curr_step=0, force=True)
            logger.info("Created seed checkpoint")
        else:
            trainer.train()
    finally:
        if trainer:
            trainer.close()

        if torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()
            logger.info("Process group destroyed.")
