import os
import gc
import math
import sys
from typing import Tuple
with open(sys.argv[0]) as f:
    code = f.read()
import uuid
import glob
import time
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config

# ==== scale free
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
flex_attention = torch.compile(flex_attention, dynamic=False)
create_block_mask = torch.compile(create_block_mask, dynamic=False)

import wandb

## without this, we error out
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.cache_size_limit = 128

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--runname', type=str, default="testing")
## testing
parser.add_argument('--testrun', action='store_true')

parser.add_argument("--train_seq_length", type=int, default=64*1024)
parser.add_argument("--base_theta", type=int, default=10_000)
parser.add_argument("--sfa", action='store_true')
parser.add_argument("--nousyarn", action='store_true')
parser.add_argument("--sfa_and_p_rope", action='store_true')
parser.add_argument("--sfa_and_rope", action='store_true')
parser.add_argument("--sfa_and_nope", action='store_true')
parser.add_argument("--p_rope", action='store_true')
parser.add_argument("--rope", action='store_true')
parser.add_argument("--log_n_trick", action='store_true')
parser.add_argument("--log_n_trick_and_p_rope", action='store_true')
parser.add_argument("--log_n_trick_and_ntk_aware", action='store_true')
parser.add_argument("--alibi", action='store_true')
parser.add_argument("--nope", action='store_true')
parser.add_argument("--ntk_aware", action='store_true') # section 3.1 yarn paper
parser.add_argument("--infini_rope", action='store_true') # infini-attention https://github.com/jlamprou/Infini-Attention https://arxiv.org/abs/2404.07143
parser.add_argument("--infini_p_rope", action='store_true') # infini-attention https://github.com/jlamprou/Infini-Attention https://arxiv.org/abs/2404.07143
parser.add_argument("--tau", default=10., type=float)
parser.add_argument("--run_id", type=str, default="")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--bs_mul", type=float, default=1.0)
parser.add_argument("--log_n_trick_and_nope", action='store_true')
sfa_args = parser.parse_args()
log_peak_memory = True

modes = [
    sfa_args.sfa,
    sfa_args.sfa_and_p_rope,
    sfa_args.sfa_and_rope,
    sfa_args.p_rope,
    sfa_args.rope,
    sfa_args.alibi,
    sfa_args.nousyarn,
    sfa_args.log_n_trick,
    sfa_args.log_n_trick_and_p_rope,
    sfa_args.log_n_trick_and_ntk_aware,
    sfa_args.ntk_aware,
    sfa_args.sfa_and_nope,
    sfa_args.nope,
    sfa_args.log_n_trick_and_nope,
    sfa_args.infini_rope,
    sfa_args.infini_p_rope,
]
assert len([x for x in modes if x]) == 1, f"only one mode can be selected, got {modes}"
if sfa_args.sfa_and_p_rope or sfa_args.p_rope or sfa_args.log_n_trick_and_p_rope or sfa_args.infini_p_rope:
    pos_emb = 'p_rope'
elif sfa_args.sfa_and_rope or sfa_args.rope or sfa_args.ntk_aware or sfa_args.log_n_trick_and_ntk_aware or sfa_args.log_n_trick or sfa_args.infini_rope:
    pos_emb = 'rope'
elif sfa_args.alibi or sfa_args.sfa or sfa_args.sfa_and_nope or sfa_args.log_n_trick_and_nope or sfa_args.nope:
    pos_emb = 'none'
elif sfa_args.nousyarn:
    pos_emb = 'nousyarn'
else:
    raise ValueError(f"unknown positional embedding mode: {modes}")

