#!/usr/bin/env python
"""
Train a LLaMA on sequence prediction tasks with Beta-rel. positional encoding.
"""

import os
import argparse
import random
import json
import math
from typing import Dict, Optional, Tuple
import torch
import torch.distributed as dist
from torch.optim import AdamW
from torch.nn.parallel import DistributedDataParallel as DDP
import matplotlib.pyplot as plt
from tqdm import tqdm
from llm import *
from dataloder import *
from transformers import LlamaConfig, LlamaForCausalLM
import math


def beta_encode(x: int) -> Optional[str]:
    """β-encoding as defined in the paper.
    
    For a position x, β(x) is defined as follows:
    1. If x = 0, β(x) is undefined (return None)
    2. Write x in binary: x = Σᵢ₌₀ᵐ bᵢ·2ⁱ where bₘ = 1
    3. Find the leftmost '0' at position j (if any)
    4. Return the suffix bⱼ₊₁bⱼ₊₂...bₘ (bits after leftmost '0')
    5. If no '0' exists, β(x) is undefined (return None)
    
    Examples from paper:
    - β(2^5 + 2^3 + 2^1) = β(42) = "101010" → leftmost '0' at pos 0 → "01010" 
    - β(2^4 + 2^1) = β(18) = "10010" → leftmost '0' at pos 1 → "010"
    - β(2^4 + 2^0) = β(17) = "10001" → leftmost '0' at pos 1 → "001"
    """
    if x == 0:
        return None
    bits = bin(x)[2:]  # binary string without '0b'
    # find leftmost '0' (reading from left, which is most significant bit)
    try:
        j = bits.index('0')
    except ValueError:
        # no '0' found - all bits are 1, so β(x) is undefined
        return None
    # return suffix after leftmost '0'
    return bits[j+1:]

def build_beta_relation(max_positions: int) -> torch.Tensor:
    """
    Build the beta-relation matrix R[i,j] as defined in the paper.

    The relation R ⊆ N * N is defined as:
    (i,j) ∈ R ⟺ i ≤ j, i ∈ [1, |beta(j)|], and beta(j) has '1' at position i

    For the matrix R[i,j] (0-indexed storage):
    - i,j are positions (1-indexed in paper, 0-indexed in storage)
    - R[i-1,j-1] = 1 if beta(j) has '1' at position i, else 0

    Args:
        max_positions: Maximum sequence length to support
        
    Returns:
        Tensor of shape (max_positions, max_positions) with R[i,j] values
    """
    R = torch.zeros((max_positions, max_positions), dtype=torch.float32)
    
    for j in range(1, max_positions + 1):  # j is 1-indexed position
        b = beta_encode(j)
        if b is None:
            continue
            
        # For each bit position in β(j)
        for i, bit_char in enumerate(b, start=1):  # i is 1-indexed bit position
            if bit_char == '1':
                # Store in 0-indexed tensor: R[i-1, j-1] = 1
                if i-1 < max_positions and j-1 < max_positions:
                    R[i-1, j-1] = 1.0
                    
    return R

def patch_llama_with_beta_pos(model, max_positions, beta_scale=1.0, enabled=True, use_beta=True):
    """
    Replace every self-attention block with β-RPE attention (once)."""
    if not enabled:
        return model

    print(f"[INFO] Patching model with β-RPE (max_positions={max_positions}, "
          f"beta_scale={beta_scale})")

    n_patched = 0
    for layer_idx, layer in enumerate(model.model.layers):
        old_attention = layer.self_attn

        # don't double-patch
        if isinstance(old_attention, LlamaAttentionWithBetaPos):
            continue

        new_attention = LlamaAttentionWithBetaPos(
            model.config,
            max_positions=max_positions,
            layer_idx=layer_idx,
            beta_scale=beta_scale,
            use_beta=use_beta
        )

        # move + copy weights
        device = next(old_attention.parameters()).device
        new_attention = new_attention.to(device)
        new_attention.q_proj.load_state_dict(old_attention.q_proj.state_dict())
        new_attention.k_proj.load_state_dict(old_attention.k_proj.state_dict())
        new_attention.v_proj.load_state_dict(old_attention.v_proj.state_dict())
        new_attention.o_proj.load_state_dict(old_attention.o_proj.state_dict())

        layer.self_attn = new_attention
        n_patched += 1

    print(f"[INFO] Successfully patched {n_patched} layers with β-RPE")
    return model



