"""
This script is originally adapted from and inspired by the tinyllama.py and 
redpajama.py scripts in the lit-gpt/pretrain directory.

The lit-gpt authors designed this such that setup -> train reads ~linearly.
"""

####################################################################################################
# Imports.
####################################################################################################

import time

global_start_time = time.time()
import math
import os

from functools import partial
from pathlib import Path
from typing import Tuple, Union
from contextlib import nullcontext
import json
import random

import lightning as L
import torch
import torch.nn as nn
from lightning.fabric.loggers import Logger, CSVLogger
from lightning.fabric.strategies import FSDPStrategy, DDPStrategy, SingleDeviceStrategy
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from torchmetrics.aggregation import RunningMean

torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)

if int(os.getenv("SLURM_PROCID", "0")) == 0:
    print("Torch SDPA settings in model.py:")
    print("torch.backends.cuda.flash_sdp_enabled(): ", torch.backends.cuda.flash_sdp_enabled())
    print("torch.backends.cuda.mem_efficient_sdp_enabled(): ", torch.backends.cuda.mem_efficient_sdp_enabled())
    print("torch.backends.cuda.math_sdp_enabled(): ", torch.backends.cuda.math_sdp_enabled())

from axonn_fabric.fabric import AxoNNFabric  # local fabric wrapper
import axonn  # package

from litgpt.settings import CLISettings
from litgpt import optim
from litgpt.model import CausalSelfAttention, LLaMAMLP
from litgpt.retrieval_lm import PrefixNet

from litgpt.tokenizer import Tokenizer
from litgpt.packed_cycle_dataset import CombinedDataset, PackedDataset
from litgpt.huggingface_dataset import HuggingfaceDataset
from litgpt.data_loading_utils import generic_collate_fn
import litgpt.utils
from litgpt.data_scheduler_utils import DataSchedulerTracker, DataScheduler
from litgpt.doc_block_utils import get_ltor_masks_and_position_ids, get_cache_attn_masks

from litgpt.multiple_negative_ranking_loss import MultipleNegativesRankingLoss

from dataclasses import asdict, is_dataclass
from jsonargparse import CLI
import re

from transformers import AutoModelForCausalLM, AutoConfig
from axonn.models.transformers import parallelize

end_time = time.time()
if int(os.getenv("SLURM_PROCID", "0")) == 0:
    print(f"Time to load libraries: {end_time - global_start_time:.02f} seconds.")

####################################################################################################
# Setup functions.
####################################################################################################


def divide(a, b):
    assert a % b == 0, f"{a} is not divisible by {b}"
    return a // b


