import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import gc
import math
import datetime
import sys
from typing import Tuple
with open(sys.argv[0]) as f:
    code = f.read()
import uuid
import glob
import time
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend='nccl', timeout=datetime.timedelta(minutes=30))
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.

# ==== scale free
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
flex_attention = torch.compile(flex_attention, dynamic=False)
create_block_mask = torch.compile(create_block_mask, dynamic=False)

import wandb

## without this, we error out
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.cache_size_limit = 64

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--runname', type=str, default="testing")
## testing
parser.add_argument('--testrun', action='store_true')
parser.add_argument('--train_acc_steps', type=int, default=2)
parser.add_argument('--seq_per_node', type=int, default=28)

parser.add_argument("--train_seq_length", type=int, default=64*1024)
parser.add_argument("--base_theta", type=int, default=10_000)
parser.add_argument("--sfa", action='store_true')
parser.add_argument("--sfa_and_p_rope", action='store_true')
parser.add_argument("--sfa_and_rope", action='store_true')
parser.add_argument("--sfa_and_nope", action='store_true')
parser.add_argument("--p_rope", action='store_true')
parser.add_argument("--rope", action='store_true')
parser.add_argument("--log_n_trick", action='store_true')
parser.add_argument("--log_n_trick_and_nope", action='store_true')
parser.add_argument("--log_n_trick_and_p_rope", action='store_true')
parser.add_argument("--log_n_trick_and_ntk_aware", action='store_true')
parser.add_argument("--alibi", action='store_true')
parser.add_argument("--nope", action='store_true')
parser.add_argument("--ntk_aware", action='store_true') # section 3.1 yarn paper

parser.add_argument("--tau", default=10., type=float)
parser.add_argument("--run_id", type=str, default="")
parser.add_argument("--resume_step", type=int, default=-1)
parser.add_argument("--seed", type=int, default=0)
sfa_args = parser.parse_args()
log_peak_memory = True

modes = [
    sfa_args.sfa,
    sfa_args.sfa_and_p_rope,
    sfa_args.sfa_and_rope,
    sfa_args.p_rope,
    sfa_args.rope,
    sfa_args.alibi,
    sfa_args.log_n_trick,
    sfa_args.log_n_trick_and_p_rope,
    sfa_args.log_n_trick_and_ntk_aware,
    sfa_args.ntk_aware,
    sfa_args.sfa_and_nope,
    sfa_args.nope,
    sfa_args.log_n_trick_and_nope,
]
assert len([x for x in modes if x]) == 1, f"only one mode can be selected, got {modes}"
if sfa_args.sfa_and_p_rope or sfa_args.p_rope or sfa_args.log_n_trick_and_p_rope:
    pos_emb = 'p_rope'
elif sfa_args.sfa_and_rope or sfa_args.rope or sfa_args.ntk_aware or sfa_args.log_n_trick_and_ntk_aware or sfa_args.log_n_trick:
    pos_emb = 'rope'
elif sfa_args.alibi or sfa_args.sfa or sfa_args.sfa_and_nope or sfa_args.log_n_trick_and_nope or sfa_args.nope:
    pos_emb = 'none'
else:
    raise ValueError(f"unknown positional embedding mode: {modes}")

do_sfa = any([sfa_args.sfa, sfa_args.sfa_and_p_rope, sfa_args.sfa_and_rope, sfa_args.sfa_and_nope])
dont_sfa = any([sfa_args.p_rope, sfa_args.rope, sfa_args.alibi, sfa_args.log_n_trick, sfa_args.log_n_trick_and_p_rope, sfa_args.ntk_aware, sfa_args.nope, sfa_args.log_n_trick_and_ntk_aware, sfa_args.log_n_trick_and_nope])
assert (do_sfa ^ dont_sfa), "exactly one of do_sfa and dont_sfa must be true"
do_log_n_trick = any([sfa_args.log_n_trick, sfa_args.log_n_trick_and_p_rope, sfa_args.log_n_trick_and_ntk_aware, sfa_args.log_n_trick_and_nope])
dont_log_n_trick = any([sfa_args.ntk_aware, sfa_args.p_rope, sfa_args.rope, sfa_args.alibi, sfa_args.sfa, sfa_args.sfa_and_nope, sfa_args.sfa_and_rope, sfa_args.nope, sfa_args.sfa_and_p_rope])
assert (do_log_n_trick ^ dont_log_n_trick), "exactly one of do_log_n_trick and dont_log_n_trick must be true"

