#!/usr/bin/env python3
"""
miniGPT Training Script for A100 GPU
Modified from JAX Stack tutorial for single A100 use with future multi-GPU scaling support
"""

import jax
from jax import vmap
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
import flax.nnx as nnx
import optax
from dataclasses import dataclass
import grain.python as pygrain
import pandas as pd
import tiktoken
import time
import subprocess
import os
import orbax.checkpoint as orbax
import matplotlib.pyplot as plt
from functools import partial

from rtpt import RTPT
from aim import Run, Figure
import logging
logging.getLogger('aim.sdk.objects.figure').setLevel(logging.ERROR)
import numpy as np
import nvtx
from jax.random import PRNGKey
from typing import Optional, Tuple, Callable

from rope_attention import MultiHeadAttention, dot_product_attention as rope_dot_product_attention, apply_rope
from fma.pallas_retrieval import causal_attn as fma_causal_attn
#from flax.nnx.attention import MultiHeadAttention
from tokenizers import Tokenizer

# Verify GPU detection
print("JAX devices:", jax.devices())
print("Device type:", jax.devices()[0].device_kind)

# Download dataset if not exists
if False and not os.path.exists('TinyStories-train.txt'):
    print("Downloading TinyStories dataset...")
    subprocess.run([
        'wget', 
        'https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true',
        '-O', 'TinyStories-train.txt'
    ], check=True)
    print("Dataset downloaded successfully!")

# Setup mesh for single A100 (ready for future scaling)
mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
jax.set_mesh(mesh)

# Initialize tokenizer
CULL_VOCAB_TO = 8192
def cull_mergeable_ranks(ranks):
    return {k: v for k, v in ranks.items() if v < CULL_VOCAB_TO - 2}
gpt2_tokenizer = tiktoken.get_encoding("gpt2")
tokenizer = tiktoken.Encoding(
        name = f"gpt2_{CULL_VOCAB_TO}",
        pat_str = gpt2_tokenizer._pat_str,
        mergeable_ranks = cull_mergeable_ranks(gpt2_tokenizer._mergeable_ranks),
        special_tokens = {
            "<|endoftext|>": CULL_VOCAB_TO - 1,
            "<|mask|>": CULL_VOCAB_TO - 2,
        },
    )

# Hyperparameters
vocab_size = tokenizer.n_vocab
#dataset = "tinystories8192"
#dataset = "openwebtext8192"
#dataset = "gutenberg16384"
#dataset = "longstackv2_bucket_5"
dataset = "science_tech_2e15"
tokenizer = Tokenizer.from_file(f"data/{dataset}/tokenizer.json")
vocab_size = tokenizer.get_vocab_size()
SCALE = 2
base_scale_for_lr = 4.0
base_scale_for_residual = 4.0
FF_SCALE = 4
num_transformer_blocks = 4 * SCALE - 1
maxlen = 2**15
embed_dim = 256 * SCALE
num_heads = 4 * SCALE
global_num_heads = 4 * SCALE
head_dim = embed_dim // num_heads
global_head_dim = embed_dim // global_num_heads
feed_forward_dim = FF_SCALE * 256 * SCALE
residual_scale = (base_scale_for_residual / SCALE) ** 0.0
base_batch_size = 2
batch_size = 8  # Adjusted for A100
minibatch_size = 8
assert batch_size % minibatch_size == 0, "batch_size must be multiple of minibatch_size"
accumulation_steps = batch_size // minibatch_size
num_epochs = 4
#position_encodings = "absolute"
#position_encodings = "nope"
position_encodings = "rope"
attn_dtype_str = "bfloat16"
DTYPE_LOOKUP = {
    "float32": jnp.float32,
    "bfloat16": jnp.bfloat16,
    "float16": jnp.float16,
}
attn_dtype = DTYPE_LOOKUP[attn_dtype_str]
ff_dtype_str = "bfloat16"
ff_dtype = DTYPE_LOOKUP[ff_dtype_str]
final_dtype_str = "bfloat16"
final_dtype = DTYPE_LOOKUP[final_dtype_str]
generation_length = 128
hybrid_attention = True # Alternates local and global attention, with second layer also local
sliding_window = 256 # Set to None for all global attention
local_temperature = 1.0  # Used for local attention in hybrid attention mode
normalize_qk = False
use_prenorm_not_postnorm = True
use_embedding_tieing = True
#intermediate_window_size = 256
intermediate_temperature = 1.0  # Used for intermediate attention in hybrid attention mode
global_window_size = 2**14  # Used for global attention in hybrid attention mode
intermediate_window_size = global_window_size
global_temperature = 1.0  # Used for global attention in hybrid attention mode
base_adam_lr = 3e-3 * batch_size / base_batch_size * (base_scale_for_lr / SCALE)**0.5
final_adam_lr_decay = 1e-1
warmup_steps = 1000
total_train_tokens = 12_000_000_000  # Total tokens to train on
#total_train_tokens = 1_500_000_000  # Total tokens to train on
adam_b2 = 0.95 # Adam beta2 value, used for AdamW optimizer
adam_b1 = 0.9 # Adam beta1 value, used for AdamW optimizer
adam_weight_decay = 0.01 # AdamW weight decay value
muon_lr = 1e-3 * batch_size / base_batch_size  # Learning rate for Muon optimizer
use_muon = False  # Use Muon optimizer for some layers
rope_base = 1e5
dump_qkv = False
generation_is_enabled = False
record_location_loss = True
experiment_set = "test2"
#experiment_label = f"CUDNN_con4k_win4k_ev2_3Btokens_prenorm_scale{SCALE}"
#experiment_label = f"CUDNN_con32k_win32k_ev2_{int(total_train_tokens//1_000_000_000)}Btokens_prenorm_scale{SCALE}"
experiment_label = f"cret_con32k_QK64_B4k_NR4_ev2_{int(total_train_tokens//1_000_000_000)}Btokens_prenorm_scale{SCALE}"