def set_torch_flags(cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    # Do they AMD cards pick up on any of this? :
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True  # Should be true anyway


def setup_fabric(cfg: CLISettings) -> Tuple[L.Fabric, Union[Logger, WandbLogger]]:
    """Sets up the fabric and logger based on the cfg."""
    # Instantiate the logger.
    logger = choose_logger(
        logger_name=cfg.logger_name,
        project=cfg.logger_project,
        name=cfg.run_name,
        resume=cfg.resume,
        save_dir=cfg.out_dir,
    )
    # Instantiate the fabric.
    if cfg.fabric_strategy == "axonn_tp":
        grid_setup = (
            [1, 1, cfg.tensor_parallel_size]
            if cfg.fabric["tensor_parallel_grid"] is None
            else cfg.fabric["tensor_parallel_grid"]
        )
        fabric = AxoNNFabric(
            tensor_parallel_grid=grid_setup,
            precision=cfg.fabric_precision,
            loggers=[logger],
            depth_first=cfg.fabric["depth_first"],
            all_gather_dtype=cfg.fabric["all_gather_dtype"],
            reduce_scatter_dtype=cfg.fabric["all_reduce_dtype"],
        )
        batch_size_per_gpu = divide(cfg.world_batch_size, fabric.global_world_size_for_creating_dataloader)
        cfg.micro_batch_size = divide(batch_size_per_gpu, cfg.gradient_accumulation_steps)
        fabric.print(f"Using AxoNNFabric with tensor parallelism = 1x1x{cfg.tensor_parallel_size}")
        fabric.launch()
    else:
        if cfg.fabric_strategy == "fsdp":
            precision_strategy = derive_precision(cfg.fabric_precision, cfg.fabric)
            strategy = FSDPStrategy(
                auto_wrap_policy={cfg.model_config.Block},
                mixed_precision=precision_strategy,
                activation_checkpointing_policy={cfg.model_config.Block} if cfg.gradient_checkpointing else None,
                state_dict_type="full",
                sharding_strategy="HYBRID_SHARD",  # choose FULL_SHARD if oom
                param_init_fn=(
                    (lambda x: x.to_empty(device=fabric.device, recurse=False))
                    if cfg.model_impl == "huggingface"
                    else None
                ),
            )
        elif cfg.fabric_strategy == "ddp":
            strategy = DDPStrategy()
        elif cfg.fabric_strategy == "single":
            strategy = SingleDeviceStrategy(
                device=torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
            )
        else:
            raise ValueError(f"`fabric_strategy={cfg.fabric_strategy}` is not a valid option.")

        # Instantiate and launch/initialize the fabric distributed environment management.
        fabric = L.Fabric(
            devices=cfg.devices,
            strategy=strategy,
            precision=cfg.fabric_precision,
            loggers=[logger],
            num_nodes=cfg.num_nodes,
        )
        fabric.global_rank_for_creating_dataloader = fabric.global_rank
        fabric.global_world_size_for_creating_dataloader = fabric.world_size
        fabric.get_prefix_for_checkpoint = lambda: f"checkpoints-{cfg.fabric_strategy}"
        fabric.optimize_communication = lambda *args, **kwargs: nullcontext
        fabric.print(f"Using Lightning Fabric with strategy {cfg.fabric_strategy} ")
        fabric.launch()

    fabric.print(f"> global_batch_size = {cfg.world_batch_size}")
    fabric.print(f"> gradient_accumulation_steps = {cfg.gradient_accumulation_steps}")
    fabric.print(f"> micro_batch_size = {cfg.micro_batch_size}")
    fabric.print(f"> global_world_size_for_creating_dataloader = {fabric.global_world_size_for_creating_dataloader}")

    return fabric


####################################################################################################
# Main driver functions.
####################################################################################################


def startup(fabric: L.Fabric, cfg: CLISettings):
    """The main driver function for the training script."""
    start_time = time.time()

    # Get job remaining time
    if cfg.save_n_min_before_job_done is not None:
        if fabric.global_rank == 0:
            global_total_time = _get_time_from_slurm()
            fabric.print(f"Total job time: {global_total_time:.02f} seconds.")
        else:
            global_total_time = None

        global_total_time = fabric.broadcast(global_total_time, 0)  # does this have to be a broadcast?
        cfg.global_total_time = global_total_time

    # Prepare directories for logging
    if fabric.global_rank == 0:
        Path(cfg.out_dir).mkdir(parents=True, exist_ok=True)
        (Path(cfg.out_dir) / fabric.get_prefix_for_checkpoint()).mkdir(parents=True, exist_ok=True)
        # Last step before we move on is to dump the cfg to a file in the out_dir.
        # This is is itself loadable as a config by passing like train.py --config run_config.json
        with open(f"{cfg.out_dir}/run_config.json", "w") as f:
            json.dump(asdict(cfg), f, indent=4)
        with open(f"{cfg.out_dir}/model_config.json", "w") as f:
            json.dump(asdict(cfg.model_config) if is_dataclass(cfg.model_config) else cfg.model_config, f, indent=4)
    # Load tokenizer
    tokenizer = Tokenizer(cfg.tokenizer_path)
    if tokenizer.pad_id is None:
        tokenizer.pad_id = -1
    if cfg.cache_attn:
        assert tokenizer.cache_token_id is not None
    else:
        tokenizer.cache_token_id = None
    if cfg.doc_block_attn:
        assert tokenizer.eod_token_id is not None
    else:
        tokenizer.eod_token_id = None

    # Validate science expt flags
    science_startup(fabric, cfg)

    # Create data objects
    t0 = time.time()
    # On block size, moved this here to be more explicit that this is happening ...
    if not cfg.ignore_block_size_mismatch:
        assert cfg.block_size == cfg.model_config.block_size, "cfg.block_size must match config.block_size"
    # Increase by one to actually be supervising "block_size" tokens in every update after rshift.
    cfg.loader_block_size = cfg.block_size + 1 if cfg.max_seq_len is None else cfg.max_seq_len
    train_dataloader, val_dataloader, data_scheduler_tracker = create_dataloaders(
        batch_size=cfg.micro_batch_size,
        block_size=cfg.loader_block_size,
        fabric=fabric,
        seed=(cfg.seed + fabric.global_rank_for_creating_dataloader),
        cfg=cfg,
        tokenizer=tokenizer,
    )
    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
    data_scheduler = DataScheduler(data_scheduler_tracker, cfg.data_config["train_data"], cfg)
    data_scheduler.step(0)
    fabric.print(f"Time to instantiate and setup dataloaders: {time.time() - t0:.02f} seconds.")

    # Construct the model
    fabric.seed_everything(cfg.seed)  # same seed for every process to init model (FSDP)
    if cfg.model_checkpoint is not None:
        litgpt.utils.check_valid_checkpoint_dir(Path(cfg.model_checkpoint))
    fabric.print(f"Loading model with {cfg.model_config.__dict__}")

    # Set the objective
    objective = MultipleNegativesRankingLoss(loss_type=cfg.loss_type, reduce=cfg.reduce)

    # Initialize the model
    t0 = time.time()
    with fabric.init_module(empty_init=cfg.fabric_strategy == "fsdp"):
        if cfg.model_impl in ["litgpt", "dynamic"]:
            model = PrefixNet(
                    cfg.model_config, 
                    objective=objective, 
                    gradient_checkpointing=cfg.gradient_checkpointing and cfg.fabric_strategy != "fsdp", 
                    checkpoint_dir=cfg.model_checkpoint,
                    keep_eos=cfg.keep_eos,)
            # model = cfg.model_config.construct_model(
            #     objective=objective, gradient_checkpointing=cfg.gradient_checkpointing and cfg.fabric_strategy != "fsdp"
            # )
        elif cfg.model_impl == "huggingface":
            source = cfg.model_checkpoint or cfg.model_name
            with parallelize(source) if cfg.fabric_strategy == "axonn_tp" else nullcontext():
                model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(source))

    if cfg.compile_model:
        model = torch.compile(model, dynamic=False)  # error on dynamic shape
    fabric.print(f"Time to instantiate model: {time.time() - t0:.02f} seconds.")
    fabric.print(f"Total parameters: {litgpt.utils.num_parameters(model):,}")
    # With fabric and the model up, we can compute a few last derived cfg
    if cfg.max_iters is None:
        cfg.max_tokens_per_device = cfg.max_tokens // fabric.world_size
        cfg.tokens_per_iter = cfg.micro_batch_size * cfg.block_size
        cfg.max_iters = cfg.max_tokens_per_device // cfg.tokens_per_iter

    # Set up the final fabric+model details
    t0 = time.time()
    model = fabric.setup(model)
    fabric.print(f"Model with full setup is {model}")
    fabric.print(f"Time to setup model: {time.time() - t0:.02f} seconds.")

    t0 = time.time()
    # Set up the optimizer and training state object.
    param_groups = optim.get_param_groups(model.named_parameters(), cfg.no_weight_decay_for_bias_and_norm_params)
    if cfg.optimizer == "SlimAdamW":
        cfg.optim_config["model_object"] = model
    optimizer = optim.get_optimizer(cfg.optimizer)(param_groups, **cfg.optim_config, foreach=False)
    if cfg.optimizer == "SlimAdamW":
        del cfg.optim_config["model_object"] # this messes with a print statement later
    optimizer = fabric.setup_optimizers(optimizer)
    fabric.print(f"Time to instantiate and setup optimizers: {time.time() - t0:.02f} seconds.")

    state = {
        "model": model,
        "optimizer": optimizer,
        "tokenizer": tokenizer,
        "data_scheduler": data_scheduler,
        "microbatch_step": 0,  # mbs steps
        "optimizer_step": 0,  # optimizer updates taken
        "real_token_count": torch.zeros(fabric.world_size, dtype=torch.long),  # tracking each rank's token count (non-padding actual tokens seen)
    }

    t0 = time.time()
    # If resuming, determine the checkpoint to resume from.
    resume_ckpt = load_checkpoint(
        fabric, cfg, state, cfg.out_dir, cfg.run_name, cfg.model_checkpoint, cfg.model_impl, cfg.resume
    )
    fabric.print(f"Time to load model checkpoint: {time.time() - t0:.02f} seconds.")

    # Report the full cfg set for the run.
    fabric.print(f"cmdline + derived cfg:\n{json.dumps(cfg.__dict__, default=lambda x:x.__dict__, indent=4)}")
    if cfg.logger_name in ("tensorboard", "wandb") and fabric.global_rank == 0:
        fabric.logger.log_hyperparams(cfg.__dict__)

    end_time = time.time()
    fabric.print(f"Total time to run main func setups: {end_time - start_time:.02f} seconds.")

    return state, train_dataloader, val_dataloader, data_scheduler, resume_ckpt


