import os
import time
import gc
from typing import Tuple
from pathlib import Path
import math
import sys
with open(sys.argv[0]) as f:
    code = f.read()
import uuid
import glob
import random
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
import gc


# -----------------------------------------------------------------------------

# ==== scale free

def do_gc():
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
flex_attention = torch.compile(flex_attention, dynamic=False)
create_block_mask = torch.compile(create_block_mask, dynamic=False)

import wandb
import torch._dynamo

## without this, we error out
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.cache_size_limit = 128

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--runname', type=str, default="testing")
parser.add_argument('--testrun', action='store_true')
parser.add_argument("--run_id", type=str, default="")
parser.add_argument("--resume_step", type=int, default=0)
parser.add_argument("--seed", type=int, default=0)

parser.add_argument("--sfa", action='store_true')
parser.add_argument("--nousyarn", action='store_true')
parser.add_argument("--sfa_and_p_rope", action='store_true')
parser.add_argument("--sfa_and_rope", action='store_true')
parser.add_argument("--sfa_and_nope", action='store_true')
parser.add_argument("--p_rope", action='store_true')
parser.add_argument("--rope", action='store_true')
parser.add_argument("--log_n_trick", action='store_true')
parser.add_argument("--log_n_trick_and_nope", action='store_true')
parser.add_argument("--log_n_trick_and_p_rope", action='store_true')
parser.add_argument("--log_n_trick_and_ntk_aware", action='store_true')
parser.add_argument("--alibi", action='store_true')
parser.add_argument("--nope", action='store_true')
parser.add_argument("--ntk_aware", action='store_true') # section 3.1 yarn paper

parser.add_argument("--tau", default=10., type=float)
parser.add_argument("--bs_mul", default=1., type=float)
parser.add_argument("--train_seq_length", default=4*1024, type=int)
parser.add_argument("--base_theta", default=10_000, type=float)
parser.add_argument("--train_from_scratch", action='store_true')

parser.add_argument("--ctx_len", type=int, default=4096)
sfa_args = parser.parse_args()
log_peak_memory = True

if sfa_args.run_id == "":
    run_id = uuid.uuid4()
    run_id = str(run_id)[:4]
    run_id = f"{sfa_args.runname}__{run_id}"
else:
    run_id = sfa_args.run_id


do_load_existing_model = sfa_args.resume_step > 0 and sfa_args.run_id != ""
if not do_load_existing_model: assert sfa_args.resume_step == 0, "if not loading a model, resume_step must be 0"

ctx_len = sfa_args.ctx_len
if not do_load_existing_model:
    parent = Path(f"logs_nih_scratch/{run_id}")
    parent.mkdir(exist_ok=True, parents=True)
    metrics_file = f'logs_nih_scratch/{run_id}/nih_metrics_{ctx_len}_s{sfa_args.seed}.pt'
else:
    metrics_file = f'logs/{run_id}/nih_metrics_{ctx_len}_s{sfa_args.seed}.pt'


if do_load_existing_model:
    load_path = './logs/%s/state_step%06d.pt'%(run_id, sfa_args.resume_step)
    assert os.path.exists(load_path), load_path
    state_dict = torch.load(load_path)


    ## copy arguments from the pretrain run
    if sfa_args.runname == "testing":
        sfa_args.runname = state_dict['_config']['runname']
    sfa_args.rope = state_dict['_config']['rope']
    sfa_args.nope = state_dict['_config']['nope']
    sfa_args.base_theta = state_dict['_config']['base_theta']
    sfa_args.sfa = state_dict['_config']['sfa']
    sfa_args.sfa_and_rope = state_dict['_config']['sfa_and_rope']
    sfa_args.sfa_and_p_rope = state_dict['_config']['sfa_and_p_rope']
    sfa_args.alibi = state_dict['_config']['alibi']
    sfa_args.train_seq_length = state_dict['_config']['train_seq_length']
    sfa_args.p_rope = state_dict['_config']['p_rope']
    sfa_args.tau = state_dict['_config']['tau']
    sfa_args.log_n_trick = state_dict['_config']['log_n_trick']
    sfa_args.log_n_trick_and_p_rope = state_dict['_config'].get('log_n_trick_and_p_rope', False)
    sfa_args.log_n_trick_and_ntk_aware = state_dict['_config'].get('log_n_trick_and_ntk_aware', False)
    sfa_args.ntk_aware = state_dict['_config'].get('ntk_aware', False)
    sfa_args.sfa_and_nope = state_dict['_config'].get('sfa_and_nope', False)
    sfa_args.log_n_trick_and_nope = state_dict['_config'].get('log_n_trick_and_nope', False)
    sfa_args.nousyarn = state_dict['_config'].get('nousyarn', False)