seed = sfa_args.seed
import random; import numpy as np
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(sfa_args.seed)
torch.cuda.manual_seed(sfa_args.seed)
torch.cuda.manual_seed_all(sfa_args.seed)

# ---- model setting
num_heads = 8
hidden_size = 1024
num_layers = 16
max_position_embeddings = 64*1024 ## maximum test sequence length

# ---- setup
run_id = sfa_args.run_id
assert run_id != "", "run_id must be provided"
if master_process:
    if sfa_args.run_id == "":
        run_id = uuid.uuid4()
        run_id = str(run_id)[:4]
        run_id = f"{sfa_args.runname}__{run_id}"
        ## check there isn't already a run with this runname, with any uuid
        if any(glob.glob(f"logs/{sfa_args.runname}__*")):
            raise ValueError(f"run with runname {run_id} already exists")

_config = sfa_args.__dict__
_config['run_id'] = run_id
if master_process:
    print(_config, flush=True)


# CUDNN attention is ~4ms faster than Flash, but doesn't get selected by default in PyTorch 2.5.1
from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
enable_cudnn_sdp(True)
enable_flash_sdp(False)
enable_mem_efficient_sdp(False)
enable_math_sdp(False)

if not sfa_args.testrun and master_process:
    _id = run_id
    wandb.init(project="pretrain_gpt2_med", name=run_id, config=_config, resume='allow', id=_id)

if master_process:
    os.makedirs("logs", exist_ok=True)
    os.makedirs(f"logs/{run_id}", exist_ok=True)
    logfile = f"logs/{run_id}/log.txt"
    print(logfile, flush=True)
def print0(s, console=False):
    if master_process:
        print(s, flush=True)
        with open(logfile, "a") as f:
            print(s, file=f, flush=True)

def nvidia_smi():
    import subprocess  # avoid top level import
    return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
print0(nvidia_smi())

def print_max_alloc():
    if log_peak_memory:
        print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
            f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)

# ---- flex attn functions
def scale_free_score_mod(score, b, h, q_idx, kv_idx):
    ix = q_idx - kv_idx
    lttau = torch.log1p(ix/sfa_args.tau)
    a_t = (2*lttau + 1).sqrt()
    m_t = -2*lttau
    global_score = a_t * score + m_t # 'scale free attention'
    return global_score

def alibi_score_modifier(score, b, h, q_idx, kv_idx):
    scale = torch.exp2(-((h + 1) * 8.0 / num_heads))
    bias = (kv_idx - q_idx) * scale
    return score + bias
# ====

# -----------------------------------------------------------------------------
# Muon optimizer

def zeropower_via_svd(G, steps=None):
    U, S, V = G.svd()
    return U @ V.T

@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' \\sim Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    X /= (X.norm() + eps) # ensure top singular value <= 1
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X

zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)

