#!/usr/bin/env python3
"""
miniGPT Training Script for A100 GPU
Modified from JAX Stack tutorial for single A100 use with future multi-GPU scaling support
"""

import jax
from jax import vmap
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
import flax.nnx as nnx
import optax
from dataclasses import dataclass
import grain.python as pygrain
import pandas as pd
import tiktoken
import time
import subprocess
import os
import orbax.checkpoint as orbax
import matplotlib.pyplot as plt
from functools import partial
from tqdm import tqdm
import math

from rtpt import RTPT
from aim import Run, Figure
import logging
logging.getLogger('aim.sdk.objects.figure').setLevel(logging.ERROR)
import numpy as np
import nvtx
from jax.random import PRNGKey
from typing import Optional, Tuple, Callable

from rope_attention import MultiHeadAttention, dot_product_attention as rope_dot_product_attention, apply_rope
from fma.pallas_retrieval import causal_attn as fma_causal_attn
#from flax.nnx.attention import MultiHeadAttention
from tokenizers import Tokenizer

import orbax.checkpoint as ocp

# Verify GPU detection
print("JAX devices:", jax.devices())
print("Device type:", jax.devices()[0].device_kind)

# Download dataset if not exists
if False and not os.path.exists('TinyStories-train.txt'):
    print("Downloading TinyStories dataset...")
    subprocess.run([
        'wget', 
        'https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true',
        '-O', 'TinyStories-train.txt'
    ], check=True)
    print("Dataset downloaded successfully!")

# Setup mesh for single A100 (ready for future scaling)
DEFAULT_BATCH_MESH = 2
DEFAULT_MODEL_MESH = 4
if (num_devices := len(jax.devices())) != 8:
    print(f"WARNING: Unexpected device count, expected 8, got {num_devices}, training will be slower")
    adjusted_model_mesh = num_devices // DEFAULT_BATCH_MESH
    assert adjusted_model_mesh > 0, "Not enough devices for desired batch parallelism"
else:
    adjusted_model_mesh = DEFAULT_MODEL_MESH
mesh = Mesh(mesh_utils.create_device_mesh((DEFAULT_BATCH_MESH, adjusted_model_mesh)), ('batch', 'model'))
jax.set_mesh(mesh)

# Initialize tokenizer
CULL_VOCAB_TO = 8192
def cull_mergeable_ranks(ranks):
    return {k: v for k, v in ranks.items() if v < CULL_VOCAB_TO - 2}
gpt2_tokenizer = tiktoken.get_encoding("gpt2")
tokenizer = tiktoken.Encoding(
        name = f"gpt2_{CULL_VOCAB_TO}",
        pat_str = gpt2_tokenizer._pat_str,
        mergeable_ranks = cull_mergeable_ranks(gpt2_tokenizer._mergeable_ranks),
        special_tokens = {
            "<|endoftext|>": CULL_VOCAB_TO - 1,
            "<|mask|>": CULL_VOCAB_TO - 2,
        },
    )

# Hyperparameters
vocab_size = tokenizer.n_vocab
#dataset = "tinystories8192"
#dataset = "openwebtext8192"
#dataset = "gutenberg16384"
#dataset = "longstackv2_bucket_5"
#dataset = "longstackv2_bucket_6"
dataset = "longstackv2_buckets_5678"
#dataset = "science_tech_2e15"
#dataset = "science_tech_2e16"
tokenizer = Tokenizer.from_file(f"data/{dataset}/tokenizer.json")
vocab_size = tokenizer.get_vocab_size()
SCALE = 6
base_scale_for_lr = 4.0
base_scale_for_residual = 2.0
FF_SCALE = 4
super_block_size = 2
#num_transformer_blocks = 4 * SCALE - 1
num_transformer_blocks = 3 * SCALE
if SCALE == 8:
    num_transformer_blocks = 20
if num_transformer_blocks % super_block_size != 0:
    num_transformer_blocks += super_block_size - (num_transformer_blocks % super_block_size)
maxlen = 2**16
base_maxlen_for_lr = 2**15
embed_dim = 256 * SCALE
num_heads = 4 * SCALE
global_num_heads = 4 * SCALE
head_dim = embed_dim // num_heads
global_head_dim = embed_dim // global_num_heads
feed_forward_dim = FF_SCALE * 256 * SCALE
residual_scale = (base_scale_for_residual / SCALE) ** 0.0
base_batch_size = 2
batch_size = 2  # Adjusted for A100
minibatch_size = 2
assert batch_size % minibatch_size == 0, "batch_size must be multiple of minibatch_size"
accumulation_steps = batch_size // minibatch_size
num_epochs = 4
#position_encodings = "absolute"
#position_encodings = "nope"
position_encodings = "rope"
attn_dtype_str = "bfloat16"
DTYPE_LOOKUP = {
    "float32": jnp.float32,
    "bfloat16": jnp.bfloat16,
    "float16": jnp.float16,
}
attn_dtype = DTYPE_LOOKUP[attn_dtype_str]
backbone_dtype_str = "bfloat16"
backbone_dtype = DTYPE_LOOKUP[backbone_dtype_str]
ff_dtype_str = "bfloat16"
ff_dtype = DTYPE_LOOKUP[ff_dtype_str]
final_dtype_str = "bfloat16"
final_dtype = DTYPE_LOOKUP[final_dtype_str]
param_dtype_str = "bfloat16"
param_dtype = DTYPE_LOOKUP[param_dtype_str]
generation_length = 128
hybrid_attention = True # Alternates local and global attention, with second layer also local
sliding_window = 256 # Set to None for all global attention
local_temperature = 1.0  # Used for local attention in hybrid attention mode
normalize_qk = False
use_prenorm_not_postnorm = True
use_embedding_tieing = True
#intermediate_window_size = 256
intermediate_temperature = 1.0  # Used for intermediate attention in hybrid attention mode
global_window_size = 2**16  # Used for global attention in hybrid attention mode
intermediate_window_size = global_window_size
global_temperature = 1.0  # Used for global attention in hybrid attention mode
base_adam_lr = 3e-3 * batch_size / base_batch_size * (base_scale_for_lr / SCALE)**0.5 * (maxlen / base_maxlen_for_lr)**1.0
final_adam_lr_decay = 1e-1
warmup_steps = int(4000 * (base_batch_size / batch_size))
total_train_tokens = 12_000_000_000  # Total tokens to train on
#total_train_tokens = 7_000_000_000  # Total tokens to train on
adam_b2 = 0.95 # Adam beta2 value, used for AdamW optimizer
adam_b1 = 0.9 # Adam beta1 value, used for AdamW optimizer
adam_weight_decay = 0.01 # AdamW weight decay value
muon_lr = 1e-3 * batch_size / base_batch_size # Learning rate for Muon optimizer
use_muon = False  # Use Muon optimizer for some layers
rope_base = 1e5
dump_qkv = False
generation_is_enabled = False
record_location_loss = True