if sfa_args.sfa_and_p_rope or sfa_args.p_rope or sfa_args.log_n_trick_and_p_rope:
    pos_emb = 'p_rope'
elif sfa_args.sfa_and_rope or sfa_args.rope or sfa_args.log_n_trick or sfa_args.log_n_trick_and_ntk_aware or sfa_args.ntk_aware:
    pos_emb = 'rope'
elif sfa_args.alibi or sfa_args.sfa or sfa_args.nope or sfa_args.sfa_and_nope or sfa_args.log_n_trick_and_nope:
    pos_emb = 'none'
elif sfa_args.nousyarn:
    pos_emb = 'nousyarn'
else:
    raise ValueError(f"unknown positional embedding mode")
print("config:", sfa_args.__dict__)

do_sfa = any([sfa_args.sfa, sfa_args.sfa_and_p_rope, sfa_args.sfa_and_rope, sfa_args.sfa_and_nope])
do_log_n_trick = any([sfa_args.log_n_trick, sfa_args.log_n_trick_and_p_rope, sfa_args.log_n_trick_and_ntk_aware, sfa_args.log_n_trick_and_nope])

# ---- seeds
torch.manual_seed(sfa_args.seed)
torch.cuda.manual_seed(sfa_args.seed)
torch.cuda.manual_seed_all(sfa_args.seed)
import random; import numpy as np
random.seed(sfa_args.seed)
np.random.seed(sfa_args.seed)

# ---- model setting
num_heads = 6
hidden_size = 768
max_position_embeddings = 64*1024

# ---- setup


_config = sfa_args.__dict__
_config['run_id'] = run_id

if not sfa_args.testrun:
    runname = sfa_args.runname
    wandb.init(project="ft_nih_gpt2", config=_config, name=runname)

# CUDNN attention is ~4ms faster than Flash, but doesn't get selected by default in PyTorch 2.5.1
from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
enable_cudnn_sdp(True)
enable_flash_sdp(False)
enable_mem_efficient_sdp(False)
enable_math_sdp(False)

if do_load_existing_model:
    logfile = f"logs/{run_id}/log_nih.txt"
else:
    logfile = f"logs_nih_scratch/{run_id}/log_nih_scratch.txt"
print(logfile)
def print(*args, console=False):
    __builtins__.print(*args, flush=True)
    with open(logfile, "a") as f:
        __builtins__.print(*args, file=f, flush=True)

print("config:", sfa_args.__dict__)
def print_max_alloc():
    print(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
        f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)

if os.path.exists(metrics_file):
    _existing_metrics_data = torch.load(metrics_file)
    try:
        if _existing_metrics_data[-1]['step'] < 300 :
            print("-"*100)
            print("WARNING: run seems to fail...")
            print("metrics file exists, but last step is less than 300")
            print("metrics file:")
            print(metrics_file)
            print("last executed step:", _existing_metrics_data[-1]['step'])
            print("full metrics data:", _existing_metrics_data)
            print("continuing anyway...")
            print("-"*100)
        elif _existing_metrics_data[-1]['step'] >= 300 :
            print("-"*100)
            print("metrics file exists already!")
            print("metrics file:")
            print(metrics_file)
            raise ValueError("WARNING: already finished this run!, aborting")
    except:
        print("WARNING: metrics file is corrupted!")
        print("metrics file:")
        print(metrics_file)
        raise ValueError("WARNING: metrics file is corrupted!")

# ---- flex attn functions
def mk_scale_free_score_mod(attn_mask):
    def score_modifier(score, b, h, q_idx, kv_idx):
        ix = q_idx - kv_idx
        lttau = torch.log1p(ix/sfa_args.tau)
        a_t = (2*lttau + 1).sqrt()
        m_t = -2*lttau
        global_score = a_t * score + m_t # 'scale free attention'
        if attn_mask is not None:
            global_score = global_score + attn_mask[b, q_idx]
        return global_score
    return score_modifier