class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - This optimizer assumes that all parameters passed in are 2D.
    - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
    parameters; those should all be optimized by a standard method (e.g., AdamW).
    - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
    - We believe it is unlikely to work well for training with small batch size.
    - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
    - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).

    Arguments:
        lr: The learning rate used by the internal SGD.
        momentum: The momentum used by the internal SGD.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
        backend_steps: The number of iteration steps to use in the backend, if it is iterative.
    """
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
                 backend='newtonschulz5', backend_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
        super().__init__(params, defaults)

    def step(self):

        for group in self.param_groups:

            lr = group['lr']
            momentum = group['momentum']
            zeropower_backend = zeropower_backends[group['backend']]

            # generate weight updates in distributed fashion
            total_params = sum(p.numel() for p in group['params'])
            updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
            curr_idx = 0
            for i, p in enumerate(group['params']):
                # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
                if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']):
                    g = p.grad
                    assert g is not None
                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = torch.zeros_like(g)
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(g)
                    if group['nesterov']:
                        g = g.add(buf, alpha=momentum)
                    g = zeropower_backend(g, steps=group['backend_steps'])
                    g *= max(1, g.size(0)/g.size(1))**0.5
                    updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
                curr_idx += p.numel()

            # sync updates across devices. we are not memory-constrained so can do this simple deserialization
            dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)

            # deserialize and apply updates
            curr_idx = 0
            for p in group['params']:
                g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
                p.data.add_(g, alpha=-lr)
                curr_idx += p.numel()

# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model

class CastedLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, x):
        return F.linear(x, self.weight.to(x.dtype))

"""for standard RoPE"""
def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4 # multihead attention
    d = x.shape[3]//2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)
class Rotary(torch.nn.Module):
    def __init__(self, dim, base=sfa_args.base_theta):
        super().__init__()
        self.dim = dim
        self.base = base
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            if sfa_args.ntk_aware and seq_len > sfa_args.train_seq_length:
                s = seq_len / sfa_args.train_seq_length
                D = self.dim
                scaled_base = self.base * (s ** (D/(D-2)))
                inv_freq = 1.0 / (scaled_base ** (torch.arange(0, self.dim, 2).float() / self.dim))
            else:
                inv_freq = self.inv_freq

            t = torch.arange(seq_len, device=x.device).type_as(inv_freq)
            freqs = torch.outer(t, inv_freq).to(x.device)
            self.cos_cached = freqs.cos().bfloat16()
            self.sin_cached = freqs.sin().bfloat16()
        return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]

"""'p-RoPE' https://arxiv.org/abs/2410.06205, with p=0.5"""
class pRoPE(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
        angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.cos = nn.Buffer(theta.cos(), persistent=False)
        self.sin = nn.Buffer(theta.sin(), persistent=False)
    def forward(self, x_BTHD):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_k = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_v = CastedLinear(self.n_embd, self.n_embd, bias=False)
        # output projection
        self.c_proj = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
        if pos_emb == 'rope':
            self.rotary = Rotary(self.head_dim)
        elif pos_emb == 'p_rope':
            max_seq_len = 64*1024
            self.rotary = pRoPE(self.head_dim, max_seq_len)
        else:
            self.rotary = None
        if do_log_n_trick:
            self.logn_s = torch.nn.Parameter(torch.ones(self.n_head))
            logn_list = [math.log(i+1, math.e) for i in range(max_position_embeddings)] # (T,)
            self.register_buffer("logn", torch.tensor(logn_list))
    def _rotary_emb(self, q, k):
        if pos_emb == 'rope':
            cos, sin = self.rotary(q)
            q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        elif pos_emb == 'p_rope':
            q, k = self.rotary(q), self.rotary(k)
        elif pos_emb == 'none':
            pass
        else:
            raise ValueError(f"unknown positional embedding type: {sfa_args.pos_emb}")
        return q, k

    def forward(self, x, block_mask=None):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm
        q, k = self._rotary_emb(q, k)
        if do_log_n_trick:
            q = self.logn[None,:T,None,None] * q * self.logn_s[None,None,:,None]
        q = q.bfloat16(); k = k.bfloat16(); v = v.bfloat16()
        ## pick attention
        if do_sfa:
            y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, score_mod=scale_free_score_mod)
        elif sfa_args.alibi:
            y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, score_mod=alibi_score_modifier)
        else:
            y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
        y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = CastedLinear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj  = CastedLinear(4 * config.n_embd, config.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x, block_mask=None):
        x = x + self.attn(F.rms_norm(x, (x.size(-1),)), block_mask=block_mask)
        x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
        return x

# -----------------------------------------------------------------------------
# The main GPT-2 model