def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

def sharded_rope_local_cudnn_attention(q,k,v, **kwargs):
    def inner(q,k,v):
        q = apply_rope(q, base=rope_base)
        k = apply_rope(k, base=rope_base)
        print(f"dtypes: q: {q.dtype}, k: {k.dtype}, v: {v.dtype}")
        return jax.nn.dot_product_attention(
            q, k, v,
            is_causal=True,
            implementation="cudnn",
            local_window_size=(sliding_window, 0),
        )
    return jax.shard_map(inner, mesh=mesh, in_specs=P('batch', None, 'model', None), out_specs=P('batch', None, 'model', None), check_vma=False)(q,k,v)

def sharded_rope_causal_cudnn_attention(q,k,v, **kwargs):
    def inner(q,k,v):
        q = apply_rope(q, base=rope_base)
        k = apply_rope(k, base=rope_base)
        return jax.nn.dot_product_attention(
            q, k, v,
            is_causal=True,
            implementation="cudnn",
            local_window_size=None,
        )
    return jax.shard_map(inner, mesh=mesh, in_specs=P('batch', None, 'model', None), out_specs=P('batch', None, 'model', None), check_vma=False)(q,k,v)

@vmap
@partial(vmap, in_axes=-2, out_axes=-2)
def fma_multihead_attention(q,k,v):
    Q = 64
    K = 64
    num_retrievals = 4
    blk_size = 2**12
    bidiagonal = False
    dipole = False
    lse_out, v_out = fma_causal_attn(Q, K, blk_size, num_retrievals, bidiagonal, dipole, q, k, v)
    return v_out

def sharded_rope_fma_attention(q,k,v, **kwargs):
    def inner(q,k,v):
        q = apply_rope(q, base=rope_base)
        k = apply_rope(k, base=rope_base)
        return fma_multihead_attention(q,k,v)
    return jax.shard_map(inner, mesh=mesh, in_specs=P('batch', None, 'model', None), out_specs=P('batch', None, 'model', None), check_vma=False)(q,k,v)


def sharded_rope_dot_product_attention(q,k,v, **kwargs):
    B,T,H,D = q.shape
    def inner(q,k,v):
        return rope_dot_product_attention(q,k,v, **kwargs)
    return jax.shard_map(inner, mesh=mesh, in_specs=P('batch', None, 'model', None), out_specs=P('batch', None, 'model', None), check_vma=False)(q,k,v)

class TransformerBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, sliding_window: int, *, temperature: float = 1.0, rngs: nnx.Rngs, rate: float = 0.1, residual_scale: float = 1.0, prenorm: bool = False):
        self.residual_scale = residual_scale
        self.prenorm = prenorm
        if sliding_window is None:
            attention_fn = sharded_rope_causal_cudnn_attention
        elif sliding_window == 2**14:
            attention_fn = sharded_rope_fma_attention
        else:
            attention_fn = sharded_rope_local_cudnn_attention
        self.mha = MultiHeadAttention(
            num_heads=num_heads,
            in_features=embed_dim,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, 'model')),
            out_kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ('model', None)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
            out_bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
            rngs=rngs,
            apply_rope=False, #position_encodings == "rope",# and sliding_window is None,
            dtype=attn_dtype,
            sliding_window=sliding_window,
            normalize_qk=normalize_qk,
            rope_base=rope_base,
            temperature=temperature,
            attention_fn=attention_fn,
        )
        
        self.dropout1 = nnx.Dropout(rate=rate)
        
        self.layer_norm1 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), (None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
            rngs=rngs
        )
        
        self.linear1 = nnx.Linear(
            in_features=embed_dim,
            out_features=ff_dim,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, 'model')),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
            rngs=rngs,
            dtype=ff_dtype,
        )

        self.linear2 = nnx.Linear(
            in_features=ff_dim,
            out_features=embed_dim,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ('model', None)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
            rngs=rngs,
            dtype=ff_dtype,
        )
        
        self.dropout2 = nnx.Dropout(rate=rate)
        
        self.layer_norm2 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), (None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
            rngs=rngs
        )

    def __call__(self, inputs, training: bool = False, dump_qkv: bool = False):
        input_shape = inputs.shape
        _, seq_len, _ = input_shape
        #mask = causal_attention_mask(seq_len)
        mask = None
        
        if self.prenorm:
            x1 = self.layer_norm1(inputs)
        else:
            x1 = inputs
        attention_output = self.mha(inputs_q=x1, mask=mask, decode=False, dump_qkv=dump_qkv)
        if dump_qkv:
            return attention_output
        attention_output = self.dropout1(attention_output, deterministic=not training)
        if self.prenorm:
            out1 = inputs + attention_output*self.residual_scale
        else:
            out1 = self.layer_norm1(inputs + attention_output*self.residual_scale)
        
        if self.prenorm:
            x2 = self.layer_norm2(out1)
        else:
            x2 = out1
        ffn_output = self.linear1(x2)
        ffn_output = nnx.relu(ffn_output)
        ffn_output = self.linear2(ffn_output)
        ffn_output = self.dropout2(ffn_output, deterministic=not training)

        if self.prenorm:
            out2 = out1 + ffn_output*self.residual_scale
        else:
            out2 = self.layer_norm2(out1 + ffn_output*self.residual_scale)
        return out2

class TokenAndPositionEmbedding(nnx.Module):
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
        if position_encodings == "absolute":
            self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs)

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        token_embedding = self.token_emb(x)
        if position_encodings == "absolute":
            position_embedding = self.pos_emb(positions)
            return token_embedding + position_embedding
        else:
            return token_embedding

def sliding_window_i(i):
    if i % 4 == 3 and hybrid_attention:
        return global_window_size, global_temperature, global_num_heads
    if i % 4 == 1 and hybrid_attention:
        return intermediate_window_size, intermediate_temperature, num_heads
    else:
        return sliding_window, local_temperature, num_heads