use_fma_attention = True
fma_Q = 128
fma_K = 128
fma_num_retrievals = 8
fma_num_spatial = 1
fma_blk_size = 2**13
fma_bidiagonal = False
fma_dipole = False

experiment_set = "test6"
#experiment_label = f"CUDNN_con4k_win4k_ev2_3Btokens_prenorm_scale{SCALE}"
#experiment_label = f"CUDNN_con32k_win32k_ev2_{int(total_train_tokens//1_000_000_000)}Btokens_prenorm_scale{SCALE}"
#experiment_label = f"cret_con32k_QK64_B4k_NR4_ev2_{int(total_train_tokens//1_000_000_000)}Btokens_prenorm_scale{SCALE}"
method_str = f"fma_stopped_Q{fma_Q}_K{fma_K}_NR{fma_num_retrievals}_B{fma_blk_size//2**10}k{'_bidiagonal' if fma_bidiagonal else ''}{'_dipole' if fma_dipole else ''}" if use_fma_attention else "cudnn"
experiment_label = f"{method_str}_con{maxlen // 2**10}k_win{sliding_window}_ev{super_block_size}_{int(total_train_tokens//1_000_000_000)}Btokens_scale{SCALE}"

checkpoint_save_path = None
checkpoint_save_path = f"/checkpoints/{dataset}_{experiment_set}_{experiment_label}"
checkpoint_save_interval_steps = 3_000 # Save checkpoint every N steps

if SCALE >= 5 and checkpoint_save_path is None:
    print("Warning: Running large model without checkpointing may lead to loss of progress if interrupted.")


def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

def sharded_rope_local_cudnn_attention(q,k,v, **kwargs):
    def inner(q,k,v):
        q = apply_rope(q, base=rope_base)
        k = apply_rope(k, base=rope_base)
        return jax.nn.dot_product_attention(
            q, k, v,
            is_causal=True,
            implementation="cudnn",
            local_window_size=(sliding_window, 0),
        )
    return jax.shard_map(inner, mesh=mesh, in_specs=P('batch', None, 'model', None), out_specs=P('batch', None, 'model', None), check_vma=False)(q,k,v)

def sharded_rope_causal_cudnn_attention(q,k,v, **kwargs):
    def inner(q,k,v):
        q = apply_rope(q, base=rope_base)
        k = apply_rope(k, base=rope_base)
        return jax.nn.dot_product_attention(
            q, k, v,
            is_causal=True,
            implementation="cudnn",
            local_window_size=None,
        )
    return jax.shard_map(inner, mesh=mesh, in_specs=P('batch', None, 'model', None), out_specs=P('batch', None, 'model', None), check_vma=False)(q,k,v)

@vmap
@partial(vmap, in_axes=-2, out_axes=-2)
def fma_multihead_attention(q,k,v):
    Q = fma_Q
    K = fma_K
    num_retrievals = fma_num_retrievals
    num_spatial = fma_num_spatial
    blk_size = fma_blk_size
    bidiagonal = fma_bidiagonal
    dipole = fma_dipole
    lse_out, v_out = fma_causal_attn(Q, K, blk_size, num_retrievals, bidiagonal, dipole, q, k, v, num_spatial)
    return v_out

def sharded_rope_fma_attention(q,k,v, **kwargs):
    def inner(q,k,v):
        q = apply_rope(q, base=rope_base)
        k = apply_rope(k, base=rope_base)
        return fma_multihead_attention(q,k,v)
    return jax.shard_map(inner, mesh=mesh, in_specs=P('batch', None, 'model', None), out_specs=P('batch', None, 'model', None), check_vma=False)(q,k,v)


def sharded_rope_dot_product_attention(q,k,v, **kwargs):
    B,T,H,D = q.shape
    def inner(q,k,v):
        return rope_dot_product_attention(q,k,v, **kwargs)
    return jax.shard_map(inner, mesh=mesh, in_specs=P('batch', None, 'model', None), out_specs=P('batch', None, 'model', None), check_vma=False)(q,k,v)

def xavier_uniform(scale=1.0):
    return jax.nn.initializers.variance_scaling(scale, "fan_in", "uniform", in_axis=-2,
                          out_axis=-1, batch_axis=())

def embedding_init(scale=1.0):
    return jax.nn.initializers.variance_scaling(scale, "fan_out", "normal", in_axis=-2,
                          out_axis=-1, batch_axis=())

class TransformerBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, sliding_window: int, *, temperature: float = 1.0, rngs: nnx.Rngs, rate: float = 0.1, residual_scale: float = 1.0, prenorm: bool = True, attention_fn: Optional[Callable] = None):
        self.residual_scale = 1.0
        self.prenorm = prenorm
        if attention_fn is not None:
            pass
        elif sliding_window is None or sliding_window == maxlen:
            attention_fn = sharded_rope_causal_cudnn_attention
        elif sliding_window == 2**14:
            attention_fn = sharded_rope_fma_attention
        else:
            attention_fn = sharded_rope_local_cudnn_attention
        #attention_fn = sharded_rope_dot_product_attention
        self.mha = MultiHeadAttention(
            num_heads=num_heads,
            in_features=embed_dim,
            kernel_init=nnx.with_partitioning(xavier_uniform(), (None, 'model')),
            out_kernel_init=nnx.with_partitioning(xavier_uniform(residual_scale), ('model', None)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
            out_bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
            rngs=rngs,
            apply_rope=False, #position_encodings == "rope",# and sliding_window is None,
            dtype=attn_dtype,
            #param_dtype=param_dtype,
            sliding_window=sliding_window,
            normalize_qk=normalize_qk,
            rope_base=rope_base,
            temperature=temperature,
            attention_fn=attention_fn,
        )
        
        self.dropout1 = nnx.Dropout(rate=rate)
        
        self.layer_norm1 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), (None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
            rngs=rngs
        )
        
        self.linear1 = nnx.Linear(
            in_features=embed_dim,
            out_features=ff_dim,
            kernel_init=nnx.with_partitioning(xavier_uniform(), (None, 'model')),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
            rngs=rngs,
            dtype=ff_dtype,
            #param_dtype=param_dtype,
        )

        self.linear2 = nnx.Linear(
            in_features=ff_dim,
            out_features=embed_dim,
            kernel_init=nnx.with_partitioning(xavier_uniform(residual_scale), ('model', None)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
            rngs=rngs,
            dtype=ff_dtype,
            #param_dtype=param_dtype,
        )
        
        self.dropout2 = nnx.Dropout(rate=rate)
        
        self.layer_norm2 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), (None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
            rngs=rngs
        )

    def __call__(self, inputs, training: bool = False, dump_qkv: bool = False):
        input_shape = inputs.shape
        _, seq_len, _ = input_shape
        #mask = causal_attention_mask(seq_len)
        mask = None
        SHARD_INTERMEDIATES = True
        DEBUG_SHARDING = False
    
        if SHARD_INTERMEDIATES:
            inputs = jax.lax.with_sharding_constraint(inputs, P('batch', 'model', None))

        if DEBUG_SHARDING:
            jax.debug.inspect_array_sharding(inputs, callback=lambda sharding: print(f"Transfomer inputs shape: {inputs.shape} dtype: {inputs.dtype}, sharding: {sharding}"))
        #print(f"TransformerBlock inputs: {inputs}, sharding: {jax.debug.inspect_array_sharding(inputs)}")
        
        if self.prenorm:
            x1 = self.layer_norm1(inputs)
        else:
            x1 = inputs
        attention_output = self.mha(inputs_q=x1, mask=mask, decode=False, dump_qkv=dump_qkv)
        if dump_qkv:
            attention_output, qkv = attention_output
        attention_output = self.dropout1(attention_output, deterministic=not training)
        if SHARD_INTERMEDIATES:
            attention_output = jax.lax.with_sharding_constraint(attention_output, P('batch', 'model', None))
        #print(f"TransformerBlock attention_output: {attention_output}")
        out1_unnormed = inputs + attention_output*self.residual_scale
        if self.prenorm:
            out1 = out1_unnormed
        else:
            out1 = self.layer_norm1(out1_unnormed)

        if DEBUG_SHARDING:
            jax.debug.inspect_array_sharding(out1, callback=lambda sharding: print(f"TransformerBlock out1 shape: {out1.shape} dtype: {out1.dtype}, sharding: {sharding}"))
        #print(f"TransformerBlock out1: {out1}, sharding: {jax.debug.inspect_array_sharding(out1)}")
        
        if self.prenorm:
            x2 = self.layer_norm2(out1)
        else:
            x2 = out1
        ffn_output = self.linear1(x2)
        ffn_output = nnx.relu(ffn_output)
        ffn_output = self.linear2(ffn_output)
        ffn_output = self.dropout2(ffn_output, deterministic=not training)
        if SHARD_INTERMEDIATES:
            ffn_output = jax.lax.with_sharding_constraint(ffn_output, P('batch', 'model', None))

        #print(f"TransformerBlock ffn_output: {ffn_output}")

        if self.prenorm:
            out2 = out1 + ffn_output*self.residual_scale
        else:
            out2 = self.layer_norm2(out1 + ffn_output*self.residual_scale)
        if DEBUG_SHARDING:
            jax.debug.inspect_array_sharding(out2, callback=lambda sharding: print(f"TransformerBlock out2 shape: {out2.shape} dtype: {out2.dtype}, sharding: {sharding}"))
        #print(f"TransformerBlock out2: {out2}, sharding: {jax.debug.inspect_array_sharding(out2)}")
        if dump_qkv:
            return out2, qkv
        return out2

class TokenAndPositionEmbedding(nnx.Module):
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs, embedding_init=nnx.with_partitioning(embedding_init(), (None, 'model')), dtype=backbone_dtype)
        if position_encodings == "absolute":
            self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs, embedding_init=nnx.with_partitioning(embedding_init(), (None, 'model')), dtype=backbone_dtype)

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        token_embedding = self.token_emb(x) / residual_scale
        if position_encodings == "absolute":
            position_embedding = self.pos_emb(positions)
            return token_embedding + position_embedding
        else:
            return token_embedding

def sliding_window_i(i):
    if i % 4 == 3 and hybrid_attention:
        return global_window_size, global_temperature, global_num_heads
    if i % 4 == 1 and hybrid_attention:
        return intermediate_window_size, intermediate_temperature, num_heads
    else:
        return sliding_window, local_temperature, num_heads

class SuperBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1, residual_scale: float = 1.0):
        assert super_block_size >= 1, "super_block_size must be at least 1"
        self.local_blocks = nnx.List([
            TransformerBlock(embed_dim, num_heads, ff_dim, sliding_window, rngs=rngs, temperature=local_temperature, rate=rate, residual_scale=residual_scale)
            for _ in range(super_block_size - 1)
        ])
        global_attention_fn = sharded_rope_fma_attention if use_fma_attention else sharded_rope_causal_cudnn_attention
        self.global_block = TransformerBlock(embed_dim, global_num_heads, ff_dim, None, rngs=rngs, temperature=global_temperature, rate=rate, residual_scale=residual_scale, attention_fn=global_attention_fn)

    #@partial(nnx.remat, static_argnums=(2,3))
    def __call__(self, inputs, training: bool = False, dump_qkv: bool = False):
        x = inputs
        for block in self.local_blocks:
            x = block(x, training=training)
        x = self.global_block(x, training=training, dump_qkv=dump_qkv)
        return x

