import os
from pickle import NONE
import time
import copy
import logging
import warnings
from abc import abstractmethod
from pathlib import Path
from functools import partial
from contextlib import nullcontext
from typing import Callable, Optional, Union, List, Dict, Tuple, Any
from unittest.mock import MagicMock
import shutil
import torch
from torch import nn
from torch import Tensor
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    LinearLR,
    PolynomialLR,
    StepLR,
    LRScheduler,
)
from torch.utils.data import DataLoader
from pydantic import BaseModel
import rich
from rich.table import Table
from rich.live import Live
from rich.panel import Panel
from rich.progress import (
    Progress,
    TextColumn,
    BarColumn,
    MofNCompleteColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
    Group,
)

import deepspeed
from accelerate import (
    Accelerator,
    FullyShardedDataParallelPlugin,
    DeepSpeedPlugin,
    DistributedType,
)
from accelerate.utils import (
    set_seed,
    TorchTensorParallelPlugin,
    MegatronLMPlugin,
)

from safetensors.torch import save_file

from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper

from transformers import PreTrainedTokenizer
from mmengine import (
    print_log,
    DATA_SAMPLERS,
    Config,
    ConfigDict,
)
from mmengine.dataset import worker_init_fn as default_worker_init_fn
from mmengine.dist import get_rank

from mmhug.registry import HF_MODELS, DATASETS, FUNCTIONS
from mmhug.utils.hf_hub_utils import push_to_hub
from mmhug.utils.memory_utils import get_gpu_memory_gb
from mmhug.utils import dtype_from_str

# Disable irrelevant warnings from transformers
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# Silence bitsandbytes warnings about casting
warnings.filterwarnings(
    "ignore",
    message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
)

# Disable progress bars if not main process
IS_MAIN_PROCESS = os.environ.get("LOCAL_RANK", "0") == "0"
if not IS_MAIN_PROCESS:
    from transformers.utils.logging import disable_progress_bar

    disable_progress_bar()

StepCallback = Callable[
    [int, int, list[Path]], None
]  # (step, total, list[sampled_video_path]) -> None

MEMORY_CHECK_INTERVAL = 200


class TrainingStats(BaseModel):
    """
    Encapsulates aggregate statistics and resource usage collected during a training run.

    Attributes:
        total_time_seconds (float):
            The wall‑clock time elapsed from the very start to the very end of the training loop,
            including compilation/warmup, forward/backward passes, validation, and checkpointing.
        training_time (float):
            The net time spent on actual training steps.
        steps_per_second (float):
            The average number of optimizer steps completed per second during the training phase.
        samples_per_second (float):
            The effective throughput in terms of training samples processed per second, computed as:
                steps_per_second × num_processes × per‑process batch_size
        peak_gpu_memory_gb (float):
            The maximum GPU memory (in gigabytes) observed at any point during training.
        global_batch_size (int):
            The total batch size aggregated across all processes (num_processes × per‑process batch_size).
        num_processes (int):
            The number of distributed processes (or GPUs) participating in the training run.
    """

    total_time_seconds: float
    training_time: float
    steps_per_second: float
    samples_per_second: float
    peak_gpu_memory_gb: float
    global_batch_size: int
    num_processes: int


