#!/usr/bin/env python3
# causal_transformer_extended.py
"""
Causal-Transformer toy task (x, y, x′, y′) with:

• Original batch generation (5-token sequence)
• CLI hyper-parameters
• Optional plots (AUC, PCA, attention maps)
• Periodic metric snapshots and checkpoints
• Original per-batch printout
"""

from __future__ import annotations
import argparse, math, os, random, shutil, multiprocessing as mp
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import pickle

# ───────────────────────── Config ───────────────────────── #
@dataclass
class Config:
    # data / task
    N: int = 128
    M: int | None = None
    p_train: float = 0.8
    p_eval:  float = 0.5
    use_bos: bool = False

    # model
    d_model: int = 64
    num_heads: int = 1
    num_layers: int = 3
    first_layer_mlp: bool = False
    dropout: float = 0.0
    use_rms: bool = False
    tie: bool = False
    normalize_embeddings: bool = False
    freeze_kv: bool = False

    # optimisation
    batch_size: int = 128
    lr: float = 2e-3
    weight_decay: float = 1e-2
    num_steps: int = 20_000

    # logging cadence
    plot_interval: int = 250
    save_interval: int = 2000
    num_eval_samples: int = 2000

    # misc
    seed: int = 2025
    num_workers: int = max(1, mp.cpu_count() - 1)
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    output_path: str = "outputs"

    # feature flags
    freeze_embeddings: bool = False
    one_hot: bool = False
    no_norm: bool = False
    plot_attention_maps: bool = False
    plot_classification: bool = False
    plot_pca: bool = False

    # derived (filled in finalise)
    vocab_size: int = 0
    seq_len: int = 0

    def finalise(self):
        if self.M is None:
            self.M = self.N
        self.vocab_size = 1 + self.N + self.M            # <bos> + input + output
        self.seq_len    = 4 if self.use_bos else 3      # without y′
        self.device     = torch.device(self.device)

# ─────────────────────── Utilities ─────────────────────── #
def seed_everything(seed:int):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def create_random_mapping(cfg:Config) -> Dict[int,int]:
    rng = np.random.default_rng(cfg.seed)
    ins  = np.arange(1, cfg.N + 1)
    outs = np.arange(cfg.N + 1, cfg.N + cfg.M + 1)
    return {int(i): int(o) for i, o in zip(ins, rng.permutation(outs))}

def generate_single_sample(rng:np.random.Generator,
                           cfg:Config,
                           f_map:Dict[int,int],
                           *,
                           is_train:bool):
    x, x_ = rng.integers(1, cfg.N + 1, 2)
    p     = cfg.p_train if is_train else cfg.p_eval
    is_true = rng.random() < p
    if is_true:
        y, y_ = f_map[x], f_map[x_]
    else:
        y  = rng.integers(cfg.N + 1, cfg.N + cfg.M + 1)
        y_ = rng.integers(cfg.N + 1, cfg.N + cfg.M + 1)
    seq = [0, x, y, x_, y_] if cfg.use_bos else [x, y, x_, y_]
    return seq, is_true

# ORIGINAL batching ------------------------------------------------- #
def gen_batch(rng:np.random.Generator,
              batch_size:int,
              cfg:Config,
              f_map:Dict[int,int],
              is_train:bool=True):
    seqs, labs = [], []
    for _ in range(batch_size):
        s, l = generate_single_sample(rng, cfg, f_map, is_train=is_train)
        seqs += s; labs.append(1 if l else 0)
    full = np.asarray(seqs,dtype=np.int64).reshape(batch_size,
                                                   cfg.seq_len + 1)    # 5 tokens
    return full, np.asarray(labs,dtype=np.int64)

# def iterate_batches(cfg:Config,
#                     f_map:Dict[int,int],
#                     seed:int = 42,
#                     is_train:bool=True):
#     def worker(q, seed_vec):
#         g = np.random.default_rng(seed_vec)
#         while True:
#             full, lab = gen_batch(g, cfg.batch_size, cfg, f_map, is_train)
#             q.put((full, lab))