class MiniGPT(nnx.Module):
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, 
                 feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs, *, residual_scale: float = 1.0, prenorm: bool = True, tie_weights=False):
        
        self.prenorm = prenorm
        self.tie_weights = tie_weights
        self.embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim, rngs=rngs)
        
        #self.transformer_blocks = nnx.List([
        #    TransformerBlock(embed_dim, sliding_window_i(i)[2], feed_forward_dim, sliding_window_i(i)[0], rngs=rngs, temperature=sliding_window_i(i)[1], residual_scale=residual_scale, prenorm=prenorm)
        #    for i in range(num_transformer_blocks)
        #])
        assert num_transformer_blocks % super_block_size == 0, "num_transformer_blocks must be multiple of super_block_size"
        num_super_blocks = num_transformer_blocks // super_block_size
        self.super_blocks = nnx.List([
            SuperBlock(embed_dim, num_heads, feed_forward_dim, rngs=rngs)
            for _ in range(num_super_blocks)
        ])

        @nnx.split_rngs(splits=num_super_blocks)
        @nnx.vmap(in_axes=(0,), out_axes=0)
        def create_super_block(rngs: nnx.Rngs):
            return SuperBlock(embed_dim, num_heads, feed_forward_dim, rngs=rngs, residual_scale=residual_scale)
        #self.super_blocks = create_super_block(rngs)



        if self.prenorm:
            self.output_layer_norm = nnx.LayerNorm(
                epsilon=1e-6,
                num_features=embed_dim,
                scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), (None,)),
                bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
                rngs=rngs
            )


        if not tie_weights:
            self.output_layer = nnx.Linear(
                in_features=embed_dim,
                out_features=vocab_size,
                kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None)),
                bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)),
                rngs=rngs,
                dtype=final_dtype,
                #param_dtype=param_dtype,
            )

    #@partial(nnx.remat, static_argnums=(2,), policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
    def __call__(self, inputs, training: bool = False, dump_qkv: bool = False):
        x = self.embedding_layer(inputs)
        print(f"MiniGPT post_embedding x: {x}")
        #for i, transformer_block in enumerate(self.transformer_blocks):
        #    if dump_qkv_at is not None and dump_qkv_at == i:
        #        qkv = transformer_block(x, training=training, dump_qkv=True)
        #        return qkv
        #    x = transformer_block(x, training=training)
        if dump_qkv:
            qs, ks, vs = [], [], []
        for i, super_block in enumerate(self.super_blocks):
            #if dump_qkv_at is not None and dump_qkv_at == (i + len(self.transformer_blocks)):
            #    qkv = super_block(x, training=training, dump_qkv=True)
            #    return qkv
            x = super_block(x, training=training, dump_qkv=dump_qkv)
            if dump_qkv:
                x, (q, k, v) = x
                qs.append(q)
                ks.append(k)
                vs.append(v)

        if dump_qkv:
            qs = jnp.concatenate(qs, axis=-2)
            ks = jnp.concatenate(ks, axis=-2)
            vs = jnp.concatenate(vs, axis=-2)

        #@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
        #def apply_super_blocks(x, super_block):
        #    return super_block(x, training=training)
        #x = apply_super_blocks(x, self.super_blocks)
        print(f"MiniGPT post_transformer x: {x}")

        if self.prenorm:
            x = self.output_layer_norm(x)
        if self.tie_weights:
            outputs = x @ self.embedding_layer.token_emb.embedding.T * residual_scale
        else:
            outputs = self.output_layer(x)
        if dump_qkv:
            return outputs, (qs, ks, vs)
        return outputs

    def generate_text(self, max_tokens: int, start_tokens: list, top_k=10):
        def sample_from(logits):
            logits, indices = jax.lax.top_k(logits, k=top_k)
            logits = nnx.softmax(logits)
            return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

        def generate_step(start_tokens):
            pad_len = maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1
            
            if pad_len < 0:
                x = jnp.array(start_tokens[:maxlen])
                sample_index = maxlen - 1
            elif pad_len > 0:
                x = jnp.array(start_tokens + [0] * pad_len)
            else:
                x = jnp.array(start_tokens)
            
            x = x[None, :]
            logits = self(x)
            next_token = sample_from(logits[0][sample_index])
            return next_token

        generated = []
        for _ in range(max_tokens):
            next_token = generate_step(start_tokens + generated)
            #if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'}).ids[0]:
            if next_token == tokenizer.token_to_id("<|endoftext|>"):
                break
            generated.append(int(next_token))
        
        return tokenizer.decode(start_tokens + generated)

def create_model(rngs):
    return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, 
                   num_transformer_blocks=num_transformer_blocks, rngs=rngs, residual_scale=residual_scale, prenorm=use_prenorm_not_postnorm, tie_weights=use_embedding_tieing)

@dataclass
class TextDataset:
    data: list
    maxlen: int

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = tokenizer.encode(self.data[idx], allowed_special={'<|endoftext|>'})[:self.maxlen]
        return encoding + [0] * (self.maxlen - len(encoding))

def load_and_preprocess_data(file_path, batch_size, maxlen):
    with open(file_path, 'r') as f:
        text = f.read()
    
    stories = text.split('<|endoftext|>')
    stories = [story+'<|endoftext|>' for story in stories if story.strip()]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen)
    
    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=True,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )
    
    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )
    
    return dl

@dataclass
class MemMapDataset:
    split: str
    name: str
    maxlen: int

    def __post_init__(self):
        data_dir = os.path.join('data', self.name, f"{self.split}.bin")
        mmap_data = np.memmap(data_dir, dtype=np.uint16, mode='r')
        # Copy to RAM and explicitly delete memmap to release file handle
        self._data = np.array(mmap_data, dtype=np.uint16)
        del mmap_data  # This closes the file descriptor

    def get_data(self):
        return self._data

    def __len__(self):
        return len(self.get_data() - 1) // self.maxlen # -1 to avoid out of bounds for targets

    def __getitem__(self, idx: int):
        data = self.get_data()
        start = idx * self.maxlen
        end = start + self.maxlen + 1
        return np.array(data[start:end], dtype=np.int32)

    def batches_iter(self, batch_size: int, shuffle_key: Optional[PRNGKey]=None):
        N = len(self)
        if shuffle_key is not None:
            indices = jax.random.permutation(shuffle_key, N)
        else:
            indices = jnp.arange(N)
        for i in range(0, N, batch_size):
            data = self.get_data()
            seqidx = np.arange(self.maxlen, dtype=np.int32)
            batch_indices = indices[i:i + batch_size]
            combidx = batch_indices[:, None] * self.maxlen + seqidx[None, :]
            inputs = data[combidx]
            targets = data[combidx + 1]
            assert inputs.shape == (batch_size, self.maxlen), f"Expected input shape {(batch_size, self.maxlen)}, got {inputs.shape}"
            assert targets.shape == (batch_size, self.maxlen), f"Expected target shape {(batch_size, self.maxlen)}, got {targets.shape}"
            yield np.array(inputs), np.array(targets)