class BaseTrainer:
    """
    Base class for training diffusion/transformer models, supporting:
      - Config‐driven optimizer & scheduler
      - Dataset and DataLoader initialized in an mmengine-style
      - LoRA adapters and full‐fine tuning modes
      - Checkpoint saving/loading (with safetensors)
      - Accelerator wrapper to support multi-GPU, multi-node, etc.
    Subclasses *must* implement:
      - `_prepare_models_for_training`: wrap models/datasets for Accelerate
      - `_training_step`: compute a scalar loss from a training batch
      - `_save_lora_weights`: how to persist only LoRA adapters
      - `_sample_videos`: (optional) sample validation outputs
    """

    def __init__(self, cfg: Union[Config, ConfigDict]) -> None:
        """
        Args:
            cfg: a pydantic TrainerConfig containing at least the following sections:
              - model: dict with HF_MODELS spec
              - data: dict for dataset, sampler, DataLoader args
              - optimization: learning rate, steps, scheduler_type, etc.
              - acceleration: mixed_precision_mode, fsdp_plugin, etc.
              - checkpoints: interval, keep_last_n, output_dir
              - validation: interval, output settings
        """

        self.cfg = cfg
        self._check_config()
        self._print_config()
        self._setup_accelerator()
        self._load_models()
        self._load_checkpoint()
        self._prepare_models_for_training()
        self._collect_trainable_params()
        # Build optimizer, dataloader, and sampler
        self._init_optimizer()
        self._init_dataloader()
        self._prepare_accelerator()
        self._global_step = -1
        self._checkpoint_paths = []

    def _check_config(self):
        if not hasattr(self.cfg, "model") or not isinstance(self.cfg.model, dict):
            raise ValueError(
                "`cfg.model` must be a dict describing each module to load."
            )
        if not hasattr(self.cfg, "optimization"):
            raise ValueError("Missing `cfg.optimization` section.")
        if not hasattr(self.cfg, "data"):
            raise ValueError("Missing `cfg.data` section.")
        if not hasattr(self.cfg, "acceleration"):
            raise ValueError("Missing `cfg.acceleration` section.")
        if not hasattr(self.cfg, "checkpoints"):
            raise ValueError("Missing `cfg.checkpoints` section.")

        if self.cfg.get("validation", None) is None:
            print_log(
                "⚠️ `cfg.validation` not specified. Validation will not run.",
                level=logging.WARNING,
            )
        elif (
            self.cfg["validation"].get("interval", None) is None
            or self.cfg["validation"]["interval"] <= 0
        ):
            print_log(
                "⚠️ `cfg.validation.interval` not specified or <= 0. Validation will not run.",
                level=logging.WARNING,
            )

    def train(  # noqa: PLR0912, PLR0915
        self,
        disable_progress_bars: bool = False,
        step_callback: StepCallback = None,
    ) -> tuple[Path, TrainingStats]:
        """
        Run the full training loop, with optional validation, checkpointing, and progress reporting.

        Args:
            disable_progress_bars (bool): If True, suppress rich progress bars and emit
                periodic log messages instead.
            step_callback (StepCallback | None): Optional function called after each optimizer
                step with signature (global_step, total_steps, sampled_videos_paths).

        Returns:
            Tuple[Path, TrainingStats]:
                - saved_checkpoint_path: Path to the final model checkpoint (.safetensors).
                - stats: TrainingStats object summarizing runtime, throughput, and memory usage.
        """
        # ─── 1. Preparation ────────────────────────────────────────────────────────────
        device = self.accelerator.device
        cfg = self.cfg

        # Record GPU memory baseline
        start_mem = get_gpu_memory_gb(device)

        # Start timing
        train_start_time = time.time()

        # Ensure reproducibility
        set_seed(cfg.seed)
        print_log(
            f"Process {self.accelerator.process_index} using seed: {cfg.seed}",
            level=logging.DEBUG,
        )

        data_iter = iter(self.dataloader)

        # Barrier so all processes start training together
        self.accelerator.wait_for_everyone()

        # Prepare output directory
        Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
        print_log("🚀 Starting training...", level=logging.INFO)

        # ─── 2. Progress Bars Setup ───────────────────────────────────────────────────
        if disable_progress_bars or not IS_MAIN_PROCESS:
            train_progress = MagicMock()
            sample_progress = MagicMock()
            live = nullcontext()
            if IS_MAIN_PROCESS:
                print_log(
                    "Progress bars disabled. Status messages will be printed occasionally.",
                    level=logging.WARNING,
                )
        else:
            train_progress = Progress(
                TextColumn("Step"),
                MofNCompleteColumn(),
                BarColumn(bar_width=40),
                TextColumn("Loss: {task.fields[loss]:.4f}"),
                TextColumn("LR: {task.fields[lr]:.2e}"),
                TextColumn("Time/Step: {task.fields[step_time]:.2f}s"),
                TimeElapsedColumn(),
                TimeRemainingColumn(compact=True),
            )
            sample_progress = Progress(
                TextColumn("Sampling"),
                MofNCompleteColumn(),
                BarColumn(bar_width=40),
                TimeElapsedColumn(),
                TimeRemainingColumn(compact=True),
            )
            live = Live(
                Panel(Group(train_progress, sample_progress)), refresh_per_second=2
            )

        self._global_step = 0
        compilation_time = None
        peak_mem = start_mem
        actual_training_start = None
        sampled_videos_paths = None

        # ─── 3. Main Training Loop ────────────────────────────────────────────────────
        with live:
            # Create a task to track training steps
            task = train_progress.add_task(
                "Training",
                total=cfg.optimization.steps,
                loss=0.0,
                lr=cfg.optimization.learning_rate,
                step_time=0.0,
            )

            # Optionally do an initial validation sample
            do_val = cfg.get("validation", None)
            if do_val and cfg.validation.interval > 0 and IS_MAIN_PROCESS:
                self._validation(sample_progress)

            self.accelerator.wait_for_everyone()
            with torch.cuda.amp.autocast(
                dtype=dtype_from_str(self.accelerator.mixed_precision)
            ):
                for raw_step in range(cfg.optimization.steps):
                    # ── Fetch next batch ─────────────────────────
                    try:
                        batch = next(data_iter)
                    except StopIteration:
                        data_iter = iter(self.dataloader)
                        batch = next(data_iter)

                    step_start = time.time()
                    # Gradient accumulation. Only perform optimization steps when gradient_accumulation_steps is reached.
                    is_opt_step = (
                        raw_step + 1
                    ) % cfg.optimization.gradient_accumulation_steps == 0
                    if is_opt_step:
                        self._global_step += 1

                    # ── Forward / Backward / Optimization ────────
                    loss = self._training_step(batch)

                    self.accelerator.backward(loss)

                    if (
                        self.accelerator.sync_gradients
                        and cfg.optimization.max_grad_norm > 0
                    ):
                        self.accelerator.clip_grad_norm_(
                            self._trainable_params, cfg.optimization.max_grad_norm
                        )
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    if self.lr_scheduler:
                        self.lr_scheduler.step()

                    self.accelerator.wait_for_everyone()
                    # ── Validation Sampling ───────────────────────
                    if (
                        do_val
                        and cfg.validation.interval > 0
                        and is_opt_step
                        and self._global_step % cfg.validation.interval == 0
                        and IS_MAIN_PROCESS
                    ):
                        self._validation(sample_progress)

                    self.accelerator.wait_for_everyone()
                    # ── Checkpointing ─────────────────────────────
                    if (
                        cfg.checkpoints.interval > 0
                        and is_opt_step
                        and self._global_step % cfg.checkpoints.interval == 0
                        and IS_MAIN_PROCESS
                    ):
                        self._save_checkpoint()

                    self.accelerator.wait_for_everyone()
                    # ── Progress Update ──────────────────────────
                    if IS_MAIN_PROCESS:
                        current_lr = self.optimizer.param_groups[0]["lr"]
                        elapsed = time.time() - train_start_time
                        pct = self._global_step / cfg.optimization.steps
                        total_est = elapsed / pct if pct > 0 else 0
                        eta = total_est - elapsed if pct > 0 else 0
                        step_time = time.time() - step_start

                        train_progress.update(
                            task,
                            advance=1,
                            loss=loss.item(),
                            lr=current_lr,
                            step_time=step_time,
                            total_time=f"{int(total_est // 3600)}h {int((total_est % 3600) // 60)}m",
                        )

                        if disable_progress_bars and self._global_step % 20 == 0:
                            print_log(
                                f"Step {self._global_step}/{cfg.optimization.steps} "
                                f"Loss={loss.item():.4f} LR={current_lr:.2e} "
                                f"Time/Step={step_time:.2f}s ETA={int(eta)}s",
                                level=logging.INFO,
                            )

                        # Track peak memory
                        if raw_step % MEMORY_CHECK_INTERVAL == 0:
                            current_mem = get_gpu_memory_gb(device)
                            peak_mem = max(peak_mem, current_mem)

        # ─── 4. Finalization ──────────────────────────────────────────────────────────
        train_end = time.time()
        end_mem = get_gpu_memory_gb(device)
        peak_mem = max(peak_mem, end_mem)

        total_train = train_end - train_start_time
        steps_per_sec = cfg.optimization.steps / total_train

        samples_per_sec = (
            steps_per_sec * self.accelerator.num_processes * cfg.data.batch_size
        )

        stats = TrainingStats(
            total_time_seconds=train_end - train_start_time,
            training_time=total_train,
            compilation_time_seconds=compilation_time,
            steps_per_second=steps_per_sec,
            samples_per_second=samples_per_sec,
            peak_gpu_memory_gb=peak_mem,
            num_processes=self.accelerator.num_processes,
            global_batch_size=cfg.data.batch_size
            * self.accelerator.num_processes,
        )

        # Shut down Accelerator
        self.accelerator.end_training()

        # Save final checkpoint, push to hub, and log stats (main process only)
        saved_path = None
        if IS_MAIN_PROCESS:
            saved_path = self._save_checkpoint()
            if cfg.hub.push_to_hub:
                push_to_hub(saved_path, sampled_videos_paths, cfg)
            self._log_training_stats(stats)

        return saved_path, stats

    @abstractmethod
    def _training_step(self, batch: dict[str, dict[str, Tensor]]) -> Tensor:
        raise NotImplementedError("Subclasses must implement _training_step")

    def _print_config(self) -> None:
        """
        Render the training configuration (self.cfg) as a formatted table using Rich.

        - Only the main process prints to avoid duplicated output in distributed runs.
        - Nested BaseModel fields are flattened with dot-separated keys.
        - Lists, tuples, and sets are rendered as comma-separated strings.
        """
        if not IS_MAIN_PROCESS:
            return

        # Build the Rich table
        table = Table(
            title="⚙️ Training Configuration",
            show_header=True,
            header_style="bold green",
        )
        table.add_column("Parameter", style="bold white", no_wrap=True)
        table.add_column("Value", style="bold cyan")

        def flatten_config(
            cfg_obj: Union[Config, ConfigDict], prefix: str = ""
        ) -> List[Tuple[str, str]]:
            """
            Recursively flatten a BaseModel or dict into (key, value) pairs.

            Args:
                cfg_obj (Config | ConfigDict): The configuration object or dict to flatten.
                prefix (str): Dot-separated prefix for nested keys.

            Returns:
                List of tuples: (full_parameter_name, stringified_value)
            """
            items: List[Tuple[str, str]] = []
            # Convert BaseModel to dict for uniform processing
            raw_dict = cfg_obj.to_dict() if isinstance(cfg_obj, Config) else cfg_obj
            for key, val in raw_dict.items():
                full_key = f"{prefix}.{key}" if prefix else key

                # Nested model/dict: recurse
                if isinstance(val, BaseModel) or isinstance(val, dict):
                    items.extend(flatten_config(val, full_key))

                # Sequence types: join into comma-separated string
                elif isinstance(val, (list, tuple, set)):
                    joined = ", ".join(str(elem) for elem in val)
                    items.append((full_key, joined))

                # Primitive/other types: stringify directly
                else:
                    items.append((full_key, str(val)))

            return items

        # Gather and sort rows for consistent ordering
        rows = flatten_config(self.cfg)
        rows.sort(key=lambda row: row[0])

        # Add each parameter/value pair to the table
        for param_name, param_value in rows:
            table.add_row(param_name, param_value)

        # Render to stdout
        rich.print(table)

    def _load_models(self) -> None:
        """Load model components according to config file.
        Model config example:
        model=dict(
            vae=dict(
                type="xxxxxxxs",
                from_pretrained=dict(
                    pretrained_model_name_or_path='xxxx',
                    subfolder='xxxx'
                )
            ),
            transformer=dict(
                type="xxxxxxx",
                from_pretrained=dict(
                    pretrained_model_name_or_path='xxxx',
                    subfolder='xxxx'
                )
            )
        )
        Then the trainer will automaticlly initialize and set self.vae and self.transformer to the initialized modules.
        """
        if not hasattr(self.cfg, "model") or not isinstance(self.cfg.model, dict):
            raise AttributeError(
                "`cfg` must have a `model` dict for _load_models() to work."
            )
        self.module_names = []
        model_cfg = self.cfg.model
        for name, module_config in model_cfg.items():
            self.module_names.append(name)
            module = HF_MODELS.build(module_config)
            # Attach it to self under the given name
            setattr(self, name, module)

            # Optional: log what you’ve loaded
            print_log(f"Loaded {name!r} -> {module.__class__.__name__}")
        # setup lora
        self._setup_lora()

    @abstractmethod
    def _prepare_models_for_training(self) -> None:
        """Prepare models for training with Accelerate."""
        raise NotImplementedError(
            "Subclasses must implement _prepare_models_for_training"
        )

    def _collect_trainable_params(self) -> None:
        """
        Collects all trainable parameters from the listed sub-modules.

        This method will:
          1. Iterate over each attribute name in self.module_names.
          2. If the attribute exists and is an nn.Module, collect its parameters
             for which requires_grad=True.
          3. Store the union of these parameters in self._trainable_params.
          4. Log the total number of trainable parameters.
          5. Additionally, log each module’s own trainable-parameter count and
             percentage of the total.

        After this runs, self._trainable_params holds a flat list of all
        parameters to be optimized.
        """
        all_trainable = []
        module_param_counts = {}

        # 1. Gather from each named module
        for name in self.module_names:
            module = getattr(self, name, None)
            if isinstance(module, nn.Module):
                # collect only parameters that require grad
                params = [p for p in module.parameters() if p.requires_grad]
                count = sum(p.numel() for p in params)
                module_param_counts[name] = count
                all_trainable.extend(params)

        # 2. Assign to self
        self._trainable_params = all_trainable

        # 3. Compute totals
        total_params = sum(module_param_counts.values())
        if total_params == 0:
            print_log(
                "Warning: No trainable parameters found in any listed module.",
                level=logging.WARNING,
            )
            return

        # 4. Log per-module breakdown
        for name, count in module_param_counts.items():
            pct = 100.0 * count / total_params
            print_log(
                f"Module '{name}': {count:,} trainable params "
                f"({pct:.2f}% of total {total_params:,})",
                level=logging.DEBUG,
            )

        # 5. Log total
        print_log(
            f"Total trainable params count: {total_params:,}", level=logging.DEBUG
        )

    def _setup_lora(self) -> None:
        """
        Configure and attach LoRA adapters to the specified submodules of the transformer.

        This method is only invoked when `cfg.model.training_mode == "lora"`.
        It reads the `lora` section of the configuration to determine:
        - Which modules to target (e.g., `"transformer.layers.5.self_attn"`).
        - The LoRA hyperparameters (rank, alpha, dropout, etc.).

        Steps:
        1. Retrieve and validate the `lora` config from `self.cfg`.
        2. Normalize `target_models` into a list of attribute names.
        3. Instantiate a `LoraConfig` with the given parameters.
        4. For each target module:
            a. Fetch the module attribute via `getattr`.
            b. Ensure it supports `add_adapter` (or similar) and raise otherwise.
            c. Attach the LoRA adapter.
        5. Log a summary of the adapters added.

        Raises:
            ValueError: If `target_models` is missing or not a list/str.
            AttributeError: If a specified module does not exist on `self`.
            NotImplementedError: If a target module does not expose an adapter interface.
        """
        # 1. Fetch the LoRA config dictionary (may be None if full fine-tuning)
        lora_cfg_dict = self.cfg.get("lora", None)
        if not lora_cfg_dict:
            # Nothing to do if LoRA is not configured
            return

        # 2. Extract and normalize target module names
        target_models = lora_cfg_dict.pop("target_modules", None)
        if target_models is None:
            raise ValueError(
                "`lora.target_modules` must be specified when using LoRA mode"
            )
        if isinstance(target_models, str):
            target_models = [target_models]
        if not isinstance(target_models, (list, tuple)):
            raise ValueError(
                "`lora.target_modules` must be a string or list of strings"
            )

        # 3. Instantiate the Peft LoraConfig from the dict
        #    This will validate rank, alpha, dropout, etc.
        lora_config = LoraConfig(**lora_cfg_dict)

        # 4. Attach adapters to each named submodule
        for module_name in target_models:
            # a. Retrieve submodule, e.g. self.transformer.encoder.layers[5].self_attn
            try:
                module = getattr(self, module_name)
            except AttributeError as e:
                raise AttributeError(
                    f"Module `{module_name}` not found on trainer"
                ) from e

            # b. Ensure the module supports adding a LoRA adapter
            if not hasattr(module, "add_adapter"):
                raise NotImplementedError(
                    f"Module `{module_name}` does not implement `add_adapter` for LoRA"
                )

            # c. Add the LoRA adapter
            module.add_adapter(lora_config)
            print_log(
                f"🔌 LoRA adapter added to `{module_name}` "
                f"(rank={lora_config.rank}, alpha={lora_config.alpha}, dropout={lora_config.dropout})",
                level=logging.DEBUG,
            )

    def _load_checkpoint(self) -> None:
        """
        Load pretrained model modules and optimizer states per `self.cfg.resume.load_from`.

        Resume modes:

        **A. Per-module override (dict)**
        ```python
        cfg.resume.load_from = {
            "transformer": "/path/to/transformer",
            "vae": "/path/to/vae",
            "optimizer": "/my/opt_chkpt.safetensors"
        }
        ```
        -> Uses `.from_pretrained(...)` for model parts and `torch.load(...) + load_state_dict()` for optimizer.

        **B. Auto/latest (string "auto" or "latest")**
        ```python
        cfg.resume.load_from = "auto"
        ```
        -> Reads `<output_dir>/checkpoints/latest_checkpoint`, then loads all modules and optimizer from that subdir.

        **C. Explicit directory path (string or Path)**
        ```python
        cfg.resume.load_from = "/custom/checkpoint_dir"
        ```
        -> Same as B but uses provided directory.

        Supported optimizer checkpoint extensions: `.safetensors`, `.pth`, `.pt`, `.bin`.

        Raises:
            ValueError: Unsupported `load_from` type.
            FileNotFoundError: Missing checkpoint folders or files.
            AttributeError: Referencing missing modules.
        """
        resume_cfg: Dict[str, Any] = self.cfg.get("resume", {}) or {}
        load_from = resume_cfg.get("load_from")
        if not load_from:
            return  # nothing to resume

        base_ckpts = Path(self.cfg.output_dir) / "checkpoints"

        def _load_module(name: str, ckpt_path: Path):
            module = getattr(self, name, None)
            if module is None:
                raise AttributeError(f"No module `{name}` found on trainer.")
            if hasattr(module, "from_pretrained"):
                module.from_pretrained(str(ckpt_path))
            else:
                print_log(
                    f"Skipping `{name}`: no from_pretrained()", level=logging.DEBUG
                )

        # --- Mode A: Per-module overrides ---
        if isinstance(load_from, dict):
            for name, path_str in load_from.items():
                ckpt_path = Path(path_str)
                if name == "optimizer":
                    state = torch.load(str(ckpt_path), map_location="cpu")
                    self.optimizer.load_state_dict(state)
                    print_log(f"Optimizer loaded from {ckpt_path}", level=logging.INFO)
                else:
                    _load_module(name, ckpt_path)
                    print_log(f"Loaded `{name}` from {ckpt_path}", level=logging.INFO)
            return

        # --- Determine checkpoint directory ---
        if isinstance(load_from, str) and load_from in ("auto", "latest"):
            record = base_ckpts / "latest_checkpoint"
            if record.is_file():
                latest_dir = Path(base_ckpts / record.read_text().strip())
            else:
                print_log(
                    f"Latest checkpoint record missing: {record}", level=logging.WARNING
                )
                return
        else:
            latest_dir = Path(load_from)

        print_log(f"Loading checkpoint from {latest_dir}", level=logging.INFO)

        # --- Load all modules from checkpoint directory ---
        for name in self.module_names:
            if name == "optimizer":
                opt_dir = latest_dir / "optimizer"
                if not opt_dir.is_dir():
                    raise FileNotFoundError(f"Missing optimizer folder: {opt_dir}")

                # Find first valid file among supported extensions
                opt_file = None
                for ext in ("*.safetensors", "*.pth", "*.pt", "*.bin"):
                    opt_file = next(opt_dir.glob(ext), None)
                    if opt_file:
                        break

                if opt_file is None:
                    print_log(
                        f"No optimizer checkpoint found in {opt_dir}, skip loading optimizer checkpoint",
                        level=logging.WARNING,
                    )

                state = torch.load(str(opt_file), map_location="cpu")
                self.optimizer.load_state_dict(state)
                print_log(f"Optimizer loaded from {opt_file}", level=logging.INFO)

            else:
                ckpt_path = latest_dir / name
                if not ckpt_path.exists():
                    print_log(
                        f"No checkpoint for `{name}` at {ckpt_path}`, skipping.",
                        level=logging.WARNING,
                    )
                    continue
                _load_module(name, ckpt_path)
                print_log(f"Loaded `{name}` from {ckpt_path}", level=logging.INFO)

    @abstractmethod
    def _prepare_models_for_training(self) -> None:
        """Set the training mode for each module and do necessary device placement."""
        raise NotImplementedError("_prepare_models_for_training must be implemented")

    @staticmethod
    def _find_checkpoint(checkpoint_path: str | Path) -> Path | None:
        """Find the checkpoint file to load, handling both file and directory paths."""
        checkpoint_path = Path(checkpoint_path)

        if checkpoint_path.is_file():
            if not checkpoint_path.suffix == ".safetensors":
                raise ValueError(
                    f"Checkpoint file must have a .safetensors extension: {checkpoint_path}"
                )
            return checkpoint_path

        if checkpoint_path.is_dir():
            # Look for checkpoint files in the directory
            checkpoints = list(checkpoint_path.rglob("*step_*.safetensors"))

            if not checkpoints:
                return None

            # Sort by step number and return the latest
            def _get_step_num(p: Path) -> int:
                try:
                    return int(p.stem.split("step_")[1])
                except (IndexError, ValueError):
                    return -1

            latest = max(checkpoints, key=_get_step_num)
            return latest

        else:
            raise ValueError(
                f"Invalid checkpoint path: {checkpoint_path}. Must be a file or directory."
            )

    def _init_dataloader(
        self,
    ) -> DataLoader:
        """Copied from mmengine.FlexibleRunner. Use the params given in config file to set dataset, sampler and dataloader.

        The method builds three components:

        - Dataset
        - Sampler
        - Dataloader

        An example of ``dataloader``::

            dataloader = dict(
                dataset=dict(type='ToyDataset'),
                sampler=dict(type='DefaultSampler', shuffle=True),
                batch_size=1,
                num_workers=9
            )

        Returns:
            Dataloader: DataLoader build from ``self.cfg.data``.
        """
        seed = self.cfg.seed
        dataloader_cfg = copy.deepcopy(self.cfg.data)

        # build dataset
        dataset_cfg = dataloader_cfg.pop("dataset")
        dataset = DATASETS.build(dataset_cfg)

        # build sampler
        sampler_cfg = dataloader_cfg.pop("sampler", None)
        if isinstance(sampler_cfg, dict):
            sampler = DATA_SAMPLERS.build(
                sampler_cfg, default_args=dict(dataset=dataset, seed=seed)
            )
        else:
            sampler = sampler_cfg

        # build batch sampler
        batch_sampler_cfg = dataloader_cfg.pop("batch_sampler", None)
        if batch_sampler_cfg is None:
            batch_sampler = None
        else:
            assert isinstance(batch_sampler_cfg, dict)
            batch_sampler = DATA_SAMPLERS.build(
                batch_sampler_cfg,
                default_args=dict(
                    sampler=sampler, batch_size=dataloader_cfg.pop("batch_size")
                ),
            )

        # build dataloader
        init_fn: Optional[partial]
        if "worker_init_fn" in dataloader_cfg:
            worker_init_fn_cfg = dataloader_cfg.pop("worker_init_fn")
            worker_init_fn_type = worker_init_fn_cfg.pop("type")
            worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
            assert callable(worker_init_fn)
            init_fn = partial(worker_init_fn, **worker_init_fn_cfg)  # type: ignore
        else:
            if seed is not None:
                disable_subprocess_warning = dataloader_cfg.pop(
                    "disable_subprocess_warning", False
                )
                assert isinstance(disable_subprocess_warning, bool), (
                    "disable_subprocess_warning should be a bool, but got "
                    f"{type(disable_subprocess_warning)}"
                )
                init_fn = partial(
                    default_worker_init_fn,
                    num_workers=dataloader_cfg.get("num_workers"),
                    rank=get_rank(),
                    seed=seed,
                    disable_subprocess_warning=disable_subprocess_warning,
                )
            else:
                init_fn = None

        # The default behavior of `collat_fn` in dataloader is to
        # merge a list of samples to form a mini-batch of Tensor(s).
        # However, in mmengine, if `collate_fn` is not defined in
        # dataloader_cfg, `pseudo_collate` will only convert the list of
        # samples into a dict without stacking the batch tensor.
        collate_fn_cfg = dataloader_cfg.pop("collate_fn", dict(type="pseudo_collate"))
        if isinstance(collate_fn_cfg, dict):
            collate_fn_type = collate_fn_cfg.pop("type")
            if isinstance(collate_fn_type, str):
                collate_fn = FUNCTIONS.get(collate_fn_type)
            else:
                collate_fn = collate_fn_type
            collate_fn = partial(collate_fn, **collate_fn_cfg)  # type: ignore
        elif callable(collate_fn_cfg):
            collate_fn = collate_fn_cfg
        else:
            raise TypeError(
                "collate_fn should be a dict or callable object, but got "
                f"{collate_fn_cfg}"
            )
        self.data_loader = DataLoader(
            dataset=dataset,
            sampler=sampler if batch_sampler is None else None,
            batch_sampler=batch_sampler,
            collate_fn=collate_fn,
            worker_init_fn=init_fn,
            **dataloader_cfg,
        )

    def _init_optimizer(self) -> None:
        """
        Initialize the optimizer and its learning rate scheduler according to config.

        Reads from `self.cfg.optimization`:
          - optimizer_type: str, one of ["adamw", "adamw8bit"]
          - learning_rate: float

        Raises:
            ValueError: if `optimizer_type` is unrecognized.
        """
        opt_cfg = self.cfg.optimization
        lr = opt_cfg.learning_rate
        opt_type = opt_cfg.optimizer_type.lower()

        # --- 1. Instantiate the optimizer ---
        if opt_type == "adamw":
            self.optimizer: Optimizer = AdamW(self._trainable_params, lr=lr)
        elif opt_type == "adamw8bit":
            # 8‑bit AdamW from bitsandbytes
            from bitsandbytes.optim import AdamW8bit  # type: ignore

            self.optimizer = AdamW8bit(self._trainable_params, lr=lr)
        else:
            raise ValueError(f"Unsupported optimizer type: '{opt_cfg.optimizer_type}'")

        # --- 2. Create the LR scheduler (may be None) ---
        self.lr_scheduler: Optional[LRScheduler] = self._create_scheduler(
            self.optimizer
        )

    def _create_scheduler(self, optimizer: Optimizer) -> Optional[LRScheduler]:
        """
        Instantiate and return an LR scheduler according to configuration.

        The configuration is read from `self.cfg.optimization`:
          - scheduler_type: str | None
          - steps: int
          - scheduler_params: dict[str, Any] | None

        Supported scheduler types:
          * "linear"               → LinearLR
          * "cosine"               → CosineAnnealingLR
          * "cosine_with_restarts" → CosineAnnealingWarmRestarts
          * "polynomial"           → PolynomialLR
          * "step"                 → StepLR
          * "constant"             → No scheduler (returns None)

        Returns:
            An instance of a torch LR scheduler, or None if no scheduler.

        Raises:
            ValueError: if `scheduler_type` is unrecognized.
        """
        cfg = self.cfg.optimization
        scheduler_type: Optional[str] = cfg.scheduler_type
        total_steps: int = cfg.steps
        params: dict = cfg.scheduler_params or {}

        # If no scheduler requested, return None immediately
        if scheduler_type is None or scheduler_type.lower() == "constant":
            return None

        scheduler_type = scheduler_type.lower()

        if scheduler_type == "linear":
            # Linear decay from start_factor→end_factor over total_iters
            return LinearLR(
                optimizer,
                start_factor=params.pop("start_factor", 1.0),
                end_factor=params.pop("end_factor", 0.1),
                total_iters=total_steps,
                **params,
            )

        if scheduler_type == "cosine":
            # Single-cycle cosine annealing
            return CosineAnnealingLR(
                optimizer,
                T_max=total_steps,
                eta_min=params.pop("eta_min", 0),
                **params,
            )

        if scheduler_type == "cosine_with_restarts":
            # Cosine schedule with periodic restarts
            return CosineAnnealingWarmRestarts(
                optimizer,
                T_0=params.pop("T_0", max(1, total_steps // 4)),
                T_mult=params.pop("T_mult", 1),
                eta_min=params.pop("eta_min", 5e-5),
                **params,
            )

        if scheduler_type == "polynomial":
            # Polynomial decay to zero (or other factor) over total_iters
            return PolynomialLR(
                optimizer,
                total_iters=total_steps,
                power=params.pop("power", 1.0),
                **params,
            )

        if scheduler_type == "step":
            # Decay the LR by `gamma` every `step_size` steps
            return StepLR(
                optimizer,
                step_size=params.pop("step_size", max(1, total_steps // 2)),
                gamma=params.pop("gamma", 0.1),
                **params,
            )

        # If we get here, the scheduler_type was invalid
        raise ValueError(f"Unknown scheduler type: '{scheduler_type}'")

    def _setup_accelerator(self) -> None:
        """
        Initialize the HuggingFace Accelerate `Accelerator` with the desired
        distributed training plugin.

        Reads plugin configurations from `self.cfg.accelerator`. Supports exactly
        one of:
        - Fully Sharded Data Parallel (FSDP)
        - DeepSpeed ZeRO-3
        - PyTorch Tensor Parallelism (DTensor)
        - Megatron-LM

        After plugin instantiation, any remaining keys in
        `self.cfg.accelerator` (e.g. mixed_precision, log_with) are passed
        straight to the Accelerator constructor.

        Raises:
            ValueError: if more than one parallelism plugin is set.
        """
        # --- 1. Extract raw plugin configs (dicts) ---
        accelerator_cfg = dict(self.cfg.acceleration)  # shallow copy
        fsdp_cfg = accelerator_cfg.pop("fsdp_plugin", None)
        ds_cfg = accelerator_cfg.pop("deepspeed_plugin", None)
        torch_tp_cfg = accelerator_cfg.pop("torch_tp_plugin", None)
        megatron_lm_cfg = accelerator_cfg.pop("megatron_lm_plugin", None)

        # --- 2. Ensure at most one plugin is specified ---
        plugin_flags = [
            bool(fsdp_cfg),
            bool(ds_cfg),
            bool(torch_tp_cfg),
            bool(megatron_lm_cfg),
        ]
        if sum(plugin_flags) > 1:
            raise ValueError(
                "Only one of 'fsdp_plugin', 'deepspeed_plugin', "
                "'torch_tp_plugin', or 'megatron_lm_plugin' may be set."
            )

        # --- 3. Instantiate the selected plugin, if any ---
        fsdp_plugin = FullyShardedDataParallelPlugin(**fsdp_cfg) if fsdp_cfg else None
        ds_plugin = DeepSpeedPlugin(**ds_cfg) if ds_cfg else None
        tp_plugin = TorchTensorParallelPlugin(**torch_tp_cfg) if torch_tp_cfg else None
        megatron_plugin = (
            MegatronLMPlugin(**megatron_lm_cfg) if megatron_lm_cfg else None
        )

        # --- 4. Create the Accelerator ---
        # Remaining keys in accelerator_cfg may include:
        #   mixed_precision, log_with, project_dir, etc.
        self.accelerator = Accelerator(
            fsdp_plugin=fsdp_plugin,
            deepspeed_plugin=ds_plugin,
            torch_tp_plugin=tp_plugin,
            megatron_lm_plugin=megatron_plugin,
            **accelerator_cfg,
        )

        # --- 5. Log basic distributed-training info ---
        if self.accelerator.num_processes > 1:
            # Num processes includes all world_size across Node(s)
            print_log(
                f"Distributed training enabled: {self.accelerator.num_processes} processes."
            )
            local_bs = self.cfg.data.batch_size
            global_bs = local_bs * self.accelerator.num_processes
            print_log(f"Local batch size:  {local_bs}")
            print_log(f"Global batch size: {global_bs}")

    def _prepare_accelerator(self) -> None:
        """Prepare the accelerator for training. Including prepare model, optimizer, scheduler, dataloader."""
        # We collect all nn.Module from self.module_names to build a temporary ModuleDict.
        # The key is the module name, and the value is the module itself.
        # We don't separately prepare each module or other components because it is not allowed when using DeepspeedPlugin.
        # When do so using DeepspeedPlugin, a "model" is required to be passed everytime calling self.accelerator.prepare.
        module_dict = nn.ModuleDict(
            {
                name: getattr(self, name)
                for name in self.module_names
                if isinstance(getattr(self, name), nn.Module)
            }
        )
        # we don't need
        _, self.optimizer, self.lr_scheduler, self.dataloader = (
            self.accelerator.prepare(
                module_dict, self.optimizer, self.lr_scheduler, self.data_loader
            )
        )

    @torch.no_grad()
    def _validation(self, progress: Progress) -> list[Path] | None:
        """The validation process"""
        raise NotImplementedError("_sample_videos must be implemented")

    def _setup_pipeline(self) -> None:
        """
        Setup the pipeline for inference.
        """
        raise NotImplementedError("_setup_pipeline must be implemented")

    @staticmethod
    def _log_training_stats(stats: TrainingStats) -> None:
        """
        Log aggregated training statistics after completion.

        Parameters:
            stats (TrainingStats):
                total_time_seconds: float — wall‐clock time including setup
                training_time: float — time spent in the actual train loop
                steps_per_second: float — average training iterations per second
                samples_per_second: float — average samples processed per second
                peak_gpu_memory_gb: float — maximum GPU memory usage in GB
                compilation_time_seconds: Optional[float] — time spent compiling (if applicable)
                num_processes: int — world_size for distributed runs
                global_batch_size: int — effective batch size across all processes
        """
        # Build the base log message
        total_mins = stats.total_time_seconds / 60
        train_mins = stats.training_time / 60

        lines = [
            "📊 Training Statistics:",
            f" - Total time:           {total_mins:.1f} minutes",
            f" - Training time:        {train_mins:.1f} minutes",
            f" - Training speed:       {stats.steps_per_second:.2f} steps/s",
            f" - Samples per second:   {stats.samples_per_second:.2f}",
            f" - Peak GPU memory:      {stats.peak_gpu_memory_gb:.2f} GB",
        ]

        # Optionally include compilation time
        if stats.compilation_time_seconds is not None:
            lines.append(
                f" - Compilation time:     {stats.compilation_time_seconds:.1f} seconds"
            )

        # If distributed, include process count and global batch size
        if stats.num_processes > 1:
            lines.extend(
                [
                    f" - Number of processes:  {stats.num_processes}",
                    f" - Global batch size:    {stats.global_batch_size}",
                ]
            )

        # Join lines and emit to log
        message = "\n".join(lines)
        print_log(message)

    def _save_checkpoint(self) -> Path:
        """
        Save each component listed in self.module_names—and the optimizer state—
        into a single step-specific folder under output_dir/checkpoints.

        Directory layout example for step 00010 and prefix "model":

        output_dir/
        └── checkpoints/
            └── model_weights_step_00010/
                ├── transformer/
                │   ├── config.json
                │   └── pytorch_model.safetensors
                ├── text_encoder/
                │   ├── config.json
                │   └── pytorch_model.safetensors
                ├── tokenizer/
                │   ├── config.json
                │   └── spiece.model
                ├── scheduler/
                │   └── config.json
                ├── vae/
                │   ├── config.json
                │   └── pytorch_model.safetensors
                ├── audio_processor/
                │   └── config.json
                ├── audio_encoder/
                │   ├── config.json
                │   └── pytorch_model.safetensors
                └── optimizer/
                    └── model_weights_step_00010.safetensors
                latest_checkpoint   # log file

        - We create ONE folder per step (`model_weights_step_xxxxx` or `lora_weights_step_xxxxx`).
        - Inside it, each module/component gets its own subfolder.
        - `save_pretrained` is used when available; otherwise we fall back to saving `.config`.
        - The optimizer state lives in the same step folder under `optimizer/`.

        Returns:
            Path: the main module’s weights file (e.g. transformer/pytorch_model.safetensors).
        """
        # 1) Prepare base checkpoint directory
        base_ckpt_dir = Path(self.cfg.output_dir) / "checkpoints"
        base_ckpt_dir.mkdir(parents=True, exist_ok=True)

        # 2) Determine prefix ("model" vs "lora") and step folder
        prefix = "lora" if self.cfg.get("lora", None) else "model"
        step_str = f"{self._global_step:05d}"
        step_dir = base_ckpt_dir / f"{prefix}_weights_step_{step_str}"
        step_dir.mkdir(exist_ok=True)

        # 3) Ensure all processes sync before unwrapping/saving
        print_log(f"💾 Saving {prefix} weights step {step_str} to {step_dir}")

        # 4) Save each module/component into its own subfolder
        for name in self.module_names:
            module = getattr(self, name, None)
            if module is None:
                # Skip if module not present on self
                continue

            subfolder = step_dir / name
            subfolder.mkdir(exist_ok=True)
            print_log(f"Saving {name} to {subfolder}", level=logging.INFO)
            # 4a) Unwrap model if wrapped by Accelerator (DDP/DeepSpeed/FSDP/etc.)
            if isinstance(module, torch.nn.Module) and hasattr(
                self.accelerator, "unwrap_model"
            ):
                module_to_save = self.accelerator.unwrap_model(module)
            else:
                module_to_save = module

            # 4b) If the module supports save_pretrained, use it
            if hasattr(module_to_save, "save_pretrained"):
                module_to_save.save_pretrained(str(subfolder))

            # 4c) Else if only its .config can be saved, save that
            elif hasattr(module_to_save, "config") and hasattr(
                module_to_save.config, "save_pretrained"
            ):
                module_to_save.config.save_pretrained(str(subfolder))

            # 4d) Otherwise, warn and skip
            else:
                print_log(
                    f"Module {name} has neither save_pretrained nor .config.save_pretrained. Skipping.",
                    level=logging.WARNING,
                )
                continue
            print_log(f"{name} has been saved to {subfolder}", level=logging.INFO)

        # 5) Save optimizer state alongside the modules
        opt_folder = step_dir / "optimizer"
        opt_folder.mkdir(exist_ok=True)
        opt_filename = f"{prefix}_weights_step_{step_str}.safetensors"
        opt_path = opt_folder / opt_filename

        self.accelerator.save(self.optimizer.state_dict(), opt_path)
        print_log(
            f"💾 Optimizer state saved in {opt_path.relative_to(self.cfg.output_dir)}"
        )
        # 6) Register this checkpoint path and perform cleanup
        self._checkpoint_paths.append(step_dir)

        # 7) Update the 'latest_checkpoint' log file
        latest_file = base_ckpt_dir / "latest_checkpoint"
        latest_file.write_text(step_dir.name)

        self._cleanup_checkpoints()

    def _cleanup_checkpoints(self) -> None:
        """
        Remove checkpoint folders or files older than the most recent N to keep.

        Reads self._checkpoint_paths (a list of Path objects, ordered from oldest to newest)
        and deletes any entries beyond the last `keep_last_n`. After removal, updates
        self._checkpoint_paths to contain only the kept checkpoints.

        Behavior:
        - If keep_last_n <= 0, nothing is deleted.
        - If there are no excess checkpoints, does nothing.
        - Supports both files and directories (using unlink() or shutil.rmtree()).
        """
        keep_n = self.cfg.checkpoints.keep_last_n

        # Nothing to do if we aren't keeping a positive number of checkpoints
        if keep_n <= 0:
            return

        total = len(self._checkpoint_paths)
        # Determine how many old checkpoints to remove
        num_to_remove = total - keep_n
        if num_to_remove <= 0:
            # We have equal or fewer checkpoints than we need to keep
            return

        # Split list into to-remove and to-keep slices
        old_ckpts = self._checkpoint_paths[:num_to_remove]
        remaining_ckpts = self._checkpoint_paths[num_to_remove:]

        for ckpt_path in old_ckpts:
            if not ckpt_path.exists():
                continue

            try:
                if ckpt_path.is_dir():
                    # Remove entire directory tree
                    shutil.rmtree(ckpt_path)
                else:
                    # Remove single file
                    ckpt_path.unlink()
                print_log(f"Removed old checkpoint: {ckpt_path}", level=logging.DEBUG)
            except Exception as e:
                print_log(f"Failed to remove {ckpt_path}: {e}", level=logging.WARNING)

        # Keep only the most recent `keep_n` checkpoints
        self._checkpoint_paths = remaining_ckpts
