import os
import sys
import time
import warnings
import json
from collections import defaultdict
from dataclasses import dataclass, asdict, field
from pathlib import Path
from typing import Optional, Dict, Union
from functools import partial

warnings.filterwarnings("ignore")

import deepspeed
import torch
import torchvision
import torch.distributed as dist
from deepspeed.runtime import lr_schedules
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from einops import rearrange

from vm.config import parse_args
from vm.constants import C_SCALE, PROMPT_TEMPLATE
from vm.dataset.video_loader import VideoDataset
from vm.diffusion import load_denoiser
from vm.ds_config import get_deepspeed_config
from vm.utils.train_utils import (
    prepare_model_inputs,
    load_state_dict,
    set_worker_seed_builder,
    get_module_kohya_state_dict,
    load_lora,
)
from vm.modules import load_model
from vm.text_encoder import TextEncoder
from vm.utils.file_utils import (
    safe_dir,
    get_experiment_max_number,
    empty_logger,
    dump_args,
    dump_codes,
    resolve_resume_path,
    logger_filter,
)
from vm.utils.helpers import (
    as_tuple,
    set_manual_seed,
    set_reproducibility,
    profiler_context,
    all_gather_sum,
    EventsMonitor,
)
from vm.vae import load_vae
from vm.constants import PRECISION_TO_TYPE

from peft import LoraConfig, get_peft_model
from safetensors.torch import save_file


def setup_distributed_training(args):
    deepspeed.init_distributed()

    # Treat micro/global batch size as tuples for compatibility with mix-scale training.
    world_size = dist.get_world_size()
    if args.data_type == "video" and args.video_micro_batch_size is None:
        # When data_type is video and video_micro_batch_size is None, we set the value from micro_batch_size
        args.video_micro_batch_size = args.micro_batch_size

    micro_batch_size = as_tuple(args.micro_batch_size)
    video_micro_batch_size = as_tuple(args.video_micro_batch_size)
    grad_accu_steps = args.gradient_accumulation_steps
    global_batch_size = as_tuple(args.global_batch_size)
    if "video" in args.data_type:
        refer_micro_batch_size = video_micro_batch_size
    else:
        refer_micro_batch_size = micro_batch_size

    if global_batch_size[0] is None:
        # Note: Model/Pipeline parallel is not supported yet. So, data-parallel-size equals to world-size.
        global_batch_size = tuple(
            [mbs_i * world_size * grad_accu_steps for mbs_i in refer_micro_batch_size]
        )
    else:
        assert global_batch_size == [
            mbs_i * world_size * grad_accu_steps for mbs_i in refer_micro_batch_size
        ], f"Global batch size should be divisible by world size, but got {global_batch_size} and {world_size}."

    rank = dist.get_rank()  # Rank of the current process in the cluster.
    device = (
        rank % torch.cuda.device_count()
    )  # Device of the current process in current node.
    # Set current device for the current process, otherwise dist.barrier() will occupy more memory in rank 0.
    torch.cuda.set_device(device)

    # Setup seed for reproducibility or performance.
    set_manual_seed(args.global_seed)
    set_reproducibility(args.reproduce, args.global_seed)

    return (
        rank,
        device,
        world_size,
        micro_batch_size,
        video_micro_batch_size,
        grad_accu_steps,
        global_batch_size,
    )


def setup_experiment_directory(args, rank):
    output_dir = safe_dir(args.output_dir)

    # Automatically increase the experiment number.
    existed_experiments = list(output_dir.glob("*"))
    experiment_index = get_experiment_max_number(existed_experiments) + 1
    model_name = args.model.replace("/", "").replace(
        "-", "_"
    )  # Replace '/' to avoid sub-directory.
    experiment_dir = (
        output_dir / f"{experiment_index:04d}_{model_name}_{args.task_flag}"
    )
    ckpt_dir = experiment_dir / "checkpoints"

    # Makesure all processes have the same experiment directory.
    dist.barrier()

    if rank == 0:
        from loguru import logger

        logger.add(
            experiment_dir / "train.log",
            level="DEBUG",
            colorize=False,
            backtrace=True,
            diagnose=True,
            encoding="utf-8",
            filter=logger_filter("train"),
        )
        logger.add(
            experiment_dir / "val.log",
            level="DEBUG",
            colorize=False,
            backtrace=True,
            diagnose=True,
            encoding="utf-8",
            filter=logger_filter("val"),
        )
        train_logger = logger.bind(name="train")
        val_logger = logger.bind(name="val")

        ckpt_dir = safe_dir(ckpt_dir)
    else:
        val_logger = train_logger = empty_logger()

    train_logger.info(f"Experiment directory created at: {experiment_dir}")

    return experiment_dir, ckpt_dir, train_logger, val_logger