def load_memmap_data(name, split, batch_size, maxlen):
    dataset = MemMapDataset(split=split, name=name, maxlen=maxlen)
    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )
    
    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )
    
    return dl

    

def loss_fn(model, batch):
    logits = model(batch[0])
    location_loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1])
    B, T = location_loss.shape
    eos_token = tokenizer.token_to_id("<|endoftext|>")
    #location_mask = jnp.cumsum(batch[1] == eos_token, axis=1) == 0
    location_mask = jnp.broadcast_to(jnp.any(batch[1] == eos_token, axis=1, keepdims=True), batch[1].shape)
    #loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    loss = location_loss.mean()
    return loss, (jnp.where(location_mask, location_loss, 0.).sum(axis=0), location_mask.sum(axis=0))

def segexpmeanlog(values, segment_ids, num_segments):
    logvalues = jnp.log(values)
    segsum = jax.ops.segment_sum(logvalues, segment_ids, num_segments)
    segcnt = jax.ops.segment_sum(jnp.ones_like(values), segment_ids, num_segments)
    return jnp.exp(segsum / segcnt)

def fit_irreducible(x, y):
    def power_law(x, a, b, c):
        return a * (x ** b) + c
    from scipy.optimize import curve_fit
    params, _ = curve_fit(power_law, x, y, p0=[1.0, -0.5, 0.0], maxfev=10000)
    a, b, c = params
    return c