do_sfa = any([sfa_args.sfa, sfa_args.sfa_and_p_rope, sfa_args.sfa_and_rope, sfa_args.sfa_and_nope])
dont_sfa = any([sfa_args.p_rope, sfa_args.rope, sfa_args.alibi, sfa_args.nousyarn, sfa_args.log_n_trick, sfa_args.log_n_trick_and_p_rope, sfa_args.ntk_aware, sfa_args.nope, sfa_args.log_n_trick_and_ntk_aware, sfa_args.log_n_trick_and_nope, sfa_args.infini_rope, sfa_args.infini_p_rope])
assert (do_sfa ^ dont_sfa), "exactly one of do_sfa and dont_sfa must be true"
do_log_n_trick = any([sfa_args.log_n_trick, sfa_args.log_n_trick_and_p_rope, sfa_args.log_n_trick_and_ntk_aware, sfa_args.log_n_trick_and_nope])
dont_log_n_trick = any([sfa_args.ntk_aware, sfa_args.p_rope, sfa_args.rope, sfa_args.alibi, sfa_args.nousyarn, sfa_args.sfa, sfa_args.sfa_and_nope, sfa_args.sfa_and_rope, sfa_args.nope, sfa_args.sfa_and_p_rope, sfa_args.infini_rope, sfa_args.infini_p_rope])
assert (do_log_n_trick ^ dont_log_n_trick), "exactly one of do_log_n_trick and dont_log_n_trick must be true"

do_infini = any([sfa_args.infini_rope, sfa_args.infini_p_rope])

seed = sfa_args.seed
import random; import numpy as np
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(sfa_args.seed)
torch.cuda.manual_seed(sfa_args.seed)
torch.cuda.manual_seed_all(sfa_args.seed)


# ---- model setting
num_heads = 6
hidden_size = 768
max_position_embeddings = 64*1024 ## maximum test sequence length

# ---- setup
if sfa_args.run_id == "":
    run_id = uuid.uuid4()
    run_id = str(run_id)[:4]
    run_id = f"{sfa_args.runname}__{run_id}"
    ## check there isn't already a run with this runname, with any uuid
    if any(glob.glob(f"logs/{sfa_args.runname}__*")):
        raise ValueError(f"run with runname {run_id} already exists")
else:
    run_id = sfa_args.run_id


_config = sfa_args.__dict__
_config['run_id'] = run_id
print(_config, flush=True)

# CUDNN attention is ~4ms faster than Flash, but doesn't get selected by default in PyTorch 2.5.1
from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
enable_cudnn_sdp(True)
enable_flash_sdp(False)
enable_mem_efficient_sdp(False)
enable_math_sdp(False)

if not sfa_args.testrun:
    wandb.init(project="pretrain_gpt2_sia", name=run_id, config=_config)

logfile = None
os.makedirs("logs", exist_ok=True)
os.makedirs(f"logs/{run_id}", exist_ok=True)
logfile = f"logs/{run_id}/log.txt"
print(logfile, flush=True)
def print0(s, console=False):
    print(s, flush=True)
    with open(logfile, "a") as f:
        print(s, file=f, flush=True)
def print_max_alloc():
    if log_peak_memory:
        print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
            f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)

# ---- flex attn functions
def scale_invariant_score_mod(score, b, h, q_idx, kv_idx):
    ix = q_idx - kv_idx
    lttau = torch.log1p(ix/sfa_args.tau)
    a_t = (2*lttau + 1).sqrt()
    m_t = -2*lttau
    global_score = a_t * score + m_t # 'scale invariant attention'
    return global_score

def alibi_score_modifier(score, b, h, q_idx, kv_idx):
    scale = torch.exp2(-((h + 1) * 8.0 / num_heads))
    bias = (kv_idx - q_idx) * scale
    return score + bias
# ====

# -----------------------------------------------------------------------------
# Muon optimizer

def zeropower_via_svd(G, steps=None):
    U, S, V = G.svd()
    return U @ V.T

@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' \\sim Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    X /= (X.norm() + eps) # ensure top singular value <= 1
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X

zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)