@dataclass
class GPTConfig:
    vocab_size : int = 50304
    n_layer : int = num_layers
    n_head : int = num_heads
    n_embd : int = hidden_size

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ))
        self.lm_head = CastedLinear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight.data.zero_()
        self._block_mask_cache = dict()

    def _get_block_mask(self, T):
        if T in self._block_mask_cache: # use cache
            return self._block_mask_cache[T]
        self._block_mask_cache = dict() # reset cache
        H = num_heads if sfa_args.alibi else None
        if do_sfa or sfa_args.alibi:
            def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx
            bm = create_block_mask(causal_mask, None, H, T, T, device='cuda', _compile=True)
            self._block_mask_cache[T] = bm
            return bm
        else:
            self._block_mask_cache[T] = None
            return None
    def forward(self, idx, targets=None, return_logits=True):
        _, T = idx.size()
        block_mask = self._get_block_mask(T)
        x = self.transformer.wte(idx)
        x = F.rms_norm(x, (x.size(-1),))
        for block in self.transformer.h:
            x = block(x, block_mask=block_mask)
        x = F.rms_norm(x, (x.size(-1),))

        if targets is not None:
            logits = self.lm_head(x)
            logits = logits.float() # use tf32/fp32 for logits
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            logits = logits.float() # use tf32/fp32 for logits
            loss = None
        if not return_logits:
            logits = None

        return logits, loss

# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

def _peek_data_shard(filename):
    # only reads the header, returns header data
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
    if header[0] != 20240520:
        print("ERROR: magic number mismatch in the data .bin file!")
        print("---> HINT: Are you passing in a correct file with --input_bin?")
        print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
        print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
        exit(1)
    assert header[1] == 1, "unsupported version"
    ntok = header[2] # number of tokens (claimed)
    return ntok # for now just return the number of tokens

def _load_data_shard(filename):
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
        assert header[0] == 20240520, "magic number mismatch in the data .bin file"
        assert header[1] == 1, "unsupported version"
        ntok = header[2] # number of tokens (claimed)
        # the rest of it are tokens, stored as uint16
        tokens = np.frombuffer(f.read(), dtype=np.uint16)
    assert len(tokens) == ntok, "number of tokens read does not match header?"
    return tokens

class DataLoader:
    def __init__(self, filename_pattern, B, T, num_steps):
        self.B = B
        self.T = T
        self.num_steps = num_steps
        self.process_rank = int(os.environ.get('RANK', 0))
        self.num_processes = int(os.environ.get('WORLD_SIZE', 1))

        # glob files that match the pattern
        self.files = sorted(glob.glob(filename_pattern))
        assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

        # load and validate all data shards, count number of tokens in total
        ntok_total = 0
        for fname in self.files:
            shard_ntok = _peek_data_shard(fname)
            assert shard_ntok >= self.num_processes * B * T + 1
            ntok_total += int(shard_ntok)
        self.ntok_total = ntok_total

        # kick things off
        self.reset()

    def reset(self):
        self.current_shard = 0
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def get_state(self):
        return dict(
            current_shard=self.current_shard,
            current_position=self.current_position
        )

    def set_state(self, state_dict):
        self.current_shard = state_dict['current_shard']
        self.current_position = state_dict['current_position']
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + 1) % len(self.files)
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def next_batch(self):
        B = self.B
        T = self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance current position and load next shard if necessary
        self.current_position += B * T * self.num_processes
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.advance()
        return x.cuda(), y.cuda()

# -----------------------------------------------------------------------------
# int main

#nnodes = 4; train_acc_steps = 4; seq_per_node = 16 #seq_per_node = 4
nnodes = 4; train_acc_steps = sfa_args.train_acc_steps; seq_per_node = sfa_args.seq_per_node #seq_per_node = 4
# want the product of these numbers to be ~ 313 ~= 128 * sqrt(6)

#num_iterations = 10900 #19070
ntokens = 10_000_000_000 # 10B tokens
num_iterations = ntokens / (nnodes * train_acc_steps * sfa_args.train_seq_length * seq_per_node)
num_iterations = math.ceil(num_iterations)