@torch.no_grad()
def validate(
    fabric: L.Fabric,
    model: nn.Module,
    val_dataloader: DataLoader,
    max_iters: int,
    tokenizer: Tokenizer,
    cfg,
) -> torch.Tensor:
    if val_dataloader is None:
        return torch.as_tensor(float("-Inf"))
    fabric.print(f"Validating for {max_iters} steps ...")
    model.eval()

    losses = torch.zeros(max_iters, device=fabric.device)
    accuracies = torch.zeros(max_iters, device=fabric.device)
    for k, val_data in enumerate(val_dataloader):
        if k >= max_iters:
            break
        if isinstance(val_data, dict):
            prefix_input_ids, suffix_input_ids = val_data["prefix_input"], val_data["suffix_input"]
        else:
            prefix_input_ids, suffix_input_ids = val_data, None
        _, seq_len = prefix_input_ids.shape
        prefix_input_ids = prefix_input_ids[:, 0 : (seq_len - 1)].contiguous().long()
        # for the input we need to replace any pad ids with the eos token
        # knowing that they're trailing, and wont contrib to activations/loss
        # but that they need to be valid indices in the model's embedding layer
        prefix_attn_mask = (prefix_input_ids != tokenizer.pad_id)
        prefix_input_ids[prefix_input_ids == tokenizer.pad_id] = tokenizer.eos_id

        prefix_input_ids = prefix_input_ids.to(fabric.device, non_blocking=True)
        mask, positions = get_attention_mask(prefix_input_ids, tokenizer, cfg.cache_attn, cfg.doc_block_attn)

        suffix_attn_mask = None
        if suffix_input_ids is not None:
            _, seq_len = suffix_input_ids.shape
            suffix_input_ids = suffix_input_ids[:, 0 : (seq_len - 1)].contiguous().long()
            suffix_attn_mask = (suffix_input_ids != tokenizer.pad_id)
            suffix_input_ids[suffix_input_ids == tokenizer.pad_id] = tokenizer.eos_id
            suffix_input_ids = suffix_input_ids.to(fabric.device, non_blocking=True)

        with fabric.optimize_communication(model=model, enabled=cfg.fabric["optimize_communication"])():
            outputs = model(
                (prefix_input_ids, suffix_input_ids), attn_mask=(prefix_attn_mask, suffix_attn_mask), position_ids=positions, output_hidden_states=True
            )
            axonn.intra_layer.clear_weights_cache() if cfg.fabric_strategy == "axonn_tp" else None
        losses[k] = outputs["loss"] if not outputs["loss"] is None else -1.0
        accuracies[k] = outputs["accuracy"] if not outputs["accuracy"] is None else -1.0

    model.train()
    return losses.mean(), accuracies.mean()