def get_trainable_params(model, args):
    if args.training_parts is None:
        params = []
        for param in model.parameters():
            if param.requires_grad == True:
                params.append(param)
    else:
        raise ValueError(f"Unknown training_parts {args.training_parts}")
    return params


@dataclass
class ScalarStates:
    rank: int = 0  # rank id
    epoch: int = 1  # Accumulated training epochs
    epoch_train_steps: int = 0  # Accumulated training steps in current epoch
    epoch_update_steps: int = 0  # Accumulated update steps in current epoch
    train_steps: int = 0  # Accumulated training steps
    update_steps: int = 0  # Accumulated update steps
    current_run_update_steps: int = 0  # Update steps in current run
    consumed_samples_total: int = 0  # Accumulated consumed samples
    consumed_video_samples_total: int = 0  # Accumulated consumed video samples
    consumed_samples_per_dp: int = (
        0  # Accumulated consumed samples per data-parallel group
    )
    consumed_video_samples_per_dp: int = (
        0  # Accumulated consumed video samples per data-parallel group
    )
    consumed_tokens_total: int = 0  # Accumulated consumed tokens
    consumed_computations_attn: int = (
        0  # Accumulated consumed computations of attention + mlp
    )
    consumed_computations_total: int = 0  # Accumulated consumed computations of total

    def add(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, getattr(self, k) + v)


@dataclass
class CycleStates:
    log_steps: int = 0
    running_loss: float = 0
    running_tokens: int = 0
    running_samples: int = 0
    running_video_samples: int = 0
    running_grad_norm: float = 0
    running_loss_dict: Dict[int, float] = field(
        default_factory=lambda: defaultdict(float)
    )
    log_steps_dict: Dict[int, int] = field(default_factory=lambda: defaultdict(int))

    def add(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, getattr(self, k) + v)

    def reset(self):
        self.log_steps = 0
        self.running_loss = 0
        self.running_tokens = 0
        self.running_samples = 0
        self.running_video_samples = 0
        self.running_grad_norm = 0
        # Must be reset to float to avoid all_reduce type error.
        self.running_loss_dict = defaultdict(float)
        self.log_steps_dict = defaultdict(int)


def save_checkpoint(
    args,
    rank: int,
    logger,
    model_engine: DeepSpeedEngine,
    ema,
    scalar_state: ScalarStates,
    ckpt_dir: Path,
):
    _ = rank  # Currently not used.

    # gather scalar state
    scalar_state_dict = dict(**asdict(scalar_state))
    gather_results_list = [None for _ in range(dist.get_world_size())]
    torch.distributed.all_gather_object(gather_results_list, scalar_state_dict)
    gather_scalar_states = {}
    for results in gather_results_list:
        gather_scalar_states[results["rank"]] = results

    client_state = {
        "args": args,
        "scalar_state": gather_scalar_states,
    }
    if ema is not None:
        client_state["ema"] = ema.state_dict()
        client_state["ema_config"] = ema.config

    def try_save(_save_name):
        checkpoint_path = ckpt_dir / _save_name
        try:
            model_engine.save_checkpoint(
                str(ckpt_dir),
                client_state=client_state,
                tag=_save_name,
            )
            logger.info(f"Saved checkpoint to {checkpoint_path}")
            return checkpoint_path
        except Exception as e:
            logger.error(f"Saved failed to {checkpoint_path}. {type(e)}: {e}")
            return None

    update_steps = scalar_state.update_steps
    save_name = f"{update_steps:07d}"
    save_path = try_save(save_name)

    return [save_path]