class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - This optimizer assumes that all parameters passed in are 2D.
    - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
    parameters; those should all be optimized by a standard method (e.g., AdamW).
    - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
    - We believe it is unlikely to work well for training with small batch size.
    - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
    - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).

    Arguments:
        lr: The learning rate used by the internal SGD.
        momentum: The momentum used by the internal SGD.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
        backend_steps: The number of iteration steps to use in the backend, if it is iterative.
    """
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
                 backend='newtonschulz5', backend_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
        super().__init__(params, defaults)

    def step(self):

        for group in self.param_groups:

            lr = group['lr']
            momentum = group['momentum']
            zeropower_backend = zeropower_backends[group['backend']]

            # generate weight updates in distributed fashion
            total_params = sum(p.numel() for p in group['params'])
            updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
            curr_idx = 0
            for i, p in enumerate(group['params']):
                # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
                g = p.grad
                assert g is not None
                state = self.state[p]
                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(g)
                buf = state['momentum_buffer']
                buf.mul_(momentum).add_(g)
                if group['nesterov']:
                    g = g.add(buf, alpha=momentum)
                g = zeropower_backend(g, steps=group['backend_steps'])
                g *= max(1, g.size(0)/g.size(1))**0.5
                updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
                curr_idx += p.numel()

            # deserialize and apply updates
            curr_idx = 0
            for p in group['params']:
                g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
                p.data.add_(g, alpha=-lr)
                curr_idx += p.numel()

# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model

class CastedLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, x):
        return F.linear(x, self.weight.to(x.dtype))

"""for standard RoPE"""
def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4 # multihead attention
    d = x.shape[3]//2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)
class Rotary(torch.nn.Module):
    def __init__(self, dim, base=sfa_args.base_theta):
        super().__init__()
        self.dim = dim
        self.base = base
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            if sfa_args.ntk_aware and seq_len > sfa_args.train_seq_length:
                s = seq_len / sfa_args.train_seq_length
                D = self.dim
                scaled_base = self.base * (s ** (D/(D-2)))
                inv_freq = 1.0 / (scaled_base ** (torch.arange(0, self.dim, 2).float() / self.dim))
            else:
                inv_freq = self.inv_freq

            t = torch.arange(seq_len, device=x.device).type_as(inv_freq)
            freqs = torch.outer(t, inv_freq).to(x.device)
            self.cos_cached = freqs.cos().bfloat16()
            self.sin_cached = freqs.sin().bfloat16()
        return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]

"""'p-RoPE' https://arxiv.org/abs/2410.06205, with p=0.5"""
class pRoPE(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
        angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.cos = nn.Buffer(theta.cos(), persistent=False)
        self.sin = nn.Buffer(theta.sin(), persistent=False)
    def forward(self, x_BTHD):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)

"""--------------------------------------------------------------------------"""
"""use https://huggingface.co/NousResearch/Yarn-Llama-2-13b-128k/blob/main/modeling_llama_together_yarn.py"""
from einops import rearrange, repeat

def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
        dim=-1,
    )


# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations, dim, base=sfa_args.base_theta, max_position_embeddings=sfa_args.train_seq_length):
    return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))

# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot, high_rot, dim, base=sfa_args.base_theta, max_position_embeddings=sfa_args.train_seq_length):
    low = math.floor(_yarn_find_correction_dim(
        low_rot, dim, base, max_position_embeddings))
    high = math.ceil(_yarn_find_correction_dim(
        high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim-1)  # Clamp values just in case

def _yarn_linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func

def _yarn_get_mscale(scale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0

class NousYaRN(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.
    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration
    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
    This implements the YaRN extension method.
    """

    def __init__(self, dim: int, base=sfa_args.base_theta, interleaved=False,
                 scaling_factor=1.0, pos_idx_in_fp32=True,
                 max_position_embeddings=max_position_embeddings,
                 original_max_position_embeddings=sfa_args.train_seq_length, extrapolation_factor=1,
                 attn_factor=1, beta_fast=32, beta_slow=1,
                 dynamic=True, finetuned=False, device=None):
        """
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
            pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
                otherwise they might be in lower precision.
                This option was added because previously (before 2023-07-02), when we construct
                the position indices, we use the dtype of self.inv_freq. In most cases this would
                be fp32, but if the model is trained in pure bf16 (not mixed precision), then
                self.inv_freq would be bf16, and the position indices are also in bf16.
                Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
                embeddings for some positions will coincide.
                To maintain compatibility with models previously trained in pure bf16,
                we add this option.
            scaling_factor: RotaryEmbedding extended with YaRN scaling.
        """
        super().__init__()

        self.dim = dim
        self.base = float(base)
        self.interleaved = interleaved
        self.scaling_factor = scaling_factor
        self.max_position_embeddings = max_position_embeddings
        self.original_max_position_embeddings = original_max_position_embeddings if original_max_position_embeddings else max_position_embeddings
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        self.pos_idx_in_fp32 = pos_idx_in_fp32
        self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) # Get n-d magnitude scaling corrected for interpolation
        self.dynamic = dynamic
        self.finetuned = finetuned

        # Generate and save the inverse frequency buffer (non trainable)
        if not dynamic:
            self._compute_inv_freq(scaling_factor, device)

        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None

    def _compute_inv_freq(self, scaling_factor, device=None):
        pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
        inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
        inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def _compute_inv_freq_original(self, device=None):
        inv_freq = 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
                                                 dtype=torch.float32) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
        # Reset the tables if the sequence length has changed,
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
        if (seqlen != self._seq_len_cached or self._cos_cached.device != device
            or self._cos_cached.dtype != dtype
            or (self.training and self._cos_cached.is_inference())):
            self._seq_len_cached = seqlen

            if self.dynamic:
                scaling_factor = None
                if seqlen <= self.max_position_embeddings:
                    if self.finetuned:
                        scaling_factor = self.scaling_factor
                else:
                    scaling_factor = seqlen / self.original_max_position_embeddings
                if scaling_factor:
                    self._compute_inv_freq(scaling_factor, device)
                    self.mscale = float(_yarn_get_mscale(scaling_factor) * self.attn_factor)
                else:
                    self._compute_inv_freq_original(device)

            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self.inv_freq.to(torch.float32)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            freqs = torch.outer(t, inv_freq)
            self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
            self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)


    def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        q: (batch, seqlen, nheads, headdim)
        k: (batch, seqlen, nheads, headdim)
        seqlen_offset: can be used in generation where the qkv being passed in is only the last
        token in the batch.
        """
        assert q.shape[1] in [1024, 4096, 16*1024, 64*1024]
        self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
        return apply_rotary_emb_torch(
            q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], True
        ), apply_rotary_emb_torch(
            k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], True
        )
        # return apply_rotary_emb_func(
        #     q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
        #     self.interleaved, True # inplace=True
        # ), apply_rotary_emb_func(
        #     k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
        #     self.interleaved, True # inplace=True
        # )
"""--------------------------------------------------------------------------"""

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_k = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_v = CastedLinear(self.n_embd, self.n_embd, bias=False)
        # output projection
        self.c_proj = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
        if pos_emb == 'rope':
            self.rotary = Rotary(self.head_dim)
        elif pos_emb == 'p_rope':
            max_seq_len = 64*1024
            self.rotary = pRoPE(self.head_dim, max_seq_len)
        elif pos_emb == 'nousyarn':
            self.rotary = NousYaRN(self.head_dim)
        else:
            self.rotary = None
        if do_log_n_trick:
            self.logn_s = torch.nn.Parameter(torch.ones(self.n_head))
            logn_list = [math.log(i+1, math.e) for i in range(max_position_embeddings)] # (T,)
            self.register_buffer("logn", torch.tensor(logn_list))

        if do_infini:
            self.infini_beta = nn.Parameter(torch.randn(1))
            self.register_buffer("M", torch.zeros(self.n_head, self.head_dim, self.head_dim))
            self.register_buffer("z", torch.zeros(self.n_head, self.head_dim))
            self.segment_size = 2048

    def _rotary_emb(self, q, k):
        if pos_emb == 'rope':
            cos, sin = self.rotary(q)
            q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        elif pos_emb == 'p_rope':
            q, k = self.rotary(q), self.rotary(k)
        elif pos_emb == 'nousyarn':
            q,k = self.rotary(q,k)
        elif pos_emb == 'none':
            pass
        else:
            raise ValueError(f"unknown positional embedding type: {sfa_args.pos_emb}")
        return q, k

    def forward(self, x, block_mask=None):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm
        q, k = self._rotary_emb(q, k)
        if do_log_n_trick:
            q = self.logn[None,:T,None,None] * q * self.logn_s[None,None,:,None]
        q = q.bfloat16(); k = k.bfloat16(); v = v.bfloat16()
        ## pick attention
        if do_sfa:
            y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, score_mod=scale_invariant_score_mod)
        elif sfa_args.alibi:
            y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, score_mod=alibi_score_modifier)
        elif do_infini:
            q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)
            # now q, k, v are (B, H, T, D)
            # see https://github.com/jlamprou/Infini-Attention/blob/3a86dd5c528f3cf9ecd231314338c5fac6f88dd1/infiniAttention.py#L98
            memory_output = self._infini_retrieve_from_memory(q, self.M, self.z)
            self.M, self.z  = self._infini_update_memory(k, v, self.M, self.z)
            attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
            y = self._infini_long_term_injection(attn_output, memory_output)
        else:
            y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
        y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

    """infini-attention methods, see https://github.com/jlamprou/Infini-Attention/blob/main/infiniAttention.py"""
    def _infini_retrieve_from_memory(self, Q, M, z):
        assert do_infini
        # Retrieve context from compressive memory using linear attention (Eq. 3)
        #print(f"{Q.size()=}, {M.size()=}, {z.size()=}", flush=True)
        M_s_1 = torch.matmul(F.elu(Q) + 1, M)
        Z_s_1 = torch.matmul(F.elu(Q) + 1, z.unsqueeze(-1)) + 1e-8
        A_mem = M_s_1 / Z_s_1
        return A_mem

    def _infini_update_memory(self, K, V, M, z, use_delta=False):
        assert do_infini
        if use_delta:
            V_retrieved = torch.matmul(F.elu(K) + 1, M) / (torch.matmul(F.elu(K) + 1, z.unsqueeze(-1)) + 1e-8)
            updated_M = M + torch.matmul(F.elu(K).transpose(-2, -1) + 1, V - V_retrieved)
        else:
            updated_M = M + torch.matmul(F.elu(K).transpose(-2, -1) + 1, V)

        updated_z = z + (F.elu(K) + 1).sum(dim=-2)
        M = updated_M
        z = updated_z
        return M, z

    def _infini_long_term_injection(self, A_dot, A_mem):
        assert do_infini
        beta = torch.sigmoid(self.infini_beta)
        A = beta * A_mem + (1 - beta) * A_dot
        return A
    def infini_reset_memory(self):
        assert do_infini
        self.M = torch.zeros(self.n_head, self.head_dim, self.head_dim, device=self.M.device, dtype=self.M.dtype)
        self.z = torch.zeros(self.n_head, self.head_dim, device=self.z.device, dtype=self.z.dtype)

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = CastedLinear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj  = CastedLinear(4 * config.n_embd, config.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x, block_mask=None):
        x = x + self.attn(F.rms_norm(x, (x.size(-1),)), block_mask=block_mask)
        x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
        return x