def train_step(train_data, fabric, state, running_loss, running_accuracy, cfg):
    """Separate scope for a single train step, encapsulating the part that is actual work"""
    # Do some checks on the val loop and the throughput of the model.
    model = state["model"]
    optimizer = state["optimizer"]
    data_scheduler = state["data_scheduler"]
    tokenizer = state["tokenizer"]

    # switch NEPTune on/off based on config before iter++
    if cfg.neptune_noise_alpha:
        total_tokens_seen = state["microbatch_step"] * cfg.micro_batch_size * cfg.block_size * fabric.world_size
        if cfg.neptune_from_tokens <= total_tokens_seen <= cfg.neptune_till_tokens:
            model.config.neptune_noise_alpha = cfg.neptune_noise_alpha
        else:
            model.config.neptune_noise_alpha = None

    state["microbatch_step"] += 1

    if isinstance(train_data, dict):
        prefix_input_ids, suffix_input_ids = train_data["prefix_input"], train_data["suffix_input"]
    else:
        prefix_input_ids, suffix_input_ids = train_data, None

    # Realize the input and labels tensors.
    bsz, seq_len = prefix_input_ids.shape
    prefix_input_ids = prefix_input_ids[:, 0 : (seq_len - 1)].contiguous().long()
    # Count non-padding tokens
    state["real_token_count"][fabric.global_rank] += int((prefix_input_ids != tokenizer.pad_id).sum().item())
    # for the input we need to replace any pad ids with the eos token
    # knowing that they're trailing so they wont contrib to activations
    # but that they do need to be valid indices in the model's embedding layer
    prefix_attn_mask = (prefix_input_ids != tokenizer.pad_id)
    prefix_input_ids[prefix_input_ids == tokenizer.pad_id] = tokenizer.eos_id
    prefix_input_ids = prefix_input_ids.to(fabric.device, non_blocking=True)
    mask, positions = get_attention_mask(prefix_input_ids, tokenizer, cfg.cache_attn, cfg.doc_block_attn)

    suffix_attn_mask = None
    if suffix_input_ids is not None:
        bsz, seq_len = suffix_input_ids.shape
        suffix_input_ids = suffix_input_ids[:, 0 : (seq_len - 1)].contiguous().long()
        # Count non-padding tokens
        state["real_token_count"][fabric.global_rank] += int((suffix_input_ids != tokenizer.pad_id).sum().item())
        suffix_attn_mask = (suffix_input_ids != tokenizer.pad_id)
        suffix_input_ids[suffix_input_ids == tokenizer.pad_id] = tokenizer.eos_id
        suffix_input_ids = suffix_input_ids.to(fabric.device, non_blocking=True)

    if state["microbatch_step"] < cfg.shape_watching_iters:
        fabric.print(f"bsz: {bsz} | seq_len: {seq_len}")
        fabric.print(f"prefix_input_ids.shape: {prefix_input_ids.shape}")
        fabric.print(f"suffix_input_ids.shape: {None if suffix_input_ids is None else suffix_input_ids.shape}")
    elif state["microbatch_step"] == cfg.shape_watching_iters and cfg.shape_watching_iters > 0:
        fabric.print("Silencing shape watching ...")

    # Forward, loss, and backward computation.
    is_accumulating = state["microbatch_step"] % cfg.gradient_accumulation_steps != 0

    with fabric.optimize_communication(
        model=model.module,
        enabled=cfg.fabric["optimize_communication"],
        opt_level=cfg.fabric["opt_level"],
    )():
        with fabric.no_backward_sync(model, enabled=is_accumulating):
            with torch.autograd.set_detect_anomaly(True):
                outputs = model(
                    (prefix_input_ids, suffix_input_ids), attn_mask=(prefix_attn_mask, suffix_attn_mask), position_ids=positions, output_hidden_states=True
                )
                # h1 = model(prefix_input_ids, position_ids=positions, output_hidden_states=True)
                # h2 = model(suffix_input_ids, position_ids=positions, output_hidden_states=True)
                # breakpoint()
                # objective = MultipleNegativesRankingLoss(loss_type=cfg.loss_type, reduce=cfg.reduce, k_diag=cfg.mask_k_diags, drop_k=cfg.drop_k, k_pos_labels=cfg.k_pos_labels, decay_factor=cfg.decay_factor)
                # outputs = objective.get_cross_batch_negative_loss(h1, h2)
                fabric.backward(outputs["loss"] / cfg.gradient_accumulation_steps)

    running_loss.update(outputs["loss"].detach())
    running_accuracy.update(outputs["accuracy"].detach())

    # Take an optimization step if not accumulating.
    if not is_accumulating:
        # checking nan in gradients
        for name, param in model.named_parameters():
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                fabric.print(f"param {name} is nan or inf")
        grad_norm = fabric.clip_gradients(model, optimizer, max_norm=cfg.grad_clip).detach()
        optimizer.step()
        optimizer.zero_grad(set_to_none=cfg.fabric_strategy != "axonn_tp")
        axonn.intra_layer.clear_weights_cache() if cfg.fabric_strategy == "axonn_tp" else None
        state["optimizer_step"] += 1
        # Update learning rate (post-increment since we init it before the first step).
        next_step_lr = get_lr(it=state["microbatch_step"], lr_decay_iters=cfg.max_iters, cfg=cfg)
        for param_group in optimizer.param_groups:
            param_group["lr"] = next_step_lr
        data_scheduler.step(state["optimizer_step"])
    else:
        grad_norm = None
        next_step_lr = None
    return grad_norm, next_step_lr, is_accumulating


def train(fabric, state, train_dataloader, val_dataloader, resume_ckpt=None, cfg=None, data_scheduler=None):
    """The main training loop."""

    if cfg.sanity_validate:
        validate(fabric, state["model"], val_dataloader, max_iters=2, tokenizer=state["tokenizer"], cfg=cfg)

    initial_iter = state["microbatch_step"]
    train_iterator = iter(train_dataloader)

    # Resume data loader state by fast-forwarding through all seen batches.
    # If we migrate to the streaming dataset in future, we might not need this.
    if resume_ckpt:
        resume_t0 = time.time()
        for resume_iter in range(initial_iter):
            next(train_iterator)
            if resume_iter % 1000 == 0:
                fabric.print(f"Resuming dataset: {resume_iter} / {initial_iter}")

            data_scheduler.step(resume_iter + 1)

        fabric.barrier()
        fabric.print(f"Resuming data loader finished. Took {time.time() - resume_t0:.1f} seconds to reach iteration")

    # Set up global loss monitor.
    running_loss = RunningMean(window=cfg.gradient_accumulation_steps, sync_on_compute=False).to(fabric.device)
    running_accuracy = RunningMean(window=cfg.gradient_accumulation_steps, sync_on_compute=False).to(fabric.device)
    fabric.barrier()
    total_t0 = time.time()

    lr = get_lr(it=state["microbatch_step"], lr_decay_iters=cfg.max_iters, cfg=cfg)
    for param_group in state["optimizer"].param_groups:
        param_group["lr"] = lr

    # Main training loop.
    for train_data in train_iterator:
        # Main work
        iter_t0 = time.time()
        grad_norm, next_step_lr, is_accumulating = train_step(train_data, fabric, state, running_loss, running_accuracy, cfg=cfg)
        # Log at an interval.
        if state["microbatch_step"] % cfg.log_iter_interval == 0:
            log_iter(
                fabric,
                state,
                running_loss,
                running_accuracy,
                next_step_lr,
                initial_iter,
                total_t0,
                iter_t0,
                grad_norm,
                is_accumulating,
                data_scheduler,
                cfg,
            )

        # Maybe validate
        validate_regular = not is_accumulating and state["optimizer_step"] % cfg.eval_step_interval == 0
        validate_at_the_end = state["microbatch_step"] >= cfg.max_iters - 1
        if validate_regular or validate_at_the_end:
            t0 = time.time()
            val_loss, val_accuracy = validate(fabric, state["model"], val_dataloader, cfg.eval_iters, state["tokenizer"], cfg=cfg)
            val_loss = val_loss.item()
            val_accuracy = val_accuracy.item()
            td = time.time() - t0

            fabric.print(f"iter {state['microbatch_step']}: val loss {val_loss:.4f}, val accuracy {val_accuracy:.4f}, val time: {td * 1000:.2f} ms")
            metrics = {
                "val_loss": val_loss,
                "val_accuracy": val_accuracy,
                "optimizer_step": state["optimizer_step"],
            }
            fabric.log_dict(metrics, step=state["microbatch_step"])
            fabric.barrier()
        # Maybe save
        maybe_save_checkpoint(fabric, state, cfg, is_accumulating=is_accumulating)

        if state["microbatch_step"] >= cfg.max_iters - 1:
            break


