# eval_delta_lm_loss.py
# -*- coding: utf-8 -*-
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import json
import argparse
import warnings
import inspect
from typing import Iterable, List, Dict, Tuple, Optional, Any

import torch as t
import torch.nn.functional as F

# IMPORTANT: use transformers for Dream official diffusion_generate API
# from transformers import AutoModel, AutoTokenizer
from modelscope import AutoModel, AutoTokenizer
from tqdm import tqdm
from modelscope.msdatasets import MsDataset

from dictionary_learning.dictionary_learning import utils


############################################
# General helpers
############################################

def is_dream_like(model) -> bool:
    """
    Heuristic check: model class name includes 'dream' or 'diffusion',
    or the model exposes diffusion_generate().
    """
    name = model.__class__.__name__.lower()
    return ("dream" in name) or ("diffusion" in name) or hasattr(model, "diffusion_generate")


def iter_texts(dataset_name: str, split: str = "train") -> Iterable[str]:
    """
    Stream raw text strings from a HF dataset using our utils helper.
    Assumes dataset only has a 'train' split (or we always read 'train').
    """
    return utils.hf_dataset_to_generator(dataset_name, split=split, streaming=True)


def iter_instructions(dataset_name: str, split: str = "train", streaming: bool = True) -> Iterable[Dict[str, Any]]:
    """
    Stream instruction examples from a HF dataset. Alpaca-like expected fields:
      - instruction, input, output (or prompt/query/context variants)

    We only need instruction+input for prompting.
    """
    ds = MsDataset.load(dataset_name, split=split, use_streaming=streaming)
    for ex in ds:
        yield ex


def format_alpaca_prompt(ex: Dict[str, Any]) -> str:
    """
    Alpaca-like prompt formatting (raw user content).
    """
    instr = (ex.get("instruction") or ex.get("prompt") or ex.get("query") or ex.get("question") or "").strip()
    inp = (ex.get("input") or ex.get("context") or ex.get("inputs") or "").strip()

    if inp:
        return f"### Instruction:\n{instr}\n\n### Input:\n{inp}\n\n### Response:\n"
    else:
        return f"### Instruction:\n{instr}\n\n### Response:\n"


def build_chat_inputs_if_available(tokenizer, user_content: str, device: t.device) -> Dict[str, t.Tensor]:
    """
    Dream official: tokenizer.apply_chat_template(..., return_tensors="pt", return_dict=True, add_generation_prompt=True)
    Fallback: plain tokenizer(user_content, return_tensors="pt").
    Returns a dict containing at least: input_ids, attention_mask (if available).
    """
    if hasattr(tokenizer, "apply_chat_template"):
        try:
            messages = [{"role": "user", "content": user_content}]
            inputs = tokenizer.apply_chat_template(
                messages,
                return_tensors="pt",
                return_dict=True,
                add_generation_prompt=True,
            )
            # Some tokenizers return BatchEncoding with attrs
            input_ids = inputs.input_ids.to(device=device)
            attention_mask = inputs.attention_mask.to(device=device) if hasattr(inputs, "attention_mask") else None
            out = {"input_ids": input_ids}
            if attention_mask is not None:
                out["attention_mask"] = attention_mask
            return out
        except Exception:
            pass

    enc = tokenizer(user_content, return_tensors="pt", add_special_tokens=True)
    enc = {k: v.to(device) for k, v in enc.items()}
    return enc