def mk_alibi_score_modifier(attn_mask):
    def score_modifier(score, b, h, q_idx, kv_idx):
        scale = torch.exp2(-((h + 1) * 8.0 / num_heads))
        bias = (kv_idx - q_idx) * scale
        global_score = score + bias
        if attn_mask is not None:
            global_score = global_score + attn_mask[b, q_idx]
        return global_score
    return score_modifier
# ====
# -----------------------------------------------------------------------------
# Muon optimizer

def zeropower_via_svd(G, steps=None):
    U, S, V = G.svd()
    return U @ V.T

@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' \\sim Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    X /= (X.norm() + eps) # ensure top singular value <= 1
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X

zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)

class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - This optimizer assumes that all parameters passed in are 2D.
    - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
    parameters; those should all be optimized by a standard method (e.g., AdamW).
    - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
    - We believe it is unlikely to work well for training with small batch size.
    - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
    - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).

    Arguments:
        lr: The learning rate used by the internal SGD.
        momentum: The momentum used by the internal SGD.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
        backend_steps: The number of iteration steps to use in the backend, if it is iterative.
    """
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
                 backend='newtonschulz5', backend_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
        super().__init__(params, defaults)

    def step(self):

        for group in self.param_groups:

            lr = group['lr']
            momentum = group['momentum']
            zeropower_backend = zeropower_backends[group['backend']]

            total_params = sum(p.numel() for p in group['params'])
            updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
            curr_idx = 0
            for i, p in enumerate(group['params']):
                # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
                g = p.grad
                assert g is not None
                state = self.state[p]
                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(g)
                buf = state['momentum_buffer']
                buf.mul_(momentum).add_(g)
                if group['nesterov']:
                    g = g.add(buf, alpha=momentum)
                g = zeropower_backend(g, steps=group['backend_steps'])
                g *= max(1, g.size(0)/g.size(1))**0.5
                updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
                curr_idx += p.numel()

            # deserialize and apply updates
            curr_idx = 0
            for p in group['params']:
                g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
                p.data.add_(g, alpha=-lr)
                curr_idx += p.numel()

# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model

class CastedLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, x):
        return F.linear(x, self.weight.to(x.dtype))

"""for standard RoPE"""
def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4 # multihead attention
    d = x.shape[3]//2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)
class Rotary(torch.nn.Module):
    def __init__(self, dim, base=sfa_args.base_theta):
        super().__init__()
        self.dim = dim
        self.base = base
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            if sfa_args.ntk_aware and seq_len > sfa_args.train_seq_length:
                s = seq_len / sfa_args.train_seq_length
                D = self.dim
                scaled_base = self.base * (s ** (D/(D-2)))
                inv_freq = 1.0 / (scaled_base ** (torch.arange(0, self.dim, 2).float() / self.dim))
            else:
                inv_freq = self.inv_freq

            t = torch.arange(seq_len, device=x.device).type_as(inv_freq)
            freqs = torch.outer(t, inv_freq).to(x.device)
            self.cos_cached = freqs.cos().bfloat16()
            self.sin_cached = freqs.sin().bfloat16()
        return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]

"""'p-RoPE' https://arxiv.org/abs/2410.06205, with p=0.5"""
class pRoPE(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
        angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.cos = nn.Buffer(theta.cos(), persistent=False)
        self.sin = nn.Buffer(theta.sin(), persistent=False)
    def forward(self, x_BTHD):
        assert self.cos.size(0) >= x_BTHD.size(-3), f"cos.size(0) = {self.cos.size(0)}, x_BTHD.size(-3) = {x_BTHD.size(-3)}"
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)

"""--------------------------------------------------------------------------"""
"""use https://huggingface.co/NousResearch/Yarn-Llama-2-13b-128k/blob/main/modeling_llama_together_yarn.py"""
from einops import rearrange, repeat

def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
        dim=-1,
    )


# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations, dim, base=sfa_args.base_theta, max_position_embeddings=sfa_args.train_seq_length):
    return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))

# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot, high_rot, dim, base=sfa_args.base_theta, max_position_embeddings=sfa_args.train_seq_length):
    low = math.floor(_yarn_find_correction_dim(
        low_rot, dim, base, max_position_embeddings))
    high = math.ceil(_yarn_find_correction_dim(
        high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim-1)  # Clamp values just in case

def _yarn_linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func

def _yarn_get_mscale(scale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0

class NousYaRN(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.
    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration
    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
    This implements the YaRN extension method.
    """

    def __init__(self, dim: int, base=sfa_args.base_theta, interleaved=False,
                 scaling_factor=1.0, pos_idx_in_fp32=True,
                 max_position_embeddings=max_position_embeddings,
                 original_max_position_embeddings=sfa_args.train_seq_length, extrapolation_factor=1,
                 attn_factor=1, beta_fast=32, beta_slow=1,
                 dynamic=True, finetuned=False, device=None):
        """
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
            pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
                otherwise they might be in lower precision.
                This option was added because previously (before 2023-07-02), when we construct
                the position indices, we use the dtype of self.inv_freq. In most cases this would
                be fp32, but if the model is trained in pure bf16 (not mixed precision), then
                self.inv_freq would be bf16, and the position indices are also in bf16.
                Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
                embeddings for some positions will coincide.
                To maintain compatibility with models previously trained in pure bf16,
                we add this option.
            scaling_factor: RotaryEmbedding extended with YaRN scaling.
        """
        super().__init__()

        self.dim = dim
        self.base = float(base)
        self.interleaved = interleaved
        self.scaling_factor = scaling_factor
        self.max_position_embeddings = max_position_embeddings
        self.original_max_position_embeddings = original_max_position_embeddings if original_max_position_embeddings else max_position_embeddings
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        self.pos_idx_in_fp32 = pos_idx_in_fp32
        self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) # Get n-d magnitude scaling corrected for interpolation
        self.dynamic = dynamic
        self.finetuned = finetuned

        # Generate and save the inverse frequency buffer (non trainable)
        if not dynamic:
            self._compute_inv_freq(scaling_factor, device)

        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None

    def _compute_inv_freq(self, scaling_factor, device=None):
        pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
        inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
        inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def _compute_inv_freq_original(self, device=None):
        inv_freq = 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
                                                 dtype=torch.float32) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
        # Reset the tables if the sequence length has changed,
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
        if (seqlen != self._seq_len_cached or self._cos_cached.device != device
            or self._cos_cached.dtype != dtype
            or (self.training and self._cos_cached.is_inference())):
            self._seq_len_cached = seqlen

            if self.dynamic:
                scaling_factor = None
                if seqlen <= self.max_position_embeddings:
                    if self.finetuned:
                        scaling_factor = self.scaling_factor
                else:
                    scaling_factor = seqlen / self.original_max_position_embeddings
                if scaling_factor:
                    self._compute_inv_freq(scaling_factor, device)
                    self.mscale = float(_yarn_get_mscale(scaling_factor) * self.attn_factor)
                else:
                    self._compute_inv_freq_original(device)

            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self.inv_freq.to(torch.float32)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            freqs = torch.outer(t, inv_freq)
            self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
            self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)


    def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        q: (batch, seqlen, nheads, headdim)
        k: (batch, seqlen, nheads, headdim)
        seqlen_offset: can be used in generation where the qkv being passed in is only the last
        token in the batch.
        """
        assert q.shape[1] in [1024, 4096, 16*1024, 64*1024]
        self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
        return apply_rotary_emb_torch(
            q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], True
        ), apply_rotary_emb_torch(
            k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], True
        )
        # return apply_rotary_emb_func(
        #     q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
        #     self.interleaved, True # inplace=True
        # ), apply_rotary_emb_func(
        #     k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
        #     self.interleaved, True # inplace=True
        # )
"""--------------------------------------------------------------------------"""

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_k = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_v = CastedLinear(self.n_embd, self.n_embd, bias=False)
        # output projection
        self.c_proj = CastedLinear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
        if pos_emb == 'rope':
            self.rotary = Rotary(self.head_dim)
        elif pos_emb == 'p_rope':
            max_seq_len = 64*1024
            self.rotary = pRoPE(self.head_dim, max_seq_len)
        elif pos_emb == 'nousyarn':
            self.rotary = NousYaRN(self.head_dim)
        else:
            self.rotary = None
        if do_log_n_trick:
            self.logn_s = torch.nn.Parameter(torch.ones(self.n_head))
            logn_list = [math.log(i+1, math.e) for i in range(max_position_embeddings)] # (T,)
            self.register_buffer("logn", torch.tensor(logn_list))
        self.kv_cache = {'k': None, 'v': None}
    def _rotary_emb(self, q, k):
        if pos_emb == 'rope':
            cos, sin = self.rotary(k)
            k = apply_rotary_emb(k, cos, sin)
            q = apply_rotary_emb(q, cos, sin)
        elif pos_emb == 'p_rope':
            q, k = self.rotary(q), self.rotary(k)
        elif pos_emb == 'nousyarn':
            q, k = self.rotary(q, k)
        elif pos_emb == 'none':
            pass
        else:
            raise ValueError(f"unknown positional embedding type: {sfa_args.pos_emb}")
        return q, k

    def forward(self, x, block_mask=None, attention_mask=None, use_cache=False, q_ixs=None):
        assert (attention_mask <= 0).all(), "attention mask must be negative"
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm

        # Handle KV caching for incremental generation
        if use_cache:
            if self.kv_cache['k'] is None:
                # First pass: initialize cache with full k,v (before rotary)
                self.kv_cache['k'], self.kv_cache['v'] = k, v
            else:
                # Incremental pass: update cache at specific positions
                assert T == 1, "cache update only works for T=1"
                # Update cache in-place at positions specified by q_ixs
                self.kv_cache['k'][torch.arange(B), q_ixs, :, :] = k[:, 0, :, :]
                self.kv_cache['v'][torch.arange(B), q_ixs, :, :] = v[:, 0, :, :]

                # Use cached k,v for attention computation
                k, v = self.kv_cache['k'], self.kv_cache['v']

                # Expand q to match cached sequence length for attention computation
                q_expanded = torch.zeros_like(k)
                q_expanded[torch.arange(B), q_ixs, :, :] = q[:, 0, :, :]
                q = q_expanded

        # Apply rotary embeddings to full sequences (after caching)
        q, k = self._rotary_emb(q, k)
        if do_log_n_trick:
            q = self.logn[None,:q.size(1),None,None] * q * self.logn_s[None,None,:,None]
        q = q.bfloat16(); k = k.bfloat16(); v = v.bfloat16()
        # ## pick attention
        if do_sfa:
            score_mod = mk_scale_free_score_mod(attention_mask)
            y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, score_mod=score_mod)
        elif sfa_args.alibi:
            score_mod = mk_alibi_score_modifier(attention_mask)
            y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, score_mod=score_mod)
        else:
            def default_score_mod(score, b, _h, q_idx, kv_idx): return score + attention_mask[b, q_idx]
            y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, score_mod=default_score_mod)
        if T == 1: # Extract attention output for single token generation
            # Select outputs at the query positions for each batch element
            y = y[torch.arange(B), :, q_ixs, :].unsqueeze(2)  # (B, H, 1, D)

        y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = CastedLinear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj  = CastedLinear(4 * config.n_embd, config.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x, block_mask=None, attention_mask=None, use_cache=False, q_ixs=None):
        x = x + self.attn(F.rms_norm(x, (x.size(-1),)), block_mask=block_mask, attention_mask=attention_mask, use_cache=use_cache, q_ixs=q_ixs)
        x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
        return x