def process_loc_loss(loc_loss, *, num_buckets=1):
    T, = loc_loss.shape
    positions = jnp.arange(T)+1
    #irreducible_loss = fit_irreducible(np.array(positions[64:]), np.array(loc_loss[64:]))
    seg_labels = jnp.ceil(num_buckets*jnp.log2(positions)).astype(jnp.int32)
    num_segs = jnp.max(seg_labels) + 1
    # bucket the loc_loss
    B = 64
    assert T % B == 0, f"T={T} not divisible by B={B} for bucketing loc_loss"
    loc_loss_buckets = segexpmeanlog(loc_loss, seg_labels, num_segs)
    #irreducible_loss = jnp.minimum(irreducible_loss, jnp.min(loc_loss_buckets[2:]) - 0.01)
    irreducible_loss = jnp.min(loc_loss_buckets, where=jnp.isfinite(loc_loss_buckets), initial=jnp.inf) - 0.005
    #loc_loss_buckets = loc_loss_buckets - irreducible_loss
    #approx_irreducible_loss = jnp.min(loc_loss_buckets, where=jnp.isfinite(loc_loss_buckets)) - 0.01
    #loc_loss_buckets = segexpmeanlog(loc_loss - approx_irreducible_loss, seg_labels, num_segs)
    non_nan = jnp.isfinite(loc_loss_buckets)
    #loc_loss_buckets = jnp.exp(jnp.log(loc_loss.reshape((T // B, B))).mean(axis=1))
    positions_buckets = segexpmeanlog(positions, seg_labels, num_segs)
    safe_pos_buck = positions_buckets[non_nan]
    safe_loc_loss_buck = loc_loss_buckets[non_nan]
    N = len(safe_pos_buck)
    try:
        irreducible_loss = fit_irreducible(np.array(safe_pos_buck[N//3:-N//3]), np.array(safe_loc_loss_buck[N//3:-N//3]))
    except Exception as e:
        print(f"Warning: fit_irreducible failed with error {e}, using min loc_loss as irreducible_loss")
        irreducible_loss = jnp.min(safe_loc_loss_buck) - 0.01
    irreducible_loss = jnp.minimum(irreducible_loss, jnp.min(safe_loc_loss_buck) - 0.01)
    safe_loc_loss_buck = safe_loc_loss_buck - irreducible_loss
    #positions_buckets = jnp.exp(jnp.log(positions.reshape((T // B, B))).mean(axis=1))
    return safe_pos_buck, safe_loc_loss_buck, irreducible_loss
    #return np.array(positions_buckets)[non_nan], np.array(loc_loss_buckets)[non_nan], irreducible_loss

def record_loc_loss(run, loc_loss, step, **kwargs):
    fig, ax = plt.subplots()
    T = len(loc_loss)
    x, y, irred = process_loc_loss(loc_loss)
    ax.plot(x, y)
    ax.set_title(f"Location-wise Reducible Loss (irred: {irred:.3f})")
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Loss')
    ax.set_yscale('log')
    ax.set_xscale('log')
    run.track(Figure(fig), name="loss_by_position", step=step, **kwargs)
    plt.close(fig)
    fig, ax = plt.subplots()
    x, y, irred = process_loc_loss(loc_loss, num_buckets=4)
    ax.plot(x, y)
    ax.set_title(f"Location-wise Reducible Loss (irred: {irred:.3f})")
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Loss')
    ax.set_yscale('log')
    ax.set_xscale('log')
    run.track(Figure(fig), name="loss_by_position_fine", step=step, **kwargs)
    plt.close(fig)


@nnx.jit
def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, (loc_loss, loc_count)), grads = grad_fn(model, batch)
    #loss = jnp.where(jnp.isfinite(loss), loss, jnp.array(0.0, dtype=loss.dtype))
    #grads = jax.tree.map(lambda g: jnp.where(jnp.isfinite(g), g, jnp.array(0.0, dtype=g.dtype)), grads)
    #metrics.update(loss=loss, logits=logits, lables=batch[1])
    metrics.update(loss=loss, loc_loss=loc_loss, loc_count=loc_count)
    optimizer.update(model, grads)

@nnx.jit
def eval_step(model: MiniGPT, metrics: nnx.MultiMetric, batch):
    loss, (loc_loss, loc_count) = loss_fn(model, batch)
    metrics.update(loss=loss, loc_loss=loc_loss, loc_count=loc_count)

@nnx.jit
def get_all_qkv(model: MiniGPT, batch):
    y, (q, k, v) = model(batch, dump_qkv=True)
    return q, k, v

@nnx.jit
def get_qkv_at_3(model: MiniGPT, batch):
    return model(batch, dump_qkv_at=3)

def save_qkv(model: MiniGPT, batch, step):
    filename = f"out/qkv_step_{step}.npz"
    #qkv = get_qkv_at_3(model, batch)
    q, k, v = get_all_qkv(model, batch)
    #q, k, v = np.array(qkv[0], dtype=np.float32), np.array(qkv[1], dtype=np.float32), np.array(qkv[2], dtype=np.float32)
    q, k, v = np.array(q, dtype=np.float32), np.array(k, dtype=np.float32), np.array(v, dtype=np.float32)
    np.savez(filename, q=q, k=k, v=v)
    print(f"QKV saved to {filename}")

class VectorAverage(nnx.Metric):
    def __init__(self, argname, cntname, shape):
        self.argname = argname
        self.cntname = cntname
        self.total = nnx.Variable(jnp.zeros(shape))
        self.count = nnx.Variable(jnp.zeros((shape), dtype=jnp.int32))

    def update(self, **kwargs):
        value = kwargs[self.argname]
        count = kwargs[self.cntname]
        self.total[...] += value
        self.count[...] += count

    def compute(self):
        return self.total[...] / self.count[...]

    def reset(self):
        self.total[...] = jnp.zeros_like(self.total[...])
        self.count[...] = jnp.zeros_like(self.count[...])

def main():
    assert accumulation_steps == 1, "need to implement grad accum as an inner loop for consistent behaviour"
    print("Starting miniGPT training...")

    # Load data
    #text_dl = load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen)
    text_dl = load_memmap_data(name=dataset, split="train", batch_size=batch_size, maxlen=maxlen)
    text_dataset = MemMapDataset(split="train", name=dataset, maxlen=maxlen)
    start_prompt = "Once upon a time" if dataset == "tinystories8192" else "The main reason"

    dataset_num_sequences = len(text_dataset)
    dataset_num_tokens = len(text_dataset.get_data())
    print(f"Dataset '{dataset}' has {dataset_num_sequences} sequences, {dataset_num_tokens} tokens.")

    batches_per_epoch = len(text_dl._data_source) // batch_size
    tokens_per_batch = batch_size * maxlen
    
    # Initialize RTPT
    rtpt = RTPT(
        name_initials="XX",
        experiment_name="miniGPT",
        max_iterations= batches_per_epoch * num_epochs,
    )

    # Initialize Aim run
    #run = Run(repo="/aim", experiment="miniGPT")
    #run["hparams"] = {
    
    # Create model
    #with jax.set_mesh(mesh):
    model = create_model(rngs=nnx.Rngs(0))
    total_params = sum(x.size for x in jax.tree.leaves(nnx.state(model, nnx.Param)))
    total_Mparams = total_params / 1_000_000
    print(f"Total model parameters: {total_Mparams:.1f}M")
    #run["hparams"]["total_Mparams"] = total_Mparams
    adam_lr = base_adam_lr / SCALE ** 0.5 / FF_SCALE ** 0.5
    adam_schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=adam_lr,
        warmup_steps=warmup_steps,
        decay_steps=total_train_tokens // tokens_per_batch,
        end_value=adam_lr * final_adam_lr_decay,
    )

    adamw = optax.adamw(adam_schedule, b1=adam_b1, b2=adam_b2, eps=1e-6, weight_decay=adam_weight_decay)
    optax_optimizer = optax.chain(optax.clip_by_global_norm(1.0), adamw)
    #optax_optimizer = optax.MultiSteps(optax_optimizer, every_k_schedule=accumulation_steps)
    optax_optimizer = optax.MultiSteps(optax_optimizer, every_k_schedule=lambda step: jnp.where(step < 0, 1, accumulation_steps))
    optimizer = nnx.Optimizer(model, optax_optimizer, wrt=nnx.Param)
    metrics_kwargs = dict(loss=nnx.metrics.Average('loss'))
    if record_location_loss:
        metrics_kwargs |= dict(loc_loss=VectorAverage('loc_loss', 'loc_count', shape=(maxlen,)))
    metrics = nnx.MultiMetric(**metrics_kwargs)
    
    # Initial text generation
    if generation_is_enabled:
        start_tokens = tokenizer.encode(start_prompt).ids[:maxlen]
        generated_text = model.generate_text(min(generation_length, maxlen), start_tokens)
        print(f"Initial generated text:\n{generated_text}\n")
    
    # Training setup
    metrics_history = {'train_loss': []}
    prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))
    step = 0


    def octave_offset_and_width(step):
        step = int(step)
        msb = step.bit_length() - 1
        width = 1 << msb
        offset = step - (1 << msb)
        return offset, width
    def should_eval(step, per_octave_count, min_stride=1):
        step = int(step)
        if step <= 0:
            return False
        step = step + 1
        offset, width = octave_offset_and_width(step)
        stride = max(min_stride, width // int(per_octave_count))
        return offset % stride == 0 and step >= min_stride

    def should_eval_hfreq(step):
        return should_eval(step, 8, min_stride=32)
        #return (step + 1) % 20 == 0
    def should_eval_lfreq(step):
        return should_eval(step, 8, min_stride=2048)
        #return (((step + 1) % 20 == 0) and (step <= 200)) or \
        #return  (((step + 1) % 200 == 0) and (step <= 2000)) or \
        #        ((step + 1) % 2000 == 0)

    if checkpoint_save_path is not None:
        options = ocp.CheckpointManagerOptions(
            preservation_policy=ocp.checkpoint_managers.preservation_policy.PreserveAll(),
        )
        checkpoint_manager = ocp.CheckpointManager(
            checkpoint_save_path,
            options=options,
        )

    def save_checkpoint(step, run_hash):
        model_state = nnx.state(model)
        optimizer_state = nnx.state(optimizer)
        metrics_state = nnx.state(metrics)
        metadata = dict(
            total_Mtokens_so_far=((step + 1) * tokens_per_batch) // 1_000_000,
            run_hash=run_hash,
        )
        checkpoint_manager.save(
            step,
            args=ocp.args.Composite(
                model=ocp.args.StandardSave(model_state),
                optimizer=ocp.args.StandardSave(optimizer_state),
                nnx_metrics=ocp.args.StandardSave(metrics_state),
            ),
            custom_metadata=metadata,
        )
    existing_run_hash = None
    if checkpoint_save_path is not None:
        if checkpoint_manager.latest_step() is not None:
            model_state = nnx.state(model)
            optimizer_state = nnx.state(optimizer)
            metrics_state = nnx.state(metrics)
            fake_metadata = dict(
                total_Mtokens_so_far=0,
                run_hash="",
            )
            step = checkpoint_manager.latest_step()
            checkpoint = checkpoint_manager.restore(
                step,
                args=ocp.args.Composite(
                    model=ocp.args.StandardRestore(model_state),
                    optimizer=ocp.args.StandardRestore(optimizer_state),
                    nnx_metrics=ocp.args.StandardRestore(metrics_state),
                ),
            )
            nnx.update(model, checkpoint.model)
            nnx.update(optimizer, checkpoint.optimizer)
            nnx.update(metrics, checkpoint.nnx_metrics)
            existing_run_hash = checkpoint_manager.metadata(step).custom_metadata.get('run_hash')
            print(f"Restored checkpoint from step {step} with run_hash {existing_run_hash}")


    if existing_run_hash is None:
        run = Run(repo="/aim", experiment="miniGPT")
    else:
        run = Run(repo="/aim", experiment="miniGPT", run_hash=existing_run_hash)
    run["hparams"] = {
        "vocab_size": vocab_size,
        "num_transformer_blocks": num_transformer_blocks,
        "maxlen": maxlen,
        "embed_dim": embed_dim,
        "num_heads": num_heads,
        "global_num_heads": global_num_heads,
        "head_dim": head_dim,
        "global_head_dim": global_head_dim,
        "feed_forward_dim": feed_forward_dim,
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "dataset": dataset,
        "position_encodings": position_encodings,
        "attn_dtype": attn_dtype_str,
        "param_dtype": param_dtype_str,
        "ff_dtype": ff_dtype_str,
        "final_dtype": final_dtype_str,
        "hybrid_attention": hybrid_attention,
        "sliding_window": sliding_window,
        "head_dim": head_dim,
        "normalize_qk": normalize_qk,
        "intermediate_window_size": intermediate_window_size,
        "global_window_size": global_window_size,
        "base_adam_lr": base_adam_lr,
        "generation_length": generation_length,
        "total_train_tokens": total_train_tokens,
        "warmup_steps": warmup_steps,
        "adam_b1": adam_b1,
        "adam_b2": adam_b2,
        "adam_weight_decay": adam_weight_decay,
        "final_adam_lr_decay": final_adam_lr_decay,
        "muon_lr": muon_lr,
        "use_muon": use_muon,
        "rope_base": rope_base,
        "local_temperature": local_temperature,
        "intermediate_temperature": intermediate_temperature,
        "global_temperature": global_temperature,
        "accumulation_steps": accumulation_steps,
        "dump_qkv": dump_qkv,
        "experiment_set": experiment_set,
        "experiment_label": experiment_label,
        "generation_is_enabled": generation_is_enabled,
        "dataset_total_size_Mtokens": dataset_num_tokens//1_000_000,
        "residual_scale": residual_scale,
        "use_prenorm_not_postnorm": use_prenorm_not_postnorm,
        "use_fma_attention": use_fma_attention,
        "fma_Q": fma_Q,
        "fma_K": fma_K,
        "fma_num_retrievals": fma_num_retrievals,
        "fma_num_spatial": fma_num_spatial,
        "fma_blk_size": fma_blk_size,
        "fma_bidiagonal": fma_bidiagonal,
        "fma_dipole": fma_dipole,
        "super_block_size": super_block_size,
        "total_Mparams": total_Mparams,
    }


    
    # Training loop
    terminate_training = False
    rtpt.start()
    for epoch in range(num_epochs):
        if terminate_training:
            break
        start_time = time.time()
        start_step = step
        #for batch in text_dataset.batches_iter(batch_size):
        for batch in text_dl:
            if terminate_training:
                break
            if False and len(batch) % len(jax.devices()) != 0:
                raise ValueError(f"Batch size {len(batch)} not divisible by number of devices {len(jax.devices())}")
                continue
                
            input_batch = jnp.array(batch[:,:-1])
            target_batch = jnp.array(batch[:,1:])
            #input_batch = jnp.array(jnp.array(batch).T)
            #input_batch = jnp.array(batch)
            #input_batch, target_batch = batch
            assert input_batch.shape == (batch_size, maxlen), \
                f"Expected input shape {(batch_size, maxlen)}, got {input_batch.shape}"
            #target_batch = prep_target_batch(input_batch)
            
            if False and step == 0:
                #DEBUG MEMORY
                model_graph, model_state = nnx.split(model)
                opt_graph, opt_state = nnx.split(optimizer)
                metrics_graph, metrics_state = nnx.split(metrics)
                def pure_train_step(model_state, opt_state, metrics_state, batch):
                    model = nnx.merge(model_graph, model_state)
                    optimizer = nnx.merge(opt_graph, opt_state)
                    metrics = nnx.merge(metrics_graph, metrics_state)
                    train_step(model, optimizer, metrics, batch)
                    _, new_model_state = nnx.split(model)
                    _, new_opt_state = nnx.split(optimizer)
                    _, new_metrics_state = nnx.split(metrics)
                    return new_model_state, new_opt_state, new_metrics_state


                pure_train_step_args = (model_state, opt_state, metrics_state, 
                          jax.device_put((input_batch, target_batch), 
                                       NamedSharding(mesh, P('batch', None))))
                lowered = jax.jit(pure_train_step).lower(*pure_train_step_args)
                with open("/run/determined/train_step_hlo.txt", "w") as f:
                    f.write(lowered.as_text())
                compiled = lowered.compile()
                mem_analysis = compiled.memory_analysis()
                print(f"Memory analysis at step {step}:\n{mem_analysis}")
            if False and step == 257:
                terminate_training = True
                run.close()
                exit()

            with nvtx.annotate("train_step", color="blue"):
                train_step(model, optimizer, metrics, 
                      jax.device_put((input_batch, target_batch), 
                                   NamedSharding(mesh, P('batch', None))))

            if step == 600:
                nvtx.push_range("training")
            if step >= 699:
                nvtx.pop_range()
                #run.close()
                #exit()
            
            if should_eval_hfreq(step):
                computed = metrics.compute()
                # pop loc_loss
                loc_loss = computed.pop('loc_loss', None)
                for metric, value in computed.items():
                    metrics_history[f'train_{metric}'].append(value)
                metrics.reset()
                
                old_start_time = start_time
                start_time = time.time()
                elapsed_time = start_time - old_start_time
                elapsed_steps = step - start_step
                time_per_step = elapsed_time / max(elapsed_steps, 1)
                mega_tokens_per_second = tokens_per_batch / time_per_step / 1.0e6
                train_loss = metrics_history['train_loss'][-1]
                print(f"[{step + 1} ({elapsed_time:.2f}s)] "
                      f"Loss: {train_loss:.4f}, "
                      f"{1e3*time_per_step:.1f}ms/step, "
                      f"{mega_tokens_per_second:.2f}Mtok/s")
                #print(f"Step {step + 1}, Loss: {metrics_history['train_loss'][-1]:.4f}, "
                #      f"Elapsed Time: {elapsed_time:.2f} seconds")
                total_Mtokens_so_far = ((step + 1) * tokens_per_batch) // 1_000_000
                if loc_loss is not None:
                    record_loc_loss(run, loc_loss, step=total_Mtokens_so_far)
                run.track(float(metrics_history['train_loss'][-1]), name='train_loss', step=total_Mtokens_so_far)
                run.track(step * tokens_per_batch, name='train_tokens', step=total_Mtokens_so_far)
                run.track(step, name='train_steps', step=total_Mtokens_so_far)
                current_lr = adam_schedule(optimizer.step[...])
                run.track(float(current_lr), name='learning_rate', step=total_Mtokens_so_far)
                #start_time = time.time()
                start_step = step
                
                if should_eval_lfreq(step) and generation_is_enabled:
                    generated_text = model.generate_text(min(generation_length, maxlen), start_tokens)
                    elapsed_time = time.time() - start_time
                    print(f"Generated text ({elapsed_time:.2f}s):\n{generated_text}\n")
                    start_time = time.time()
            if False and step >= 1000:
                run.close()
                exit()
            if should_eval_hfreq(step) and (step+1) * tokens_per_batch >= total_train_tokens:
                terminate_training = True
                #run.close()
                #exit()
            if dump_qkv and step % 200 == 0:
                save_qkv(model, input_batch, step)
            
            step += 1
            rtpt.step()
            if checkpoint_save_path is not None and step % checkpoint_save_interval_steps == 0:
                print(f"Saving checkpoint at step {step}...")
                save_checkpoint(step, run_hash=run.hash)
    #run.close()
    #exit()

    # Do eval pass
    print("Loading eval data...")
    MAX_EVAL_STEPS = 3000
    eval_dl = load_memmap_data(name=dataset, split="test", batch_size=batch_size, maxlen=maxlen)
    #print(f"WARNING: Testing on train set for debug purposes")
    #eval_dl = load_memmap_data(name=dataset, split="train", batch_size=batch_size, maxlen=maxlen)
    metrics.reset()
    print("Starting cudnn evaluation...")
    #model.super_blocks.global_block.mha.attention_fn = sharded_rope_causal_cudnn_attention
    for sblk in model.super_blocks:
        sblk.global_block.mha.attention_fn = sharded_rope_causal_cudnn_attention
    for i, batch in enumerate(tqdm(eval_dl)):
        input_batch = jnp.array(batch[:,:-1])
        target_batch = jnp.array(batch[:,1:])
        assert input_batch.shape == (batch_size, maxlen), \
            f"Expected input shape {(batch_size, maxlen)}, got {input_batch.shape}"
        eval_step(model, metrics, 
                  jax.device_put((input_batch, target_batch), 
                               NamedSharding(mesh, P('batch', None))))
        if i > MAX_EVAL_STEPS:
            break
    computed = metrics.compute()
    loc_loss = computed.pop('loc_loss', None)
    for metric, value in computed.items():
        print(f"Eval {metric}: {value:.4f}")
    if loc_loss is not None:
        record_loc_loss(run, loc_loss, step=None, context={"attn": "cudnn"})
    run.track(float(value), name='final_loss', step=None, context={"attn": "cudnn"})
    metrics.reset()
    print("Starting fma evaluation...")
    #model.super_blocks.global_block.mha.attention_fn = sharded_rope_fma_attention
    for sblk in model.super_blocks:
        sblk.global_block.mha.attention_fn = sharded_rope_fma_attention
    for i, batch in enumerate(tqdm(eval_dl)):
        input_batch = jnp.array(batch[:,:-1])
        target_batch = jnp.array(batch[:,1:])
        assert input_batch.shape == (batch_size, maxlen), \
            f"Expected input shape {(batch_size, maxlen)}, got {input_batch.shape}"
        eval_step(model, metrics, 
                  jax.device_put((input_batch, target_batch), 
                               NamedSharding(mesh, P('batch', None))))
        if i > MAX_EVAL_STEPS:
            break
    computed = metrics.compute()
    loc_loss = computed.pop('loc_loss', None)
    for metric, value in computed.items():
        print(f"Eval {metric}: {value:.4f}")
    if loc_loss is not None:
        record_loc_loss(run, loc_loss, step=None, context={"attn": "fma"})
    run.track(float(value), name='final_loss', step=None, context={"attn": "fma"})


    
    # Final generation
    if generation_is_enabled:
        generated_text = model.generate_text(min(generation_length, maxlen), start_tokens)
        print(f"Final generated text:\n{generated_text}")

    # Close aim run
    run.close()
    if checkpoint_save_path is not None:
        checkpoint_manager.close()
    exit()
    
    # Plot training loss
    plt.figure(figsize=(10, 6))
    plt.plot(metrics_history['train_loss'])
    plt.title('Training Loss')
    plt.xlabel('Step (x200)')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.savefig('training_loss.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("Training loss plot saved as 'training_loss.png'")
    
    # Save model checkpoint
    state = nnx.state(model)
    checkpointer = orbax.PyTreeCheckpointer()
    checkpoint_dir = './minigpt_checkpoint'
    checkpointer.save(checkpoint_dir, state)
    print(f"Model checkpoint saved to {checkpoint_dir}")

if __name__ == "__main__":
    main()