@t.no_grad()
def generate_rollout_ids_official_dream(
    model,
    tokenizer,
    user_content: str,
    device: t.device,
    *,
    max_new_tokens: int = 256,
    steps: int = 512,
    temperature: float = 0.0,
    top_p: float = 1.0,
    alg: str = "entropy",
    alg_temp: float = 0.0,
) -> Tuple[t.Tensor, t.Tensor, t.Tensor]:
    """
    STRICT Dream official rollout:
      - Build chat-formatted (input_ids, attention_mask) via apply_chat_template(add_generation_prompt=True)
      - Call model.diffusion_generate(input_ids, attention_mask, ...)
      - Return:
          prompt_input_ids: (1, T_prompt)
          prompt_attention_mask: (1, T_prompt) long/bool (if missing, we create ones)
          full_ids: (1, T_total) = output.sequences (includes prompt + generated)
    """
    inputs = build_chat_inputs_if_available(tokenizer, user_content=user_content, device=device)
    prompt_input_ids = inputs["input_ids"]
    prompt_attention_mask = inputs.get("attention_mask", None)
    if prompt_attention_mask is None:
        prompt_attention_mask = t.ones_like(prompt_input_ids, dtype=t.long, device=device)

    # Dream official call (expects ids + attention_mask)
    if not hasattr(model, "diffusion_generate"):
        raise RuntimeError("Model has no diffusion_generate(). Make sure you loaded Dream via transformers + trust_remote_code=True.")

    output = model.diffusion_generate(
        prompt_input_ids,
        attention_mask=prompt_attention_mask,
        max_new_tokens=max_new_tokens,
        output_history=True,
        return_dict_in_generate=True,
        steps=steps,
        temperature=float(temperature),
        top_p=float(top_p),
        alg=str(alg),
        alg_temp=float(alg_temp),
    )

    seq = getattr(output, "sequences", None)
    if seq is None:
        seq = output.get("sequences", None) if isinstance(output, dict) else None
    if seq is None:
        raise RuntimeError("diffusion_generate did not return sequences.")

    # seq is usually (B, T_total). Here we run B=1.
    if isinstance(seq, t.Tensor):
        full_ids = seq.to(device=device)
    elif isinstance(seq, list) and len(seq) > 0 and isinstance(seq[0], t.Tensor):
        full_ids = seq[0].unsqueeze(0).to(device=device)
    else:
        raise RuntimeError(f"Unsupported sequences type: {type(seq)}")

    return prompt_input_ids, prompt_attention_mask, full_ids


def build_section_mask(full_ids: t.Tensor, prompt_len: int, section: str) -> t.Tensor:
    """
    full_ids: (1, T_total)
    prompt_len: length of prompt tokenization (int)
    section: 'prompt' or 'rollout'
    returns (1, T_total) bool
    """
    T = full_ids.shape[1]
    sel = t.zeros((1, T), dtype=t.bool, device=full_ids.device)
    if section == "prompt":
        sel[:, :prompt_len] = True
    elif section == "rollout":
        sel[:, prompt_len:] = True
    else:
        raise ValueError("section must be 'prompt' or 'rollout'")
    return sel


def _pick_best_tensor(cands: List[t.Tensor]) -> t.Tensor:
    """
    Choose a likely-logits tensor from candidate tensors:
      - Prefer rank-3 tensors (B, T, V)
      - Else prefer rank-2 tensors (T, V)
      - Else first tensor.
    """
    rank3 = [x for x in cands if isinstance(x, t.Tensor) and x.ndim == 3]
    if len(rank3) > 0:
        return rank3[0]
    rank2 = [x for x in cands if isinstance(x, t.Tensor) and x.ndim == 2]
    if len(rank2) > 0:
        return rank2[0]
    for x in cands:
        if isinstance(x, t.Tensor):
            return x
    raise RuntimeError("No tensor candidate found among provided candidates.")


def get_logits(outputs: Any) -> t.Tensor:
    """
    Robustly extract a tensor of logits (or equivalent) from model outputs.
    """
    if isinstance(outputs, dict):
        if "logits" in outputs and isinstance(outputs["logits"], t.Tensor):
            return outputs["logits"]
        if "last_hidden_state" in outputs and isinstance(outputs["last_hidden_state"], t.Tensor):
            return outputs["last_hidden_state"]
        cands = [v for v in outputs.values() if isinstance(v, t.Tensor)]
        if len(cands) > 0:
            return _pick_best_tensor(cands)
        nested = []
        for v in outputs.values():
            if isinstance(v, (tuple, list)):
                nested.extend([x for x in v if isinstance(x, t.Tensor)])
        if len(nested) > 0:
            return _pick_best_tensor(nested)

    if hasattr(outputs, "logits") and isinstance(outputs.logits, t.Tensor):
        return outputs.logits
    if hasattr(outputs, "last_hidden_state") and isinstance(outputs.last_hidden_state, t.Tensor):
        return outputs.last_hidden_state

    if isinstance(outputs, (tuple, list)):
        cands = [x for x in outputs if isinstance(x, t.Tensor)]
        if len(cands) > 0:
            return _pick_best_tensor(cands)

    raise RuntimeError(f"Cannot find logits in model outputs of type: {type(outputs)}")