# -----------------------------------------------------------------------------
# The main GPT-2 model

@dataclass
class GPTConfig:
    vocab_size : int = 50304
    n_layer : int = 12
    n_head : int = num_heads
    n_embd : int = hidden_size

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ))
        self.lm_head = CastedLinear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight.data.zero_()
    def _get_block_mask(self, B, T, max_seq_len, q_ixs):
        """Create block mask for causal attention with flex_attention"""
        H = num_heads if sfa_args.alibi else None

        if T > 1:
            # Full sequence processing - standard causal mask
            def causal_mask(b, h, q_idx, kv_idx):
                return q_idx >= kv_idx
            return create_block_mask(causal_mask, None, H, T, T, device='cuda', _compile=True)
        else:
            # Single token generation - mask to specific query positions
            def causal_mask(b, h, q_idx, kv_idx):
                return (q_idx >= kv_idx) & (q_ixs[b] == q_idx)
            return create_block_mask(causal_mask, B, H, max_seq_len, max_seq_len, device='cuda', _compile=True)
    def reset_kv_cache(self):
        for x in self.transformer.h:
            x.attn.kv_cache['k'] = None
            x.attn.kv_cache['v'] = None
        do_gc()

    def forward(self, idx, targets=None, return_logits=True, attention_mask=None, return_last_logit=False, use_cache=False):
        B, T = idx.size()
        if attention_mask is not None: # convert booleans to -inf/0 for causal attention
            assert attention_mask.min() >= 0 and attention_mask.max() == 1, "attention mask must be 0/1"
            q_ixs = (attention_mask == 0).to(torch.int).argmax(dim=1) - 1  # find the first masked token (0), then go back one to get the last unmasked token
            attention_mask = torch.where(attention_mask.bool(),
                                         torch.tensor(0.0),
                                         torch.tensor(float('-inf'))).to(device=device, dtype=torch.bfloat16)
            max_seq_len = attention_mask.size(1)
        else:
            q_ixs = None

        block_mask = self._get_block_mask(B, T, max_seq_len, q_ixs)
        x = self.transformer.wte(idx)
        x = F.rms_norm(x, (x.size(-1),))
        for block in self.transformer.h:
            x = block(x, block_mask=block_mask, attention_mask=attention_mask, use_cache=use_cache, q_ixs=q_ixs)
        x = F.rms_norm(x, (x.size(-1),))


        if targets is not None:
            logits = self.lm_head(x)
            logits = logits.float() # use tf32/fp32 for logits
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)
        else:
            if return_last_logit:
                logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            else:
                logits = self.lm_head(x) # note: using list [-1] to preserve the time dim
            logits = logits.float() # use tf32/fp32 for logits
            loss = None
        if not return_logits:
            logits = None

        return logits, loss

    @torch.no_grad()
    def generate_greedy(model, input_ids, max_new_tokens=100, eos_token_id=50256, attention_mask=None, use_cache=False):
        """
        Greedy generation with fixed block size.

        Strategy:
        1. Pad input to next_power_of_2(T + max_new_tokens) with padding tokens
        2. Use attention_mask to track valid vs padding positions
        3. In each step, replace the next padding token with the predicted token
        4. Use KV cache when use_cache=True for faster incremental generation
        """
        model.eval()
        device = next(model.parameters()).device
        input_ids = input_ids.to(device)
        B, T = input_ids.size()

        # Pad to power of 2 block size for efficient attention
        block_size = next_power_of_2(T + max_new_tokens)
        amount_padding = block_size - T

        # Pad input_ids with eos tokens (acts as padding)
        padding = torch.full((B, amount_padding), eos_token_id, dtype=input_ids.dtype, device=device)
        input_ids = torch.cat([input_ids, padding], dim=-1)

        # Extend attention mask - True for valid tokens, False for padding
        attention_mask_padding = torch.zeros((B, amount_padding), dtype=torch.bool, device=device)
        attention_mask = torch.cat([attention_mask, attention_mask_padding], dim=-1)

        # Track current position for each sequence (last valid token index)
        current_pos = torch.tensor([attention_mask[i].long().sum() - 1 for i in range(B)], device=device)
        finished = torch.zeros(B, dtype=torch.bool, device=device)

        # Reset caches
        model.reset_kv_cache()
        for step in range(max_new_tokens):
            # Forward pass - full sequence first time, single tokens if using cache
            if use_cache and step > 0:
                # Extract current tokens for each sequence
                current_tokens = input_ids[torch.arange(B), current_pos].unsqueeze(1)  # (B, 1)
                logits, _ = model(current_tokens, return_logits=True, attention_mask=attention_mask, use_cache=True)
                # Get predictions from single token output
                next_tokens = torch.argmax(logits[:, 0, :], dim=-1)  # (B,)
            else:
                # Full forward pass
                logits, _ = model(input_ids, return_logits=True, attention_mask=attention_mask, use_cache=use_cache)
                # Get predictions from current positions
                next_tokens = logits[torch.arange(B), current_pos, :].argmax(dim=-1)  # (B,)

            # Update sequences that haven't finished
            for b in range(B):
                if finished[b]:
                    continue

                next_pos = current_pos[b] + 1
                if next_tokens[b] != eos_token_id and next_pos < block_size:
                    # Add predicted token
                    input_ids[b, next_pos] = next_tokens[b]
                    attention_mask[b, next_pos] = True
                    current_pos[b] = next_pos
                else:
                    # Mark as finished
                    finished[b] = True

            # Stop if all sequences finished
            if finished.all():
                break

        # Cleanup
        model.reset_kv_cache()
        torch.cuda.empty_cache()
        do_gc()
        return input_ids