@dataclass
class Hyperparameters:
    # data hyperparams
    input_bin : str = './data/fineweb100B/fineweb_train_*.bin' # input .bin to train on
    input_val_bin : str = './data/fineweb100B/fineweb_val_*.bin' # input .bin to eval validation loss on
    # optimization hyperparams
    batch_size = int(nnodes * seq_per_node * train_acc_steps)
    sequence_length : int = sfa_args.train_seq_length
    num_iterations : int = num_iterations # number of iterations to run
    # evaluation and logging hyperparams
    val_loss_every : int = 250 # every how many steps to evaluate val loss? 0 for only at the end
    val_tokens : int = (64 * 1024 * 2 * 8) * 10  # how many tokens of validation data ## make sure it's divisible by 64 * 6 * 4
    save_every : int = 500 # every how many steps to save the checkpoint? 0 for only at the end
args = Hyperparameters()

args.device_batch_size = args.batch_size // (nnodes * train_acc_steps)
assert args.device_batch_size == seq_per_node, f"device batch size should be {seq_per_node}"

# convenience variables
B, T = args.device_batch_size, args.sequence_length

# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % (B * ddp_world_size) == 0
train_accumulation_steps = args.batch_size // (B * ddp_world_size)
print0(f"train_accumulation_steps: {train_accumulation_steps}")

def get_val_loader(Tval):
    #Bval = int(64*1024 / Tval)
    #Bval = max(int(args.device_batch_size * T / Tval), 1)
    Bval = max(int(32 * T / Tval), 1)
    assert args.val_tokens % (Bval * Tval * ddp_world_size) == 0
    val_steps = args.val_tokens // (Bval * Tval * ddp_world_size)
    val_loader = DataLoader(args.input_val_bin, Bval, Tval, val_steps)
    return val_loader

# load tokens
train_loader = DataLoader(args.input_bin, B, T, None)
# val_loader_1k = get_val_loader(1*1024)
val_loader_4k = get_val_loader(4*1024)
val_loader_16k = get_val_loader(16*1024)
# val_loader_32k = get_val_loader(32*1024)
val_loader_64k = get_val_loader(64*1024)
if master_process:
    print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files", console=True)

num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=num_layers, n_head=num_heads, n_embd=hidden_size))

if sfa_args.resume_step >= 0:
    ## make sure to load rank == 0 (master node) state_dict!
    ## we only save the model weights in this file
    fname = f"logs/{run_id}/state_rank_0_step%06d.pt" % sfa_args.resume_step
    print0(f"loading model from {fname}")
    state_dict = torch.load(fname, weights_only=False, map_location=torch.device('cpu'))['model']
    #model.load_state_dict(state_dict['model'])
    for name, param in model.named_parameters():
        if name in state_dict:
            param.data.copy_(state_dict[name])
    del state_dict
    torch.cuda.empty_cache()
    gc.collect()

model = model.cuda().bfloat16()

uncompiled_model = model
print0("=====")
print0(f"model size: {sum(p.numel() for p in model.parameters())}")
print0("=====")
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True # suggested by @Chillee
#model = torch.compile(model, dynamic=True)
model = torch.compile(model, dynamic=False)
# here we wrap model into DDP container
ddp_model = DDP(model, device_ids=[ddp_local_rank])
model = ddp_model
raw_model = model.module # always contains the "raw" unwrapped model

ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)

# init the optimizer(s)
linear_lr_mult = 786./hidden_size # existing LRs are relative to 786 hidden size
                                  # this is muParam (simple) from https://arxiv.org/pdf/2309.14322
params = list(raw_model.transformer.h.parameters())
matrix_params = [p for p in params if p.ndim == 2]
vector_and_scalar_params = [p for p in params if p.ndim < 2]
other_params = [p for p in params if p.ndim > 2]
assert len(other_params) == 0
if do_log_n_trick: assert len(vector_and_scalar_params) > 0
optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight], lr=0.3,   betas=(0.9, 0.95), fused=True)
optimizer2 = torch.optim.Adam([raw_model.lm_head.weight] + vector_and_scalar_params, lr=0.002, betas=(0.9, 0.95), fused=True)
optimizer3 = Muon(matrix_params,  lr=0.02 * linear_lr_mult,  momentum=0.95)
optimizers = [optimizer1, optimizer2, optimizer3]
def get_lr(it): # cosine schedule
    assert it <= args.num_iterations
    max_lr = 1.0
    min_lr = 0.1
    return min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * it / args.num_iterations))
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]