####################################################################################################
# Train loop sub-routines.
####################################################################################################


def log_iter(
    fabric: L.Fabric = None,
    state: dict = None,
    running_loss: RunningMean = None,
    running_accuracy: RunningMean = None,
    lr: float = None,
    initial_iter: int = None,
    total_t0: float = None,
    iter_t0: float = None,
    grad_norm: float = None,
    is_accumulating: bool = None,
    data_scheduler: dict = None,
    cfg: CLISettings = None,
):
    """Log the iteration and compute the throughput."""
    loss = running_loss.compute().item()  # expensive device-to-host synchronization # NOTE not sure how true this is.
    accuracy = (
        running_accuracy.compute().item()
    )  # expensive device-to-host synchronization # NOTE not sure how true this is.
    t1 = time.time()

    # Log the metrics.
    metrics = {
        "loss": loss,
        "accuracy": accuracy,
        "microbatch_step": state["microbatch_step"],
        "optimizer_step": state["optimizer_step"],
        "iter_time": t1 - iter_t0,
        "remaining_time": (
            (t1 - total_t0) / (state["microbatch_step"] - initial_iter) * (cfg.max_iters - state["microbatch_step"])
        ),
        "tokens": state["microbatch_step"] * cfg.micro_batch_size * cfg.block_size,
        "total_tokens": state["microbatch_step"] * cfg.micro_batch_size * cfg.block_size * fabric.world_size,
        "learning_rate": lr,
        "max_iters": cfg.max_iters,
        "grad_norm": grad_norm,
    }

    # Update loss and grad_norm with all_reduce
    # FIXME _these_ could be expensive if the topo is large, so do we need to always report
    # world reduced loss or is rank-local loss sufficient? Maybe add a flag option.
    loss = fabric.all_reduce(loss)
    accuracy = fabric.all_reduce(accuracy)
    grad_norm = fabric.all_reduce(grad_norm)
    real_token_count = fabric.all_reduce(state["real_token_count"].clone(), reduce_op="sum").sum().item()

    metrics["loss"] = loss
    metrics["accuracy"] = accuracy
    metrics["grad_norm"] = grad_norm
    metrics["real_token_count"] = real_token_count
    metrics["real_token_count"] = real_token_count

    if data_scheduler is not None:
        curr_data_weights = data_scheduler.get_data_weights()
        curr_data_weights = dict(zip(cfg.dataset_names, curr_data_weights))

        curr_sample_count = data_scheduler.get_sample_count()
        curr_sample_count = fabric.all_reduce(curr_sample_count, reduce_op="sum")

        curr_epoch_count = data_scheduler.get_epoch_count()
        curr_epoch_count = fabric.all_reduce(curr_epoch_count, reduce_op="mean")

        for i, x in enumerate(curr_data_weights.keys()):
            metrics["data_scheduler_weight/" + x] = curr_data_weights[x]
            metrics["data_scheduler_norm_weight/" + x] = curr_data_weights[x] / sum(list(curr_data_weights.values()))

            metrics["data_scheduler_sample_count/" + x] = curr_sample_count[i]
            metrics["data_scheduler_epoch_count/" + x] = curr_epoch_count[i]

    fabric.log_dict(metrics, step=state["microbatch_step"])

    # Log some metrics to the console.
    fabric.print(
        f" Iteration {metrics['microbatch_step']} | Optim step {metrics['optimizer_step']}: loss {metrics['loss']:.4f}, accuracy {metrics['accuracy']:.4f}, iter time:"
        f" {metrics['iter_time'] * 1000:.2f} ms{' (optimizer.step),' if not is_accumulating else ','}"
        f" remaining time: {metrics['remaining_time'] / 3600 / 24:.2f} days"
        f" XXXX-13 iters: {metrics['max_iters']}"
        f" grad norm: {metrics['grad_norm']:.4f}"
        f" learning rate: {metrics['learning_rate']:2.4e}"
        f" total tokens: {metrics['total_tokens']}"
        f" real token count: {metrics['real_token_count']}"
    )
    pass


####################################################################################################
# Data utility functions.
####################################################################################################