# -----------------------------------------------------------------------------
# The main GPT-2 model

@dataclass
class GPTConfig:
    vocab_size : int = 50304
    n_layer : int = 12
    n_head : int = num_heads
    n_embd : int = hidden_size

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ))
        self.lm_head = CastedLinear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight.data.zero_()
        self._block_mask_cache = dict()

    def _get_block_mask(self, T):
        if T in self._block_mask_cache: # use cache
            return self._block_mask_cache[T]
        self._block_mask_cache = dict() # reset cache
        H = num_heads if sfa_args.alibi else None
        if do_sfa or sfa_args.alibi:
            def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx
            bm = create_block_mask(causal_mask, None, H, T, T, device='cuda', _compile=True)
            self._block_mask_cache[T] = bm
            return bm
        else:
            self._block_mask_cache[T] = None
            return None
    def forward(self, idx, targets=None, return_logits=True):
        _, T = idx.size()
        block_mask = self._get_block_mask(T)
        x = self.transformer.wte(idx)
        x = F.rms_norm(x, (x.size(-1),))
        for block in self.transformer.h:
            x = block(x, block_mask=block_mask)
        x = F.rms_norm(x, (x.size(-1),))

        if targets is not None:
            logits = self.lm_head(x)
            logits = logits.float() # use tf32/fp32 for logits
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            logits = logits.float() # use tf32/fp32 for logits
            loss = None
        if not return_logits:
            logits = None

        return logits, loss

    def infini_reset_memory(self):
        assert do_infini
        for block in self.transformer.h:
            block.attn.infini_reset_memory()