def main(args):
    # ============================= Setup ==============================
    # Setup distributed training environment and reproducibility.
    (
        rank,
        device,
        world_size,
        micro_batch_size,
        video_micro_batch_size,
        grad_accu_steps,
        global_batch_size,
    ) = setup_distributed_training(args)
    # Setup experiment directory
    exp_dir, ckpt_dir, logger, val_logger = setup_experiment_directory(args, rank)
    # Load deepspeed config
    deepspeed_config = get_deepspeed_config(
        args,
        video_micro_batch_size[0],
        global_batch_size[0],
        args.output_dir,
        exp_dir.name,
    )
    # Log and dump the arguments and codes.
    logger.info(sys.argv)
    logger.info(str(args))
    if rank == 0:
        # Dump the arguments to a file.
        extra_args = {"world_size": world_size, "global_batch_size": global_batch_size}
        dump_args(args, exp_dir / "args.json", extra_args)
        # Dump codes to the experiment directory.
        dump_codes(
            exp_dir / "codes.tar.gz",
            root=Path(__file__).parent.parent,
            sub_dirs=["hymm", "jobs"],
            save_prefix=args.task_flag,
        )

    # =========================== Build main model ===========================
    logger.info("Building model...")
    factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
    if args.i2v_mode and args.i2v_condition_type == "latent_concat":
        in_channels = args.latent_channels * 2 + 1
        image_embed_interleave = 2
    elif args.i2v_mode and args.i2v_condition_type == "token_replace":
        in_channels = args.latent_channels
        image_embed_interleave = 4
    else:
        in_channels = args.latent_channels
        image_embed_interleave = 1
    out_channels = args.latent_channels

    if args.embedded_cfg_scale:
        factor_kwargs["guidance_embed"] = True

    model = load_model(
        args,
        in_channels=in_channels,
        out_channels=out_channels,
        factor_kwargs=factor_kwargs,
    )
    model = load_state_dict(args, model, logger)

    if args.use_lora:
        for param in model.parameters():
            param.requires_grad_(False)

        target_modules = [
            "linear",
            "fc1",
            "fc2",
            "img_attn_qkv",
            "img_attn_proj",
            "txt_attn_qkv",
            "txt_attn_proj",
            "linear1",
            "linear2",
        ]

        lora_config = LoraConfig(
            r=args.lora_rank,
            lora_alpha=args.lora_rank,
            init_lora_weights="gaussian",
            target_modules=target_modules,
        )
        model = get_peft_model(model, lora_config)

        if args.lora_path != "":
            model = load_lora(model, args.lora_path, device=device)

    logger.info(model)

    if args.reproduce:
        model.enable_deterministic()

    # After model initialization, we set different seed for each process.
    if args.same_data_batch:
        set_manual_seed(args.global_seed)
    else:
        set_manual_seed(args.global_seed + rank)

    ema = None
    ss = ScalarStates(rank=rank)

    # ========================== Initialize model_engine, optimizer =========================
    if args.warmup_num_steps > 0:
        logger.info(
            f"Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_max_lr={args.lr}, "
            f"warmup_num_steps={args.warmup_num_steps}."
        )
        lr_scheduler = partial(
            lr_schedules.WarmupLR,
            warmup_min_lr=args.warmup_min_lr,
            warmup_max_lr=args.lr,
            warmup_num_steps=args.warmup_num_steps,
        )
    else:
        lr_scheduler = None

    logger.info("Initializing optimizer (using deepspeed)...")
    model_engine, opt, _, scheduler = deepspeed.initialize(
        args=args,
        model=model,
        model_parameters=get_trainable_params(model, args),
        config_params=deepspeed_config,
        lr_scheduler=lr_scheduler,
    )

    # ====================== Build denoise scheduler ========================
    logger.info("Building denoise scheduler...")
    denoiser = load_denoiser(args)

    # ============================= Build extra models =========================
    # 2d/3d VAE
    vae, vae_path, s_ratio, t_ratio = load_vae(
        args.vae, args.vae_precision, logger=logger, device=device
    )

    # Text encoder
    text_encoder = TextEncoder(
        text_encoder_type=args.text_encoder,
        max_length=args.text_len
        + (
            PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
            if args.prompt_template_video is not None
            else PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
            if args.prompt_template is not None
            else 0
        ),
        text_encoder_precision=args.text_encoder_precision,
        tokenizer_type=args.tokenizer,
        i2v_mode=args.i2v_mode,
        prompt_template=(
            PROMPT_TEMPLATE[args.prompt_template]
            if args.prompt_template is not None
            else None
        ),
        prompt_template_video=(
            PROMPT_TEMPLATE[args.prompt_template_video]
            if args.prompt_template_video is not None
            else None
        ),
        hidden_state_skip_layer=args.hidden_state_skip_layer,
        apply_final_norm=args.apply_final_norm,
        reproduce=args.reproduce,
        logger=logger,
        device=device,
        image_embed_interleave=image_embed_interleave
    )
    if args.text_encoder_2 is not None:
        text_encoder_2 = TextEncoder(
            text_encoder_type=args.text_encoder_2,
            max_length=args.text_len_2,
            text_encoder_precision=args.text_encoder_precision_2,
            tokenizer_type=args.tokenizer_2,
            reproduce=args.reproduce,
            logger=logger,
            device=device,
        )
    else:
        text_encoder_2 = None

    # ================== Define dtype and forward autocast ===============
    target_dtype = None
    autocast_enabled = False
    if model_engine.bfloat16_enabled():
        target_dtype = torch.bfloat16
        autocast_enabled = True
    elif model_engine.fp16_enabled():
        target_dtype = torch.half
        autocast_enabled = True

    # ============================== Load dataset ==============================
    if "video" in args.data_type:
        video_dataset = VideoDataset(
            data_jsons_path=args.data_jsons_path,
            sample_n_frames=args.sample_n_frames,
            sample_stride=args.sample_stride,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            uncond_p=args.uncond_p,
            args=args,
            logger=logger,
        )
        video_sampler = DistributedSampler(
            video_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            seed=args.global_seed,
            drop_last=False,
        )
        video_batch_sampler = None
        video_loader = DataLoader(
            video_dataset,
            batch_size=video_micro_batch_size[0],
            shuffle=False,
            sampler=video_sampler,
            batch_sampler=video_batch_sampler,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False,
            prefetch_factor=None if args.num_workers == 0 else args.prefetch_factor,
            worker_init_fn=set_worker_seed_builder(rank),
            persistent_workers=True,
        )
        num_video_samples = len(video_dataset)
    else:
        video_dataset = None
        video_loader = None
        num_video_samples = 0

    loader = video_loader

    # ============================= Print key info =============================
    print(f"[{rank}] Worker ready.")
    dist.barrier()
    main_loader = video_loader

    try:
        iters_per_epoch = len(main_loader) // grad_accu_steps
    except NotImplementedError:
        iters_per_epoch = 0
    except TypeError:
        iters_per_epoch = 0

    params_count = model.params_count()
    logger.info("****************************** Running training ******************************")
    logger.info(f"  Number GPUs:               {world_size}")
    logger.info(f"  Training video samples(total):   {num_video_samples:,}")
    for k, v in params_count.items():
        logger.info(f"  Number {k} parameters:   {v:,}")
    logger.info(f"  Number trainable params:   {sum(p.numel() for p in get_trainable_params(model, args)):,}")
    logger.info("------------------------------------------------------------------------------")
    logger.info(f"  Iters per epoch:           {iters_per_epoch:,}")
    logger.info(f"  Updates per epoch:         {iters_per_epoch // grad_accu_steps:,}")

    logger.info(f"  Batch size per device:     {video_micro_batch_size}")
    logger.info(f"  Batch size all device:     {global_batch_size:}")

    logger.info(f"  Gradient Accu steps:       {args.gradient_accumulation_steps}")
    logger.info(f"  Training epochs:           {ss.epoch}/{args.epochs}")
    logger.info(f"  Training total steps:      {ss.update_steps:,}/{args.max_training_steps:,}")
    logger.info("------------------------------------------------------------------------------")

    logger.info(f"  Path type:                 {args.flow_path_type}")
    logger.info(f"  Predict type:              {args.flow_predict_type}")
    logger.info(f"  Loss weight:               {args.flow_loss_weight}")
    logger.info(f"  Flow reverse:              {args.flow_reverse}")
    logger.info(f"  Flow shift:                {args.flow_shift}")
    logger.info(f"  Train eps:                 {args.flow_train_eps}")
    logger.info(f"  Sample eps:                {args.flow_sample_eps}")
    logger.info(f"  Timestep type:             {args.flow_snr_type}")

    logger.info("------------------------------------------------------------------------------")
    logger.info(f"  Main model precision:      {args.precision}")
    logger.info("------------------------------------------------------------------------------")
    logger.info(f"  VAE:                       {args.vae} ({args.vae_precision}) - {vae_path}")
    logger.info(f"  Text encoder:              {text_encoder}")
    if text_encoder_2 is not None:
        logger.info(f"  Text encoder 2:            {text_encoder_2}")
    logger.info(f"  Experiment directory:      {ckpt_dir}")
    logger.info("*******************************************************************************")

    # ============================= Start training =============================
    model_engine.train()

    if args.init_save:
        save_checkpoint(args, rank, logger, model_engine, ema, ss, ckpt_dir)

    # Training loop
    start_epoch = ss.epoch
    finished = False

    ss.current_run_update_steps = 0
    for epoch in range(start_epoch, args.epochs):

        if video_dataset is not None:
            logger.info(f"Start video random shuffle(seed={args.global_seed + epoch})")
            video_sampler.set_epoch(epoch)  # epoch start from 1
            logger.info(f"End of video random shuffle")

        logger.info(f"Beginning epoch {epoch}...")
        with profiler_context(
            args.profile, exp_dir, worker_name=f"Rank_{rank}"
        ) as prof:
            # Define cycle states, which accumulate the training information between log_steps.
            cs = CycleStates()
            start_time = time.time()

            for batch_idx, batch in enumerate(loader):
                # broadcast a zero size tensor to indicate starting of step
                start_flag_tensor = torch.cuda.FloatTensor([])
                if torch.distributed.is_initialized():
                    torch.distributed.broadcast(start_flag_tensor, 0, async_op=True)

                # main diff
                (
                    latents,
                    model_kwargs,
                    n_tokens,
                    cond_latents,
                ) = prepare_model_inputs(
                    args,
                    batch,
                    device,
                    model,
                    vae,
                    text_encoder,
                    text_encoder_2,
                    rope_theta_rescale_factor=args.rope_theta_rescale_factor,
                    rope_interpolation_factor=args.rope_interpolation_factor,
                )
                cur_batch_size = latents.shape[0]

                cur_anchor_size = max(args.video_size)

                # A forward-backward step
                with torch.autocast(
                    device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
                ):
                    _, loss_dict = denoiser.training_losses(
                        model_engine,
                        latents,
                        model_kwargs,
                        n_tokens=n_tokens,
                        i2v_mode=args.i2v_mode,
                        cond_latents=cond_latents,
                        args=args,
                    )
                loss = loss_dict["loss"].mean()
                model_engine.backward(loss)

                # Update model parameters at the step of gradient accumulation.
                model_engine.step(lr_kwargs={"last_batch_iteration": ss.update_steps})

                # Update accumulated states
                ss.add(
                    train_steps=1,
                    epoch_train_steps=1,
                    consumed_samples_per_dp=cur_batch_size,
                )
                ss.add(consumed_video_samples_per_dp=cur_batch_size)

                # We enable `is_update_step` if the current step is the gradient accumulation boundary.
                is_update_step = ss.train_steps % grad_accu_steps == 0
                if is_update_step:
                    ss.add(
                        update_steps=1, epoch_update_steps=1, current_run_update_steps=1
                    )

                if ss.update_steps >= args.max_training_steps:
                    # Enter stopping routine if max steps reached after this step.
                    finished = True

                # Log training information:
                cs.add(
                    log_steps=1,
                    running_loss=loss.item(),
                    running_samples=cur_batch_size,
                    running_tokens=cur_batch_size * n_tokens,
                    running_grad_norm=0,
                )
                cs.add(running_video_samples=cur_batch_size)

                cs.running_loss_dict[cur_anchor_size] += loss.item()
                cs.log_steps_dict[cur_anchor_size] += 1
                if is_update_step and ss.update_steps % args.log_every == 0:
                    # Reduce loss history over all processes:
                    avg_loss = (
                        all_gather_sum(cs.running_loss / cs.log_steps, device)
                        / world_size
                    )
                    avg_grad_norm = (
                        all_gather_sum(cs.running_grad_norm / cs.log_steps, device)
                        / world_size
                    )
                    cum_samples = all_gather_sum(cs.running_samples, device)
                    cum_video_samples = all_gather_sum(cs.running_video_samples, device)
                    cum_tokens = all_gather_sum(cs.running_tokens, device)
                    # Measure training speed:
                    torch.cuda.synchronize()
                    end_time = time.time()
                    steps_per_sec = (
                        cs.log_steps / (end_time - start_time) / grad_accu_steps
                    )
                    samples_per_sec = cum_samples / (end_time - start_time)
                    sec_per_step = (end_time - start_time) / cs.log_steps
                    ss.add(
                        consumed_samples_total=cum_samples,
                        consumed_video_samples_total=cum_video_samples,
                        consumed_tokens_total=cum_tokens,
                        consumed_computations_attn=6
                        * params_count["attn+mlp"]
                        * cum_tokens
                        / C_SCALE,
                        consumed_computations_total=6
                        * params_count["total"]
                        * cum_tokens
                        / C_SCALE,
                    )
                    log_events = [
                        f"Train Loss: {avg_loss:.4f}",
                        f"Grad Norm: {avg_grad_norm:.4f}",
                        f"Lr: {opt.param_groups[0]['lr']:.6g}",
                        f"Sec/Step: {sec_per_step:.2f}, "
                        f"Steps/Sec: {steps_per_sec:.2f}",
                        f"Samples/Sec: {int(samples_per_sec):d}",
                        f"Consumed Samples: {ss.consumed_samples_total:,}",
                        f"Consumed Video Samples: {ss.consumed_video_samples_total:,}",
                        f"Consumed Tokens: {ss.consumed_tokens_total:,}",
                    ]
                    summary_events = [
                        ("Train/Steps/train_loss", avg_loss, ss.update_steps),
                        ("Train/Steps/grad_norm", avg_grad_norm, ss.update_steps),
                        ("Train/Steps/steps_per_sec", steps_per_sec, ss.update_steps),
                        (
                            "Train/Steps/samples_per_sec",
                            int(samples_per_sec),
                            ss.update_steps,
                        ),
                        ("Train/Tokens/train_loss", avg_loss, ss.consumed_tokens_total),
                        (
                            "Train/ComputationsAttn/train_loss",
                            avg_loss,
                            ss.consumed_computations_attn,
                        ),
                        (
                            "Train/ComputationsTotal/train_loss",
                            avg_loss,
                            ss.consumed_computations_total,
                        ),
                    ]
                    # Log the training information to the logger.
                    logger.info(
                        f"(step={ss.update_steps:07d}) " + ", ".join(log_events)
                    )
                    if model_engine.monitor.enabled and rank == 0:
                        model_engine.monitor.write_events(summary_events)

                    # Reset monitoring variables:
                    cs.reset()
                    start_time = time.time()

                # Save checkpoint:
                if (is_update_step and ss.update_steps % args.ckpt_every == 0) or (
                    finished and args.final_save
                ):
                    if args.use_lora:
                        if rank == 0:
                            output_dir = os.path.join(
                                ckpt_dir, f"global_step{ss.update_steps}"
                            )
                            os.makedirs(output_dir, exist_ok=True)

                            lora_kohya_state_dict = get_module_kohya_state_dict(
                                model, "Hunyuan_video_I2V_lora", dtype=torch.bfloat16
                            )
                            save_file(
                                lora_kohya_state_dict,
                                f"{output_dir}/pytorch_lora_kohaya_weights.safetensors",
                            )
                    else:
                        save_checkpoint(
                            args, rank, logger, model_engine, ema, ss, ckpt_dir
                        )

                if prof:
                    prof.step()

                if finished:
                    logger.info(
                        f"Finished and breaking loop at step={ss.update_steps}."
                    )
                    break

            if finished:
                logger.info(f"Finished and breaking loop at epoch={epoch}.")
                break

            # Reset epoch states
            ss.epoch += 1
            ss.epoch_train_steps = 0
            ss.epoch_update_steps = 0

    logger.info("Training Finished!")


if __name__ == "__main__":
    main(parse_args(mode="train"))