# -----------------------------------------------------------------------------
# int main

@dataclass
class Hyperparameters:
    # optimization hyperparams
    batch_size : int = 8 * int(sfa_args.bs_mul * 64*1024 / sfa_args.train_seq_length)  # batch size, in sequences, per device, assuming 8 devices
    num_iterations : int = 300 # number of iterations to run
    warmup_iters : int = 100
    warmdown_iters : int = 100 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
    weight_decay : float = 0
    # evaluation and logging hyperparams
    val_loss_every : int = 50 # every how many steps to evaluate val loss? 0 for only at the end
args = Hyperparameters()
args.device_batch_size = args.batch_size // 8

device = f'cuda:0'
torch.cuda.set_device(device)

# Create the tiktoken wrapper
from dataset import TiktokenWrapper, next_power_of_2, C4NeedleTrainEval
tokenizer = TiktokenWrapper(encoding_name="gpt2")

"""train"""
train_needle = C4NeedleTrainEval(
    tokenizer=tokenizer,
    ctx_len=sfa_args.ctx_len,
    batch_size=int(8 * sfa_args.bs_mul),
    do_train=True,
)


"""val"""
val_needle_4k = C4NeedleTrainEval(
    tokenizer=tokenizer,
    ctx_len=4096,
    batch_size=int(8 * sfa_args.bs_mul),
    do_train=False,
)