# -----------------------------------------------------------------------------
# Chunking function for Infini-Attention

def chunk_sequence(idx, targets, segment_size=2048):
    """Chunk a sequence into segments for Infini-Attention processing"""
    B, T = idx.size()
    chunks = []

    for start_idx in range(0, T, segment_size):
        end_idx = min(start_idx + segment_size, T)
        idx_chunk = idx[:, start_idx:end_idx]
        targets_chunk = targets[:, start_idx:end_idx] if targets is not None else None
        weight = idx_chunk.size(1) / T
        chunks.append((idx_chunk, targets_chunk, weight))

    return chunks

# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

def _peek_data_shard(filename):
    # only reads the header, returns header data
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
    if header[0] != 20240520:
        print("ERROR: magic number mismatch in the data .bin file!")
        print("---> HINT: Are you passing in a correct file with --input_bin?")
        print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
        print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
        exit(1)
    assert header[1] == 1, "unsupported version"
    ntok = header[2] # number of tokens (claimed)
    return ntok # for now just return the number of tokens

def _load_data_shard(filename):
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
        assert header[0] == 20240520, "magic number mismatch in the data .bin file"
        assert header[1] == 1, "unsupported version"
        ntok = header[2] # number of tokens (claimed)
        # the rest of it are tokens, stored as uint16
        tokens = np.frombuffer(f.read(), dtype=np.uint16)
    assert len(tokens) == ntok, "number of tokens read does not match header?"
    return tokens

class DataLoader:
    def __init__(self, filename_pattern, B, T, num_steps):
        self.B = B
        self.T = T
        self.num_steps = num_steps

        # glob files that match the pattern
        self.files = sorted(glob.glob(filename_pattern))
        assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

        # load and validate all data shards, count number of tokens in total
        ntok_total = 0
        for fname in self.files:
            shard_ntok = _peek_data_shard(fname)
            assert shard_ntok >= B * T + 1
            ntok_total += int(shard_ntok)
        self.ntok_total = ntok_total

        # kick things off
        self.reset()

    def reset(self):
        self.current_shard = 0
        self.current_position = self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])
    def get_state(self):
        self.current_shard = 0
        self.current_position = self.B * self.T
        return dict(
            current_shard=self.current_shard,
            current_position=self.current_position
        )
    def set_state(self, state_dict):
        self.current_shard = state_dict['current_shard']
        self.current_position = state_dict['current_position']
        self.tokens = _load_data_shard(self.files[self.current_shard])
    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + 1) % len(self.files)
        self.current_position = self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def next_batch(self):
        B = self.B
        T = self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance current position and load next shard if necessary
        self.current_position += B * T
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.advance()
        return x.cuda(), y.cuda()

# -----------------------------------------------------------------------------
# int main

@dataclass
class Hyperparameters:
    # data hyperparams
    input_bin : str = './data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
    input_val_bin : str = './data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
    # optimization hyperparams
    batch_size : int = 8 * int(sfa_args.bs_mul * 64*1024 / sfa_args.train_seq_length) # batch size, in sequences, across all devices, assuming 8 devices
    sequence_length : int = sfa_args.train_seq_length
    num_iterations : int = 4578 # number of iterations to run
    warmup_iters : int = 0
    warmdown_iters : int = 1308 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
    weight_decay : float = 0
    # evaluation and logging hyperparams
    val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
    val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
args = Hyperparameters()
args.device_batch_size = int(args.batch_size / 8)  # batch size, in sequences, per device

