#!/usr/bin/env python
"""
Train a LLaMA on sequence prediction tasks with relative positional encodings (RPE).

This script:
  * Loads KCM-based datasets (mul, dvd, exp, prime, gcd, add)
  * Builds a small LLaMA
  * Optionally patches its attention with generic RPE as in the paper
    (via LlamaAttentionWithRPE from rpe_llama.py)
  * Trains with early stopping and logs results
"""

import os
import argparse
import random
import json
import math
from typing import Dict, Optional, Tuple

import torch
import torch.distributed as dist
from torch.optim import AdamW
from torch.nn.parallel import DistributedDataParallel as DDP
import matplotlib.pyplot as plt
from tqdm import tqdm

from dataloder import *    # expects create_dataloaders, IGNORE_INDEX or dataset classes
from transformers import LlamaConfig, LlamaForCausalLM

# --- NEW: import generic RPE machinery (no β-specific stuff) -----------------
from rpe_llama import (
    build_relation_matrix,
    LlamaAttentionWithRPE,
    replace_llama_attn_with_rpe,
)

# Fallback in case IGNORE_INDEX is not imported from dataloder
try:
    IGNORE_INDEX
except NameError:
    IGNORE_INDEX = -100

try:
    SPECIAL
except NameError:
    SPECIAL = {
        "BOS": "$",
        "EOS": "#",
        "PAD": "<PAD>",
        "SEP": "|",
    }


# ===================== Utils =====================