############################################
# Helpers to deal with tuple-ish activations in hooks
############################################

def _first_tensor(obj: Any) -> Optional[t.Tensor]:
    if isinstance(obj, t.Tensor):
        return obj

    if isinstance(obj, (list, tuple)):
        for item in obj:
            found = _first_tensor(item)
            if found is not None:
                return found

    if isinstance(obj, dict):
        preferred_keys = (
            "hidden_states",
            "last_hidden_state",
            "x",
            "resid",
            "resid_post",
            "out",
            "activations",
            "act",
            "states",
            "x_hat",
            "reconstruction",
            "decoded",
        )
        for k in preferred_keys:
            if k in obj and isinstance(obj[k], t.Tensor):
                return obj[k]
        for v in obj.values():
            found = _first_tensor(v)
            if found is not None:
                return found

    if hasattr(obj, "__dict__"):
        return _first_tensor(vars(obj))

    return None


def _replace_first_tensor(structure: Any, new_tensor: t.Tensor) -> Any:
    if isinstance(structure, t.Tensor):
        return new_tensor

    if isinstance(structure, list):
        new_items = []
        replaced = False
        for item in structure:
            if not replaced and _first_tensor(item) is not None:
                new_items.append(_replace_first_tensor(item, new_tensor))
                replaced = True
            else:
                new_items.append(item)
        return type(structure)(new_items)

    if isinstance(structure, tuple):
        new_items = []
        replaced = False
        for item in structure:
            if not replaced and _first_tensor(item) is not None:
                new_items.append(_replace_first_tensor(item, new_tensor))
                replaced = True
            else:
                new_items.append(item)
        return type(structure)(new_items)

    if isinstance(structure, dict):
        new_dict = {}
        replaced = False
        for k, v in structure.items():
            if (not replaced) and _first_tensor(v) is not None:
                new_dict[k] = _replace_first_tensor(v, new_tensor)
                replaced = True
            else:
                new_dict[k] = v
        return new_dict

    if hasattr(structure, "__dict__"):
        attrs = vars(structure).copy()
        replaced = False
        for k, v in attrs.items():
            if (not replaced) and _first_tensor(v) is not None:
                attrs[k] = _replace_first_tensor(v, new_tensor)
                replaced = True
        return attrs

    return new_tensor


############################################
# Device / dtype alignment helpers
############################################

def _move_module_like_tensor(module: t.nn.Module, ref_like: Any):
    ref_tensor = _first_tensor(ref_like)
    if ref_tensor is None:
        module.eval()
        return

    dev = ref_tensor.device
    dt = ref_tensor.dtype

    if getattr(module, "_dlm_cached_device", None) != dev or getattr(module, "_dlm_cached_dtype", None) != dt:
        module.to(device=dev, dtype=dt)
        module._dlm_cached_device = dev
        module._dlm_cached_dtype = dt
    module.eval()


def reconstruct_with_dictionary(dictionary: t.nn.Module, x: t.Tensor) -> t.Tensor:
    with t.no_grad():
        _move_module_like_tensor(dictionary, x)

        if hasattr(dictionary, "encode") and hasattr(dictionary, "decode"):
            try:
                f = dictionary.encode(x)
                x_hat = dictionary.decode(f)
                return x_hat.to(dtype=x.dtype, device=x.device)
            except Exception:
                pass

        try:
            out = dictionary(x)
            if isinstance(out, t.Tensor):
                return out.to(dtype=x.dtype, device=x.device)

            if isinstance(out, (list, tuple)):
                same = [o for o in out if isinstance(o, t.Tensor) and o.shape == x.shape]
                if len(same) > 0:
                    return same[0].to(dtype=x.dtype, device=x.device)
                tensors = [o for o in out if isinstance(o, t.Tensor)]
                if len(tensors) > 0:
                    return tensors[0].to(dtype=x.dtype, device=x.device)

            if isinstance(out, dict):
                for k in ("x_hat", "reconstruction", "act_hat", "decoded"):
                    if k in out and isinstance(out[k], t.Tensor):
                        return out[k].to(dtype=x.dtype, device=x.device)

                for v in out.values():
                    if isinstance(v, t.Tensor) and v.shape == x.shape:
                        return v.to(dtype=x.dtype, device=x.device)

                for v in out.values():
                    if isinstance(v, t.Tensor):
                        return v.to(dtype=x.dtype, device=x.device)
        except Exception:
            pass

        try:
            if hasattr(dictionary, "encode"):
                f = dictionary.encode(x)
                if hasattr(dictionary, "decode"):
                    x_hat = dictionary.decode(f)
                    return x_hat.to(dtype=x.dtype, device=x.device)
        except Exception:
            pass

        if not hasattr(reconstruct_with_dictionary, "_warned"):
            print(
                "[Warn] SAE reconstruction fell back to identity; "
                "check dictionary.forward/encode/decode APIs.",
                flush=True,
            )
            reconstruct_with_dictionary._warned = True
        return x