#     q = mp.Queue(maxsize=10000)
#     for i in range(cfg.num_workers):
#         mp.Process(target=worker,
#                    args=(q, [seed, i]),
#                    daemon=True).start()
#     while True:
#         full, lab = q.get()
#         yield full[:, :-1], full[:, 1:], lab   # inputs, targets, labels

def iterate_batches(cfg: Config,
                    f_map: Dict[int, int],
                    seed: int = 42,
                    is_train: bool = True):
    """
    Infinite generator that yields (inp, tgt, lab) triples.

    It starts `cfg.num_workers` background processes; each one streams
    synthetic batches into a multiprocessing.Queue.  When the caller
    stops iterating (or the generator object is GC’d), all helper
    processes are terminated and joined, and the queue is closed.
    """
    # ---------- helper that actually makes batches -----------------
    def _worker(q, seed_vec):
        rng = np.random.default_rng(seed_vec)
        while True:
            full, lab = gen_batch(rng, cfg.batch_size, cfg, f_map, is_train)
            q.put((full, lab))

    # ---------- spin up children -----------------------------------
    q      = mp.Queue(maxsize=10000)
    procs  = []
    for i in range(cfg.num_workers):
        p = mp.Process(target=_worker, args=(q, [seed, i]), daemon=True)
        p.start()
        procs.append(p)

    try:
        # -------- main generator loop ------------------------------
        while True:
            full, lab = q.get()                       # blocking get
            # split (bos x y x' y') →  (inp , tgt)
            yield full[:, :-1], full[:, 1:], lab      # numpy arrays
    finally:
        # -------- ensure resources are released --------------------
        q.close()
        for p in procs:
            p.terminate()
            p.join()

# ───────────────────── Model definition ───────────────────── #
class CausalSelfAttention(nn.Module):
    def __init__(self,d:int,h:int):
        super().__init__(); assert d%h==0
        self.h,self.dh = h,d//h
        self.q = nn.Linear(d,d, bias=False); self.k = nn.Linear(d,d, bias=False)
        self.v = nn.Linear(d,d, bias=False); self.o = nn.Linear(d,d, bias=False)
        self.bias = nn.Parameter(torch.zeros(1,1,d))
    def forward(self,x):
        B,T,D = x.shape
        q = self.q(x).view(B,T,self.h,self.dh).transpose(1,2)
        k = self.k(x).view(B,T,self.h,self.dh).transpose(1,2)
        v = self.v(x).view(B,T,self.h,self.dh).transpose(1,2)
        scores = (q @ k.transpose(-2,-1))/math.sqrt(self.dh)
        scores.masked_fill_(torch.triu(torch.ones(T,T,device=x.device,dtype=torch.bool),1),
                            float('-inf'))
        attn = F.softmax(scores,dim=-1); self.attn_weights=attn.detach()
        y = (attn @ v).transpose(1,2).reshape(B,T,D)
        return self.o(y)

class TransformerBlock(nn.Module):
    def __init__(self,d:int,h:int,*,no_attn=False,no_mlp=False,no_norm=False,dropout=0., use_rms=False):
        super().__init__()
        self.attn=None if no_attn else CausalSelfAttention(d,h)
        self.mlp =None if no_mlp  else nn.Sequential(nn.Linear(d,4*d),
                                                     nn.GELU(),nn.Linear(4*d,d))
        
        if not use_rms:
            self.n1 = nn.LayerNorm(d, elementwise_affine=False) if not no_norm else nn.Identity()
            self.n2 = nn.LayerNorm(d, elementwise_affine=False) if not no_norm else nn.Identity()
        else:
            self.n1 = nn.RMSNorm(d,elementwise_affine=False) if not no_norm else nn.Identity()
            self.n2 = nn.RMSNorm(d,elementwise_affine=False) if not no_norm else nn.Identity()
        self.drop=nn.Dropout(dropout)
    def forward(self,x):
        if self.attn is not None: x=self.n1(x+self.drop(self.attn(x)))
        if self.mlp  is not None: x=self.n2(x+self.drop(self.mlp(x)))
        return x