# begin logging
if master_process:
    with open(logfile, "a") as f:
        # begin the log by printing this file (the Python code)
        f.write('='*100 + '\n')
        f.write(code)
        f.write('='*100 + '\n')
        f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
        import subprocess
        result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        f.write(f'{result.stdout}\n')
        f.write('='*100 + '\n')

_train_loss = 0.
_pickle_metrics = []
nbatches = 0
training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.time()
# begin training
train_loader.reset()

start_step = 0
resumed = False
if sfa_args.resume_step >= 0:
    print0("resuming from step ", sfa_args.resume_step)
    # now load the state dict for this node
    state_dict = torch.load(f"logs/{run_id}/state_rank_{ddp_rank}_step%06d.pt" % sfa_args.resume_step, weights_only=False)
    resumed = True
    start_step = state_dict['step']
    train_loader.set_state(state_dict['train_loader_state'])
    for opt, opt_state in zip(optimizers, state_dict['optimizers']):
        opt.load_state_dict(opt_state)

    for sched, sched_state in zip(schedulers, state_dict['scheduler']):
        sched.load_state_dict(sched_state)

    # Restore random states
    if 'random_state' in state_dict:
        random.setstate(state_dict['random_state'])
    if 'numpy_random_state' in state_dict:
        np.random.set_state(state_dict['numpy_random_state'])
    if 'torch_random_state' in state_dict:
        torch.random.set_rng_state(state_dict['torch_random_state'])
    if 'torch_cuda_random_state' in state_dict and torch.cuda.is_available():
        torch.cuda.set_rng_state_all(state_dict['torch_cuda_random_state'])

    ## load pickle step
    _pickle_metrics = torch.load(f"logs/{run_id}/metrics.pt", weights_only=False)
    _pickle_metrics = [x for x in _pickle_metrics if x['step'] <= start_step]
    print0("loaded state from step", start_step)
    del state_dict
    gc.collect()
    torch.cuda.empty_cache()
    print0("resume GC done", start_step)

def eval_model(val_loader):
    # run validation batches
    torch.cuda.synchronize()
    time_val = time.time()
    model.eval()
    val_loader.reset()
    val_loss = 0.0

    # Clear memory once before starting validation
    torch.cuda.empty_cache()
    gc.collect()

    with torch.no_grad():
        for ix in range(val_loader.num_steps):
            x_val, y_val = val_loader.next_batch()
            with ctx:
                _, loss = model(x_val, y_val, return_logits=False)
                val_loss += loss.detach()
    # Single all_reduce at the end
    dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
    val_loss /= val_loader.num_steps

    torch.cuda.synchronize()
    print0(f"time_val: {(time.time() - time_val):.4f}")

    return val_loss

torch.cuda.synchronize()