def register_sae_splice_hook(
    submodule: t.nn.Module,
    dictionary: t.nn.Module,
    section_sel: t.Tensor,
    io: str = "out",
):
    """
    Register a hook that splices SAE reconstruction ONLY on positions in section_sel.

    section_sel: (B,T) bool mask
    We assume activations are (B,T,C) or (T,C). If activation isn't tokenwise, we no-op.
    """

    def _splice_tensor(act: t.Tensor) -> t.Tensor:
        if act.ndim == 3:
            B, T, _ = act.shape
            if section_sel.shape[0] != B or section_sel.shape[1] != T:
                return act
            sel = section_sel.bool()
            try:
                flat = act[sel]  # (N,C)
                if flat.numel() == 0:
                    return act
                flat_hat = reconstruct_with_dictionary(dictionary, flat)
                out = act.clone()
                out[sel] = flat_hat
                return out
            except Exception:
                act_hat = reconstruct_with_dictionary(dictionary, act)
                out = act.clone()
                out[sel] = act_hat[sel]
                return out

        if act.ndim == 2:
            T, _ = act.shape
            if section_sel.shape[0] != 1 or section_sel.shape[1] != T:
                return act
            sel = section_sel[0].bool()
            try:
                flat = act[sel]
                if flat.numel() == 0:
                    return act
                flat_hat = reconstruct_with_dictionary(dictionary, flat)
                out = act.clone()
                out[sel] = flat_hat
                return out
            except Exception:
                act_hat = reconstruct_with_dictionary(dictionary, act)
                out = act.clone()
                out[sel] = act_hat[sel]
                return out

        return act

    if io == "out":
        def _hook(_, __, output):
            act = _first_tensor(output)
            if act is None:
                return output
            act_hat = _splice_tensor(act)
            new_output = _replace_first_tensor(output, act_hat)
            return new_output

        return submodule.register_forward_hook(_hook)

    elif io == "in":
        def _pre_hook(_, inputs):
            if len(inputs) == 0:
                return inputs
            x0 = inputs[0]
            act = _first_tensor(x0)
            if act is None:
                return inputs
            act_hat = _splice_tensor(act)
            new_x0 = _replace_first_tensor(x0, act_hat)
            return (new_x0,) + tuple(inputs[1:])

        return submodule.register_forward_pre_hook(_pre_hook)

    else:
        raise ValueError("io must be 'in' or 'out'")


############################################
# Tokenization / masking / timestep helpers
############################################

def build_random_mask(
    input_ids: t.Tensor,
    attention_mask: t.Tensor,
    mask_prob: float,
    exclude_first_token: bool = True,
    allowed_mask: Optional[t.Tensor] = None,
) -> t.Tensor:
    B, T = input_ids.shape
    cand = attention_mask.bool().clone()
    if exclude_first_token and T > 0:
        cand[:, 0] = False
    if allowed_mask is not None:
        cand = cand & allowed_mask.bool()

    rand = t.rand((B, T), device=input_ids.device, dtype=t.float32)
    m = (rand < mask_prob) & cand
    return m