def set_seeds(seed: int = 42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def causal_token_acc(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100):
    logits_s = logits[:, :-1, :]
    labels_s = labels[:, 1:]
    mask = labels_s.ne(ignore_index)
    preds = logits_s.argmax(dim=-1)
    num = (preds.eq(labels_s) & mask).sum().item()
    den = mask.sum().item()
    return num, den


def causal_seq_acc(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100):
    logits_s = logits[:, :-1, :]
    labels_s = labels[:, 1:]
    preds = logits_s.argmax(dim=-1)
    mask = labels_s.ne(ignore_index)
    num = 0
    den = 0
    B = labels_s.size(0)
    for b in range(B):
        m = mask[b]
        if not m.any():
            continue
        num += int((preds[b][m] == labels_s[b][m]).all().item())
        den += 1
    return num, den


# ===================== Eval =====================

@torch.no_grad()
def evaluate(model, dl, device, use_amp: bool = True):
    model.eval()
    total_loss = 0.0
    total_count = 0
    tok_n = tok_d = 0
    seq_n = seq_d = 0

    for batch in dl:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        bsz = batch["input_ids"].size(0)
        with torch.amp.autocast("cuda", enabled=use_amp):
            out = model(**batch)
            loss, logits = out.loss, out.logits
        total_loss += loss.item() * bsz
        total_count += bsz
        n, d = causal_token_acc(logits, batch["labels"], IGNORE_INDEX)
        tok_n += n
        tok_d += d
        n, d = causal_seq_acc(logits, batch["labels"], IGNORE_INDEX)
        seq_n += n
        seq_d += d

    avg_loss = (total_loss / total_count) if total_count else 0.0
    tok_acc = (tok_n / tok_d) if tok_d else 0.0
    seq_acc = (seq_n / seq_d) if seq_d else 0.0
    return avg_loss, tok_acc, seq_acc


# ===================== Checkpoints & vocab =====================

def save_checkpoint(model, ds, save_dir: str, save_name: str = "model"):
    model_dir = os.path.join(save_dir, save_name)
    os.makedirs(model_dir, exist_ok=True)
    model.save_pretrained(model_dir)
    itos_list = [ds.itos[i] for i in range(len(ds.itos))]
    with open(os.path.join(model_dir, "itos.json"), "w", encoding="utf-8") as f:
        json.dump(itos_list, f, ensure_ascii=False)


def load_vocab(save_dir: str, save_name: str = "model") -> Tuple[Dict[str, int], Dict[int, str]]:
    model_dir = os.path.join(save_dir, save_name)
    itos_json = os.path.join(model_dir, "itos.json")
    if os.path.exists(itos_json):
        with open(itos_json, "r", encoding="utf-8") as f:
            itos_list = json.load(f)
        itos = {i: ch for i, ch in enumerate(itos_list)}
        stoi = {ch: i for i, ch in enumerate(itos_list)}
        return stoi, itos
    # fallback for old checkpoints
    stoi = torch.load(os.path.join(model_dir, "stoi.pt"))
    itos = torch.load(os.path.join(model_dir, "itos.pt"))
    return stoi, itos


def refresh_ds_special_ids(ds):
    for tok in SPECIAL.values():
        if tok not in ds.stoi:
            raise ValueError(f"Special token {repr(tok)} missing from vocab.")
    ds.pad_id = ds.stoi[SPECIAL["PAD"]]
    ds.sep_id = ds.stoi[SPECIAL["SEP"]]
    ds.eos_id = ds.stoi[SPECIAL["EOS"]]
    ds.bos_id = ds.stoi[SPECIAL["BOS"]]


# ===================== Training =====================

def train_gpt2(
    n_embd=128, n_layer=2, n_head=4, batch_size=32, n_epochs=10, lr=3e-4, model='mul',
    save_dir="checkpoints", save_name="saved_model", seed=42, use_amp=True, resume=False,
    ddp=False, local_rank=0, world_size=1, max_steps: int = 30000, pos_bias: bool = True,
    pos_mode: str = "beta", beta_scale: float = 1.0, gradient_checkpointing: bool = False,
    rpe_window: int = 3,
):
    # Interpret "beta_predicate" flag as "use RPE" (backwards compatible)
    use_rpe = (args.beta_predicate == 'on')

    # DDP setup
    if ddp:
        if 'LOCAL_RANK' in os.environ:
            local_rank = int(os.environ['LOCAL_RANK'])
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    set_seeds(seed)

    train_dl, val_dls, train_ds, (val0_ds, val1_ds, val2_ds) = create_dataloaders(
        batch_size, model, ddp=ddp, local_rank=local_rank
    )
    val0_dl, val1_dl = val_dls[0], val_dls[1]
    val2_dl = val_dls[2]

    # compute max sequence length across all splits
    with torch.no_grad():
        max_len = 0
        for s in train_ds.samples + val0_ds.samples + val1_ds.samples + val2_ds.samples:
            max_len = max(max_len, len(s))
    print('Max fused length across splits:', max_len)

    # round context up to multiple of 64
    ctx = int(math.ceil(max_len / 64) * 64)
    model_path = os.path.join(save_dir, save_name)

    # Build relation matrix R once (for this ctx)
    # Default: local window |i-j| ≤ rpe_window
    relation = build_relation_matrix(
        ctx,
        lambda i, j, W=rpe_window: abs(i - j) <= W,
    )

    # Create model or resume
    if resume and os.path.exists(model_path):
        if local_rank == 0:
            print(f"[INFO] Loading pretrained model from {model_path}")
        model_obj = LlamaForCausalLM.from_pretrained(model_path).to(device)

        # If user asks for RPE but the loaded checkpoint wasn't patched, patch now.
        if pos_mode == 'beta':  # "beta" now means "generic RPE" for compatibility
            any_rpe = any(
                isinstance(l.self_attn, LlamaAttentionWithRPE)
                for l in model_obj.model.layers
            )
            if not any_rpe:
                replace_llama_attn_with_rpe(
                    model_obj,
                    relation=relation,
                    lambda_init=beta_scale,
                    use_rpe=use_rpe,
                    use_log_n_scaling=True,
                )
    else:
        # Determine positioning configuration based on mode
        config = LlamaConfig(
            vocab_size=len(train_ds.stoi),
            max_position_embeddings=ctx,
            hidden_size=n_embd,
            num_hidden_layers=n_layer,
            num_attention_heads=n_head,
            intermediate_size=4 * n_embd,
            pad_token_id=train_ds.pad_id,
            eos_token_id=train_ds.eos_id,
            bos_token_id=train_ds.bos_id,
            hidden_dropout_prob=0.0,
            attention_probs_dropout_prob=0.0,
        )

        model_obj = LlamaForCausalLM(config).to(device)

        if pos_mode == 'beta':
            # Use the generic RPE model from rpe_llama (no β-specific encoding)
            replace_llama_attn_with_rpe(
                model_obj,
                relation=relation,
                lambda_init=beta_scale,
                use_rpe=use_rpe,
                use_log_n_scaling=True,
            )
            if local_rank == 0:
                print(f"[INFO] Using generic RPE with λ init {beta_scale}, window {rpe_window}")
        elif pos_mode == 'zero':
            if local_rank == 0:
                print("[INFO] Using neutralized positioning (no extra positional bias)")
            # (If you also neutralize rotary embeddings, do that in your Llama config / model.)
        elif pos_mode == 'none':
            if local_rank == 0:
                print("[INFO] Using vanilla attention (no RPE patch)")

    # Enable gradient checkpointing if requested
    if gradient_checkpointing:
        model_obj.gradient_checkpointing_enable()
        if local_rank == 0:
            print("[INFO] Gradient checkpointing enabled to save memory")

    if ddp:
        model_ddp = DDP(model_obj, device_ids=[local_rank], output_device=local_rank)
        model_ref = model_ddp
        model_for_eval = model_ddp.module
    else:
        model_ref = model_obj
        model_for_eval = model_obj

    # LR and weight decay depending on task
    optimizer = AdamW(
        model_ref.parameters(),
        lr=(1e-4 if model == 'mul' else 5e-5 if model == 'dvd' else 3e-4),
        weight_decay=(0.03 if model == 'mul' else 0.05 if model == 'dvd' else 0.01)
    )
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    except Exception:
        pass

    best_val_avg_seq_acc = 0.0
    best_epoch = -1
    best_stats: Dict[str, float] = {}
    train_losses = []
    val0_losses = []
    global_step = 0
    early_stop_flag = False
    
    # Early stopping parameters
    patience = 50  # Number of epochs to wait for improvement  
    patience_counter = 0
    best_val_loss = float('inf')

    for epoch in tqdm(range(n_epochs), disable=(ddp and local_rank != 0)):
        if ddp:
            train_dl.sampler.set_epoch(epoch)

        model_ref.train()
        run_loss = 0.0
        run_cnt = 0
        tok_n = tok_d = 0
        seq_n = seq_d = 0

        for batch in train_dl:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            bsz = batch["input_ids"].size(0)

            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", enabled=use_amp):
                out = model_ref(**batch)
                loss, logits = out.loss, out.logits
            if ddp:
                loss = loss / world_size
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model_ref.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            global_step += 1
            if global_step >= max_steps:
                if (not ddp) or local_rank == 0:
                    print(f"Early stopping: Reached max_steps {max_steps} at epoch {epoch}, step {global_step}")
                early_stop_flag = True
                break

            run_loss += loss.item() * bsz
            run_cnt += bsz
            n, d = causal_token_acc(logits.detach(), batch["labels"], IGNORE_INDEX)
            tok_n += n
            tok_d += d
            n, d = causal_seq_acc(logits.detach(), batch["labels"], IGNORE_INDEX)
            seq_n += n
            seq_d += d

        train_loss = run_loss / run_cnt if run_cnt else 0.0
        train_tok = (tok_n / tok_d) if tok_d else 0.0
        train_seq = (seq_n / seq_d) if seq_d else 0.0

        val0_loss, val0_tok, val0_seq = evaluate(model_for_eval, val0_dl, device, use_amp=use_amp)
        val1_loss, val1_tok, val1_seq = evaluate(model_for_eval, val1_dl, device, use_amp=use_amp)
        val2_loss, val2_tok, val2_seq = evaluate(model_for_eval, val2_dl, device, use_amp=use_amp)

        if (not ddp) or local_rank == 0:
            train_losses.append(train_loss)
            val0_losses.append(val0_loss)
            
            avg_val_seq_acc = (val0_seq + val1_seq + val2_seq) / 3.0
            
            # Early stopping based on val0 loss (primary validation set)
            if val0_loss < best_val_loss:
                best_val_loss = val0_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if avg_val_seq_acc > best_val_avg_seq_acc:
                best_val_avg_seq_acc = avg_val_seq_acc
                best_epoch = epoch
                best_stats = {
                    'train_loss': train_loss,
                    'train_tok': train_tok,
                    'train_seq': train_seq,
                    'val0_loss': val0_loss,
                    'val0_tok': val0_tok,
                    'val0_seq': val0_seq,
                    'val1_loss': val1_loss,
                    'val1_tok': val1_tok,
                    'val1_seq': val1_seq,
                    'val2_loss': val2_loss,
                    'val2_tok': val2_tok,
                    'val2_seq': val2_seq,
                    'avg_val_seq_acc': avg_val_seq_acc,
                }
                save_checkpoint(model_for_eval, train_ds, save_dir, save_name=save_name)
                print(f"Epoch {epoch} | step {global_step} | *** NEW BEST AVG SEQ ACC: {avg_val_seq_acc:.4f} ***")
                print(f"Train | loss {train_loss:.4f} | tok_acc {train_tok:.4f} | seq_acc {train_seq:.4f}")
                print(
                    "Val0/1/2 | loss "
                    f"{val0_loss:.4f}/{val1_loss:.4f}/{val2_loss:.4f} | "
                    f"seq_acc {val0_seq:.4f}/{val1_seq:.4f}/{val2_seq:.4f}"
                )
                print(f"[INFO] Saved checkpoint to {os.path.join(save_dir, save_name)}")
            else:
                print(
                    f"Epoch {epoch} | step {global_step} | "
                    f"Avg seq acc: {avg_val_seq_acc:.4f} (best: {best_val_avg_seq_acc:.4f})"
                )
                print(f"Train | loss {train_loss:.4f} | tok_acc {train_tok:.4f} | seq_acc {train_seq:.4f}")
                print(
                    "Val0/1/2 | loss "
                    f"{val0_loss:.4f}/{val1_loss:.4f}/{val2_loss:.4f} | "
                    f"seq_acc {val0_seq:.4f}/{val1_seq:.4f}/{val2_seq:.4f}"
                )
                print(
                    "[INFO] No improvement "
                    f"(best avg: {best_val_avg_seq_acc:.4f} at epoch {best_epoch})"
                )

            # Early stopping conditions
            if avg_val_seq_acc >= 1.0:
                print(f"Early stopping: Average validation seq accuracy reached 100% at epoch {epoch}.")
                break
                
            if patience_counter >= patience:
                print(
                    "Early stopping: No improvement in val0_loss "
                    f"for {patience} epochs. Best val0_loss: {best_val_loss:.4f}"
                )
                break

        if early_stop_flag or (global_step >= max_steps):
            break

    # Post-training logging and plotting
    if (not ddp) or local_rank == 0:
        summary_path = os.path.join('logs', f"{model}.txt")
        os.makedirs(os.path.dirname(summary_path), exist_ok=True)

        # Use the stored best stats
        avg_best_seq_acc = best_stats.get('avg_val_seq_acc', best_val_avg_seq_acc)
        avg_best_seq_acc_from_splits = (
            best_stats.get('val0_seq', 0.0)
            + best_stats.get('val1_seq', 0.0)
            + best_stats.get('val2_seq', 0.0)
        ) / 3.0 if best_stats else 0.0

        with open(summary_path, "a") as f:
            f.write(
                f'{n_embd}_{n_layer}_{n_head}_{lr}: '
                f"Best epoch: {best_epoch} | "
                f"Best avg seq_acc: {avg_best_seq_acc:.4f} | "
                f"Train loss: {best_stats.get('train_loss', 0.0):.4f} | "
                f"Train tok_acc: {best_stats.get('train_tok', 0.0):.4f} | "
                f"Train seq_acc: {best_stats.get('train_seq', 0.0):.4f} | "
                f"Val0 loss: {best_stats.get('val0_loss', 0.0):.4f} | "
                f"Val1 loss: {best_stats.get('val1_loss', 0.0):.4f} | "
                f"Val2 loss: {best_stats.get('val2_loss', 0.0):.4f} | "
                f"Val0 seq_acc: {best_stats.get('val0_seq', 0.0):.4f} | "
                f"Val1 seq_acc: {best_stats.get('val1_seq', 0.0):.4f} | "
                f"Val2 seq_acc: {best_stats.get('val2_seq', 0.0):.4f} | "
                f"Avg-best seq_acc (from splits): {avg_best_seq_acc_from_splits:.4f}\n"
            )

        # Loss curves
        plt.figure()
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val0_losses, label='Val0 Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'Training and Validation Loss ({model})')
        plt.legend()
        plt.grid(True)
        plot_path = os.path.join('logs', f"{model}_loss_curve.png")
        os.makedirs(os.path.dirname(plot_path), exist_ok=True)
        plt.savefig(plot_path)
        plt.close()
        print(f"[INFO] Saved loss curve to {plot_path}")

    if ddp:
        dist.destroy_process_group()
    return model_ref, train_ds


# ===================== Inference (greedy) =====================

@torch.no_grad()
def greedy_decode(
    model, ds, input_str: str,
    max_new_tokens: Optional[int] = None,
    stop_on_eos: bool = True,
    min_tokens: int = 1
):
    model.eval()
    device = next(model.parameters()).device
    refresh_ds_special_ids(ds)

    # Check for out-of-vocabulary characters
    unknown = [ch for ch in input_str if ch not in ds.stoi]
    if unknown:
        raise ValueError(f"OOV chars {unknown}. Not in training vocab: {unknown}")

    # Build prefix: $ INPUT |
    prefix = SPECIAL["BOS"] + input_str + SPECIAL["SEP"]
    ids = [ds.stoi[ch] for ch in prefix]
    x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    attn = torch.ones_like(x, device=device)

    max_ctx = getattr(model.config, 'max_position_embeddings', 2048)
    if x.size(1) >= max_ctx:
        raise ValueError(
            f"Input prefix length {x.size(1)} exceeds model capacity max_ctx={max_ctx}."
        )

    if max_new_tokens is None:
        max_new_tokens = max_ctx - x.size(1)

    generated = 0
    for _ in range(max_new_tokens):
        out = model(input_ids=x, attention_mask=attn)
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1)

        if stop_on_eos and generated >= min_tokens and next_id.item() == ds.eos_id:
            break

        x = torch.cat([x, next_id.unsqueeze(-1)], dim=1)
        attn = torch.cat(
            [attn, torch.ones((1, 1), dtype=attn.dtype, device=device)],
            dim=1
        )
        generated += 1

        if x.size(1) >= max_ctx:
            break

    full_decoded = "".join(ds.itos[int(t)] for t in x[0].tolist())
    after_sep = full_decoded.split(SPECIAL["SEP"], 1)[1] if SPECIAL["SEP"] in full_decoded else ""
    pred_output = after_sep.split(SPECIAL["EOS"], 1)[0] if SPECIAL["EOS"] in after_sep else after_sep
    return full_decoded, pred_output