class MiniGPT(nnx.Module):
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, 
                 feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs, *, residual_scale: float = 1.0, prenorm: bool = False, tie_weights=False):
        
        self.prenorm = prenorm
        self.tie_weights = tie_weights
        self.embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim, rngs=rngs)
        
        self.transformer_blocks = nnx.List([
            TransformerBlock(embed_dim, sliding_window_i(i)[2], feed_forward_dim, sliding_window_i(i)[0], rngs=rngs, temperature=sliding_window_i(i)[1], residual_scale=residual_scale, prenorm=prenorm)
            for i in range(num_transformer_blocks)
        ])


        if self.prenorm:
            self.output_layer_norm = nnx.LayerNorm(
                epsilon=1e-6,
                num_features=embed_dim,
                scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), (None,)),
                bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
                rngs=rngs
            )


        if not tie_weights:
            self.output_layer = nnx.Linear(
                in_features=embed_dim,
                out_features=vocab_size,
                kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None)),
                bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
                rngs=rngs,
                dtype=final_dtype,
            )

    def __call__(self, inputs, training: bool = False, dump_qkv_at: Optional[int] = None):
        x = self.embedding_layer(inputs)
        for i, transformer_block in enumerate(self.transformer_blocks):
            if dump_qkv_at is not None and dump_qkv_at == i:
                qkv = transformer_block(x, training=training, dump_qkv=True)
                return qkv
            x = transformer_block(x, training=training)
        if self.prenorm:
            x = self.output_layer_norm(x)
        if self.tie_weights:
            outputs = x @ self.embedding_layer.token_emb.embedding.T
        else:
            outputs = self.output_layer(x)
        return outputs

    def generate_text(self, max_tokens: int, start_tokens: list, top_k=10):
        def sample_from(logits):
            logits, indices = jax.lax.top_k(logits, k=top_k)
            logits = nnx.softmax(logits)
            return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

        def generate_step(start_tokens):
            pad_len = maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1
            
            if pad_len < 0:
                x = jnp.array(start_tokens[:maxlen])
                sample_index = maxlen - 1
            elif pad_len > 0:
                x = jnp.array(start_tokens + [0] * pad_len)
            else:
                x = jnp.array(start_tokens)
            
            x = x[None, :]
            logits = self(x)
            next_token = sample_from(logits[0][sample_index])
            return next_token

        generated = []
        for _ in range(max_tokens):
            next_token = generate_step(start_tokens + generated)
            #if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'}).ids[0]:
            if next_token == tokenizer.token_to_id("<|endoftext|>"):
                break
            generated.append(int(next_token))
        
        return tokenizer.decode(start_tokens + generated)

def create_model(rngs):
    return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, 
                   num_transformer_blocks=4, rngs=rngs, residual_scale=residual_scale, prenorm=use_prenorm_not_postnorm, tie_weights=use_embedding_tieing)

@dataclass
class TextDataset:
    data: list
    maxlen: int

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = tokenizer.encode(self.data[idx], allowed_special={'<|endoftext|>'})[:self.maxlen]
        return encoding + [0] * (self.maxlen - len(encoding))

def load_and_preprocess_data(file_path, batch_size, maxlen):
    with open(file_path, 'r') as f:
        text = f.read()
    
    stories = text.split('<|endoftext|>')
    stories = [story+'<|endoftext|>' for story in stories if story.strip()]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen)
    
    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=True,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )
    
    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )
    
    return dl

@dataclass
class MemMapDataset:
    split: str
    name: str
    maxlen: int

    def __post_init__(self):
        data_dir = os.path.join('data', self.name, f"{self.split}.bin")
        self._data = np.memmap(data_dir, dtype=np.uint16, mode='r')
        #self._data = np.array(data, dtype=np.uint16)  # Load only a subset for testing
        #_ = self._data[:] # Force loading the data into memory

    def get_data(self):
        return self._data

    def __len__(self):
        return len(self.get_data() - 1) // self.maxlen # -1 to avoid out of bounds for targets

    def __getitem__(self, idx: int):
        data = self.get_data()
        start = idx * self.maxlen
        end = start + self.maxlen + 1
        return np.array(data[start:end], dtype=np.int32)

    def batches_iter(self, batch_size: int, shuffle_key: Optional[PRNGKey]=None):
        N = len(self)
        if shuffle_key is not None:
            indices = jax.random.permutation(shuffle_key, N)
        else:
            indices = jnp.arange(N)
        for i in range(0, N, batch_size):
            data = self.get_data()
            seqidx = np.arange(self.maxlen, dtype=np.int32)
            batch_indices = indices[i:i + batch_size]
            combidx = batch_indices[:, None] * self.maxlen + seqidx[None, :]
            inputs = data[combidx]
            targets = data[combidx + 1]
            assert inputs.shape == (batch_size, self.maxlen), f"Expected input shape {(batch_size, self.maxlen)}, got {inputs.shape}"
            assert targets.shape == (batch_size, self.maxlen), f"Expected target shape {(batch_size, self.maxlen)}, got {targets.shape}"
            yield np.array(inputs), np.array(targets)