def _make_additive_float_mask_from_1d(attn_1d: t.Tensor, dtype: t.dtype) -> t.Tensor:
    am = attn_1d.to(dtype)
    am4 = am[:, None, None, :]
    minus_inf_val = t.finfo(am4.dtype).min
    add_mask = t.where(am4 > 0, t.zeros_like(am4), t.full_like(am4, minus_inf_val))
    return add_mask


def _safe_forward_with_masks(model, inputs: Dict[str, t.Tensor], prefer_additive: bool = True):
    try:
        param_dtype = next(model.parameters()).dtype
    except Exception:
        param_dtype = t.float32

    if prefer_additive and "attention_mask" in inputs:
        try:
            inp = dict(inputs)
            add_mask = _make_additive_float_mask_from_1d(inp["attention_mask"], dtype=param_dtype)
            add_mask = add_mask.to(device=inp["attention_mask"].device)
            inp["attention_mask"] = add_mask
            return model(**inp)
        except Exception:
            pass

    try:
        inp = dict(inputs)
        if "attention_mask" in inp:
            inp["attention_mask"] = inp["attention_mask"].to(t.bool)
        return model(**inp)
    except Exception:
        pass

    try:
        inp = dict(inputs)
        inp.pop("attention_mask", None)
        return model(**inp)
    except Exception:
        pass

    inp = dict(inputs)
    if "attention_mask" in inp:
        add_mask = _make_additive_float_mask_from_1d(inp["attention_mask"], dtype=param_dtype)
        add_mask = add_mask.to(device=inp["attention_mask"].device)
        inp["attention_mask"] = add_mask
    return model(**inp)


def _maybe_add_time_condition(inputs: Dict[str, t.Tensor], p_scalar: float, model) -> Dict[str, t.Tensor]:
    try:
        sig = inspect.signature(model.forward)
        param_names = set(sig.parameters.keys())
    except Exception:
        return inputs

    T = None
    cfg = getattr(model, "config", None)
    for cand in ("num_diffusion_steps", "diffusion_steps", "n_timesteps", "timesteps", "T"):
        if hasattr(cfg, cand) and isinstance(getattr(cfg, cand), int):
            T = int(getattr(cfg, cand))
            break
    if T is None:
        T = 1000
    step = int(max(0, min(T - 1, round(float(p_scalar) * (T - 1)))))

    B = inputs["input_ids"].shape[0]
    dev = inputs["input_ids"].device

    def add(name: str, tensor: t.Tensor):
        new_inputs = dict(inputs)
        new_inputs[name] = tensor
        return new_inputs

    if "t" in param_names:
        return add("t", t.full((B,), float(p_scalar), dtype=t.float32, device=dev))
    if "time" in param_names:
        return add("time", t.full((B,), float(p_scalar), dtype=t.float32, device=dev))
    if "noise_level" in param_names:
        return add("noise_level", t.full((B,), float(p_scalar), dtype=t.float32, device=dev))
    if "sigma" in param_names:
        return add("sigma", t.full((B,), float(p_scalar), dtype=t.float32, device=dev))

    if "timestep" in param_names:
        return add("timestep", t.full((B,), step, dtype=t.long, device=dev))
    if "timesteps" in param_names:
        return add("timesteps", t.full((B,), step, dtype=t.long, device=dev))
    if "diffusion_step" in param_names:
        return add("diffusion_step", t.full((B,), step, dtype=t.long, device=dev))
    if "time_ids" in param_names:
        return add("time_ids", t.full((B,), step, dtype=t.long, device=dev))

    return inputs


def ce_sum_with_mask(
    logits: t.Tensor,
    labels: t.Tensor,
    mask: t.Tensor,
    weight_scalar: Optional[float] = None,
) -> Tuple[t.Tensor, int]:
    n = int(mask.sum().item())
    if n == 0:
        return t.zeros((), device=labels.device, dtype=t.float32), 0

    if not isinstance(logits, t.Tensor):
        raise RuntimeError(f"ce_sum_with_mask expected logits as Tensor but got {type(logits)}")

    per_tok_loss = F.cross_entropy(logits[mask], labels[mask], reduction="none")  # (n,)
    if weight_scalar is not None:
        per_tok_loss = per_tok_loss * float(weight_scalar)
    loss_sum = per_tok_loss.sum()
    return loss_sum, n


############################################
# Main loss eval helper (mask-only, instruction section)
############################################