val_needle_16k = C4NeedleTrainEval(
    tokenizer=tokenizer,
    ctx_len=16384,
    batch_size=int(8 * sfa_args.bs_mul),
    do_train=False,
)

val_needle_64k = C4NeedleTrainEval(
    tokenizer=tokenizer,
    ctx_len=65536,
    batch_size=int(2 * sfa_args.bs_mul),
    do_train=False,
)

# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=num_heads, n_embd=768))
model = model.cuda().bfloat16()
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model, dynamic=False)

if do_load_existing_model:
    model_state_dict = state_dict['model']
    model.load_state_dict(model_state_dict, strict=True)

# init the optimizer(s)
params = list(model.transformer.h.parameters())
matrix_params = [p for p in params if p.ndim == 2]
vector_and_scalar_params = [p for p in params if p.ndim < 2]
other_params = [p for p in params if p.ndim > 2]
assert len(other_params) == 0
if do_log_n_trick: assert len(vector_and_scalar_params) > 0
optimizer1 = torch.optim.Adam([model.transformer.wte.weight], lr=0.3,   betas=(0.9, 0.95), fused=True)
optimizer2 = torch.optim.Adam([model.lm_head.weight] + vector_and_scalar_params, lr=0.002, betas=(0.9, 0.95), fused=True)
optimizer3 = Muon(matrix_params,           lr=0.02,  momentum=0.95)
optimizers = [optimizer1, optimizer2, optimizer3]
# # learning rate decay scheduler (linear warmup and warmdown)
def get_lr(it):
    assert it <= args.num_iterations
    # 1) linear warmup for warmup_iters steps
    if it < args.warmup_iters:
        return (it+1) / args.warmup_iters
    # 2) constant lr for a while
    elif it < args.num_iterations - args.warmdown_iters:
        return 1.0
    # 3) linear warmdown
    else:
        decay_ratio = (args.num_iterations - it) / args.warmdown_iters
        return decay_ratio
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
if do_load_existing_model:
    for opt, opt_state in zip(optimizers, state_dict['optimizers']):
        opt.load_state_dict(opt_state)
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)