for step in range(start_step, args.num_iterations + 1):
    last_step = (step == args.num_iterations)
    # This effectively ignores timing first 10 steps, which are slower for weird reasons.
    # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
    # steps with dummy data first, and then re-initialize the model and reset the loader.
    if step == 10:
        training_time_ms = 0
        t0 = time.time()
    timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val

    # once in a while evaluate the validation dataset
    do_val = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)
    #do_val = do_val and step > 0
    if do_val:
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        print_max_alloc()


        print0("begin 4k val")
        val_loss_4k = eval_model(val_loader_4k)
        gc.collect()
        torch.cuda.empty_cache()
        print0("4k done, begin 16k")

        val_loss_16k = eval_model(val_loader_16k)
        gc.collect()
        torch.cuda.empty_cache()
        print0("16k done, begin 64k")
        val_loss_64k = eval_model(val_loader_64k)
        gc.collect()
        torch.cuda.empty_cache()
        print0("64k done")

        # Clear memory before switching back to training
        torch.cuda.empty_cache()
        gc.collect()

        print0(f'step:{step}/{args.num_iterations} {val_loss_4k:.4f},{val_loss_16k:.4f},{val_loss_64k:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')

        metrics = {
            # f'val_loss_1k': val_loss_1k.item(),
            f'val_loss_4k': val_loss_4k.item(),
            f'val_loss_16k': val_loss_16k.item(),
            f'val_loss_64k': val_loss_64k.item(),
            'step':step,
        }
        assert isinstance(_train_loss, float), f"train_loss is not a float, but {_train_loss}"
        if _train_loss > 0.:
            metrics[f'train_loss_{sfa_args.train_seq_length}'] = float(_train_loss / nbatches) # put val/train on same scale
        # reset
        _train_loss = 0.
        nbatches = 0
        if not sfa_args.testrun and master_process:
            wandb.log(metrics)
            ## also put metrics in a pickle file using torch.save
            _pickle_metrics.append(metrics)
            torch.save(_pickle_metrics, f"logs/{run_id}/metrics.pt")

        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    if (last_step or (args.save_every > 0 and step % args.save_every == 0)):
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        # save the state of the training process
        log = dict(step=step,
                    code=code,
                    _config=_config,
                    optimizers=[opt.state_dict() for opt in optimizers],
                    scheduler=[sched.state_dict() for sched in schedulers],
                    train_loader_state=train_loader.get_state(),
                    random_state=random.getstate(),
                    numpy_random_state=np.random.get_state(),
                    torch_random_state=torch.random.get_rng_state(),
                    torch_cuda_random_state=torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None)
        if master_process:
            log['model'] = uncompiled_model.state_dict()
        fname = f"logs/{run_id}/state_rank_{ddp_rank}_step%06d.pt" % step
        if not os.path.exists(fname):
            torch.save(log, fname)
        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    # bit confusing: we want to make sure to eval on 0th iteration
    # but also after the very last iteration. so we loop for step <= num_iterations
    # instead of just < num_iterations (one extra due to <=), only to do
    # the validation/sampling one last time, and then we break right here as we're done.
    if last_step:
        break

    # --------------- TRAINING SECTION BEGIN -----------------
    model.train()

    # Clear memory before starting training
    torch.cuda.empty_cache()
    gc.collect()

    def train_iter():
        ntoken = 0
        global _train_loss
        global nbatches
        for i in range(1, train_accumulation_steps+1):
            if step == start_step: print0(f"acc={i}. about to run forward pass. {(time.time() - t0)/1000}s")
            # forward pass
            x, y = train_loader.next_batch()
            ntoken += x.numel()
            with ctx:
                _, loss = model(x, y, return_logits=False)
                train_loss = loss.detach()
                _train_loss += float(train_loss.item() / train_accumulation_steps)
            # backward pass
            if step == start_step: print0(f"acc={i}. forward pass done. {(time.time() - t0)/1000}s")
            if i < train_accumulation_steps:
                with model.no_sync(): # there's no need to sync gradients every accumulation step
                    loss.backward()
            else:
                loss.backward() # just sync on the last step
            if step == start_step: print0(f"acc={i}. backward pass done. about to GC. {(time.time() - t0)/1000}s")
            # Clear memory after each accumulation step
            del loss, x, y
            torch.cuda.empty_cache()
            gc.collect()
            if step == start_step: print0(f"acc={i}. GC done. {(time.time() - t0)/1000}s")
        for p in model.parameters():
            p.grad /= train_accumulation_steps
        nbatches += 1
        # step the optimizers and schedulers
        for opt, sched in zip(optimizers, schedulers):
            opt.step()
            sched.step()
        model.zero_grad(set_to_none=True)
        # --------------- TRAINING SECTION END -------------------
        # everything that follows now is just diagnostics, prints, logging, etc.
        if step == start_step: print0(f"all_reduce train:")
        dist.all_reduce(train_loss, op=dist.ReduceOp.AVG)
        if master_process:
            if step == start_step: print0(f"syncrhonize:")
            torch.cuda.synchronize()
            approx_time = training_time_ms + 1000 * (time.time() - t0)
            print0(f"st {step+1}/{args.num_iterations}|ntoken {ntoken}|tr_loss:{train_loss.item():.4f}|tr_time:{approx_time:.0f}ms avg_step_time:{approx_time/timed_steps:.2f}ms", console=True)
    train_iter()

    if step % 50 == 0:
        print0(nvidia_smi())

    torch.cuda.empty_cache()
    gc.collect()

print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()