@t.no_grad()
def dlm_mask_only_losses_instruction(
    model,
    tokenizer,
    submodule,
    dictionary,
    full_input_ids: t.Tensor,         # (1,T)
    section_sel: t.Tensor,            # (1,T) bool
    mask_token_id: int,
    device: t.device,
    io: str = "out",
    t_min: float = 0.05,
    t_max: float = 0.50,
    fixed_t: Optional[float] = None,
    verbose: bool = False,
) -> Dict[str, t.Tensor]:
    """
    mask-only ΔLM loss:
      - sample t ~ U[t_min, t_max] (or fixed_t)
      - mask ONLY within section_sel
      - splice SAE ONLY within section_sel
      - compute CE ONLY on masked positions (a subset of section_sel), weighted by 1/t
    """
    if fixed_t is not None:
        t_prob = float(fixed_t)
    else:
        t_prob = float(t.empty((), device=device).uniform_(t_min, t_max).item())
    t_prob = max(1e-6, min(0.999, t_prob))
    inv_t_weight = 1.0 / t_prob

    input_ids = full_input_ids.to(device)
    attention_mask = t.ones_like(input_ids, dtype=t.long, device=device)

    m = build_random_mask(
        input_ids=input_ids,
        attention_mask=attention_mask,
        mask_prob=t_prob,
        exclude_first_token=True,
        allowed_mask=section_sel.to(device),
    )
    masked_ids = input_ids.clone()
    masked_ids[m] = mask_token_id

    base_inputs = {"input_ids": masked_ids, "attention_mask": attention_mask}
    base_inputs = _maybe_add_time_condition(base_inputs, p_scalar=t_prob, model=model)

    outputs_clean = _safe_forward_with_masks(model, base_inputs, prefer_additive=True)
    logits_clean = get_logits(outputs_clean)

    handle = register_sae_splice_hook(submodule, dictionary, section_sel=section_sel.to(device), io=io)
    try:
        outputs_sae = _safe_forward_with_masks(model, base_inputs, prefer_additive=True)
        logits_sae = get_logits(outputs_sae)
    finally:
        handle.remove()

    loss_clean_sum, n_mask = ce_sum_with_mask(logits_clean, input_ids, m, weight_scalar=inv_t_weight)
    loss_sae_sum, _        = ce_sum_with_mask(logits_sae,   input_ids, m, weight_scalar=inv_t_weight)

    if verbose:
        Ttot = int(input_ids.shape[1])
        print(
            f"[Instr-ΔLoss(mask-only)] t={t_prob:.3f} masked={n_mask}/{Ttot} "
            f"clean_sum={float(loss_clean_sum.item()):.3f} sae_sum={float(loss_sae_sum.item()):.3f}",
            flush=True,
        )

    return {
        "loss_clean_mask_sum": loss_clean_sum,
        "loss_sae_mask_sum": loss_sae_sum,
        "n_masked_tokens": t.tensor(int(n_mask), device=device),
        "t_used": t.tensor(float(t_prob), device=device, dtype=t.float32),
    }


############################################
# SAE folder scanning helper
############################################

def find_sae_trainer_dirs(root: str) -> List[str]:
    trainer_dirs: List[str] = []
    for dirpath, dirnames, filenames in os.walk(root):
        base = os.path.basename(dirpath)
        if base.startswith("trainer_") and os.path.isdir(dirpath):
            trainer_dirs.append(dirpath)
    trainer_dirs = sorted(trainer_dirs)
    return trainer_dirs


############################################
# Main entry point
############################################