class CausalTransformer(nn.Module):
    def __init__(self,cfg:Config):
        super().__init__(); self.cfg=cfg
        self.tok=nn.Embedding(cfg.vocab_size,cfg.d_model)
        self.pos=nn.Parameter(torch.zeros(1,cfg.seq_len,cfg.d_model))
        self.blocks=nn.ModuleList([
            TransformerBlock(cfg.d_model,cfg.num_heads,
                             no_attn=(cfg.first_layer_mlp and i==0),
                             no_mlp =(not cfg.first_layer_mlp or i>0),
                             no_norm=cfg.no_norm, use_rms=cfg.use_rms) for i in range(cfg.num_layers)])
        self.out=nn.Linear(cfg.d_model,cfg.vocab_size)
        self.apply(lambda m: nn.init.normal_(m.weight,0.,0.02)
                   if isinstance(m,(nn.Linear,nn.Embedding)) else None)
    def forward(self,x,*,return_activations=False):
        h=self.tok(x)+self.pos[:,:x.size(1)]
        acts=[h]
        for blk in self.blocks:
            h=blk(h)
            acts.append(h)
        logits=self.out(h); 
        return (logits,acts) if return_activations else logits

# ────────────── Plot / evaluation helpers ────────────── #
def plot_auc(auc_dict, step_idx=None, output_path="outputs"):
    layers=range(len(next(iter(auc_dict.values()))[0])); plt.figure()
    for tok,(m,s) in auc_dict.items():
        plt.errorbar(layers,m,yerr=s,fmt='-o',capsize=3,label=f"token {tok}")
    plt.xlabel("Layer"); plt.ylabel("AUC"); plt.ylim(0,1); plt.grid(alpha=.3)
    if step_idx: plt.title(f"Linear separability (step {step_idx})")
    plt.xticks(layers); plt.legend(); plt.tight_layout()
    fname=output_path + "/" + f"classification_plots/auc_step{step_idx}.png"
    plt.savefig(fname,dpi=120); plt.clf(); plt.close()

@torch.no_grad()
def eval_linear_probe_auc(model,cfg,f_map,num_samples=2000):
    model.eval(); rng=np.random.default_rng(cfg.seed+1)
    if cfg.use_bos:
        acts={tok:None for tok in (1,2,3)}
    else:
        acts={tok:None for tok in (0,1,2)}
    labels=[]; coll=0
    while coll<num_samples:
        cur=min(cfg.batch_size,num_samples-coll)
        full,is_true = gen_batch(rng,cur,cfg,f_map,False)
        inp=torch.as_tensor(full[:, :-1],device=cfg.device)
        _,h=model(inp,return_activations=True)
        for tok in acts:
            per=[lay[:,tok,:].cpu().numpy() for lay in h]
            acts[tok]=per if acts[tok] is None else [
                np.concatenate([o,n],0) for o,n in zip(acts[tok],per)]
        labels.append(is_true); coll+=cur
    y=np.concatenate(labels); 
    out={}
    for tok,vecs in acts.items():
        m,s=[],[]
        for X in vecs:
            skf=StratifiedKFold(5,shuffle=True,random_state=cfg.seed)
            fold=[roc_auc_score(y[dv],
                                SGDClassifier(loss="log_loss").fit(X[tr],y[tr]).predict_proba(X[dv])[:,1])
                  for tr,dv in skf.split(X,y)]
            m.append(np.mean(fold)); s.append(np.std(fold))
        out[tok]=(m,s)
    return out, (acts, y)