assert torch.cuda.is_available()
device = f'cuda:0'
torch.cuda.set_device(device)
# convenience variables
B, T = args.device_batch_size, args.sequence_length
# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % B == 0
train_accumulation_steps = int(8 / sfa_args.bs_mul)


def get_val_loader(Tval):
    Bval = max(int(args.device_batch_size * T / Tval), 1)
    assert args.val_tokens % (Bval * Tval) == 0
    val_steps = args.val_tokens // (Bval * Tval)
    val_loader = DataLoader(args.input_val_bin, Bval, Tval, val_steps)
    return val_loader

train_loader = DataLoader(args.input_bin, B, T, None)
val_loader_4k = get_val_loader(4*1024)
val_loader_16k = get_val_loader(16*1024)
val_loader_64k = get_val_loader(64*1024)
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files", console=True)
x, y = train_loader.next_batch()

num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=num_heads, n_embd=768))
model = model.cuda().bfloat16()
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model, dynamic=False)

ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)

# init the optimizer(s)
params = list(model.transformer.h.parameters())
matrix_params = [p for p in params if p.ndim == 2]
vector_and_scalar_params = [p for p in params if p.ndim < 2]
other_params = [p for p in params if p.ndim > 2]
assert len(other_params) == 0
if do_log_n_trick: assert len(vector_and_scalar_params) > 0
optimizer1 = torch.optim.Adam([model.transformer.wte.weight], lr=0.3,   betas=(0.9, 0.95), fused=True)
optimizer2 = torch.optim.Adam([model.lm_head.weight] + vector_and_scalar_params, lr=0.002, betas=(0.9, 0.95), fused=True)
optimizer3 = Muon(matrix_params,           lr=0.02,  momentum=0.95)
optimizers = [optimizer1, optimizer2, optimizer3]
# learning rate decay scheduler (linear warmup and warmdown)
def get_lr(it):
    assert it <= args.num_iterations
    # 1) linear warmup for warmup_iters steps
    if it < args.warmup_iters:
        return (it+1) / args.warmup_iters
    # 2) constant lr for a while
    elif it < args.num_iterations - args.warmdown_iters:
        return 1.0
    # 3) linear warmdown
    else:
        decay_ratio = (args.num_iterations - it) / args.warmdown_iters
        return decay_ratio
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]

# begin logging
with open(logfile, "w") as f:
    # begin the log by printing this file (the Python code)
    f.write('='*100 + '\n')
    f.write(code)
    f.write('='*100 + '\n')
    f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
    import subprocess
    result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    f.write(f'{result.stdout}\n')
    f.write('='*100 + '\n')

_train_loss = 0.
_pickle_metrics = []
nbatches = 0
training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.time()
# begin training
train_loader.reset()

start_step = 0

def eval_model(val_loader):
    # run validation batches
    time_val = time.time()
    model.eval()
    val_loader.reset()
    val_loss = 0.0
    for _ in range(val_loader.num_steps):
        # reset infini attention memory before each validation sequence
        if do_infini:
            model.infini_reset_memory()
        x_val, y_val = val_loader.next_batch()

        # Handle chunking for infini attention on long sequences
        if do_infini and x_val.size(1) > 2048:
            chunks = chunk_sequence(x_val, y_val)
            chunk_losses = []
            for idx_chunk, targets_chunk, weight in chunks:
                with ctx:
                    with torch.no_grad():
                        _, loss = model(idx_chunk, targets_chunk, return_logits=False)
                        chunk_losses.append(loss.detach() * weight)
            # Average loss across chunks
            val_loss += torch.as_tensor(chunk_losses).sum()
        else:
            with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason
                with torch.no_grad():
                    _, loss = model(x_val, y_val, return_logits=False)
                    val_loss += loss.detach()
        del loss
    val_loss /= val_loader.num_steps
    print0(f"time_val: {(time.time() - time_val):.4f}")
    return val_loss