def eval_c4_needle(fineweb, length=None, niter=None):
    assert length is not None
    print("fineweb needle test -- ", length)
    loader = fineweb.get_dataloader()
    model.eval()
    total = 0
    correct_cities = 0
    correct_numbers = 0
    correct_numbers_and_cities = 0
    t = time.time()
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= niter: break
            with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason
                ## find second instance of eos
                start_ixs = batch['start_ixs']
                answers = batch['expected_raw_output']
                input_ids = batch['input_ids'].cuda()

                # we can make input_ids smaller by using start_ixs
                # the start_ixs indicate where the Answer starts
                Tmax = max(start_ixs)
                input_ids = input_ids[:, :Tmax]

                # setup attention mask
                attention_mask = torch.ones_like(input_ids)
                for j, start_ix in enumerate(start_ixs):
                    attention_mask[j, start_ix:] = False

                MAX_NEW_TOKENS = 41
                outputs = model.generate_greedy(input_ids,
                                                max_new_tokens=MAX_NEW_TOKENS,
                                                attention_mask=attention_mask,
                                                use_cache=True)
                for b, start_ix in enumerate(start_ixs):
                    if len(outputs[b]) <= len(input_ids[b]): raise ValueError("output must be longer than input!")
                    true_answers = answers[b]

                    try:
                        eof_ix = outputs[b].tolist().index(tokenizer.eos_token_id) + 1
                    except:
                        eof_ix = len(outputs[b]) - 1
                    answer_chunk = outputs[b, start_ix:eof_ix]
                    decoded = tokenizer.decode(answer_chunk.tolist())

                    for city, number in zip(batch['cities'][b], batch['numbers'][b]):
                        total += 1
                        if str(city) in (decoded):
                            correct_cities += 1
                        if str(number) in (decoded):
                            correct_numbers += 1
                        if f"{city}={number}" in str(decoded):
                            correct_numbers_and_cities += 1
                    print("OUTPUT:", decoded, "| CORRECT ANSWER:", true_answers)
                print(f"val cities:{100*correct_cities/total:.1f}%, numbers:{100*correct_numbers/total:.1f}%, numbers_and_cities:{100*correct_numbers_and_cities/total:.1f}, time: {time.time()-t:.1f}")
                t = time.time()
        print()
        metrics = {f'val_acc_cities_{length}': correct_cities/total, f'val_acc_numbers_{length}': correct_numbers/total, f'val_acc_numbers_and_cities_{length}': correct_numbers_and_cities/total}
        print(f"fineweb needle {length} accuracy_cities: {correct_cities/total*100:.1f}, accuracy_numbers: {correct_numbers/total*100:.1f}, accuracy_both: {correct_numbers_and_cities/total*100:.1f}")
        do_gc()
    return metrics

def train_iter(train_loader):
    ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
    model.reset_kv_cache()
    model.train()
    train_accumulation_steps = int(8 / sfa_args.bs_mul)
    _train_loss = 0.
    for i in range(1, train_accumulation_steps+1):
        # forward pass
        try:
            batch = next(train_loader)
        except:
            raise Exception("train_loader exhausted")
        input_ids = batch['input_ids'].cuda()
        labels = batch['labels'].cuda()
        attention_mask = batch['attention_mask'].cuda()
        seq_len = input_ids.size(-1)
        log2_seq_len = math.log(seq_len, 2)
        assert abs(log2_seq_len - int(log2_seq_len)) <= 0, f"sequence length must be a power of 2, got {seq_len}"
        with ctx:
            _, loss = model(input_ids, labels, return_logits=False, attention_mask=attention_mask, use_cache=False)
            train_loss = loss.detach()
            _train_loss += float(train_loss.item() / train_accumulation_steps)
        # backward pass
        loss.backward()
    for p in model.parameters():
        p.grad /= train_accumulation_steps

    for opt, sched in zip(optimizers, schedulers):
        opt.step()
        sched.step()
    model.zero_grad(set_to_none=True)
    return _train_loss

def train_model():
    acc_train_loss = 0.
    nbatches = 0
    training_time_ms = 0
    # start the clock
    torch.cuda.synchronize()
    t0 = time.time()
    all_metrics = []
    train_loader = None

    for step in range(args.num_iterations + 1):
        last_step = (step == args.num_iterations)
        if step == 10:
            training_time_ms = 0
            t0 = time.time()
        timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val

        if last_step or (step > 0 and step % args.val_loss_every == 0):
            print_max_alloc()
            val_metrics_4k = eval_c4_needle(val_needle_4k, "4k", niter=50)
            val_metrics_16k = eval_c4_needle(val_needle_16k, "16k", niter=50)
            val_metrics_64k = eval_c4_needle(val_needle_64k, "64k", niter=50)
            print_max_alloc()
            metrics = val_metrics_4k | val_metrics_16k | val_metrics_64k
            metrics['step'] = step
            if acc_train_loss > 0.:
                metrics[f'train_loss'] = float(acc_train_loss / nbatches) # put val/train on same scale
                nbatches = 0
                acc_train_loss = 0.
            print(metrics)
            all_metrics.append(metrics)
            if not sfa_args.testrun:
                wandb.log(metrics)
                torch.save(all_metrics, metrics_file)
        do_gc()

        if last_step:
            break

        ## acc train loss
        if train_loader is None: train_loader = iter(train_needle.get_dataloader())
        train_loss = train_iter(train_loader)
        acc_train_loss += train_loss
        nbatches += 1

        # everything that follows now is just diagnostics, prints, logging, etc.
        approx_time = training_time_ms + 1000 * (time.time() - t0)
        print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss:.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")


train_model()