#!/usr/bin/env python3

import argparse
import os
import re
import time
import yaml

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from audiotools import AudioSignal
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from .utils import build_codec_model, build_dataloader
from .modules.discriminator import Discriminator
from .modules.loss import GANLoss, MelSpectrogramLoss, MultibandMelSpectrogramLoss
from .utils.lr_scheduler import WarmupLR
from torch.cuda.amp import GradScaler

ADD_SEMANTIC_LOSS = False

def find_latest_ckpt(log_dir):
    ckpts = [f for f in os.listdir(log_dir) if f.endswith(".pth") and re.search(r"^\d+\.pth", f)]
    ckpts = sorted(ckpts, key=lambda x: int(x.split(".")[0]))
    print(ckpts)
    if len(ckpts) > 0:
        return os.path.join(log_dir, ckpts[-1])
    return "dummy-non-existing-file"


def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, help="Training configuration file")
    parser.add_argument("--log_dir", type=str, default="tmp/log/", help="Log directory")
    args = parser.parse_args()

    with open(args.config, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    os.makedirs(args.log_dir, exist_ok=True)
    return config, args.log_dir


def main():
    torch.manual_seed(777)

    # logger setting
    writer = SummaryWriter(log_dir)

    # basic training setting
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_rank = int(os.environ.get("RANK", 0))
    node_rank = int(os.environ.get("NODE_RANK", 0))

    # Ensure that each process uses its designated GPU as the default device. This avoids
    # accidental allocations on GPU0 (rank 0) and helps keep the memory usage balanced
    # across all ranks.
    torch.cuda.set_device(local_rank)

    master_uri = "tcp://%s:%s" % (
        os.environ.get("MASTER_ADDR", "localhost"),
        os.environ.get("MASTER_PORT", 1239),
    )
    dist.init_process_group(
        backend="nccl",
        init_method=master_uri,
        world_size=world_size,
        rank=world_rank,
    )
    torch.backends.cudnn.benchmark = True

    # read config file
    config, log_dir = get_config()
    if node_rank == 0:
        with open(os.path.join(log_dir, "config.yaml"), "w") as yaml_file:
            yaml.dump(config, yaml_file)

    # decoder (emb --> audio)
    codec_model_config = config["model"]
    soundstream = build_codec_model(codec_model_config=codec_model_config)
    enalbe_torch_compile = config["train"].get("enable_torch_compile", False)
    if enalbe_torch_compile:
        soundstream = torch.compile(soundstream)
    soundstream.to(local_rank)

    # discriminator
    discriminator_config = config.get("discriminator", {})
    discriminators = Discriminator(**discriminator_config)
    # NOTE: torch.compile on discriminators failed
    discriminators.to(local_rank)
    resume_ckpt = codec_model_config["resume_ckpt"]
    latest_ckpt = find_latest_ckpt(log_dir)
    if resume_ckpt != None:
        print(f"loading model from {resume_ckpt}")
        latest_info = torch.load(resume_ckpt, map_location="cpu")
        soundstream.load_state_dict(latest_info["soundstream"])
        if "discriminators" in latest_info:
            discriminators.load_state_dict(latest_info["discriminators"])
    elif os.path.exists(latest_ckpt):
        print(f"loading model from {latest_ckpt}")
        latest_info = torch.load(latest_ckpt, map_location="cpu")
        soundstream.load_state_dict(latest_info["soundstream"])
        discriminators.load_state_dict(latest_info["discriminators"])

    if config["train"]["distributed"]:
        soundstream = DDP(
            soundstream,
            device_ids=[local_rank],
            find_unused_parameters=True,
        )
        discriminators = DDP(
            discriminators,
            device_ids=[local_rank],
            find_unused_parameters=True,
        )

    print("Build dataloader")
    train_loader, valid_loader = build_dataloader(config)

    print("Build optimizers and lr-schedulers")
    generator_train_config = config["train"]["generator"]
    discriminator_train_config = config["train"]["discriminator"]
    optimizer_g = torch.optim.AdamW(
        soundstream.parameters(),
        lr=generator_train_config["lr"],
        betas=generator_train_config["betas"],
    )
    lr_type_g = generator_train_config.get("lr_type", "ExponentialLR")
    print(f"lr_type_g: {lr_type_g}")
    if lr_type_g == "ExponentialLR":
        lr_scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
            optimizer_g, gamma=generator_train_config["gamma"]
        )
    elif lr_type_g == "WarmupLR":
        lr_scheduler_g = WarmupLR(
            optimizer_g,
            warmup_step=generator_train_config["warmup_step"],
            down_step=generator_train_config["down_step"],
            max_lr=generator_train_config["max_lr"],
            min_lr=generator_train_config["min_lr"],
        )

    optimizer_d = torch.optim.AdamW(
        discriminators.parameters(),
        lr=discriminator_train_config["lr"],
        betas=discriminator_train_config["betas"],
    )
    lr_type_d = discriminator_train_config.get("lr_type", "ExponentialLR")
    print(f"lr_type_d: {lr_type_d}")
    if lr_type_d == "ExponentialLR":
        lr_scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
            optimizer_d, gamma=discriminator_train_config["gamma"]
        )
    elif lr_type_d == "WarmupLR":
        lr_scheduler_d = WarmupLR(
            optimizer_d,
            warmup_step=discriminator_train_config["warmup_step"],
            down_step=discriminator_train_config["down_step"],
            max_lr=discriminator_train_config["max_lr"],
            min_lr=discriminator_train_config["min_lr"],
        )

    spec_loss_config = config["train"].get(
        "spec_loss_config", {"mag_weight": 0.0, "log_weight": 2.0}
    )
    spec_loss_calculator = MelSpectrogramLoss(**spec_loss_config)

    # Semantic band-specific Spectrogram loss (optional)
    semantic_spec_loss_calculator = MultibandMelSpectrogramLoss(
        bands=[(0.0, 0.1)],
        band_weights=[1.0],
        loss_fn=nn.MSELoss(),
        pow=2,
        mag_weight=1,
        log_weight=1,
        n_mels=[80, 160, 320],
        window_lengths=[512, 1024, 2048],
    )

    # AMP setting
    use_amp = config["train"].get("use_amp", False)
    print(f"use_amp: {use_amp}")
    if use_amp:
        from torch.cuda.amp import autocast

        scaler_g = GradScaler()
        scaler_d = GradScaler()
    # set initial status
    global_step = 1
    soundstream.train()
    discriminators.train()

    if resume_ckpt != None:
        print(f"loading model from {resume_ckpt}")
        latest_info = torch.load(resume_ckpt, map_location="cpu")
        if "lr_scheduler_g" in latest_info:
            lr_scheduler_g.load_state_dict(latest_info["lr_scheduler_g"])
        if "optimizer_g" in latest_info:
            optimizer_g.load_state_dict(latest_info["optimizer_g"])
        if "lr_scheduler_d" in latest_info:
            lr_scheduler_d.load_state_dict(latest_info["lr_scheduler_d"])
        if "optimizer_d" in latest_info:
            optimizer_d.load_state_dict(latest_info["optimizer_d"])
        if "global_step" in latest_info:
            global_step = latest_info["global_step"]
        if use_amp and "scaler_g_state_dict" in latest_info:
            scaler_g.load_state_dict(latest_info["scaler_g_state_dict"])
        if use_amp and "scaler_d_state_dict" in latest_info:
            scaler_d.load_state_dict(latest_info["scaler_d_state_dict"])
    elif os.path.exists(latest_ckpt):
        print(f"loading model from {latest_ckpt}")
        latest_info = torch.load(latest_ckpt, map_location="cpu")
        lr_scheduler_g.load_state_dict(latest_info["lr_scheduler_g"])
        optimizer_g.load_state_dict(latest_info["optimizer_g"])
        lr_scheduler_d.load_state_dict(latest_info["lr_scheduler_d"])
        optimizer_d.load_state_dict(latest_info["optimizer_d"])
        global_step = latest_info["global_step"]
        if use_amp and "scaler_g_state_dict" in latest_info:
            scaler_g.load_state_dict(latest_info["scaler_g_state_dict"])
        if use_amp and "scaler_d_state_dict" in latest_info:
            scaler_d.load_state_dict(latest_info["scaler_d_state_dict"])

    prev_time = time.time()
    max_steps = config["train"].get("max_steps", torch.inf)
    save_interval = config["train"].get("save_interval", 50000)
    spec_loss_weight = config["train"].get("spec_loss_weight", 15.0)
    adv_g_loss_weight = config["train"].get("adv_g_loss_weight", 1.0)
    feat_loss_weight = config["train"].get("feat_loss_weight", 2.0)
    codebook_loss_weight = config["train"].get("codebook_loss_weight", 1.0)
    commitment_loss_weight = config["train"].get("commitment_loss_weight", 0.25)
    kl_loss_weight = config["train"].get("kl_loss_weight", 0.0)  # FIXME: not implemented yet.
    diffusion_loss_weight = config["train"].get(
        "diffusion_loss_weight", 0.0
    )  # FIXME: not implemented yet.
    distill_loss_weight = config["train"].get("distill_loss_weight", 1.0)
    semantic_spec_loss_weight = config["train"].get("semantic_spec_loss_weight", 15.0)
    print(f"max_steps: {max_steps}")
    print(f"save_interval: {save_interval}")
    print(f"spec_loss_weight: {spec_loss_weight}")
    print(f"adv_g_loss_weight: {adv_g_loss_weight}")
    print(f"feat_loss_weight: {feat_loss_weight}")
    print(f"codebook_loss_weight: {codebook_loss_weight}")
    print(f"commitment_loss_weight: {commitment_loss_weight}")
    print(f"kl_loss_weight: {kl_loss_weight}")
    print(f"diffusion_loss_weight: {diffusion_loss_weight}")
    print(f"distill_loss_weight: {distill_loss_weight}")
    print(f"semantic_spec_loss_weight: {semantic_spec_loss_weight}")
    print(f"data config: {config['data']}")
    # NOTE: This is not perfect solustion, but set random seed based on global_step
    #       so that different data will be consumed when resuming training.
    torch.manual_seed(global_step)
    for epoch in range(1, config["train"]["n_epochs"] + 1):
        gan_loss = GANLoss(discriminators)
        train_disc_loss = 0.0
        train_adv_g_loss = 0.0
        train_feat_loss = 0.0
        train_spec_loss = 0.0
        train_total_loss = 0.0
        train_codebook_loss = 0.0
        train_commitment_loss = 0.0
        train_diffusion_loss = 0.0
        train_kl_loss = 0.0
        train_distill_loss = 0.0
        train_semantic_spec_loss = 0.0
        train_token_ratio = 0.0
        k_iter = 0
        if config["train"]["distributed"]:
            train_loader.sampler.set_epoch(epoch)
        for data in train_loader:
            features = None
            if isinstance(data, dict):
                x = data["audio"] # [b,1,t]
                features = data["features"] # [b,t,c]
            else:
                x = data

            if torch.isnan(x).any():
                continue
            x = x.to(local_rank)
            if features is not None:
                features = features.to(local_rank)
            k_iter += 1  # local iter
            global_step += 1  # record the global step
            x_wav = x
            if global_step > max_steps:
                break

            with autocast():
                if features is not None:
                    dl_output = {
                        'audio': x,
                        'x': features,
                    }
                    out_dict = soundstream(dl_output)
                else:
                    out_dict = soundstream(x)

            generator_out = out_dict["audio"]
            commitment_loss = out_dict["vq/commitment_loss"]
            codebook_loss = out_dict["vq/codebook_loss"]
            token_ratio = out_dict.get("token_ratio")
            diffusion_loss = (
                out_dict["vq/diffusion_loss"]
                if "vq/diffusion_loss" in out_dict
                else torch.tensor(0.0)
            )
            kl_loss = (
                out_dict["kl_divergence"] if "kl_divergence" in out_dict else torch.tensor(0.0)
            )
            distill_loss = out_dict["distill_loss"] if "distill_loss" in out_dict else torch.tensor(0.0)

            semantic_spec_loss = torch.tensor(0.0, device=generator_out.device)
            if ADD_SEMANTIC_LOSS and "distill_loss" in out_dict and out_dict.get("bypassed_quantize", False):
                semantic_spec_loss = semantic_spec_loss_calculator(
                    AudioSignal(x_wav, config["data"]["sampling_rate"]),
                    AudioSignal(generator_out, config["data"]["sampling_rate"]),
                )

            with autocast():
                if config["train"]["use_hinge_loss"]:
                    disc_loss = gan_loss.discriminator_hinge_loss(generator_out, x_wav)
                else:
                    disc_loss = gan_loss.discriminator_loss(generator_out, x_wav)
            train_disc_loss += disc_loss.item()
            optimizer_d.zero_grad()
            scaler_d.scale(disc_loss).backward()

            scaler_d.unscale_(optimizer_d)
            torch.nn.utils.clip_grad_norm_(discriminators.parameters(), 1.0)
            scaler_d.step(optimizer_d)
            scaler_d.update()
            lr_scheduler_d.step()

            # FIXME: accumurate stats rather than print out an instantaneous stats
            if world_rank == 0 and (global_step % config["train"]["loss_print_freq"]) == 0:
                print(f"train/disc_loss: {disc_loss}")
                writer.add_scalar("train/disc_loss", disc_loss, global_step)
            del disc_loss

            with autocast():
                if config["train"]["use_hinge_loss"]:
                    adv_g_loss, feat_loss = gan_loss.generator_hinge_loss(generator_out, x_wav)
                else:
                    adv_g_loss, feat_loss = gan_loss.generator_loss(generator_out, x_wav)

                spec_loss = spec_loss_calculator(
                    AudioSignal(x_wav, config["data"]["sampling_rate"]),
                    AudioSignal(generator_out, config["data"]["sampling_rate"]),
                )

            total_loss = (
                commitment_loss_weight * commitment_loss
                + spec_loss_weight * spec_loss
                + adv_g_loss_weight * adv_g_loss
                + feat_loss_weight * feat_loss
                + codebook_loss_weight * codebook_loss
                + diffusion_loss_weight * diffusion_loss
                + kl_loss_weight * kl_loss
                + distill_loss_weight * distill_loss
                + semantic_spec_loss_weight * semantic_spec_loss
            )

            # FIXME: accumurate stats rather than print out an instantaneous stats
            if world_rank == 0 and (global_step % config["train"]["loss_print_freq"]) == 0:
                cur_time = time.time()
                proc_time = cur_time - prev_time
                training_sample_sec_per_sec = (
                    config["data"]["batch_size"]
                    * config["data"]["seg_len"]
                    * config["train"]["loss_print_freq"]
                    / proc_time
                )
                prev_time = cur_time
                log_msg = f"train/total_loss: {total_loss}, train/spec_loss: {spec_loss}, train/commitment_loss: {commitment_loss}, train/adv_g_loss: {adv_g_loss}, train/feat_loss: {feat_loss}, train/codebook_loss: {codebook_loss}, train/diffusion_loss: {diffusion_loss}, train/kl_loss: {kl_loss}, train/distill_loss: {distill_loss}, train/semantic_spec_loss: {semantic_spec_loss}, train/global_step: {global_step}, training_sample_sec_per_sec: {training_sample_sec_per_sec}"
                if token_ratio is not None:
                    log_msg += f", train/token_ratio: {token_ratio}"
                print(log_msg)

                writer.add_scalar("train/total_loss", total_loss, global_step)
                writer.add_scalar("train/spec_loss", spec_loss, global_step)
                writer.add_scalar("train/commitment_loss", commitment_loss, global_step)
                writer.add_scalar("train/adv_g_loss", adv_g_loss, global_step)
                writer.add_scalar("train/feat_loss", feat_loss, global_step)
                writer.add_scalar("train/codebook_loss", codebook_loss, global_step)
                writer.add_scalar("train/diffusion_loss", diffusion_loss, global_step)
                writer.add_scalar("train/kl_loss", kl_loss, global_step)
                writer.add_scalar("train/distill_loss", distill_loss, global_step)
                writer.add_scalar("train/semantic_spec_loss", semantic_spec_loss, global_step)
                if token_ratio is not None:
                    writer.add_scalar("train/token_ratio", token_ratio, global_step)
                writer.add_scalar(
                    "train/training_sample_sec_per_sec", training_sample_sec_per_sec, global_step
                )
            train_adv_g_loss += adv_g_loss.item()
            train_feat_loss += feat_loss.item()
            train_total_loss += total_loss.item()
            train_spec_loss += spec_loss.item()
            train_commitment_loss += commitment_loss.item()
            train_codebook_loss += codebook_loss.item()
            train_diffusion_loss += diffusion_loss.item()
            train_kl_loss += kl_loss.item()
            train_distill_loss += distill_loss.item()
            train_semantic_spec_loss += semantic_spec_loss.item()
            if token_ratio is not None:
                train_token_ratio += token_ratio

            optimizer_g.zero_grad()
            scaler_g.scale(total_loss).backward()

            scaler_g.unscale_(optimizer_g)
            torch.nn.utils.clip_grad_norm_(soundstream.parameters(), 1.0)
            scaler_g.step(optimizer_g)
            scaler_g.update()
            lr_scheduler_g.step()

            if global_step % save_interval == 0 and world_rank == 0:
                save_state = {}
                save_state["soundstream"] = soundstream.module.state_dict()
                save_state["discriminators"] = discriminators.module.state_dict()
                save_state["optimizer_g"] = optimizer_g.state_dict()
                save_state["lr_scheduler_g"] = lr_scheduler_g.state_dict()
                save_state["optimizer_d"] = optimizer_d.state_dict()
                save_state["lr_scheduler_d"] = lr_scheduler_d.state_dict()
                save_state["epoch"] = epoch
                save_state["global_step"] = global_step
                if use_amp:
                    save_state["scaler_g_state_dict"] = scaler_g.state_dict()
                    save_state["scaler_d_state_dict"] = scaler_d.state_dict()
                save_path = os.path.join(log_dir, f"{global_step}.pth")
                torch.save(save_state, save_path)
                print(f"model saved to {save_path}")

            if global_step % config["train"]["eval_interval"] == 0 or global_step == 1:
                print("evaluating")
                total_valid_loss = 0.0
                valid_spec_loss = 0.0
                valid_feat_loss = 0.0
                valid_adv_g_loss = 0.0
                valid_commitment_loss = 0.0
                valid_codebook_loss = 0.0
                valid_diffusion_loss = 0.0
                valid_kl_loss = 0.0
                valid_distill_loss = 0.0
                valid_semantic_spec_loss = 0.0
                valid_token_ratio = 0.0
                valid_token_ratio_count = 0
                soundstream.eval()
                for data in valid_loader:
                    features = None
                    with torch.no_grad():
                        if isinstance(data, dict):
                            x = data["audio"]
                            features = data["features"]
                        else:
                            x = data
                        x = x.to(local_rank)
                        if features is not None:
                            features = features.to(local_rank)

                        x_wav = x
                        with autocast():
                            if features is not None:
                                dl_output = {
                                    'audio': x,
                                    'x': features,
                                }
                                out_dict = soundstream(dl_output)
                            else:
                                out_dict = soundstream(x)
                        generator_out = out_dict["audio"]
                        commitment_loss = out_dict["vq/commitment_loss"]
                        codebook_loss = out_dict["vq/codebook_loss"]
                        token_ratio = out_dict.get("token_ratio")
                        diffusion_loss = (
                            out_dict["vq/diffusion_loss"]
                            if "vq/diffusion_loss" in out_dict
                            else torch.tensor(0.0)
                        )
                        kl_loss = (
                            out_dict["kl_loss"] if "kl_loss" in out_dict else torch.tensor(0.0)
                        )
                        distill_loss = out_dict["distill_loss"] if "distill_loss" in out_dict else torch.tensor(0.0)

                        semantic_spec_loss = torch.tensor(0.0, device=generator_out.device)
                        if "distill_loss" in out_dict and out_dict.get("bypassed_quantize", False):
                            semantic_spec_loss = semantic_spec_loss_calculator(
                                AudioSignal(x_wav, config["data"]["sampling_rate"]),
                                AudioSignal(generator_out, config["data"]["sampling_rate"]),
                            )

                        with autocast():
                            adv_g_loss, feat_loss = gan_loss.generator_loss(generator_out, x_wav)
                            spec_loss = spec_loss_calculator(
                                AudioSignal(x_wav, config["data"]["sampling_rate"]),
                                AudioSignal(generator_out, config["data"]["sampling_rate"]),
                            )
                        total_loss = (
                            commitment_loss_weight * commitment_loss
                            + spec_loss_weight * spec_loss
                            + adv_g_loss_weight * adv_g_loss
                            + feat_loss_weight * feat_loss
                            + codebook_loss_weight * codebook_loss
                            + diffusion_loss_weight * diffusion_loss
                            + kl_loss_weight * kl_loss
                            + distill_loss_weight * distill_loss
                            + semantic_spec_loss_weight * semantic_spec_loss
                        )
                        total_valid_loss += total_loss.item()
                        valid_spec_loss += spec_loss.item()
                        valid_feat_loss += feat_loss.item()
                        valid_adv_g_loss += adv_g_loss.item()
                        valid_commitment_loss += commitment_loss.item()
                        valid_codebook_loss += codebook_loss.item()
                        valid_diffusion_loss += diffusion_loss.item()
                        valid_kl_loss += kl_loss.item()
                        valid_distill_loss += distill_loss.item()
                        valid_semantic_spec_loss += semantic_spec_loss.item()
                        if token_ratio is not None:
                            valid_token_ratio += token_ratio
                            valid_token_ratio_count += 1

                avg_token_ratio = (
                    (valid_token_ratio / valid_token_ratio_count)
                    if valid_token_ratio_count > 0
                    else 0.0
                )
                message = "<epoch:{:d}, step:{:d}, valid_diffusion_loss:{:.4f}, valid_total_loss:{:.4f}, valid_adv_g_loss:{:.4f}, valid_feat_loss:{:.6f}, valid_spec_loss:{:.4f}, valid_commit_loss:{:.4f}, valid_codebook_loss:{:.4f}, valid_kl_loss:{:.4f}, valid_distill_loss:{:.4f}, valid_semantic_spec_loss:{:.4f}, valid_token_ratio:{:.4f}>".format(
                    epoch,
                    global_step,
                    valid_diffusion_loss / len(valid_loader),
                    total_valid_loss / len(valid_loader),
                    valid_adv_g_loss / len(valid_loader),
                    valid_feat_loss / len(valid_loader),
                    valid_spec_loss / len(valid_loader),
                    valid_commitment_loss / len(valid_loader),
                    valid_codebook_loss / len(valid_loader),
                    valid_kl_loss / len(valid_loader),
                    valid_distill_loss / len(valid_loader),
                    valid_semantic_spec_loss / len(valid_loader),
                    avg_token_ratio,
                )
                print(message)

                soundstream.train()
                if world_rank == 0:
                    print(
                        f"valid/total_loss: {total_valid_loss / len(valid_loader)}, valid/spec_loss: {valid_spec_loss / len(valid_loader)}, valid/commitment_loss: {valid_commitment_loss / len(valid_loader)}, valid/adv_g_loss: {valid_adv_g_loss / len(valid_loader)}, valid/feat_loss: {valid_feat_loss / len(valid_loader)}, valid/codebook_loss: {valid_codebook_loss / len(valid_loader)}, valid/diffusion_loss: {valid_diffusion_loss / len(valid_loader)}, valid/global_step: {global_step}, valid/kl_loss: {valid_kl_loss / len(valid_loader)}, valid/distill_loss: {valid_distill_loss / len(valid_loader)}, valid/semantic_spec_loss: {valid_semantic_spec_loss / len(valid_loader)}, valid/token_ratio: {avg_token_ratio}"
                    )
                    writer.add_scalar(
                        "valid/total_loss", total_valid_loss / len(valid_loader), global_step
                    )
                    writer.add_scalar(
                        "valid/spec_loss", valid_spec_loss / len(valid_loader), global_step
                    )
                    writer.add_scalar(
                        "valid/commitment_loss",
                        valid_commitment_loss / len(valid_loader),
                        global_step,
                    )
                    writer.add_scalar(
                        "valid/adv_g_loss", valid_adv_g_loss / len(valid_loader), global_step
                    )
                    writer.add_scalar(
                        "valid/feat_loss", valid_feat_loss / len(valid_loader), global_step
                    )
                    writer.add_scalar(
                        "valid/codebook_loss", valid_codebook_loss / len(valid_loader), global_step
                    )
                    writer.add_scalar(
                        "valid/diffusion_loss", valid_diffusion_loss / len(valid_loader), k_iter
                    )
                    writer.add_scalar("valid/kl_loss", valid_kl_loss / len(valid_loader), k_iter)
                    writer.add_scalar("valid/distill_loss", valid_distill_loss / len(valid_loader), k_iter)
                    writer.add_scalar("valid/semantic_spec_loss", valid_semantic_spec_loss / len(valid_loader), k_iter)
                    if valid_token_ratio_count > 0:
                        writer.add_scalar(
                            "valid/token_ratio", avg_token_ratio, k_iter
                        )

                prev_time = time.time()

        if global_step > max_steps:
            break

    # exit training
    destroy_process_group()


if __name__ == "__main__":
    main()