def set_seeds(seed: int = 42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

  
def causal_token_acc(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100):
    logits_s = logits[:, :-1, :]
    labels_s = labels[:, 1:]
    mask = labels_s.ne(ignore_index)
    preds = logits_s.argmax(dim=-1)
    num = (preds.eq(labels_s) & mask).sum().item()
    den = mask.sum().item()
    return num, den

def causal_seq_acc(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100):
    logits_s = logits[:, :-1, :]
    labels_s = labels[:, 1:]
    preds = logits_s.argmax(dim=-1)
    mask = labels_s.ne(ignore_index)
    num = 0
    den = 0
    B = labels_s.size(0)
    for b in range(B):
        m = mask[b]
        if not m.any():
            continue
        num += int((preds[b][m] == labels_s[b][m]).all().item())
        den += 1
    return num, den


# ----------------------
# Eval
# ----------------------
@torch.no_grad()
def evaluate(model, dl, device, use_amp: bool = True):
    model.eval()
    total_loss = 0.0
    total_count = 0
    tok_n = tok_d = 0
    seq_n = seq_d = 0

    for batch in dl:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        bsz = batch["input_ids"].size(0)
        with torch.amp.autocast("cuda", enabled=use_amp):
            out = model(**batch)
            loss, logits = out.loss, out.logits
        total_loss += loss.item() * bsz
        total_count += bsz
        n, d = causal_token_acc(logits, batch["labels"], IGNORE_INDEX)
        tok_n += n; tok_d += d
        n, d = causal_seq_acc(logits, batch["labels"], IGNORE_INDEX)
        seq_n += n; seq_d += d

    avg_loss = (total_loss / total_count) if total_count else 0.0
    tok_acc = (tok_n / tok_d) if tok_d else 0.0
    seq_acc = (seq_n / seq_d) if seq_d else 0.0
    return avg_loss, tok_acc, seq_acc


def save_checkpoint(model, ds, save_dir: str, save_name: str = "model"):
    model_dir = os.path.join(save_dir, save_name)
    os.makedirs(model_dir, exist_ok=True)
    model.save_pretrained(model_dir)
    itos_list = [ds.itos[i] for i in range(len(ds.itos))]
    with open(os.path.join(model_dir, "itos.json"), "w", encoding="utf-8") as f:
        json.dump(itos_list, f, ensure_ascii=False)

def load_vocab(save_dir: str, save_name: str = "model") -> Tuple[Dict[str,int], Dict[int,str]]:
    model_dir = os.path.join(save_dir, save_name)
    itos_json = os.path.join(model_dir, "itos.json")
    if os.path.exists(itos_json):
        with open(itos_json, "r", encoding="utf-8") as f:
            itos_list = json.load(f)
        itos = {i: ch for i, ch in enumerate(itos_list)}
        stoi = {ch: i for i, ch in enumerate(itos_list)}
        return stoi, itos
    # fallback for old checkpoints
    stoi = torch.load(os.path.join(model_dir, "stoi.pt"))
    itos = torch.load(os.path.join(model_dir, "itos.pt"))
    return stoi, itos

def refresh_ds_special_ids(ds):
    for tok in SPECIAL.values():
        if tok not in ds.stoi:
            raise ValueError(f"Special token {repr(tok)} missing from vocab.")
    ds.pad_id = ds.stoi[SPECIAL["PAD"]]
    ds.sep_id = ds.stoi[SPECIAL["SEP"]]
    ds.eos_id = ds.stoi[SPECIAL["EOS"]]
    ds.bos_id = ds.stoi[SPECIAL["BOS"]]
    
def train_gpt2(
    n_embd=128, n_layer=2, n_head=4, batch_size=32, n_epochs=10, lr=3e-4, model='mul',
    save_dir="checkpoints", save_name="saved_model", seed=42, use_amp=True, resume=False,
    ddp=False, local_rank=0, world_size=1, max_steps: int = 30000, pos_bias: bool = True,
    pos_mode: str = "beta", beta_scale: float = 1.0, gradient_checkpointing: bool = False
):
    use_beta = (args.beta_predicate == 'on')
    # DDP setup
    if ddp:
        if 'LOCAL_RANK' in os.environ:
            local_rank = int(os.environ['LOCAL_RANK'])
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    set_seeds(seed)

    train_dl, val_dls, train_ds, (val0_ds, val1_ds, val2_ds) = create_dataloaders(
        batch_size, model, ddp=ddp, local_rank=local_rank
    )
    val0_dl, val1_dl, val2_dl = val_dls

    # compute max sequence length across all splits
    with torch.no_grad():
        max_len = 0
        for s in train_ds.samples + val0_ds.samples + val1_ds.samples + val2_ds.samples:
            max_len = max(max_len, len(s))
    print('Max fused length across splits:', max_len)

    ctx = int(math.ceil(max_len / 64) * 64)
    model_path = os.path.join(save_dir, save_name)

    if resume and os.path.exists(model_path):
        if local_rank == 0:
            print(f"[INFO] Loading pretrained model from {model_path}")
        model_obj = LlamaForCausalLM.from_pretrained(model_path).to(device)

        # If user asks for beta but the loaded checkpoint wasn't patched, patch now.
        if pos_mode == 'beta':
            any_beta = any(isinstance(l.self_attn, LlamaAttentionWithBetaPos)
                        for l in model_obj.model.layers)
            if not any_beta:
                patch_llama_with_beta_pos(
                    model_obj,
                    max_positions=ctx,
                    beta_scale=beta_scale,
                    enabled=True,
                    use_beta=use_beta
                )
    else:
        # Determine positioning configuration based on mode
        config = LlamaConfig(
            vocab_size=len(train_ds.stoi),
            max_position_embeddings=ctx,
            hidden_size=n_embd,
            num_hidden_layers=n_layer,
            num_attention_heads=n_head,
            intermediate_size=4 * n_embd,
            pad_token_id=train_ds.pad_id,
            eos_token_id=train_ds.eos_id,
            bos_token_id=train_ds.bos_id,
            hidden_dropout_prob=0,
            attention_probs_dropout_prob=0,
        )

        model_obj = LlamaForCausalLM(config).to(device)

        if pos_mode == 'beta':
            patch_llama_with_beta_pos(
                model_obj,
                max_positions=ctx,
                beta_scale=beta_scale,
                enabled=True,
                use_beta = use_beta
            )
            if local_rank == 0:
                print(f"[INFO] Using β-RPE positional encoding with scale {beta_scale}")
        elif pos_mode == 'zero':
            if local_rank == 0:
                print("[INFO] Using neutralized positioning (no positional bias)")
            pass  # keep original LLaMA attention with neutralized positioning
        elif pos_mode == 'none':
            if local_rank == 0:
                print("[INFO] Using no positional encoding (vanilla attention)")
            pass  # no positional encoding at all - keep vanilla attention

    # Enable gradient checkpointing if requested
    if gradient_checkpointing:
        model_obj.gradient_checkpointing_enable()
        if local_rank == 0:
            print("[INFO] Gradient checkpointing enabled to save memory")

    if ddp:
        model_ddp = DDP(model_obj, device_ids=[local_rank], output_device=local_rank)
        model_ref = model_ddp
        model_for_eval = model_ddp.module
    else:
        model_ref = model_obj
        model_for_eval = model_obj

    optimizer = AdamW(model_ref.parameters(), 
                     lr=1e-4 if model == 'mul' else 5e-5 if model == 'dvd' else 3e-4,
                     weight_decay=0.03 if model == 'mul' else 0.05 if model == 'dvd' else 0.01)
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    except Exception:
        pass

    best_val_avg_seq_acc = 0.0
    best_epoch = -1
    best_stats = {}
    train_losses = []
    val0_losses = []
    global_step = 0
    early_stop_flag = False
    
    # Early stopping parameters
    patience = 50  # Number of epochs to wait for improvement  
    patience_counter = 0
    best_val_loss = float('inf')

    for epoch in tqdm(range(n_epochs), disable=(ddp and local_rank != 0)):
        if ddp:
            train_dl.sampler.set_epoch(epoch)

        model_ref.train()
        run_loss = 0.0
        run_cnt = 0
        tok_n = tok_d = 0
        seq_n = seq_d = 0

        for batch in train_dl:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            bsz = batch["input_ids"].size(0)

            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", enabled=use_amp):
                out = model_ref(**batch)
                loss, logits = out.loss, out.logits
            if ddp:
                loss = loss / world_size
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model_ref.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            global_step += 1
            if global_step >= max_steps:
                if (not ddp) or local_rank == 0:
                    print(f"Early stopping: Reached max_steps {max_steps} at epoch {epoch}, step {global_step}")
                early_stop_flag = True
                break

            run_loss += loss.item() * bsz
            run_cnt += bsz
            n, d = causal_token_acc(logits.detach(), batch["labels"], IGNORE_INDEX)
            tok_n += n; tok_d += d
            n, d = causal_seq_acc(logits.detach(), batch["labels"], IGNORE_INDEX)
            seq_n += n; seq_d += d

        train_loss = run_loss / run_cnt if run_cnt else 0.0
        train_tok = (tok_n / tok_d) if tok_d else 0.0
        train_seq = (seq_n / seq_d) if seq_d else 0.0

        val0_loss, val0_tok, val0_seq = evaluate(model_for_eval, val0_dl, device, use_amp=use_amp)
        val1_loss, val1_tok, val1_seq = evaluate(model_for_eval, val1_dl, device, use_amp=use_amp)
        val2_loss, val2_tok, val2_seq = evaluate(model_for_eval, val2_dl, device, use_amp=use_amp)

        if (not ddp) or local_rank == 0:
            train_losses.append(train_loss)
            val0_losses.append(val0_loss)
            
            avg_val_seq_acc = (val0_seq + val1_seq + val2_seq) / 3.0
            
            # Early stopping based on val0 loss (primary validation set)
            if val0_loss < best_val_loss:
                best_val_loss = val0_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if avg_val_seq_acc > best_val_avg_seq_acc:
                best_val_avg_seq_acc = avg_val_seq_acc
                best_epoch = epoch
                best_stats = {
                    'train_loss': train_loss, 'train_tok': train_tok, 'train_seq': train_seq,
                    'val0_loss': val0_loss, 'val0_tok': val0_tok, 'val0_seq': val0_seq,
                }
                save_checkpoint(model_for_eval, train_ds, save_dir, save_name=save_name)
                print(f"Epoch {epoch} | step {global_step} | *** NEW BEST AVG SEQ ACC: {avg_val_seq_acc:.4f} ***")
                print(f"Train | loss {train_loss:.4f} | tok_acc {train_tok:.4f} | seq_acc {train_seq:.4f}")
                print(f"Val0/1/2 | loss {val0_loss:.4f}/{val1_loss:.4f}/{val2_loss:.4f} | seq_acc {val0_seq:.4f}/{val1_seq:.4f}/{val2_seq:.4f}")
                print(f"[INFO] Saved checkpoint to {os.path.join(save_dir, save_name)}")
            else:
                print(f"Epoch {epoch} | step {global_step} | Avg seq acc: {avg_val_seq_acc:.4f} (best: {best_val_avg_seq_acc:.4f})")
                print(f"Train | loss {train_loss:.4f} | tok_acc {train_tok:.4f} | seq_acc {train_seq:.4f}")
                print(f"Val0/1/2 | loss {val0_loss:.4f}/{val1_loss:.4f}/{val2_loss:.4f} | seq_acc {val0_seq:.4f}/{val1_seq:.4f}/{val2_seq:.4f}")
                print(f"[INFO] No improvement (best avg: {best_val_avg_seq_acc:.4f} at epoch {best_epoch})")
            # Early stopping conditions
            if avg_val_seq_acc >= 1.0:
                print(f"Early stopping: Average validation seq accuracy reached 100% at epoch {epoch}.")
                break
                
            if patience_counter >= patience:
                print(f"Early stopping: No improvement in val0_loss for {patience} epochs. Best val0_loss: {best_val_loss:.4f}")
                break

        if early_stop_flag or (global_step >= max_steps):
            break

    # Post-training logging and plotting
    if (not ddp) or local_rank == 0:
        summary_path = os.path.join('logs', f"{model}.txt")
        os.makedirs(os.path.dirname(summary_path), exist_ok=True)

        # Re-evaluate val1 and val2 on final model
        val1_loss, val1_tok, val1_seq = evaluate(model_for_eval, val1_dl, device, use_amp=use_amp)
        val2_loss, val2_tok, val2_seq = evaluate(model_for_eval, val2_dl, device, use_amp=use_amp)

        with open(summary_path, "a") as f:
            avg_final_seq_acc = (best_stats.get('val0_seq', 0) + val1_seq + val2_seq) / 3.0
            # Save individual validation accuracies for all three validation sets
            f.write(
                f'{n_embd}_{n_layer}_{n_head}_{lr}: '
                f"Best epoch: {best_epoch} | Best avg seq_acc: {best_val_avg_seq_acc:.4f} | "
                f"Train loss: {best_stats.get('train_loss', 0):.4f} | "
                f"Train tok_acc: {best_stats.get('train_tok', 0):.4f} | "
                f"Train seq_acc: {best_stats.get('train_seq', 0):.4f} | "
                f"Val0 seq_acc: {best_stats.get('val0_seq', 0):.4f} | "
                f"Val1 seq_acc: {val1_seq:.4f} | "
                f"Val2 seq_acc: {val2_seq:.4f} | "
                f"Final avg seq_acc: {avg_final_seq_acc:.4f}\n"
            )

        plt.figure()
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val0_losses, label='Val0 Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'Training and Validation Loss ({model})')
        plt.legend()
        plt.grid(True)
        plot_path = os.path.join('logs', f"{model}_loss_curve.png")
        os.makedirs(os.path.dirname(plot_path), exist_ok=True)
        plt.savefig(plot_path)
        plt.close()
        print(f"[INFO] Saved loss curve to {plot_path}")

    if ddp:
        dist.destroy_process_group()
    return model_ref, train_ds

# ----------------------
# Inference (greedy)
# ----------------------
@torch.no_grad()
def greedy_decode(
    model, ds, input_str: str,
    max_new_tokens: Optional[int] = None,
    stop_on_eos: bool = True,
    min_tokens: int = 1
):
    model.eval()
    device = next(model.parameters()).device
    refresh_ds_special_ids(ds)

    # Check for out-of-vocabulary characters
    unknown = [ch for ch in input_str if ch not in ds.stoi]
    if unknown:
        raise ValueError(f"OOV chars {unknown}. Not in training vocab: {unknown}")

    # Build prefix: $ INPUT |
    prefix = SPECIAL["BOS"] + input_str + SPECIAL["SEP"]
    ids = [ds.stoi[ch] for ch in prefix]
    x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    attn = torch.ones_like(x, device=device)

    max_ctx = getattr(model.config, 'max_position_embeddings', 2048)
    if x.size(1) >= max_ctx:
        raise ValueError(f"Input prefix length {x.size(1)} exceeds model capacity max_ctx={max_ctx}.")

    if max_new_tokens is None:
        max_new_tokens = max_ctx - x.size(1)

    generated = 0
    for _ in range(max_new_tokens):
        out = model(input_ids=x, attention_mask=attn)
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1)

        if stop_on_eos and generated >= min_tokens and next_id.item() == ds.eos_id:
            break

        x = torch.cat([x, next_id.unsqueeze(-1)], dim=1)
        attn = torch.cat([attn, torch.ones((1, 1), dtype=attn.dtype, device=device)], dim=1)
        generated += 1

        if x.size(1) >= max_ctx:
            break

    full_decoded = "".join(ds.itos[int(t)] for t in x[0].tolist())
    after_sep = full_decoded.split(SPECIAL["SEP"], 1)[1] if SPECIAL["SEP"] in full_decoded else ""
    pred_output = after_sep.split(SPECIAL["EOS"], 1)[0] if SPECIAL["EOS"] in after_sep else after_sep
    return full_decoded, pred_output


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='mul', choices=['mul','dvd', 'exp', 'prime', 'gcd', 'add'])
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'inference'])
    parser.add_argument('--save_dir', type=str, default='checkpoints')
    parser.add_argument('--save_name', type=str, default='saved_model')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--nhead', type=int, default=4)
    parser.add_argument('--nlayer', type=int, default=6)
    parser.add_argument('--nembd', type=int, default=384)  # Increased default for better complex task performance
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--no_amp', action='store_true', help='Disable mixed precision')
    parser.add_argument('--resume', action='store_true', help='Continue training from checkpoints/{name}')
    parser.add_argument('--ddp', action='store_true', help='Enable DistributedDataParallel training')
    parser.add_argument('--local_rank', type=int, default=0, help='Local rank for DDP')
    parser.add_argument('--world_size', type=int, default=1, help='World size for DDP')
    parser.add_argument('--max_steps', type=int, default=30000, help='Maximum training steps before stopping')
    # Positional encoding controls
    parser.add_argument('--pos_bias', type=str, default='on', choices=['on','off'],
                        help='Add a learned NxN positional bias to attention logits')
    parser.add_argument('--pos_mode', type=str, default='beta',
                    choices=['zero','beta','none'],
                    help='zero (neutralize positioning), beta (β-RPE bias), or none (no positional encoding)')
    parser.add_argument('--beta_scale', type=float, default=1.0,
                        help='Initial scale λ for the β positional bias')
    parser.add_argument('--beta_predicate', type=str, default='on', choices=['on','off'],
                        help='Add λ·[Rβ] bias (on) or not (off)')
    parser.add_argument('--longer_training', action='store_true',
                        help='Use longer training with more steps for complex problems')
    parser.add_argument('--gradient_checkpointing', action='store_true',
                        help='Enable gradient checkpointing to save memory')
    args = parser.parse_args()
    
    use_amp = not args.no_amp
    set_seeds(args.seed)
    


    # Define alphabet and oracle for each model
    model_config = {
        'mul': (['1','0','/'], 'build_mul_binary_kcm'),
        'dvd': (['1','0','/'], 'build_dvd_binary_kcm'),
        'gcd': (['1','0','/'], 'build_gcd_binary_kcm'),
        'exp': (['1','0','/'], 'build_exp_binary_kcm'),
        'prime': (['1','0'], 'build_prime_binary_kcm'),
        'add': (['1','0','/'], 'build_add_binary_kcm'),
    }
    
    alp, kcm_func_name = model_config.get(args.model, ([], None))
    
    import kcm_binary as kcm
    auto = getattr(kcm, kcm_func_name)() if kcm_func_name else None

    if args.mode == 'train':
        # Adjust training parameters for complex problems if requested
        # max_steps = 120000 if args.longer_training else 60000
        max_steps = args.max_steps
        if args.longer_training:
            max_steps = max(max_steps, 120000)
        model_ref, ds = train_gpt2(
            n_embd=args.nembd, n_layer=args.nlayer, n_head=args.nhead,
            batch_size=args.batch_size, n_epochs=args.epochs, lr=args.lr, model=args.model,
            save_dir=args.save_dir, save_name=args.save_name,
            seed=args.seed, use_amp=use_amp, resume=args.resume,
            ddp=args.ddp, local_rank=args.local_rank, world_size=args.world_size,
            pos_bias=(args.pos_bias == 'on'), max_steps=max_steps,
            pos_mode=args.pos_mode, beta_scale=args.beta_scale,
            gradient_checkpointing=args.gradient_checkpointing
        )
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model_dir = os.path.join(args.save_dir, args.save_name)

        model = LlamaForCausalLM.from_pretrained(
            model_dir,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True
        ).to(device)
        model.eval()

        stoi, itos = load_vocab(args.save_dir, args.save_name)
        ds = ARLMGptDataset([""], [""], stoi, itos)

        while True:
            try:
                s = input(f"Input string (only {','.join(alp)}). Empty to quit: ").strip()
            except EOFError:
                break
            if not s:
                break
            if not all(ch in alp for ch in s):
                print(f"Invalid input. Only {alp} allowed.")
                continue
            full, yhat = greedy_decode(model, ds, s)
            if auto is not None:
                print("pred:", yhat, "\ngold:", auto.output_generator(s, alp))
            else:
                print("pred:", yhat)