def create_dataloader(
    data_config: dict,
    batch_size: int,
    block_size: int,
    n_chunks: int,
    data_dir: Path,
    fabric: L.Fabric,
    shuffle: bool = True,
    seed: int = 1337,
    cfg: CLISettings = None,
    tokenizer: Tokenizer = None,
) -> DataLoader:
    global_data_dir = data_dir
    datasets = []
    for curr_config in data_config:

        if curr_config["type"] == "hfds":
            assert tokenizer is not None, "tokenizer must be provided for HuggingfaceDataset"
            assert "data_dir" in curr_config, "data_dir must be provided for HuggingfaceDataset"
            dataset = HuggingfaceDataset(
                ds_name_or_path=curr_config["data_dir"],  # this is a path to a previously save_to_disk'd hfds
                seed=seed,
                shuffle=shuffle,
                num_processes=fabric.global_world_size_for_creating_dataloader,
                process_rank=fabric.global_rank_for_creating_dataloader,
                shortname=curr_config["prefix"],  # this is provided for logging, and schedule purposes
                text_key=curr_config.get("text_key", cfg.text_key),  # key for the field in dataset to return
                repetitions=curr_config.get("repetitions"),  # repeat the dataset a number of times
            )

        elif curr_config["type"] == "pkds":
            prefix = curr_config["prefix"]

            if "data_dir" in curr_config:
                data_dir = curr_config["data_dir"]
            else:
                data_dir = global_data_dir

            if fabric.global_rank == 0:
                filenames = [str(pth) for pth in sorted(Path(data_dir).glob(f"{prefix}*"))]
                if cfg.shuffle_filenames:
                    random.seed(seed)
                    random.shuffle(filenames)  # inplace
                if not filenames:
                    raise FileNotFoundError(f"No files found at {str(data_dir)} with prefix {prefix}.")
            else:
                filenames = None

            filenames = fabric.broadcast(filenames, 0)  # this is a blocking op from rank 0 to all other ranks

            # log after broadcast so we know we passed it.
            if fabric.global_rank == 0:
                num_processes = (fabric.global_world_size_for_creating_dataloader,)
                process_rank = (fabric.global_rank_for_creating_dataloader,)
                fabric.print(
                    f"Rank ({process_rank}/{num_processes}) glob'd {len(filenames)} files"
                    f" from {data_dir}{f' w/ prefix {prefix}' if prefix not in ['','*'] else ''},"
                    f" files[:3]: {filenames[:3]}"
                )

            dataset = PackedDataset(
                filenames,
                n_chunks=n_chunks,
                block_size=block_size,
                shuffle=shuffle,
                seed=seed,
                num_processes=fabric.global_world_size_for_creating_dataloader,
                process_rank=fabric.global_rank_for_creating_dataloader,
                shortname=prefix,
            )
        elif curr_config["type"] == "rngds":
            # Debugging option
            generator = torch.Generator()
            generator.manual_seed(seed)
            dataset = torch.randint(
                0,
                tokenizer.vocab_size,
                (int(1e6), block_size),
                dtype=torch.int32,
                generator=generator,
            )
        else:
            raise ValueError(f"Unsupported dataset type: {curr_config['type']}")

        datasets.append(dataset)

    if not datasets:
        raise RuntimeError(
            f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
        )

    weights = [curr_config["weight"] for curr_config in data_config]
    data_scheduler_tracker = DataSchedulerTracker(weights)

    combined_dataset = CombinedDataset(
        datasets=datasets, seed=seed, data_scheduler_tracker=data_scheduler_tracker, data_telemetry=cfg.data_telemetry
    )

    parametrized_collate_fn = partial(
        generic_collate_fn,
        tokenizer=tokenizer,
        block_size=cfg.loader_block_size,
        pad_to_block_size=cfg.pad_to_block_size,
        add_bos=cfg.add_bos,
        add_eos=cfg.add_eos,
        collate_checks_enabled=cfg.collate_checks_enabled,
        all_block_size_tensors=cfg.all_block_size_tensors,
    )

    return (
        DataLoader(
            combined_dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            collate_fn=parametrized_collate_fn,
            num_workers=cfg.dataloader_num_workers,
        ),
        data_scheduler_tracker,
    )


def create_dataloaders(
    batch_size: int,
    block_size: int,
    fabric: L.Fabric,
    seed: int = 1337,
    cfg: CLISettings = None,
    tokenizer: Tokenizer = None,
) -> Tuple[DataLoader, DataLoader]:

    fabric.print(f"Creating dataloaders with seed: {seed}")
    train_dataloader, data_scheduler_tracker = create_dataloader(
        cfg.data_config["train_data"],
        batch_size=batch_size,
        block_size=block_size,
        n_chunks=cfg.n_chunks,
        fabric=fabric,
        data_dir=cfg.train_data_dir,
        shuffle=True,
        seed=seed,
        cfg=cfg,
        tokenizer=tokenizer,
    )
    val_dataloader, _ = (
        create_dataloader(
            cfg.data_config["val_data"],
            batch_size=batch_size,
            block_size=block_size,
            n_chunks=cfg.n_chunks,
            fabric=fabric,
            data_dir=cfg.val_data_dir,
            shuffle=False,
            seed=seed,
            cfg=cfg,
            tokenizer=tokenizer,
        )
        if "val_data" in cfg.data_config
        else (None, None)
    )
    return train_dataloader, val_dataloader, data_scheduler_tracker


####################################################################################################
# Train utility functions.
####################################################################################################


def derive_precision(precision, strategy_details):
    """ "Precision setup for torch fsdp"""
    import torch.distributed.fsdp

    param_dtype = torch.float16 if "fp16" in precision else torch.bfloat16 if "bf16" in precision else torch.float32
    reduce_dtype = torch.float32 if "mixed" in precision else param_dtype
    if reduc := strategy_details["all_reduce_dtype"] is not None:
        reduce_dtype = torch.float16 if "fp16" in reduc else torch.bfloat16 if "bf16" in reduc else torch.float32
    return torch.distributed.fsdp.MixedPrecision(
        param_dtype=param_dtype,
        reduce_dtype=reduce_dtype,
        buffer_dtype=torch.float32,
        keep_low_precision_grads=False,
        cast_forward_inputs=False,
    )