def main():
    parser = argparse.ArgumentParser(
        "Compute Dream-ΔLoss (mask-only) on instruction rollouts, with SAE splice restricted to prompt/rollout."
    )
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--ae_root", type=str, required=True)
    parser.add_argument("--token_budget", type=int, default=1_000_000)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["float32", "bfloat16", "float16"],
    )

    # Instruction experiment controls (report experiment 3)
    parser.add_argument("--instruction_dataset", type=str, default="tatsu-lab/alpaca")
    parser.add_argument("--instruction_split", type=str, default="train")
    parser.add_argument("--num_instructions", type=int, default=50)
    parser.add_argument("--instruction_skip", type=int, default=0)
    parser.add_argument("--section", type=str, choices=["prompt", "rollout"], default="rollout")
    parser.add_argument("--max_new_tokens", type=int, default=256)

    # Dream sampling knobs (align official)
    parser.add_argument("--dlm_steps", type=int, default=512)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--alg", type=str, default="entropy")
    parser.add_argument("--alg_temp", type=float, default=0.0)

    # DLM timestep / masking config
    parser.add_argument("--t_min", type=float, default=0.05, help="Lower bound for per-sample t sampling.")
    parser.add_argument("--t_max", type=float, default=0.50, help="Upper bound for per-sample t sampling.")
    parser.add_argument(
        "--fixed_t",
        type=float,
        default=None,
        help="If set, use this fixed corruption level t instead of sampling t ~ U[t_min, t_max].",
    )

    parser.add_argument(
        "--io",
        type=str,
        default="out",
        choices=["in", "out"],
        help="Hook direction: splice SAE into submodule input or output.",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Print per-sample status messages.",
    )
    args = parser.parse_args()

    warnings.filterwarnings(
        "ignore",
        message="You are using `torch.load` with `weights_only=False`",
        category=FutureWarning,
    )

    dtype_map = {"float32": t.float32, "bfloat16": t.bfloat16, "float16": t.float16}
    load_dtype = dtype_map[args.dtype]

    device = t.device(args.device)

    ############################################
    # Tokenizer
    ############################################
    print(f"[Setup] Loading tokenizer for {args.model_name} ...", flush=True)
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        use_fast=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ############################################
    # Model (align Dream official: transformers + torch_dtype + .to(cuda))
    ############################################
    print("[Setup] Loading model (this can take minutes) ...", flush=True)
    model = AutoModel.from_pretrained(
        args.model_name,
        torch_dtype=load_dtype,
        trust_remote_code=True,
    ).to(device).eval()

    if not is_dream_like(model):
        print(
            "[Warn] The loaded model does not look like a Dream/diffusion LM; "
            "we will still use the DLM mask-only objective (masked CE weighted by 1/t).",
            flush=True,
        )

    eval_device = device

    mask_token_id = getattr(tokenizer, "mask_token_id", None)
    if mask_token_id is None:
        fallback_id = (
            tokenizer.unk_token_id
            if getattr(tokenizer, "unk_token_id", None) is not None
            else tokenizer.eos_token_id
        )
        mask_token_id = fallback_id
        print(
            f"[Info] tokenizer.mask_token_id not found; using fallback id = {mask_token_id}",
            flush=True,
        )

    ############################################
    # Scan for all SAE checkpoints under ae_root
    ############################################
    print(f"[Scan] Scanning SAE folders under: {args.ae_root}", flush=True)
    sae_dirs: List[str] = find_sae_trainer_dirs(args.ae_root)
    if len(sae_dirs) == 0:
        print(
            "[Scan] No trainer_* folders found via recursive scan. "
            "Falling back to utils.get_nested_folders(...)",
            flush=True,
        )
        sae_dirs = utils.get_nested_folders(args.ae_root)

    print(f"[Scan] Found {len(sae_dirs)} SAE folders.", flush=True)
    if args.verbose:
        for d in sae_dirs:
            print(f"  - {d}", flush=True)

    ############################################
    # Instruction stream
    ############################################
    instr_stream = iter_instructions(args.instruction_dataset, split=args.instruction_split, streaming=True)

    skipped = 0
    while skipped < args.instruction_skip:
        try:
            next(instr_stream)
        except StopIteration:
            break
        skipped += 1

    ############################################
    # Loop over each SAE
    ############################################
    for idx, d in enumerate(sae_dirs, start=1):
        print(f"\n[Eval {idx}/{len(sae_dirs)}] Loading SAE from: {d}", flush=True)

        try:
            dictionary, cfg = utils.load_dictionary(d, device=args.device)
            model_dtype = next(model.parameters()).dtype
            dictionary.to(dtype=model_dtype)
            dictionary.eval()

            layer = cfg["trainer"]["layer"]
            submodule = utils.get_submodule(model, layer)
            print(f"[Eval] Target layer = {layer} | io = {args.io} | section = {args.section}", flush=True)

            sum_masked_tokens = 0
            loss_clean_mask_sum_total = 0.0
            loss_sae_mask_sum_total = 0.0
            n_used = 0

            pbar = tqdm(
                total=args.token_budget,
                unit="tok",
                dynamic_ncols=True,
                desc=f"Instr-ΔLoss(mask) tokens ({os.path.basename(d)}) [{args.section}]",
            )

            while (sum_masked_tokens < args.token_budget) and (n_used < args.num_instructions):
                try:
                    ex = next(instr_stream)
                except StopIteration:
                    break

                # 1) build "user content" (alpaca formatted)
                user_content = format_alpaca_prompt(ex)

                # 2) Dream official rollout: get prompt_ids, prompt_mask, full_ids
                prompt_ids, prompt_attn, full_ids = generate_rollout_ids_official_dream(
                    model=model,
                    tokenizer=tokenizer,
                    user_content=user_content,
                    device=eval_device,
                    max_new_tokens=args.max_new_tokens,
                    steps=args.dlm_steps,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    alg=args.alg,
                    alg_temp=args.alg_temp,
                )

                prompt_len = int(prompt_ids.shape[1])

                # 3) section mask on full sequence length
                section_sel = build_section_mask(full_ids, prompt_len=prompt_len, section=args.section)

                # 4) compute mask-only ΔLoss restricted to section
                out = dlm_mask_only_losses_instruction(
                    model=model,
                    tokenizer=tokenizer,
                    submodule=submodule,
                    dictionary=dictionary,
                    full_input_ids=full_ids,
                    section_sel=section_sel,
                    mask_token_id=mask_token_id,
                    device=eval_device,
                    io=args.io,
                    t_min=args.t_min,
                    t_max=args.t_max,
                    fixed_t=args.fixed_t,
                    verbose=args.verbose,
                )

                n_mask = int(out["n_masked_tokens"].item())
                if n_mask > 0:
                    loss_clean_mask_sum_total += float(out["loss_clean_mask_sum"].item())
                    loss_sae_mask_sum_total += float(out["loss_sae_mask_sum"].item())
                    sum_masked_tokens += n_mask
                    pbar.update(n_mask)

                n_used += 1

            pbar.close()

            sum_masked_tokens = max(sum_masked_tokens, 1)

            avg_clean_mask = loss_clean_mask_sum_total / sum_masked_tokens
            avg_sae_mask = loss_sae_mask_sum_total / sum_masked_tokens
            delta_mask = avg_sae_mask - avg_clean_mask

            out_mask = {
                "tokens_masked_evaluated": float(sum_masked_tokens),
                "dream_weighted_loss_clean(mask)": float(avg_clean_mask),
                "dream_weighted_loss_sae(mask)": float(avg_sae_mask),
                "delta_lm_loss(mask)": float(delta_mask),
                "weighting": "mask-only CE weighted by 1/t; masking+splice restricted to section (prompt/rollout)",
                "section": args.section,
                "instruction_dataset": args.instruction_dataset,
                "instruction_split": args.instruction_split,
                "num_instructions_used": int(n_used),
                "t_min": float(args.t_min),
                "t_max": float(args.t_max),
                "fixed_t": (None if args.fixed_t is None else float(args.fixed_t)),
                "max_new_tokens": int(args.max_new_tokens),
                "dlm_steps": int(args.dlm_steps),
                "temperature": float(args.temperature),
                "top_p": float(args.top_p),
                "alg": str(args.alg),
                "alg_temp": float(args.alg_temp),
                "io": args.io,
                "model_name": args.model_name,
            }

            out_path = os.path.join(d, f"delta_lm_loss(mask)_base_to_base_instr_{args.section}.json")
            with open(out_path, "w") as f:
                json.dump(out_mask, f, indent=2)

            print(
                f"[Done] {d} → Instr ΔLM(mask)={out_mask['delta_lm_loss(mask)']:.6f}  "
                f"masked_tokens={int(out_mask['tokens_masked_evaluated']):,}  "
                f"n_instr={n_used}  section={args.section}  saved={out_path}",
                flush=True,
            )

        except Exception as e:
            print(f"[Eval] Failed on {d}: {e}", flush=True)


if __name__ == "__main__":
    main()