# ===================== Main =====================

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='mul',
                        choices=['mul', 'dvd', 'exp', 'prime', 'gcd', 'add'])
    parser.add_argument('--mode', type=str, default='train',
                        choices=['train', 'inference'])
    parser.add_argument('--save_dir', type=str, default='checkpoints')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--nhead', type=int, default=4)
    parser.add_argument('--nlayer', type=int, default=2)
    parser.add_argument('--nembd', type=int, default=384)  # default for harder tasks
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--no_amp', action='store_true', help='Disable mixed precision')
    parser.add_argument('--resume', action='store_true',
                        help='Continue training from checkpoints/{name}')
    parser.add_argument('--ddp', action='store_true',
                        help='Enable DistributedDataParallel training')
    parser.add_argument('--local_rank', type=int, default=0,
                        help='Local rank for DDP')
    parser.add_argument('--world_size', type=int, default=1,
                        help='World size for DDP')
    parser.add_argument('--max_steps', type=int, default=30000,
                        help='Maximum training steps before stopping')
    # Positional encoding controls
    parser.add_argument('--pos_bias', type=str, default='on',
                        choices=['on', 'off'],
                        help='(Kept for compatibility) Learned NxN positional bias in llm.py')
    parser.add_argument('--pos_mode', type=str, default='beta',
                        choices=['zero', 'beta', 'none'],
                        help='zero (neutralize positioning), '
                             'beta (generic RPE bias), or none (no positional encoding)')
    parser.add_argument('--beta_scale', type=float, default=1.0,
                        help='Initial scale λ for the RPE bias')
    parser.add_argument('--beta_predicate', type=str, default='on',
                        choices=['on', 'off'],
                        help='Add λ·[[R]](i,j) bias (on) or not (off)')
    parser.add_argument('--rpe_window', type=int, default=3,
                        help='Window size W for relation |i-j| ≤ W')
    parser.add_argument('--longer_training', action='store_true',
                        help='Use longer training with more steps for complex problems')
    parser.add_argument('--gradient_checkpointing', action='store_true',
                        help='Enable gradient checkpointing to save memory')
    args = parser.parse_args()
    
    use_amp = not args.no_amp
    set_seeds(args.seed)

    # Decide checkpoint name based on positional mode
    if args.pos_mode == 'none':
        save_name = f"{args.model}_nrpe"   # e.g. mul_nrpe
    else:
        save_name = args.model

    # Define alphabet and oracle for each model
    model_config = {
        'mul': (['1', '0', '/'], 'build_mul_binary_kcm'),
        'dvd': (['1', '0', '/'], 'build_dvd_binary_kcm'),
        'gcd': (['1', '0', '/'], 'build_gcd_binary_kcm'),
        'exp': (['1', '0'], 'build_exp_binary_kcm'),
        'prime': (['1', '0'], 'build_prime_binary_kcm'),
        'add': (['1', '0', '/'], 'build_add_binary_kcm'),
    }
    
    alp, kcm_func_name = model_config.get(args.model, ([], None))
    
    import kcm_binary as kcm
    auto = getattr(kcm, kcm_func_name)() if kcm_func_name else None

    if args.mode == 'train':
        max_steps = args.max_steps
        if args.longer_training:
            max_steps = max(max_steps, 120000)
        model_ref, ds = train_gpt2(
            n_embd=args.nembd, n_layer=args.nlayer, n_head=args.nhead,
            batch_size=args.batch_size, n_epochs=args.epochs, lr=args.lr,
            model=args.model, save_dir=args.save_dir, save_name=save_name,
            seed=args.seed, use_amp=use_amp, resume=args.resume,
            ddp=args.ddp, local_rank=args.local_rank, world_size=args.world_size,
            pos_bias=(args.pos_bias == 'on'), max_steps=max_steps,
            pos_mode=args.pos_mode, beta_scale=args.beta_scale,
            gradient_checkpointing=args.gradient_checkpointing,
            rpe_window=args.rpe_window,
        )
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        model_dir = os.path.join(args.save_dir, save_name)

        model = LlamaForCausalLM.from_pretrained(
            model_dir,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True
        ).to(device)
        model.eval()

        # Use the same save_name as in training
        stoi, itos = load_vocab(args.save_dir, save_name)
        ds = ARLMGptDataset([""], [""], stoi, itos)

        while True:
            try:
                s = input(f"Input string (only {','.join(alp)}). Empty to quit: ").strip()
            except EOFError:
                break
            if not s:
                break
            if not all(ch in alp for ch in s):
                print(f"Invalid input. Only {alp} allowed.")
                continue
            full, yhat = greedy_decode(model, ds, s)
            if auto is not None:
                print("pred:", yhat, "\ngold:", auto.output_generator(s, alp))
            else:
                print("pred:", yhat)