def get_attention_mask(input_ids, tokenizer, cache_attn=True, doc_block_attn=True):
    mask, position_ids = None, None
    if doc_block_attn:
        mask, position_ids = get_ltor_masks_and_position_ids(
            input_ids, tokenizer.eod_token_id, reset_position_ids=True, reset_attention_mask=True
        )
    elif cache_attn:
        mask, position_ids = get_cache_attn_masks(
            input_ids, tokenizer.cache_token_id, reset_position_ids=True, reset_attention_mask=True
        )
    return mask, position_ids


def science_startup(fabric, cfg):
    # Set NEPTune config and print behavior to be expected for this run.
    if cfg.neptune_noise_alpha:
        cfg.model_config.neptune_noise_alpha = cfg.neptune_noise_alpha
        start_from = "start of training" if cfg.neptune_from_tokens == 0 else str(cfg.neptune_from_tokens) + " tokens"
        end_on = (
            "end of training."
            if cfg.neptune_till_tokens == cfg.max_tokens
            else str(cfg.neptune_from_tokens) + " tokens are seen."
        )
        fabric.print(f"NEPTune will be used from {start_from} till {end_on}")
    else:
        fabric.print("NEPTune is NOT used for this run.")

    if cfg.k_token_loss_dropout is not None:
        fabric.print(f"TLD will be used with k={cfg.k_token_loss_dropout}.")
        fabric.print(f"Every {cfg.k_token_loss_dropout}-th token will be dropped from loss.")
    else:
        fabric.print("TLD is NOT used for this run.")


# learning rate decay scheduler (cosine with warmup)
def get_lr(it: int, lr_decay_iters: int, cfg: CLISettings) -> float:
    base_lr = cfg.optim_config["lr"]
    # 1) linear warmup for warmup_iters steps
    if it < cfg.warmup_iters:
        return base_lr * it / cfg.warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return cfg.min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - cfg.warmup_iters) / (lr_decay_iters - cfg.warmup_iters)
    assert 0 <= decay_ratio <= 1
    if cfg.lr_schedule == "linear":
        return base_lr - decay_ratio * (base_lr - cfg.min_lr)
    elif cfg.lr_schedule == "constant":
        return base_lr
    elif cfg.lr_schedule == "cosine":
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
        return cfg.min_lr + coeff * (base_lr - cfg.min_lr)
    else:
        raise ValueError(f"Unsupported lr_schedule: {cfg.lr_schedule}")


def load_checkpoint(fabric, cfg, state, out_dir, run_name, model_checkpoint, model_impl="litgpt", resume=True):
    resume_ckpt = None
    if resume:
        ckpt_paths = list((Path(out_dir) / fabric.get_prefix_for_checkpoint()).glob(f"*-{run_name}.pth"))
        if len(ckpt_paths) > 0:
            resume_ckpt = XXXX-13(
                ckpt_paths,
                key=(lambda p: int(p.name.split("-")[1].split(f"-{run_name}.pth")[0])),
            )
            fabric.print(f"Resuming training from {resume_ckpt}")
            fabric.load(resume_ckpt, state)
            # HACK: Making sure all ranks starts with correct token count and zeroing out all ranks except the current one
            mask = torch.zeros_like(state["real_token_count"], dtype=torch.bool)
            mask[fabric.global_rank] = True
            state["real_token_count"] *= mask # Zeroing out all ranks except the current one

    if resume_ckpt is None and model_checkpoint is not None:
        if cfg.finetune_checkpoint:
            checkpoint_path = cfg.finetune_checkpoint
        elif model_impl == "litgpt":
            checkpoint_path = f"{model_checkpoint}/lit_model.pth"
        elif model_impl == "dynamic":
            checkpoint_path = f"{model_checkpoint}/lit_model_dynamic.pth"
        else:
            raise ValueError(f"Invalid checkpoint loader for model implementation {model_impl}.")
        fabric.print(f"Loading pretrained model checkpoint from {checkpoint_path}")
        litgpt.utils.load_checkpoint(fabric, state["model"].prefix_model, checkpoint_path, strict=False)
    for name, param in state["model"].prefix_model.named_parameters():
        fabric.print(name, param.size(), param.dtype)
        fabric.print(param)
        break
    if resume_ckpt is None and not cfg.pretrained_prefix_model:
        fabric.print(f"initializing *Prefix Model* with random weights")
        state["model"].prefix_model.apply(partial(init_weights, n_layer=cfg.model_config.n_layer, n_embd=cfg.model_config.n_embd))
    for name, param in state["model"].prefix_model.named_parameters():
            fabric.print(name, param.size(), param.dtype)
            fabric.print(param)
            break
    # breakpoint()

    return resume_ckpt