def pca_layers_plot(model, cfg, f_map, *, step_idx: int, output_path="outputs"):
    token_indices = (2, 3) if cfg.use_bos else (1, 2)
    model.eval()
    rng = np.random.default_rng(0)

    acts = {tok: None for tok in token_indices}
    labels = []
    collected = 0
    while collected < cfg.num_eval_samples:
        cur = min(cfg.batch_size, cfg.num_eval_samples - collected)
        full, is_true = gen_batch(rng, cur, cfg, f_map, False)
        inp = torch.as_tensor(full[:, :-1], device=cfg.device)

        # ---- detach to avoid grad-tracking ----
        with torch.no_grad():
            _, hids = model(inp, return_activations=True)

        for tok in token_indices:
            per = [lay[:, tok, :].cpu().numpy() for lay in hids]  # hids already detached
            acts[tok] = per if acts[tok] is None else [
                np.concatenate([o, n], 0) for o, n in zip(acts[tok], per)
            ]
        labels.append(is_true)
        collected += cur
    labels = np.concatenate(labels)

    n_tok = len(token_indices)
    n_lay = len(next(iter(acts.values())))
    n_cols = min(4, n_lay)
    n_rows = n_tok * math.ceil(n_lay / n_cols)
    plt.figure(figsize=(4 * n_cols, 3 * n_rows))
    plt.suptitle(f"PCA coloured by truth (step {step_idx})")

    for r, tok in enumerate(token_indices):
        pcs_layers = [PCA(2).fit_transform(X) for X in acts[tok]]
        for i, pcs in enumerate(pcs_layers):
            row = r * math.ceil(n_lay / n_cols) + (i // n_cols)
            col = i % n_cols
            ax = plt.subplot(n_rows, n_cols, row * n_cols + col + 1)
            idx = rng.choice(pcs.shape[0], 2000, False) \
                  if pcs.shape[0] > 2000 else slice(None)
            ax.scatter(
                pcs[idx, 0], pcs[idx, 1],
                c=labels[idx], s=6, alpha=.7, cmap="coolwarm"
            )
            ax.set_xticks([]); ax.set_yticks([])
            ax.set_title(f"tok {tok}  L{i}", fontsize=8)

    plt.tight_layout(rect=[0, 0.03, 1, 0.96])
    plt.savefig(output_path + "/" + f"pca_plots/pca_step{step_idx}.png", dpi=120)
    plt.clf(); plt.close()

@torch.no_grad()
def compute_attn(model: nn.Module,
                 cfg,
                 f_map: Dict[int, int],
                 num_samples: int
                 ) -> Tuple[np.ndarray, np.ndarray]:
    """
    Collect the self-attention tensors for `num_samples` validation examples
    and return the **median attention** (over examples) for the positive and
    negative classes separately.

    Returns
    -------
    mean_attn_true  : np.ndarray  # shape (L, H, T, T)   – median over label==1
    mean_attn_false : np.ndarray  # shape (L, H, T, T)   – median over label==0
    """
    was_training = model.training
    model.eval()                       # disable dropout / use eval stats
    
    batches = iterate_batches(cfg, f_map, is_train=False)

    attn_true_batches:  List[torch.Tensor] = []   # will hold (L, B_true, H, T, T)
    attn_false_batches: List[torch.Tensor] = []   # will hold (L, B_false, H, T, T)

    processed = 0
    while processed < num_samples:
        # --------------------------------------------------- load & slice batch
        inp, tgt, lab = next(batches)                           # numpy arrays
        remaining = num_samples - processed
        take = min(inp.shape[0], remaining)

        inp = torch.as_tensor(inp[:take],  device=cfg.device)
        lab = torch.as_tensor(lab[:take],  device=cfg.device)

        # --------------------------------------------------- forward pass
        _ = model(inp)   # logits are not needed for attn collection

        # every transformer block should have .attn.attn_weights (L, B, H, T, T)
        attn_tensors = [blk.attn.attn_weights
                        for blk in model.blocks
                        if hasattr(blk, "attn") and blk.attn is not None]

        if not attn_tensors:
            processed += take
            continue          # no attention in this model

        # stack heads from all layers: shape (L, B, H, T, T)
        attn = torch.stack(attn_tensors)

        # sanity: batch lengths must agree
        assert attn.size(1) == lab.size(0), (
            f"Batch mismatch: attn B={attn.size(1)} vs lab={lab.size(0)}")

        # --------------------------------------------------- split by label
        if (lab == 1).any():
            attn_true_batches.append(attn[:, lab == 1].cpu())   # keep on CPU
        if (lab == 0).any():
            attn_false_batches.append(attn[:, lab == 0].cpu())

        processed += take

    # ------------------------------------------------------- aggregate median
    if attn_true_batches:
        full_true = torch.cat(attn_true_batches, dim=1)         # concat on B
        mean_attn_true = full_true.mean(dim=1).detach().cpu().numpy() # (L,H,T,T)
        #mean_attn_true = full_true[:,0,:,:,:].detach().cpu().numpy() # (L,H,T,T)
    else:
        mean_attn_true = np.nan

    if attn_false_batches:
        full_false = torch.cat(attn_false_batches, dim=1)
        mean_attn_false = full_false.mean(dim=1).detach().cpu().numpy()
        #mean_attn_false= full_false[:,0,:,:,:].detach().cpu().numpy()
    else:
        mean_attn_false = np.nan

    return mean_attn_true, mean_attn_false


def plot_attn_map(attn_map:np.ndarray, step_idx:int, title:str):
    plt.figure(figsize=(3*cfg.num_layers,3))
    for i,blk in enumerate(model.blocks):
        plt.subplot(1,cfg.num_layers,i+1)
        plt.imshow(attn_map[0,0])
        lbl = ['bos','x','y',"x'"] if cfg.use_bos else ['x','y',"x'"]
        plt.xticks(range(cfg.seq_len),lbl); plt.yticks(range(cfg.seq_len),lbl)
    plt.tight_layout()
    plt.savefig(f"attention_maps/attn_step{step_idx}.png"); plt.close()
    plt.close()


@torch.no_grad()
def eval_prob_and_loss(model: nn.Module,
                       cfg: Config,
                       f_map: Dict[int, int],
                       num_samples: int = 2_000
                       ) -> Tuple[np.ndarray, np.ndarray]:
    """
    Evaluate the model on dev-set *false* sequences.

    For every sequence with label == 0 we record

    • prob_correct : model probability assigned to the **correct** token
                    f_map(x₂) when predicting y₂
    • nll_obs      : token-level NLL of the **observed** token that is
                    actually present in the sequence at the same position
                    (y₂_pos = x₂_pos + 1)

    Parameters
    ----------
    model        : trained CausalTransformer
    cfg          : experiment configuration
    f_map        : ground-truth mapping dict {x → y}
    num_samples  : how many *false* examples to keep

    Returns
    -------
    probs  : np.ndarray  shape (n_false,)
    losses : np.ndarray  shape (n_false,)
    """
    model.eval()

    # token positions
    bos_shift = 1 if cfg.use_bos else 0
    x2_pos    = bos_shift + 2          # (bos,) x₁ y₁ x₂ (y₂)
    y2_pos    = x2_pos + 1

    batches        = iterate_batches(cfg, f_map, is_train=False)
    prob_recorded  = []
    loss_recorded  = []

    while len(prob_recorded) < num_samples:
        inp, tgt, lab = next(batches)                # numpy arrays
        if (lab == 0).sum() == 0:                    # no *false* examples
            continue

        inp_t   = torch.as_tensor(inp, device=cfg.device)
        logits  = model(inp_t)                       # (B, T, V)
        logp    = torch.log_softmax(logits, dim=-1)  # log-probabilities

        for i, lbl in enumerate(lab):
            if lbl == 1:                             # skip *true* sequences
                continue

            # ---- probability of the correct token -------------------------
            x2_token       = int(inp[i, x2_pos])
            correct_y2_id  = f_map[x2_token]
            p_correct      = torch.exp(logp[i, x2_pos, correct_y2_id]).item()
            prob_recorded.append(p_correct)

            # ---- NLL of the *observed* token ------------------------------
            observed_y2_id = int(tgt[i, x2_pos])     # what is actually there
            nll_obs        = -logp[i, x2_pos, observed_y2_id].item()
            loss_recorded.append(nll_obs)

            if len(prob_recorded) >= num_samples:
                break

    return np.asarray(prob_recorded), np.asarray(loss_recorded)


@torch.no_grad()
def eval_prob(model: nn.Module,
              cfg: Config,
              f_map: Dict[int, int],
              num_samples: int = 2_000) -> np.ndarray:
    """
    Iterate over dev-set samples, keep only *false* sequences (label == 0),
    and record the probability the model assigns to the **correct** token
    f_map(x₂) when predicting y₂ (the 4-th generated token).

    Parameters
    ----------
    model        : the trained CausalTransformer
    cfg          : experiment configuration
    f_map        : the ground-truth mapping dict {x → y}
    num_samples  : how many *false* examples to evaluate

    Returns
    -------
    probs : np.ndarray                # shape (n_false,)
        One entry per retained example with the model’s probability for
        the correct y₂ token.
    """
    model.eval()

    # index of the x₂ / y₂ position in the sequence
    bos_shift = 1 if cfg.use_bos else 0
    x2_pos    = bos_shift + 2   # (bos,) x₁ y₁ x₂ (y₂)

    batches   = iterate_batches(cfg, f_map, is_train=False)
    recorded  = []

    while len(recorded) < num_samples:
        inp, tgt, lab = next(batches)          # numpy arrays
        if (lab == 0).sum() == 0:
            continue                           # this batch has no false cases

        inp_t = torch.as_tensor(inp, device=cfg.device)
        logits = model(inp_t)                  # (B, T, V)
        probs  = torch.softmax(logits, dim=-1)

        for i, lbl in enumerate(lab):
            if lbl == 1:                       # skip *true* sequences
                continue
            x2_token      = int(inp[i, x2_pos])
            correct_y2_id = f_map[x2_token]
            prob_correct  = probs[i, x2_pos, correct_y2_id].item()
            #observed_y2_id = int(inp[i, x2_pos + 1])  # what is actually there
            #nll_loss = -math.log(probs[i, x2_pos, observed_y2_id].item())
            recorded.append(prob_correct)
            if len(recorded) >= num_samples:
                break

    return np.asarray(recorded)


# ──────────────────── Training loop ──────────────────── #
def train(cfg:Config):
    seed_everything(cfg.seed)
    f_map=create_random_mapping(cfg)
    model=CausalTransformer(cfg).to(cfg.device)
    print(model)
    if cfg.one_hot:
        model.tok.weight.data.zero_()
        model.tok.weight.data[torch.arange(cfg.vocab_size), torch.arange(cfg.vocab_size)] = 1
        model.pos.data.zero_()
        model.pos.data[0, torch.arange(cfg.seq_len), cfg.vocab_size + torch.arange(cfg.seq_len)] = 1
        model.out.weight.data.zero_()
        model.out.weight.data[torch.arange(cfg.vocab_size), torch.arange(cfg.vocab_size)] = 1
    if cfg.freeze_embeddings:
        model.tok.requires_grad_(False)
        model.pos.requires_grad_(False)
        model.out.requires_grad_(False)
    
    if cfg.freeze_kv:
        model.blocks[0].attn.q.weight.data.zero_()
        model.blocks[0].attn.q.weight.requires_grad_(False)
        # normalize the vectors of the tok and out layers
    if cfg.normalize_embeddings:
            with torch.no_grad():
                model.tok.weight /= model.tok.weight.norm(dim=-1, keepdim=True)
                model.out.weight /= model.out.weight.norm(dim=-1, keepdim=True)
    if cfg.tie:
        model.tok.weight = model.out.weight      

    opt=torch.optim.AdamW(model.parameters(),lr=cfg.lr,weight_decay=cfg.weight_decay)
    crit=nn.CrossEntropyLoss(reduction='none')

    os.makedirs(cfg.output_path, exist_ok=True)
    for d in ['classification_plots','pca_plots','attention_maps','metrics','checkpoints']:
        if os.path.isdir(cfg.output_path + "/" + d): shutil.rmtree(cfg.output_path + "/" + d)
        os.makedirs(cfg.output_path + "/" + d)

    batches=iterate_batches(cfg,f_map,is_train=True)
    aucs = {}
    probs_true_token = {}
    loss_false = {}

    for step in range(cfg.num_steps):
        inp,tgt,lab=next(batches)
        inp=torch.as_tensor(inp,device=cfg.device)
        tgt=torch.as_tensor(tgt,device=cfg.device)
        opt.zero_grad(); logits=model(inp).view(-1,cfg.vocab_size)
        loss_tok=crit(logits,tgt.view(-1)); loss=loss_tok.mean(); loss.backward()
        #torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        opt.step()
        if cfg.normalize_embeddings:
            with torch.no_grad():
                model.tok.weight /= model.tok.weight.norm(dim=-1, keepdim=True)
                model.out.weight /= model.out.weight.norm(dim=-1, keepdim=True)


        # console print
        if step%100==0:
            bs=inp.size(0); bos=1 if cfg.use_bos else 0
            with torch.no_grad():
                lt=loss_tok.view(bs,-1)
                l1=lt[lab==1,bos].mean() if (lab==1).any() else torch.nan
                l2=lt[lab==1,bos+2].mean() if (lab==1).any() else torch.nan
                probs=torch.softmax(logits.view(bs,-1,cfg.vocab_size),dim=-1)
                tgt_r=tgt.view(bs,-1)
                p1=probs[lab==1,bos,tgt_r[lab==1,bos]].mean() if (lab==1).any() else torch.nan
                p2=probs[lab==1,bos+2,tgt_r[lab==1,bos+2]].mean() if (lab==1).any() else torch.nan
            print(
                f"Batch {step}, Loss: {loss.item():.4f}, "
                f"First Pred Loss: {l1.item():.4f}, "
                f"Last Pred Loss: {l2.item():.4f}, "
                f"First Pred Prob: {p1.item():.4f}, "
                f"Last Pred Prob: {p2.item():.4f}"
            )

        # plots
        if step % cfg.plot_interval == 0:

            probs_true_token_on_false, loss_on_false = eval_prob_and_loss(model, cfg, f_map, 300) #eval_prob(model, cfg, f_map, 300)
            probs_true_token[step] = np.mean(probs_true_token_on_false)
            loss_false[step] = np.mean(loss_on_false)

            if cfg.plot_classification:
                auc,_ = eval_linear_probe_auc(model,cfg,f_map,cfg.num_eval_samples)
                aucs[step] = auc
                plot_auc(auc,step_idx=step, output_path=cfg.output_path)
            if cfg.plot_pca:
                pca_layers_plot(model,cfg,f_map,step_idx=step, output_path=cfg.output_path)
            if cfg.plot_attention_maps:
                plt.figure(figsize=(3*cfg.num_layers,3))
                for i,blk in enumerate(model.blocks):
                    plt.subplot(1,cfg.num_layers,i+1)
                    plt.imshow(blk.attn.attn_weights[0,0].cpu())
                    lbl = ['bos','x','y',"x'"] if cfg.use_bos else ['x','y',"x'"]
                    plt.xticks(range(cfg.seq_len),lbl); plt.yticks(range(cfg.seq_len),lbl)
                #plt.colorbar()
                plt.tight_layout()
                

                plt.savefig(cfg.output_path + "/" + f"attention_maps/attn_step{step}.png"); plt.close()
                plt.clf(); plt.close()

                mean_attn_true, mean_attn_false = compute_attn(model, cfg, f_map, 100)
                plt.figure(figsize=(3*cfg.num_layers,3))
                for i,blk in enumerate(model.blocks):
                    plt.subplot(1,cfg.num_layers,i+1)
                    plt.imshow(mean_attn_true[i,0])
                    lbl = ['bos','x','y',"x'"] if cfg.use_bos else ['x','y',"x'"]
                    plt.xticks(range(cfg.seq_len),lbl); plt.yticks(range(cfg.seq_len),lbl)
                #plt.colorbar()
                plt.tight_layout()
                
                plt.savefig(cfg.output_path + "/" + f"attention_maps/attn_step{step}_true.png"); plt.close()
                plt.clf(); plt.close()


                plt.figure(figsize=(3*cfg.num_layers,3))
                for i,blk in enumerate(model.blocks):
                    plt.subplot(1,cfg.num_layers,i+1)
                    plt.imshow(mean_attn_false[i,0])
                    lbl = ['bos','x','y',"x'"] if cfg.use_bos else ['x','y',"x'"]
                    plt.xticks(range(cfg.seq_len),lbl); plt.yticks(range(cfg.seq_len),lbl)
                #plt.colorbar()
                plt.tight_layout()
                
                plt.savefig(cfg.output_path + "/" + f"attention_maps/attn_step{step}_false.png"); plt.close()              
                plt.clf(); plt.close()

        # checkpoint + metrics
        if step % cfg.save_interval == 0 :
            torch.save(
                {
                    'step': step,
                    'state': {k: v.cpu() for k, v in model.state_dict().items()},
                    'f_map': f_map,
                    'cfg': asdict(cfg),
                },
                cfg.output_path + "/" + f"checkpoints/ckpt_step{step}.pt"
            )

    # save metrics (aucs) as pickle
    with open(cfg.output_path + "/" + "metrics/aucs.pkl", "wb") as f:
            pickle.dump(aucs, f)
    with open(cfg.output_path + "/" + "metrics/probs_true_token.pkl", "wb") as f:
            pickle.dump(probs_true_token, f)
    with open(cfg.output_path + "/" + "metrics/loss_on_false.pkl", "wb") as f:
            pickle.dump(loss_false, f)
    # final save
    torch.save(
        {
            'step': cfg.num_steps,
            'state': {k: v.cpu() for k, v in model.state_dict().items()},
            'f_map': f_map,
            'cfg': asdict(cfg),
        },
        cfg.output_path + "/" + "checkpoints/final.pt"
    )

# ───────────────────────── CLI ───────────────────────── #
def build_parser():
    p=argparse.ArgumentParser()
    p.add_argument('--N',type=int,default=512)
    p.add_argument('--M',type=int,default=None)
    p.add_argument('--weight_decay',type=float,default=1e-5)
    p.add_argument('--p_train',type=float,default=0.95)
    p.add_argument('--p_eval',type=float,default=0.5)
    p.add_argument('--d_model',type=int,default=128)
    p.add_argument('--num_heads',type=int,default=1)
    p.add_argument('--num_layers',type=int,default=2)
    p.add_argument('--batch_size',type=int,default=128)
    p.add_argument('--lr',type=float,default=1e-3)
    p.add_argument('--first_layer_mlp',action='store_true')
    p.add_argument('--num_steps',type=int,default=20000)
    p.add_argument('--plot_interval',type=int,default=250)
    p.add_argument('--save_interval',type=int,default=2000)
    p.add_argument('--plot_classification',action='store_true')
    p.add_argument('--plot_pca',action='store_true')
    p.add_argument('--plot_attention_maps',action='store_true')
    p.add_argument('--freeze_embeddings',action='store_true')
    p.add_argument('--one_hot',action='store_true')
    p.add_argument('--no_norm',action='store_true')
    p.add_argument('--use_rms',action='store_true')
    p.add_argument('--output_path', type=str, default='outputs')
    p.add_argument('--seed', type=int, default=1)
    p.add_argument('--use_bos', action='store_true', default=False)
    p.add_argument('--normalize_embeddings', action='store_true')
    p.add_argument('--tie', action='store_true')
    p.add_argument('--freeze_kv', action='store_true')
    #p.set_defaults(use_bos=False)
    return p

# ───────────────────────── main ───────────────────────── #
if __name__ == "__main__":
    cfg = Config(**vars(build_parser().parse_args()))
    cfg.finalise()
    train(cfg)