# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import os
import gc
import uuid
import time
import json
import torch
import wandb
import shutil
import argparse
import numpy as np
import bitsandbytes as bnb
import torch.distributed as dist

from tqdm import tqdm
from datetime import datetime, timedelta

from model.model import Model
from config.get_config import get_config
from data.get_dataloader import get_dataloader
from evaluation.evaluation import evaluation

from utils.synchronize import synchronize
from utils.set_random_seeds import set_random_seeds
from utils.to_cpu_recursive import to_cpu_recursive
from utils.breakpoint_on_rank import breakpoint_on_rank
from utils.report_parameter_count import report_parameter_count

from training.validation import validation
from training.should_validate import should_validate
from training.gradient_accumulation import gradient_accumulation
from training.save_checkpoint import process_iter_folders, save_checkpoint

from optimization.get_param_flags import get_param_flags
from optimization.get_learning_rate import get_learning_rate



def launch_training(remaining_args):
    # ----- #
    # Prologue
    # ----- #
    # region Stage: Prologue
    # Parse `remaining_args`
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_file", type=str, required=True)
    parser.add_argument("--vault_path", type=str, required=False, default=None)
    parser.add_argument("--skip_training", action="store_true")  
    parser.add_argument("--estimate_training_time", action="store_true")  
    parser.add_argument("--enable_training_time_validation", action="store_true")  
    args = parser.parse_args(remaining_args)

    # Set up torch.distributed
    torch.cuda.set_device(torch.device(int(os.environ["LOCAL_RANK"])))
    dist.init_process_group(
        backend="nccl",
        device_id=torch.device(int(os.environ["LOCAL_RANK"])),
        timeout=timedelta(minutes=120),
    )
    synchronize()

    # Initialize config
    config = get_config(args.config_file)
    if dist.get_rank() == 0:
        print("\n\n", "Loaded config from:", args.config_file, "\n\n", config, "\n\n")
    # Additional validation
    assert config.num_gpu == dist.get_world_size()

    # Override configs if estimate training time
    if args.estimate_training_time:
        config.num_batch_override = 120
        config.lrsched_warmup_steps = 2
        config.lrsched_decay_steps = 2
        config.eval_enable_validation = False
        config.ckpt_enabled = False
        config.runtime["enforce_random_routing"] = True
        runtime_buff = []

    # Define variables
    rank, world_size = dist.get_rank(), dist.get_world_size()
    vault_path = args.vault_path
    is_new_run = args.vault_path is None
    has_distributed_weights = config.ffwd_name in {"MHMoEHP", "MHMoEHPNRT", "MHMoETiedHP", "MHMoENaiveHP", "MoEEP", "LatentMoE"}
    if config.ffwd_name in {"MLP"}:
        auxfree_type = "OFF"
    elif config.ffwd_name in {"MoE", "MoEEP", "MHMoETied", "MHMoETiedHP", "LatentMoE"}:
        auxfree_type = "SINGLE_HEAD"
    elif config.ffwd_name in {"MHMoE", "MHMoENaive"}:
        auxfree_type = "MULTI_HEAD_NO_HP"
    elif config.ffwd_name in {"MHMoEHP", "MHMoEHPNRT", "MHMoENaiveHP"}:
        auxfree_type = "MULTI_HEAD_WITH_HP"
    else:
        raise Exception("Unexpected config.ffwd_name")
    # endregion
    # ----- #


    # ----- #
    # Random Seed
    # ----- #
    # region Stage: Random Seed
    if config.repro_use_random_seed:
        set_random_seeds(config)
    # endregion
    # ----- #


    # ----- #
    # PyTorch Settings
    # ----- #
    # region Stage: PyTorch Settings
    # CUDA
    torch.backends.cuda.matmul.allow_tf32 = True
    # CUDNN
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.allow_tf32 = True
    # Dtype
    torch.set_default_dtype(torch.float32)
    torch.set_float32_matmul_precision("high")
    # torch.compile
    if config.use_diagnostic_mode:
        torch._dynamo.config.suppress_errors = False
        torch._dynamo.config.recompile_limit = 8
        torch._dynamo.config.fail_on_recompile_limit_hit = True
    # endregion
    # ----- #


    # ----- #
    # Vault
    # ----- #
    # region Stage: Vault
    if vault_path is None:
        if rank == 0:
            # Define `vault_path`
            timestamp  = datetime.now().strftime("%y%m%d_%H%M%S")
            vault_uuid = str(uuid.uuid4())
            vault_name = f"{config.run_name}_{timestamp}_{vault_uuid}"
            vault_path = os.path.join(config.project_directory, vault_name)
            os.makedirs(vault_path, exist_ok=False)
            os.makedirs(os.path.join(vault_path, "checkpoints"), exist_ok=False)
            shutil.copy2(args.config_file, os.path.join(vault_path, "config.yaml"))
        synchronize()
        vault_path = [vault_path]
        dist.broadcast_object_list(vault_path, src=0)
        vault_path = vault_path[0]
        if rank == 0:
            print(f"\n\nCreated a new vault: {vault_path}\n\n")
    else:
        if rank == 0:
            print(f"\n\nUsing an existing vault: {vault_path}\n\n")
    # endregion
    # ----- #


    # ----- #
    # Wandb Initialization
    # ----- #
    # region Stage: Wandb Initialization
    if rank == 0:
        # Initialize wandb
        run = wandb.init(
            entity=config.project_entity,
            project=config.project_name,
            dir=vault_path,
            name=config.run_name,
            id=None,
        )
    else:
        run = None
    synchronize()
    # endregion
    # ----- #


    # ----- #
    # Data
    # ----- #
    dataloader_train = get_dataloader(config, mode="training")
    dataloader_val   = get_dataloader(config, mode="validation")
    # ----- #


    # ----- #
    # Model
    # ----- #
    # region Stage: Model
    config.runtime["expert_load_all"] = None
    config.runtime["auxfree_update_ratio"] = None
    if auxfree_type == "OFF":
        config.runtime["auxfree_enabled"] = False
        config.runtime["expert_load_no_share"] = None
        config.runtime["auxfree_shape"] = None
        config.runtime["auxfree_bias_all"] = None
    elif auxfree_type == "SINGLE_HEAD":
        config.runtime["auxfree_enabled"] = True
        config.runtime["expert_load_no_share"] = False
        # (num_block, num_expert); float32; contiguous; detached
        config.runtime["auxfree_bias_all"] = torch.zeros(
            size=(config.num_block, config.ffwd_num_expert),
            dtype=torch.float32,
            device="cuda",
            requires_grad=False,
        )
        config.runtime["auxfree_shape"] = config.runtime["auxfree_bias_all"].shape
    elif auxfree_type == "MULTI_HEAD_NO_HP":
        config.runtime["auxfree_enabled"] = True
        config.runtime["expert_load_no_share"] = False
        # (num_block, num_head, num_expert); float32; contiguous; detached
        config.runtime["auxfree_bias_all"] = torch.zeros(
            size=(config.num_block, config.ffwd_num_head, config.ffwd_num_expert),
            dtype=torch.float32,
            device="cuda",
            requires_grad=False,
        )
        config.runtime["auxfree_shape"] = config.runtime["auxfree_bias_all"].shape
    elif auxfree_type == "MULTI_HEAD_WITH_HP":
        config.runtime["auxfree_enabled"] = True
        config.runtime["expert_load_no_share"] = True
        # Define `ffwd_num_head_per_rank`
        assert config.ffwd_num_head % world_size == 0
        ffwd_num_head_per_rank = config.ffwd_num_head // world_size
        # (num_block, num_head_per_rank, num_expert); float32; contiguous; detached
        config.runtime["auxfree_bias_all"] = torch.zeros(
            size=(config.num_block, ffwd_num_head_per_rank, config.ffwd_num_expert),
            dtype=torch.float32,
            device="cuda",
            requires_grad=False,
        )
        config.runtime["auxfree_shape"] = config.runtime["auxfree_bias_all"].shape
    else:
        raise Exception("Unexpected auxfree_type")

    if rank == 0:
        print("\n\n")
    for idx_rank in tqdm(range(world_size), desc="Model Init", disable=rank != 0):
        if rank == idx_rank:
            model = Model(config).to("cuda")
        synchronize()
    if rank == 0:
        print("\n\n")

    # Broadcast parameters and buffers from rank 0 (except _no_share ones)
    for n, p in model.named_parameters():
        if "_no_share" not in n:
            dist.broadcast(p.data, src=0)
    for n, b in model.named_buffers():
        if "_no_share" not in n:
            dist.broadcast(b.data, src=0)

    # Report parameter count and print the model
    if rank == 0:
        print("\n\n")
        report_parameter_count(config, model, verbose=config.use_diagnostic_mode)
        print("\n\n")
        print(model)
        print("\n\n")
    # endregion
    # ----- #


    # ----- #
    # Optimization
    # ----- #
    # region Stage: Optimization
    params_decay = list()
    params_no_decay = list()
    params_decay_8bit = list()
    params_no_decay_8bit = list()
    for name, p in model.named_parameters():
        if p.requires_grad:
            param_flags = get_param_flags(model, p)
            if param_flags["decay"]:
                if param_flags["8bit"]:
                    params_decay_8bit.append(p)
                    if config.use_diagnostic_mode and rank == 0:
                        print(f"{name} is in `params_decay_8bit`")
                else:
                    params_decay.append(p)
                    if config.use_diagnostic_mode and rank == 0:
                        print(f"{name} is in `params_decay`")
            else:
                if param_flags["8bit"]:
                    params_no_decay_8bit.append(p)
                    if config.use_diagnostic_mode and rank == 0:
                        print(f"{name} is in `params_no_decay_8bit`")
                else:
                    params_no_decay.append(p)
                    if config.use_diagnostic_mode and rank == 0:
                        print(f"{name} is in `params_no_decay`")
    # Define `optimizer_1` and `optimizer_2`
    optimizer_1 = torch.optim.AdamW(
        params=[
            {"params": params_decay,    "weight_decay": config.adamw_weight_decay},
            {"params": params_no_decay, "weight_decay": 0.0},
        ],
        betas=(config.adamw_beta_1, config.adamw_beta_2),
        eps=config.adamw_eps,
        fused=False,
    )
    optimizer_2 = bnb.optim.AdamW(
        params=[
            {"params": params_decay_8bit,    "weight_decay": config.adamw_weight_decay},
            {"params": params_no_decay_8bit, "weight_decay": 0.0},
        ],
        betas=(config.adamw_beta_1, config.adamw_beta_2),
        eps=config.adamw_eps,
        optim_bits=8,
        min_8bit_size=4096,
    )
    # endregion
    # ----- #


    # ----- #
    # Resume from a previous run
    # ----- #
    # region Stage: Resume from a previous run
    if not is_new_run:
        if rank == 0:
            print("\n\n", "Resuming from a previous run", "\n\n")
        # Synchronize before proceeding
        synchronize()
        # Define `iter_folder_all` and sort it
        iter_folder_all = os.listdir(os.path.join(vault_path, "checkpoints"))
        iter_folder_all = sorted(iter_folder_all, key=lambda x: int(x))  # Smallest idx_iter first
        assert len(iter_folder_all) > 0
        for idx_rank in tqdm(range(world_size), disable=rank != 0):
            if rank == idx_rank:
                print(f"\n\nProcessing rank {rank}\n\n")
                # Define `checkpoint_path`
                checkpoint_path = os.path.join(vault_path, "checkpoints", iter_folder_all[-1])  # Get the last one
                if has_distributed_weights:
                    checkpoint_path = os.path.join(checkpoint_path, f"rank_{idx_rank}.pt")
                else:
                    checkpoint_path = os.path.join(checkpoint_path, "rank_0.pt")
                # Load `checkpoint_dict`, on each rank
                checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
                idx_iter_previous = checkpoint_dict["idx_iter"]
                world_size_previous = checkpoint_dict["world_size"]
                if rank == 0:
                    print("\n\n")
                    print(f"world_size_previous is {world_size_previous}")
                    print(f"world_size          is {world_size}")
                    print(f"idx_iter_previous   is {idx_iter_previous}")
                    print(f"wandb run.step      is {run.step}")
                    print("\n\n")
                config.runtime["auxfree_bias_all"] = checkpoint_dict["auxfree_bias_all"]
                if config.runtime["auxfree_bias_all"] is not None:
                    config.runtime["auxfree_bias_all"] = config.runtime["auxfree_bias_all"].cuda()
                load_state_dict_result = model.load_state_dict(checkpoint_dict["state_dict_model"], strict=False)
                print("Missing keys:", load_state_dict_result.missing_keys)  # Possible excessive printing
                print("Unexpected keys:", load_state_dict_result.unexpected_keys)
                optimizer_1.load_state_dict(checkpoint_dict["state_dict_optimizer_1"])
                optimizer_2.load_state_dict(checkpoint_dict["state_dict_optimizer_2"])
                # Release `checkpoint_dict`
                del checkpoint_dict
                gc.collect()
            synchronize()
    # endregion
    # ----- #


    # ----- #
    # Training Loop
    # ----- #
    # region Stage: Training loop
    if not args.skip_training:
        for idx_iter, (inputs, targets) in enumerate(tqdm(
            iterable=dataloader_train,
            desc="Training model",
            total=len(dataloader_train),
            disable=rank != 0,  # Enable tqdm only for rank 0
        )):
            if not is_new_run:
                if idx_iter <= idx_iter_previous:
                    continue


            # Reset wandb_log
            wandb_log = dict()
            wandb_log["step"] = idx_iter


            # Start the timer
            synchronize()
            t1 = time.perf_counter()


            # Set the model to training mode
            model.train()


            # ----- #
            # Set learning rate
            # ----- #
            config.runtime["auxfree_update_ratio"] = 50.0 * config.lrsched_max_lr
            wandb_log["auxfree_update_ratio"] = config.runtime["auxfree_update_ratio"]

            # Get learning rate
            lr = get_learning_rate(
                idx_iter=idx_iter,
                max_lr=config.lrsched_max_lr,
                min_lr=config.lrsched_min_lr,
                warmup_steps=config.lrsched_warmup_steps,
                decay_steps=config.lrsched_decay_steps,
                num_iter=len(dataloader_train),
            )
            # Apply learning rate
            for param_group in optimizer_1.param_groups:
                param_group["lr"] = lr
            for param_group in optimizer_2.param_groups:
                param_group["lr"] = lr
            # Update wandb_log
            wandb_log["lr"] = lr
            # ----- #


            # ----- #
            # Take one gradient step
            # ----- #
            # Reset the optimizers
            optimizer_1.zero_grad(set_to_none=True)
            optimizer_2.zero_grad(set_to_none=True)
            # Perform gradient accumulation
            loss_lm = gradient_accumulation(config, inputs, targets, model)
            # Update wandb_log
            wandb_log["loss_lm"] = loss_lm
            # Apply gradient clipping
            if config.gradclip_enabled:
                torch.nn.utils.clip_grad_norm_(
                    parameters=model.parameters(),
                    max_norm=config.gradclip_max_norm,
                    norm_type=config.gradclip_norm_type,
                )
            # Update the parameters
            optimizer_1.step()
            optimizer_2.step()
            # ----- #


            # ----- #
            # Stop the timer
            # ----- #
            synchronize()
            t2 = time.perf_counter()
            # Get iter_time
            iter_time = t2 - t1
            # Update wandb_log
            wandb_log["iter_time"] = iter_time
            # ----- #




            # ----- #
            # Calculate and handle estimating training time if enabled
            # ----- #
            if args.estimate_training_time and (rank == 0):
                if idx_iter >= 20:
                    runtime_buff.append(iter_time)
                    # Calculate statistics
                    runtime_arr = np.array(runtime_buff)
                    n = len(runtime_buff)
                    mean_time = np.mean(runtime_arr)
                    if n >= 2:
                        std_time = np.std(runtime_arr, ddof=1)
                    else:
                        std_time = float("nan")
                    print(f"Runtime stats: mean={mean_time:.4f}s, std={std_time:.4f}s, n={n}")
            # ----- #




            # ----- #
            # Validation
            # ----- #
            if config.eval_enable_validation:
                if args.enable_training_time_validation:
                    _should_validate = should_validate(idx_iter=idx_iter, num_iter=len(dataloader_train), num_val=32)
                else:
                    _should_validate = (idx_iter + 1) == len(dataloader_train)
                if _should_validate:
                    if rank == 0:
                        print("\n\n")
                    # Synchronize before validation
                    synchronize()
                    # Start validation
                    perplexity_val = validation(config, model, dataloader_val)
                    # Update wandb_log
                    wandb_log["perplexity_val"] = perplexity_val
                    # Synchronize before proceeding
                    synchronize()
            # ----- #


            # ----- #
            # Optionally visualize `expert_load_all`
            # ----- #
            if rank == 0:
                if config.runtime["expert_load_all"] is not None:
                    if auxfree_type == "OFF":
                        pass
                    elif auxfree_type == "SINGLE_HEAD":
                        expert_load_all_cpu = config.runtime["expert_load_all"].cpu()
                        percentile_all = torch.tensor([0.00, 0.25, 0.50, 0.75, 1.00], dtype=torch.float32, device="cpu")
                        for idx_block in range(config.num_block):
                            # (num_expert,); float32; contiguous
                            current = expert_load_all_cpu[idx_block]
                            # (5,); float32; contiguous
                            current = torch.quantile(current, percentile_all, interpolation="linear")
                            wandb_log[f"expert_load/block_{idx_block}/P000"] = current[0].item()
                            wandb_log[f"expert_load/block_{idx_block}/P025"] = current[1].item()
                            wandb_log[f"expert_load/block_{idx_block}/P050"] = current[2].item()
                            wandb_log[f"expert_load/block_{idx_block}/P075"] = current[3].item()
                            wandb_log[f"expert_load/block_{idx_block}/P100"] = current[4].item()
                    elif auxfree_type == "MULTI_HEAD_NO_HP":
                        expert_load_all_cpu = config.runtime["expert_load_all"].cpu()
                        percentile_all = torch.tensor([0.00, 0.25, 0.50, 0.75, 1.00], dtype=torch.float32, device="cpu")
                        for idx_block in range(config.num_block):
                            for idx_head in range(config.ffwd_num_head):
                                # (num_expert,); float32; contiguous
                                current = expert_load_all_cpu[idx_block, idx_head]
                                # (5,); float32; contiguous
                                current = torch.quantile(current, percentile_all, interpolation="linear")
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P000"] = current[0].item()
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P025"] = current[1].item()
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P050"] = current[2].item()
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P075"] = current[3].item()
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P100"] = current[4].item()
                    elif auxfree_type == "MULTI_HEAD_WITH_HP":
                        expert_load_all_cpu = config.runtime["expert_load_all"].cpu()
                        percentile_all = torch.tensor([0.00, 0.25, 0.50, 0.75, 1.00], dtype=torch.float32, device="cpu")
                        # Define `ffwd_num_head_per_rank`
                        assert config.ffwd_num_head % world_size == 0
                        ffwd_num_head_per_rank = config.ffwd_num_head // world_size
                        for idx_block in range(config.num_block):
                            for idx_head in range(ffwd_num_head_per_rank):
                                # (num_expert,); float32; contiguous
                                current = expert_load_all_cpu[idx_block, idx_head]
                                # (5,); float32; contiguous
                                current = torch.quantile(current, percentile_all, interpolation="linear")
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P000"] = current[0].item()
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P025"] = current[1].item()
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P050"] = current[2].item()
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P075"] = current[3].item()
                                wandb_log[f"expert_load/block_{idx_block}/head_{idx_head}/P100"] = current[4].item()
                    else:
                        raise Exception("Unexpected auxfree_type")
            # ----- #


            # ----- #
            # Visualize `watchdog`
            # ----- #
            for name, module in model.named_modules():
                if hasattr(module, "watchdog"):
                    for key, value in module.watchdog.items():
                        torch.distributed.all_reduce(value, op=torch.distributed.ReduceOp.AVG)
                        name = name if name else "Model"
                        wandb_log[f"{key}/{name}"] = value.item()
            # ----- #


            # ----- #
            # Submit wandb_log
            # ----- #
            if rank == 0:
                run.log(wandb_log, step=idx_iter)
            # ----- #


            # ----- #
            # Checkpointing
            # ----- #
            condition_1 = (idx_iter + 1) % 4000 == 0
            condition_2 = idx_iter == len(dataloader_train) - 1
            if (condition_1 or condition_2) and config.ckpt_enabled:
                if rank == 0:
                    print("\n\nCheckpointing - Start\n\n")

                if rank == 0:
                    process_iter_folders(vault_path, idx_iter)
                synchronize()

                # Define `idx_rank_all`
                if has_distributed_weights:
                    idx_rank_all = list(range(world_size))
                else:
                    # Note: If using Data Parallel, we save everything on rank 0
                    idx_rank_all = [0]

                flag_no_timeout_yet = True
                for idx_rank in tqdm(idx_rank_all, disable=rank != 0):
                    if (rank == idx_rank) and flag_no_timeout_yet:
                        # Define `checkpoint_dict`, on rank `idx_rank`
                        checkpoint_dict = dict()
                        # Save `idx_iter` and `world_size`, on rank `idx_rank`
                        checkpoint_dict["idx_iter"] = idx_iter
                        checkpoint_dict["world_size"] = world_size
                        # Save `auxfree_bias_all`, on rank `idx_rank`
                        checkpoint_dict["auxfree_bias_all"] = config.runtime["auxfree_bias_all"]
                        # Save `state_dict_model`, on rank `idx_rank`
                        checkpoint_dict["state_dict_model"] = model.state_dict()
                        # Save `state_dict_optimizer_1` and `state_dict_optimizer_2`, on rank `idx_rank`
                        checkpoint_dict["state_dict_optimizer_1"] = optimizer_1.state_dict()
                        checkpoint_dict["state_dict_optimizer_2"] = optimizer_2.state_dict()  
                        # Move `checkpoint_dict` to cpu, on rank `idx_rank`
                        checkpoint_dict = to_cpu_recursive(checkpoint_dict)
                        # Save into the new iter folder, on rank `idx_rank`
                        try:
                            save_checkpoint(vault_path, checkpoint_dict, idx_iter, rank=rank)
                        except TimeoutError as e:
                            print("\n\n")
                            print(f"Checkpointing timed out on rank {rank}!")
                            print(e)
                            print("\n\n")
                            # Cancel checkpointing on all ranks
                            flag_no_timeout_yet = False
                        # Release `checkpoint_dict`, on rank `idx_rank`
                        del checkpoint_dict
                        gc.collect()
                    synchronize()
                    # Broadcast the flag from rank `idx_rank` to all ranks
                    flag_no_timeout_yet = [flag_no_timeout_yet]
                    dist.broadcast_object_list(flag_no_timeout_yet, src=idx_rank)
                    flag_no_timeout_yet = flag_no_timeout_yet[0]
                if rank == 0:
                    print("\n\nCheckpointing - End\n\n")
            # ----- #


            # ----- #
            # Clean up to end this iteration
            # ----- #
            synchronize()
            # ----- #
    # endregion
    # ----- #


    # ----- #
    # Save Training Time Estimation Results
    # ----- #
    if args.estimate_training_time and (rank == 0):
        # Create folder if not exists
        estimation_folder = "iter_time_estimations"
        os.makedirs(estimation_folder, exist_ok=True)
        # Save results
        estimation_file = os.path.join(estimation_folder, f"{config.run_name}.json")
        estimation_data = {
            "run_name": config.run_name,
            "mean_time": mean_time,
            "std_time": std_time,
            "n_samples": n,
        }
        with open(estimation_file, "w") as f:
            json.dump(estimation_data, f, indent=2)
        print(f"Training time estimation saved to: {estimation_file}")
    # ----- #


    # ----- #
    # Evaluation
    # ----- #
    # region Stage: Evaluation
    if not args.estimate_training_time:
        # Get `evaluation_results`
        evaluation_results = evaluation(config, model)
        # Present `evaluation_results` and save to file
        if dist.get_rank() == 0:
            print("\n\n\n\nEvaluation Results:")
            for k, v in evaluation_results.items():
                print(k)
                print(v)
                print("\n")
            print("\n\n\n\n")
            # Save evaluation results to vault
            eval_results_path = os.path.join(vault_path, "eval_results.txt")
            with open(eval_results_path, "w") as f:
                json.dump(evaluation_results, f, indent=2)
            print(f"Evaluation results saved to: {eval_results_path}\n")
        synchronize()
    # endregion
    # ----- #


    # ----- #
    # Clean up to end this training run
    # ----- #
    # End the wandb run
    if rank == 0:
        run.finish()
    # Ensure all processes finish
    synchronize()
    # Final clean-up
    dist.destroy_process_group()
    # ----- #