def maybe_save_checkpoint(fabric, state, cfg, is_accumulating=False):
    # Pathing for various save conditions.
    prefix = fabric.get_prefix_for_checkpoint()
    fully_qualified_checkpoint_path = f"{cfg.out_dir}/{prefix}/step-{state['optimizer_step']:08d}-{cfg.run_name}.pth"

    # Check the three save conditions:
    save_at_interval = not is_accumulating and state["optimizer_step"] % cfg.save_step_interval == 0
    if cfg.save_n_min_before_job_done is not None:
        time_spent = time.time() - global_start_time
        remaining_time = cfg.global_total_time - time_spent
        remaining_time = remaining_time / 60.0
        remaining_time = fabric.all_reduce(remaining_time, reduce_op="mean")
        save_before_timeout = remaining_time <= cfg.save_n_min_before_job_done
        if save_before_timeout:
            fabric.print(f"Saving at {remaining_time:.02f} minutes left")
            cfg.save_n_min_before_job_done = None  # reset
    else:
        save_before_timeout = False
    save_at_last_step = cfg.save_last_step and (state["microbatch_step"] >= (cfg.max_iters - 1))

    if save_at_interval or save_at_last_step or save_before_timeout:
        fabric.print(f"Saving checkpoint to {str(fully_qualified_checkpoint_path)!r}")
        state["real_token_count"] = fabric.all_reduce(state["real_token_count"].clone(), reduce_op="sum")
        fabric.save(fully_qualified_checkpoint_path, state)
        # HACK: this is hack to make sure we don't double count in the subsequent iterations (because of the all_reduce)
        # turning all indices of real_token_count 0 except for fabric.global_rank
        mask = torch.zeros_like(state["real_token_count"], dtype=torch.bool)
        mask[fabric.global_rank] = True
        state["real_token_count"] *= mask # Zeroing out all ranks except the current one


def _get_time_from_slurm():
    try:
        global_total_time = os.popen("squeue -h -j $SLURM_JOBID -o %L").read()  # this is slow
        global_total_time = global_total_time.strip("\n")
        global_total_time = [int(i) for i in re.split(":|-", global_total_time)]
        if len(global_total_time) == 4:
            global_total_time = (
                24 * 3600 * global_total_time[0]
                + 3600 * global_total_time[1]
                + 60 * global_total_time[2]
                + global_total_time[3]
            )
        elif len(global_total_time) == 3:
            global_total_time = 3600 * global_total_time[0] + 60 * global_total_time[1] + global_total_time[2]
        elif len(global_total_time) == 2:
            global_total_time = 60 * global_total_time[0] + global_total_time[1]
    except Exception as e:
        print(e)
        global_total_time = 9999999999999999
    return global_total_time


def init_weights(module: nn.Module, n_layer: int, n_embd: int, axonn_tp: bool = False):
    # Follows GPT-NeoX: https://arxiv.org/abs/2204.06745
    if isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / n_embd))
    elif isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / n_embd))
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    if not axonn_tp:
        # AxoNN does the inits internally for its linear layers, so we will skip this
        for name, param in module.named_parameters():
            if name == "proj.weight" and isinstance(module, (LLaMAMLP, CausalSelfAttention)):
                nn.init.normal_(param, mean=0.0, std=(1 / math.sqrt(n_embd) / n_layer))


def choose_logger(logger_name: str, project: str, name: str, resume: Union[bool, Path], save_dir: str, *args, **kwargs):
    if logger_name == "csv":
        return CSVLogger(root_dir=(save_dir + "/logs"), name="csv", *args, **kwargs)
    if logger_name == "wandb":
        return WandbLogger(
            project=project, entity="XXXX-6", name=name, resume=(resume is not False), save_dir=save_dir, *args, **kwargs
        )
    if logger_name == "frontier_wandb":
        return WandbLogger(project=project, entity="XXXX-6", name=name, save_dir=save_dir, offline=True)
    raise ValueError(f"`logger={logger_name}` is not a valid option.")

####################################################################################################
# Main control loop
####################################################################################################


def main():
    """Encapsulate main scope away from import calls."""
    cfg = CLI(CLISettings)

    set_torch_flags(cfg)  # should come before fabric setup
    # Next we set up the fabric and logger.
    fabric = setup_fabric(cfg)

    # Now we call the main function with the fabric and cfg.
    state, train_dataloader, val_dataloader, data_scheduler, resume_ckpt = startup(fabric, cfg)

    # Now we call the train function with the fabric, state, and dataloaders.
    train_time = time.time()
    train(
        fabric,
        state,
        train_dataloader,
        val_dataloader,
        resume_ckpt=resume_ckpt,
        cfg=cfg,
        data_scheduler=data_scheduler,
    )
    # Now exit
    fabric.print(f"Training time: {(time.time()-train_time):.2f}s")
    if fabric.device.type == "cuda":
        max_alloc = f"{torch.cuda.max_memory_allocated(fabric.device)/float(1024**3):,.3f} GB"
        max_reserved = f"{torch.cuda.max_memory_reserved(fabric.device)/float(1024**3):,.3f} GB"
        fabric.print(f"XXXX-13. Mem allocated: {max_alloc}. XXXX-13. Mem reserved: {max_reserved}.")

    if torch.distributed.is_initialized():
        torch.distributed.barrier()  # Force a clean exit
        torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()


########## Misc Notes ######################

# 1)
# the lr schedule is computed as a function of iters not optim steps, but only evaluated after an optim step,
# so that the optim step lr lags a bit behind the current lr
# These are different if gradient_accumulation_steps > 1.
# There doesn't seem to be anything _incorrect_ about this, but it might
# not be very intuitive when picking schedule params.

# 2)
# unless prohibitively slow, we should be able to call the
# scripts.convert_pretrained_checkpoint.convert_checkpoint function in save_checkpoint
# which would turn the training checkpoint into a final saved model.
# Could even call the lit-to-hf conversion process as well.
# XXXX-19: can this be offloaded to a separate thread?

# 3)
# Saving and validating run on optimizer_step, while the main training loop runs
# on microbatch_step (microbatch steps) - this can be problematic if both are out of sync
# or if gradient accum frequency is not the right divisor
# and then learning rate, as above is on the MBS schedule
# couldn't we put everything on the mbs schedule?

# 4)
# No tokens should be added in train, this just mucks up the tokenizer internals and reproducibility,
# either some token (like <cache>) exists, or it does not. This should not be discovered/changed in train.py.
# Also we kill performance by doing 2**16+1 tokens. Tokenizer should be entirely constant

# 5)
# FIXME, token counting logic assumes fixed microbatch size w/ no padding.
# This is fine for pretraining style data, but this might not always be true.