def load_memmap_data(name, split, batch_size, maxlen):
    dataset = MemMapDataset(split=split, name=name, maxlen=maxlen)
    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )
    
    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )
    
    return dl

    

def loss_fn(model, batch):
    logits = model(batch[0])
    location_loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1])
    B, T = location_loss.shape
    eos_token = tokenizer.token_to_id("<|endoftext|>")
    location_mask = jnp.cumsum(batch[1] == eos_token, axis=1) == 0
    #loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    loss = location_loss.mean()
    return loss, (jnp.where(location_mask, location_loss, 0.).sum(axis=0), location_mask.sum(axis=0))

def segexpmeanlog(values, segment_ids, num_segments):
    logvalues = jnp.log(values)
    segsum = jax.ops.segment_sum(logvalues, segment_ids, num_segments)
    segcnt = jax.ops.segment_sum(jnp.ones_like(values), segment_ids, num_segments)
    return jnp.exp(segsum / segcnt)

def fit_irreducible(x, y):
    def power_law(x, a, b, c):
        return a * (x ** b) + c
    from scipy.optimize import curve_fit
    params, _ = curve_fit(power_law, x, y, p0=[1.0, -0.5, 0.0], maxfev=10000)
    a, b, c = params
    return c

def process_loc_loss(loc_loss, *, num_buckets=1):
    T, = loc_loss.shape
    positions = jnp.arange(T)+1
    #irreducible_loss = fit_irreducible(np.array(positions[64:]), np.array(loc_loss[64:]))
    seg_labels = jnp.ceil(num_buckets*jnp.log2(positions)).astype(jnp.int32)
    num_segs = jnp.max(seg_labels) + 1
    # bucket the loc_loss
    B = 64
    assert T % B == 0, f"T={T} not divisible by B={B} for bucketing loc_loss"
    loc_loss_buckets = segexpmeanlog(loc_loss, seg_labels, num_segs)
    #irreducible_loss = jnp.minimum(irreducible_loss, jnp.min(loc_loss_buckets[2:]) - 0.01)
    irreducible_loss = jnp.min(loc_loss_buckets, where=jnp.isfinite(loc_loss_buckets), initial=jnp.inf) - 0.005
    #loc_loss_buckets = loc_loss_buckets - irreducible_loss
    #approx_irreducible_loss = jnp.min(loc_loss_buckets, where=jnp.isfinite(loc_loss_buckets)) - 0.01
    #loc_loss_buckets = segexpmeanlog(loc_loss - approx_irreducible_loss, seg_labels, num_segs)
    non_nan = jnp.isfinite(loc_loss_buckets)
    #loc_loss_buckets = jnp.exp(jnp.log(loc_loss.reshape((T // B, B))).mean(axis=1))
    positions_buckets = segexpmeanlog(positions, seg_labels, num_segs)
    safe_pos_buck = positions_buckets[non_nan]
    safe_loc_loss_buck = loc_loss_buckets[non_nan]
    N = len(safe_pos_buck)
    try:
        irreducible_loss = fit_irreducible(np.array(safe_pos_buck[N//3:-N//3]), np.array(safe_loc_loss_buck[N//3:-N//3]))
    except Exception as e:
        print(f"Warning: fit_irreducible failed with error {e}, using min loc_loss as irreducible_loss")
        irreducible_loss = jnp.min(safe_loc_loss_buck) - 0.01
    irreducible_loss = jnp.minimum(irreducible_loss, jnp.min(safe_loc_loss_buck) - 0.01)
    safe_loc_loss_buck = safe_loc_loss_buck - irreducible_loss
    #positions_buckets = jnp.exp(jnp.log(positions.reshape((T // B, B))).mean(axis=1))
    return safe_pos_buck, safe_loc_loss_buck, irreducible_loss
    #return np.array(positions_buckets)[non_nan], np.array(loc_loss_buckets)[non_nan], irreducible_loss

def record_loc_loss(run, loc_loss, step):
    fig, ax = plt.subplots()
    T = len(loc_loss)
    x, y, irred = process_loc_loss(loc_loss)
    ax.plot(x, y)
    ax.set_title(f"Location-wise Reducible Loss (irred: {irred:.3f})")
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Loss')
    ax.set_yscale('log')
    ax.set_xscale('log')
    run.track(Figure(fig), name="loss_by_position", step=step)
    plt.close(fig)
    fig, ax = plt.subplots()
    x, y, irred = process_loc_loss(loc_loss, num_buckets=4)
    ax.plot(x, y)
    ax.set_title(f"Location-wise Reducible Loss (irred: {irred:.3f})")
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Loss')
    ax.set_yscale('log')
    ax.set_xscale('log')
    run.track(Figure(fig), name="loss_by_position_fine", step=step)
    plt.close(fig)


@nnx.jit
def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, (loc_loss, loc_count)), grads = grad_fn(model, batch)
    #loss = jnp.where(jnp.isfinite(loss), loss, jnp.array(0.0, dtype=loss.dtype))
    #grads = jax.tree.map(lambda g: jnp.where(jnp.isfinite(g), g, jnp.array(0.0, dtype=g.dtype)), grads)
    #metrics.update(loss=loss, logits=logits, lables=batch[1])
    metrics.update(loss=loss, loc_loss=loc_loss, loc_count=loc_count)
    optimizer.update(model, grads)

@nnx.jit
def get_qkv_at_3(model: MiniGPT, batch):
    return model(batch, dump_qkv_at=3)

def save_qkv(model: MiniGPT, batch, step):
    filename = f"qkv_step_{step}.npz"
    qkv = get_qkv_at_3(model, batch)
    q, k, v = np.array(qkv[0], dtype=np.float32), np.array(qkv[1], dtype=np.float32), np.array(qkv[2], dtype=np.float32)
    np.savez(filename, q=q, k=k, v=v)
    print(f"QKV saved to {filename}")

class VectorAverage(nnx.Metric):
    def __init__(self, argname, cntname, shape):
        self.argname = argname
        self.cntname = cntname
        self.total = nnx.Variable(jnp.zeros(shape))
        self.count = nnx.Variable(jnp.zeros((shape), dtype=jnp.int32))

    def update(self, **kwargs):
        value = kwargs[self.argname]
        count = kwargs[self.cntname]
        self.total[...] += value
        self.count[...] += count

    def compute(self):
        return self.total[...] / self.count[...]

    def reset(self):
        self.total[...] = jnp.zeros_like(self.total[...])
        self.count[...] = jnp.zeros_like(self.count[...])

def main():
    assert accumulation_steps == 1, "need to implement grad accum as an inner loop for consistent behaviour"
    print("Starting miniGPT training...")

    # Load data
    #text_dl = load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen)
    text_dl = load_memmap_data(name=dataset, split="train", batch_size=batch_size, maxlen=maxlen)
    text_dataset = MemMapDataset(split="train", name=dataset, maxlen=maxlen)
    start_prompt = "Once upon a time" if dataset == "tinystories8192" else "The main reason"

    dataset_num_sequences = len(text_dataset)
    dataset_num_tokens = len(text_dataset.get_data())
    print(f"Dataset '{dataset}' has {dataset_num_sequences} sequences, {dataset_num_tokens} tokens.")

    batches_per_epoch = len(text_dl._data_source) // batch_size
    tokens_per_batch = batch_size * maxlen
    
    # Initialize RTPT
    rtpt = RTPT(
        name_initials="XX",
        experiment_name="miniGPT",
        max_iterations= batches_per_epoch * num_epochs,
    )

    # Initialize Aim run
    run = Run(repo="/aim", experiment="miniGPT")
    run["hparams"] = {
        "vocab_size": vocab_size,
        "num_transformer_blocks": num_transformer_blocks,
        "maxlen": maxlen,
        "embed_dim": embed_dim,
        "num_heads": num_heads,
        "global_num_heads": global_num_heads,
        "head_dim": head_dim,
        "global_head_dim": global_head_dim,
        "feed_forward_dim": feed_forward_dim,
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "dataset": dataset,
        "position_encodings": position_encodings,
        "attn_dtype": attn_dtype_str,
        "ff_dtype": ff_dtype_str,
        "final_dtype": final_dtype_str,
        "hybrid_attention": hybrid_attention,
        "sliding_window": sliding_window,
        "head_dim": head_dim,
        "normalize_qk": normalize_qk,
        "intermediate_window_size": intermediate_window_size,
        "global_window_size": global_window_size,
        "base_adam_lr": base_adam_lr,
        "generation_length": generation_length,
        "total_train_tokens": total_train_tokens,
        "warmup_steps": warmup_steps,
        "adam_b1": adam_b1,
        "adam_b2": adam_b2,
        "adam_weight_decay": adam_weight_decay,
        "final_adam_lr_decay": final_adam_lr_decay,
        "muon_lr": muon_lr,
        "use_muon": use_muon,
        "rope_base": rope_base,
        "local_temperature": local_temperature,
        "intermediate_temperature": intermediate_temperature,
        "global_temperature": global_temperature,
        "accumulation_steps": accumulation_steps,
        "dump_qkv": dump_qkv,
        "experiment_set": experiment_set,
        "experiment_label": experiment_label,
        "generation_is_enabled": generation_is_enabled,
        "dataset_total_size_Mtokens": dataset_num_tokens//1_000_000,
        "residual_scale": residual_scale,
        "use_prenorm_not_postnorm": use_prenorm_not_postnorm,
    }
    
    # Create model
    #with jax.set_mesh(mesh):
    model = create_model(rngs=nnx.Rngs(0))
    total_params = sum(x.size for x in jax.tree.leaves(nnx.state(model, nnx.Param)))
    total_Mparams = total_params / 1_000_000
    print(f"Total model parameters: {total_Mparams:.1f}M")
    run["hparams"]["total_Mparams"] = total_Mparams
    adam_lr = base_adam_lr / SCALE ** 0.5 / FF_SCALE ** 0.5
    adam_schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=adam_lr,
        warmup_steps=warmup_steps,
        decay_steps=total_train_tokens // tokens_per_batch,
        end_value=adam_lr * final_adam_lr_decay,
    )

    adamw = optax.adamw(adam_schedule, b1=adam_b1, b2=adam_b2, eps=1e-6, weight_decay=adam_weight_decay)
    optax_optimizer = optax.chain(optax.clip_by_global_norm(1.0), adamw)
    #optax_optimizer = optax.MultiSteps(optax_optimizer, every_k_schedule=accumulation_steps)
    optax_optimizer = optax.MultiSteps(optax_optimizer, every_k_schedule=lambda step: jnp.where(step < 0, 1, accumulation_steps))
    optimizer = nnx.Optimizer(model, optax_optimizer, wrt=nnx.Param)
    metrics_kwargs = dict(loss=nnx.metrics.Average('loss'))
    if record_location_loss:
        metrics_kwargs |= dict(loc_loss=VectorAverage('loc_loss', 'loc_count', shape=(maxlen,)))
    metrics = nnx.MultiMetric(**metrics_kwargs)
    
    # Initial text generation
    if generation_is_enabled:
        start_tokens = tokenizer.encode(start_prompt).ids[:maxlen]
        generated_text = model.generate_text(min(generation_length, maxlen), start_tokens)
        print(f"Initial generated text:\n{generated_text}\n")
    
    # Training setup
    metrics_history = {'train_loss': []}
    prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))
    step = 0

    def octave_offset_and_width(step):
        step = int(step)
        msb = step.bit_length() - 1
        width = 1 << msb
        offset = step - (1 << msb)
        return offset, width
    def should_eval(step, per_octave_count, min_stride=1):
        step = int(step)
        if step <= 0:
            return False
        step = step + 1
        offset, width = octave_offset_and_width(step)
        stride = max(min_stride, width // int(per_octave_count))
        return offset % stride == 0 and step >= min_stride

    def should_eval_hfreq(step):
        return should_eval(step, 8, min_stride=32)
        #return (step + 1) % 20 == 0
    def should_eval_lfreq(step):
        return should_eval(step, 8, min_stride=2048)
        #return (((step + 1) % 20 == 0) and (step <= 200)) or \
        #return  (((step + 1) % 200 == 0) and (step <= 2000)) or \
        #        ((step + 1) % 2000 == 0)
    
    # Training loop
    rtpt.start()
    for epoch in range(num_epochs):
        start_time = time.time()
        start_step = step
        #for batch in text_dataset.batches_iter(batch_size):
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
                continue
                
            input_batch = jnp.array(batch[:,:-1])
            target_batch = jnp.array(batch[:,1:])
            #input_batch = jnp.array(jnp.array(batch).T)
            #input_batch = jnp.array(batch)
            #input_batch, target_batch = batch
            assert input_batch.shape == (batch_size, maxlen), \
                f"Expected input shape {(batch_size, maxlen)}, got {input_batch.shape}"
            #target_batch = prep_target_batch(input_batch)
            
            with nvtx.annotate("train_step", color="blue"):
                train_step(model, optimizer, metrics, 
                      jax.device_put((input_batch, target_batch), 
                                   NamedSharding(mesh, P('batch', None))))

            if step == 600:
                nvtx.push_range("training")
            if step >= 699:
                nvtx.pop_range()
                #run.close()
                #exit()
            
            if should_eval_hfreq(step):
                computed = metrics.compute()
                # pop loc_loss
                loc_loss = computed.pop('loc_loss', None)
                for metric, value in computed.items():
                    metrics_history[f'train_{metric}'].append(value)
                metrics.reset()
                
                old_start_time = start_time
                start_time = time.time()
                elapsed_time = start_time - old_start_time
                elapsed_steps = step - start_step
                time_per_step = elapsed_time / max(elapsed_steps, 1)
                mega_tokens_per_second = tokens_per_batch / time_per_step / 1.0e6
                train_loss = metrics_history['train_loss'][-1]
                print(f"[{step + 1} ({elapsed_time:.2f}s)] "
                      f"Loss: {train_loss:.4f}, "
                      f"{1e3*time_per_step:.1f}ms/step, "
                      f"{mega_tokens_per_second:.2f}Mtok/s")
                #print(f"Step {step + 1}, Loss: {metrics_history['train_loss'][-1]:.4f}, "
                #      f"Elapsed Time: {elapsed_time:.2f} seconds")
                total_Mtokens_so_far = ((step + 1) * tokens_per_batch) // 1_000_000
                if loc_loss is not None:
                    record_loc_loss(run, loc_loss, step=total_Mtokens_so_far)
                run.track(float(metrics_history['train_loss'][-1]), name='train_loss', step=total_Mtokens_so_far)
                run.track(step * tokens_per_batch, name='train_tokens', step=total_Mtokens_so_far)
                run.track(step, name='train_steps', step=total_Mtokens_so_far)
                current_lr = adam_schedule(optimizer.step[...])
                run.track(float(current_lr), name='learning_rate', step=total_Mtokens_so_far)
                #start_time = time.time()
                start_step = step
                
                if should_eval_lfreq(step) and generation_is_enabled:
                    generated_text = model.generate_text(min(generation_length, maxlen), start_tokens)
                    elapsed_time = time.time() - start_time
                    print(f"Generated text ({elapsed_time:.2f}s):\n{generated_text}\n")
                    start_time = time.time()
            if False and step >= 1000:
                run.close()
                exit()
            if should_eval_hfreq(step) and (step+1) * tokens_per_batch >= total_train_tokens:
                run.close()
                exit()
            if dump_qkv and step % 200 == 0:
                save_qkv(model, input_batch, step)
            
            step += 1
            rtpt.step()
    run.close()
    exit()
    
    # Final generation
    generated_text = model.generate_text(min(generation_length, maxlen), start_tokens)
    print(f"Final generated text:\n{generated_text}")

    # Close aim run
    run.close()
    
    # Plot training loss
    plt.figure(figsize=(10, 6))
    plt.plot(metrics_history['train_loss'])
    plt.title('Training Loss')
    plt.xlabel('Step (x200)')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.savefig('training_loss.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("Training loss plot saved as 'training_loss.png'")
    
    # Save model checkpoint
    state = nnx.state(model)
    checkpointer = orbax.PyTreeCheckpointer()
    checkpoint_dir = './minigpt_checkpoint'
    checkpointer.save(checkpoint_dir, state)
    print(f"Model checkpoint saved to {checkpoint_dir}")

if __name__ == "__main__":
    main()