for step in range(start_step, args.num_iterations + 1):
    last_step = (step == args.num_iterations)
    # This effectively ignores timing first 10 steps, which are slower for weird reasons.
    # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
    # steps with dummy data first, and then re-initialize the model and reset the loader.
    if step == 10:
        training_time_ms = 0
        t0 = time.time()
    timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val

    # once in a while evaluate the validation dataset
    do_val = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)
    if do_val:
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        print_max_alloc()
        print0("begin 4k val")
        val_loss_4k = eval_model(val_loader_4k)
        print0("4k done, begin 16k")
        val_loss_16k = eval_model(val_loader_16k)
        print0("16k done, begin 64k")
        val_loss_64k = eval_model(val_loader_64k)
        print0("64k done")
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.empty_cache()
        gc.collect()
        print0(f'step:{step}/{args.num_iterations} {val_loss_4k:.4f},{val_loss_16k:.4f},{val_loss_64k:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')

        metrics = {
            f'val_loss_4k': val_loss_4k.item(),
            f'val_loss_16k': val_loss_16k.item(),
            f'val_loss_64k': val_loss_64k.item(),
            'step':step,
        }
        assert isinstance(_train_loss, float), f"train_loss is not a float, but {_train_loss}"
        if _train_loss > 0.:
            metrics[f'train_loss_{sfa_args.train_seq_length}'] = float(_train_loss / nbatches) # put val/train on same scale
        # reset
        _train_loss = 0.
        nbatches = 0
        if not sfa_args.testrun:
            wandb.log(metrics)
            ## also put metrics in a pickle file using torch.save
            _pickle_metrics.append(metrics)
            torch.save(_pickle_metrics, f"logs/{run_id}/metrics.pt")

        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    if last_step:
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        # save the state of the training process
        log = dict(step=step,
                   code=code,
                   _config=_config,
                   model=model.state_dict(),
                   optimizers=[opt.state_dict() for opt in optimizers],
                   scheduler=[sched.state_dict() for sched in schedulers],
                   train_loader_state=train_loader.get_state())
        torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    # bit confusing: we want to make sure to eval on 0th iteration
    # but also after the very last iteration. so we loop for step <= num_iterations
    # instead of just < num_iterations (one extra due to <=), only to do
    # the validation/sampling one last time, and then we break right here as we're done.
    if last_step:
        break

    # --------------- TRAINING SECTION BEGIN -----------------
    model.train()
    ntraintokens_seen = 0
    for i in range(1, train_accumulation_steps+1):
        # reset infini attention memory before each sequence
        if do_infini:
            model.infini_reset_memory()

        ntraintokens_seen += x.numel()

        # Handle chunking for infini attention on long sequences
        if do_infini and x.size(1) > 2048:
            chunks = chunk_sequence(x, y)
            total_loss = 0
            for idx_chunk, targets_chunk, weight in chunks:
                with ctx:
                      _, loss = model(idx_chunk, targets_chunk, return_logits=False)
                      total_loss += loss * weight
            train_loss = total_loss.detach()
            _train_loss += float(train_loss.item() / train_accumulation_steps)
            total_loss.backward()
        else:
            # forward pass
            with ctx:
                _, loss = model(x, y, return_logits=False)
                train_loss = loss.detach()
                _train_loss += float(train_loss.item() / train_accumulation_steps)
            loss.backward()
        # advance the dataset for the next batch
        x, y = train_loader.next_batch()
    for p in model.parameters():
        p.grad /= train_accumulation_steps
    nbatches += 1
    # step the optimizers and schedulers
    for opt, sched in zip(optimizers, schedulers):
        opt.step()
        sched.step()
    model.zero_grad(set_to_none=True)
    # --------------- TRAINING SECTION END -------------------
    # everything that follows now is just diagnostics, prints, logging, etc.
    approx_time = training_time_ms + 1000 * (time.time() - t0)
    print0(f"st {step+1}/{args.num_iterations}|tr_loss:{train_loss.item():.4f}|tr_tokens {ntraintokens_seen}|tr_time:{approx_time:.0f}ms avg_step_time:{approx_time/timed_steps:.2f}ms", console=True)

print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

# -------------------------------------------------------------------------
