"""
Hidden-State guided probing for LLaMA-architecture models.

Usage:
  python test_cf.py \
      --model-id TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
      --prompt "Ali likes the" \
      --max-new-tokens 60

If your primary model is gated (e.g., Llama-3.*), pass --hf-token or set HUGGINGFACE_HUB_TOKEN.

Dependencies: torch, transformers, sentencepiece, accelerate (optional)
"""

import os
import math
import argparse
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


# ----------------------- RNG replay -----------------------

class RNGReplayTorch:
    def __init__(self, vocab_size: int, master_seed: int = 42):
        self.V = vocab_size
        self._master_gen = torch.Generator(device="cpu")
        self._master_gen.manual_seed(int(master_seed))
        self.seeds: List[int] = []

    def next_seed(self) -> int:
        s = int(torch.randint(low=0, high=2**31 - 1, size=(1,), generator=self._master_gen).item())
        self.seeds.append(s)
        return s

    @staticmethod
    def gumbel_from_seed(seed: int, size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
        gen = torch.Generator(device="cpu")
        gen.manual_seed(int(seed))
        u = torch.rand(size, generator=gen, dtype=torch.float32)  # CPU float32 for stability
        u = u.clamp_(min=1e-9, max=1 - 1e-9)
        g = -torch.log(-torch.log(u))
        return g.to(device=device, dtype=dtype)


# ----------------------- HGC configuration -----------------------

@dataclass
class HGCConfig:
    S_words: List[str]                 # e.g., ["farm", "farms", "insects"]
    T_words: List[str]                 # e.g., ["city", "entertainment"]
    theta: float = 0.55                # trigger threshold for q_S
    alpha0: float = 1.2                # base latent edit scale
    gamma: float = 0.93                # persistence decay
    kappa: float = 12.0                # concentration for cosine-based q_S
    beta_S: float = 1.5                # optional down-weight for S when triggered
    beta_T: float = 1.5                # optional up-weight for T when triggered
    eos_bias_while_active: float = 0.0 # set negative (e.g., -5.0) to suppress EOS while s>0


# ----------------------- HGC steerer for LLaMA -----------------------

class HGCSteererLLM:
    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, cfg: HGCConfig):
        self.model = model
        self.tokenizer = tokenizer
        self.cfg = cfg

        # Resolve token ids for S/T words that are single-piece
        self.S_ids = self._single_piece_ids(cfg.S_words)
        self.T_ids = self._single_piece_ids(cfg.T_words)
        if len(self.S_ids) == 0 or len(self.T_ids) == 0:
            print("[HGC] Warning: Empty S or T id sets (multi-piece words were skipped).")

        # Use lm_head dtype/device (tied to embeddings for LLaMA)
        W = self.model.lm_head.weight.detach().to(device=self.model.device)
        self.dtype_lm = W.dtype

        # Prototypes μ_S, μ_T in output/embedding space
        self.mu_S = self._mean_rows(W, self.S_ids)
        self.mu_T = self._mean_rows(W, self.T_ids)

        # Normalize prototypes (stay in lm dtype/device)
        self.mu_S = self._safe_normalize(self.mu_S)
        self.mu_T = self._safe_normalize(self.mu_T)

        # Concept direction v = μ_T − μ_S, normalized (lm dtype)
        self.v = self._safe_normalize(self.mu_T - self.mu_S)

        # EOS ids (best-effort)
        self.eos_ids = set()
        for tok in [self.tokenizer.eos_token, "</s>", "<|eot_id|>", "<|endoftext|>", ""]:
            if tok is None:
                continue
            tid = self.tokenizer.convert_tokens_to_ids(tok)
            if isinstance(tid, int) and tid >= 0:
                self.eos_ids.add(int(tid))

        # Period ids (best-effort; depends on tokenizer)
        self.period_ids = set()
        for s in [".", " .", "▁.", " . "]:
            toks = self.tokenizer.tokenize(s)
            if len(toks) == 1:
                tid = self.tokenizer.convert_tokens_to_ids(toks[0])
                if isinstance(tid, int) and tid >= 0:
                    self.period_ids.add(int(tid))

        # Print mapping info
        def decode_ids(ids): return [self.tokenizer.decode([i], skip_special_tokens=True) for i in ids]
        print(f"[HGC] S_ids ({len(self.S_ids)}):", decode_ids(self.S_ids))
        print(f"[HGC] T_ids ({len(self.T_ids)}):", decode_ids(self.T_ids))

    def _single_piece_ids(self, words: List[str]) -> List[int]:
        ids: List[int] = []
        for w in words:
            toks = self.tokenizer.tokenize(" " + w)  # prefer begin-of-word piece
            if len(toks) == 1:
                tid = self.tokenizer.convert_tokens_to_ids(toks[0])
                if isinstance(tid, int) and tid >= 0:
                    ids.append(int(tid))
            else:
                print(f"[HGC] Skipping multi-piece word: '{w}' -> {toks}")
        return ids

    @staticmethod
    def _safe_normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
        return x / (x.norm(p=2) + eps)

    @staticmethod
    def _mean_rows(W: torch.Tensor, ids: List[int]) -> torch.Tensor:
        if not ids:
            return torch.zeros(W.shape[1], dtype=W.dtype, device=W.device)
        rows = W[torch.tensor(ids, device=W.device, dtype=torch.long)]
        return rows.mean(dim=0)

    @staticmethod
    def _cos(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
        # Compute cosine on a's device in float32 for stability and dtype compatibility
        a32 = a.to(dtype=torch.float32)
        b32 = b.to(device=a32.device, dtype=torch.float32)
        return (a32 @ b32) / (a32.norm(p=2) * b32.norm(p=2) + eps)

    def q_S(self, h: torch.Tensor) -> float:
        # Unsupervised cosine-based preference between S and T prototypes
        k = self.cfg.kappa
        cs = float(self._cos(h, self.mu_S))
        ct = float(self._cos(h, self.mu_T))
        num = math.exp(k * cs)
        den = num + math.exp(k * ct)
        return float(num / max(den, 1e-9))

    def edit_hidden(self, h: torch.Tensor, qS: float, s_intensity: float) -> Tuple[torch.Tensor, float]:
        # Persistent intensity update: s <- γ s + α0 max(0, qS - θ)
        s_new = self.cfg.gamma * s_intensity + self.cfg.alpha0 * max(0.0, qS - self.cfg.theta)
        if s_new > 0:
            v_like = self.v.to(device=h.device, dtype=h.dtype)
            h_base = h.to(dtype=v_like.dtype)
            h_edit = torch.nn.functional.normalize(h_base + s_new * v_like, p=2, dim=-1)
            return h_edit, s_new
        return h, s_new

    def apply_logit_tweaks(self, logits: torch.Tensor, triggered: bool):
        if not triggered:
            return logits
        if self.cfg.beta_S > 0 and len(self.S_ids):
            logits[..., self.S_ids] -= self.cfg.beta_S
        if self.cfg.beta_T > 0 and len(self.T_ids):
            logits[..., self.T_ids] += self.cfg.beta_T
        return logits


# ----------------------- Low-level step function -----------------------

@torch.no_grad()
def llama_step_llm(model: AutoModelForCausalLM,
                   input_ids: torch.Tensor,
                   past_key_values=None,
                   attention_mask: Optional[torch.Tensor] = None):
    """
    Single forward step through base LlamaModel to get last hidden state (post-norm) and PKV.
    We compute logits by applying lm_head on the last token hidden.
    """
    out = model.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        past_key_values=past_key_values,
        use_cache=True,
        return_dict=True
    )
    hidden = out.last_hidden_state  # [1, seq, hidden], post-final-norm in LLaMA
    h_last = hidden[:, -1, :]       # [1, hidden]
    logits = model.lm_head(h_last)  # [1, vocab]
    return logits, h_last, out.past_key_values



@torch.no_grad()
def generate_factual(model, tokenizer, prompt: str, max_new_tokens: int, master_seed: int = 2025):
    device = model.device
    enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    input_ids = enc.input_ids.to(device)
    attn = enc.attention_mask.to(device)

    # Prime with the full prompt
    logits, h_last, pkv = llama_step_llm(model, input_ids=input_ids, past_key_values=None, attention_mask=attn)

    vocab_size = logits.shape[-1]
    rng = RNGReplayTorch(vocab_size=vocab_size, master_seed=master_seed)

    generated_ids = [int(t) for t in input_ids[0].tolist()]
    for _ in range(max_new_tokens):
        # Sample next token via Gumbel-Max with per-step seed
        seed = rng.next_seed()
        g = RNGReplayTorch.gumbel_from_seed(seed, vocab_size, device=logits.device, dtype=logits.dtype)
        scores = logits[0] + g
        next_id = int(torch.argmax(scores).item())

        generated_ids.append(next_id)

        # Stop at EOS (best-effort)
        if next_id in tokenizer.all_special_ids or (hasattr(tokenizer, "eos_token_id") and next_id == tokenizer.eos_token_id):
            break

        # Next step with past
        inp = torch.tensor([[next_id]], device=device, dtype=torch.long)
        attn = torch.ones(1, len(generated_ids), device=device, dtype=torch.long)
        logits, h_last, pkv = llama_step_llm(model, input_ids=inp, past_key_values=pkv, attention_mask=attn)

    return generated_ids, rng.seeds




@dataclass
class ReplaceRule:
    src_ids: List[int]     # tokens to replace (e.g., [id("▁farms")])
    dst_id: int            # replacement token (e.g., id("▁city"))
    max_hits: int = 10     # limit replacements if desired
    hits: int = 0          # internal counter


def single_piece_id(tokenizer: AutoTokenizer, word: str) -> Optional[int]:
    """
    Returns the single-piece token id for ' word' if it tokenizes to exactly one piece, else None.
    """
    toks = tokenizer.tokenize(" " + word)
    if len(toks) == 1:
        return int(tokenizer.convert_tokens_to_ids(toks[0]))
    return None


@torch.no_grad()
def _last_hidden_of_text(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, text: str) -> torch.Tensor:
    """
    Returns the final-token hidden state h_text ∈ R^H after feeding the entire text.
    Uses the same base path (model.model) as llama_step_llm to stay consistent.
    """
    device = model.device
    enc = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    input_ids = enc.input_ids.to(device)
    attn = enc.attention_mask.to(device)
    out = model.model(
        input_ids=input_ids,
        attention_mask=attn,
        use_cache=False,
        return_dict=True
    )
    # [1, seq, H] -> take last position -> [H]
    return out.last_hidden_state[:, -1, :][0]


@torch.no_grad()
def _hidden_sequence_of_text(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    text: str
) -> Tuple[torch.Tensor, List[int]]:
    """
    Returns:
      - hidden_seq: [L, H] last_hidden_state for each token position (post-norm for LLaMA)
      - token_ids: list[int] of length L (no special tokens)
    Uses model.model for consistency with llama_step_llm.
    """
    device = model.device
    enc = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    input_ids = enc.input_ids.to(device)
    attn = enc.attention_mask.to(device)
    out = model.model(
        input_ids=input_ids,
        attention_mask=attn,
        use_cache=False,
        return_dict=True
    )
    # out.last_hidden_state: [1, L, H] -> [L, H]
    hidden_seq = out.last_hidden_state[0]
    token_ids = [int(t) for t in enc.input_ids[0].tolist()]
    return hidden_seq, token_ids


def _period_token_ids(tokenizer: AutoTokenizer) -> set:
    """
    Best-effort set of token ids corresponding to a period. Tokenizer-dependent.
    """
    cands = [".", " .", "▁.", " . "]
    out = set()
    for s in cands:
        toks = tokenizer.tokenize(s)
        if len(toks) == 1:
            tid = tokenizer.convert_tokens_to_ids(toks[0])
            if isinstance(tid, int) and tid >= 0:
                out.add(int(tid))
    return out



def _gaussian_mask_around_diffs(
    fact_ids: List[int],
    edit_ids: List[int],
    Lmin: int,
    sigma: float = 1.5,
    window: int = 4
) -> torch.Tensor:
    """
    Build a per-position weight vector w in [0,1] of length Lmin, emphasizing
    positions where tokens differ and softly decaying around them.
    """
    device = torch.device("cpu")
    if Lmin <= 0:
        return torch.zeros(0)
    diff_positions = [i for i in range(Lmin) if fact_ids[i] != edit_ids[i]]
    if not diff_positions:
        return torch.zeros(Lmin)  # no change needed
    idx = torch.arange(Lmin, dtype=torch.float32)
    w = torch.zeros(Lmin, dtype=torch.float32)
    for p in diff_positions:
        d = (idx - float(p)).abs()
        contrib = torch.exp(-0.5 * (d / max(sigma, 1e-6)) ** 2)
        if window is not None and window > 0:
            contrib = torch.where(d <= window, contrib, torch.zeros_like(contrib))
        w = torch.maximum(w, contrib)
    # Move to CUDA if needed during use; keep CPU here.
    return w


# ----------------------- Utilities -----------------------

def load_tokenizer_model(model_id: str, hf_token: Optional[str] = None):
    kw_tok = {"use_fast": True}
    kw_mdl = {"torch_dtype": "auto","attn_implementation": "eager"}
    if hf_token:
        # Newer transformers prefer 'token'; older use 'use_auth_token'
        try:
            kw_tok["token"] = hf_token
            kw_mdl["token"] = hf_token
        except TypeError:
            kw_tok["use_auth_token"] = hf_token
            kw_mdl["use_auth_token"] = hf_token
    tok = AutoTokenizer.from_pretrained(model_id, **kw_tok)
    mdl = AutoModelForCausalLM.from_pretrained(model_id, **kw_mdl)
    return tok, mdl

def _token_ids_of_text(tokenizer: AutoTokenizer, text: str) -> List[int]:
    enc = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    return [int(t) for t in enc.input_ids[0].tolist()]


def _build_diff_map(fact_ids: List[int], edit_ids: List[int]) -> Dict[int, Tuple[int, int]]:
    """
    Returns a mapping j -> (target_id, source_id) for positions where the next token differs.
    j indexes the token position in the teacher-forced sequences (0-based).
    """
    Lmin = min(len(fact_ids), len(edit_ids))
    diffs = {}
    for j in range(Lmin):
        if fact_ids[j] != edit_ids[j]:
            diffs[j] = (edit_ids[j], fact_ids[j])
    return diffs



# ----------------------- NEW: Soft relaxation helpers and generators -----------------------

def _llama_step_inputs_embeds(model: AutoModelForCausalLM,
                              inputs_embeds: torch.Tensor,
                              past_key_values=None,
                              attention_mask: Optional[torch.Tensor] = None):
    """
    One forward step using inputs_embeds (expected embedding). Returns (logits, h_last, pkv).
    inputs_embeds: [1, 1, H]
    """
    # Ensure pkv works with the model’s current Transformers version
    pkv_compat = _ensure_cache_compatible(past_key_values)

    out = model.model(
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        past_key_values=pkv_compat,
        use_cache=True,
        return_dict=True,
        output_attentions=True
    )
    hidden = out.last_hidden_state      # [1, seq, H]
    h_last = hidden[:, -1, :]           # [1, H]
    logits = model.lm_head(h_last)      # [1, V]
    atts = out.attentions
    return logits, h_last, out.past_key_values, atts



def _softmax_expected_embedding(
    logits_1xV: torch.Tensor,
    E_in: torch.Tensor,
    tau: float,
    add_gumbel: bool,
    topk_expected: Optional[int] = None,   # set e.g. 2048 to reduce compute/memory
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    From logits [1,V], return:
      - y: [V] soft probabilities softmax((logits + g) / tau) in float32
      - e: [H] expected embedding sum_v y[v] * E_in[v], computed without allocating a float32 copy of E_in.
    If topk_expected is set (< V), approximate using only the top-k probabilities (renormalized).
    """
    V = logits_1xV.shape[-1]
    z32 = logits_1xV[0].float()  # [V]
    y32 = torch.softmax(z32/float(tau), dim=-1)  # [V], float32

    # Exact expected embedding in model dtype (no E_in.float() to avoid huge copy)
    if topk_expected is not None and 0 < int(topk_expected) < V:
        k = int(topk_expected)
        idx = torch.topk(y32, k, dim=-1, sorted=False).indices  # [k]
        yk = y32.index_select(0, idx)
        yk = yk / (yk.sum() + 1e-12)  # renormalize mass within top-k
        E_sub = E_in.index_select(0, idx)                         # [k, H], model dtype
        e = torch.matmul(yk.to(dtype=E_in.dtype), E_sub)          # [H], model dtype
    else:
        # y32 @ E_in in mixed precision; kernel promotes without duplicating E_in to fp32
        e = torch.matmul(y32.to(dtype=E_in.dtype), E_in)          # [H], model dtype

    return y32, e

def _make_attn(length: int, device: torch.device) -> torch.Tensor:
    return torch.ones(1, length, dtype=torch.long, device=device)


def _detach_pkv(pkv):
    if pkv is None:
        return None
    return tuple(tuple(t.detach() for t in pair) for pair in pkv)


def _token_is_eos(tokenizer: AutoTokenizer, tid: int) -> bool:
    return (tid in tokenizer.all_special_ids) or (
        hasattr(tokenizer, "eos_token_id") and tid == tokenizer.eos_token_id
    )


@torch.no_grad()
def generate_factual_soft(model, tokenizer, prompt: str, max_new_tokens: int,
                          tau: float = 1.0, add_gumbel: bool = False,
                          choose: str = "greedy", master_seed: int = 2025):
    """
    Factual generation using soft relaxation:
      - next-step state update uses expected embedding E_in^T y
      - printed token is greedy (or sampled) from y
    """
    device = model.device
    enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    input_ids = enc.input_ids.to(device)
    attn = enc.attention_mask.to(device)

    # Prime with the full prompt (hard tokens OK)
    logits, h_last, pkv = llama_step_llm(model, input_ids=input_ids, past_key_values=None, attention_mask=attn)

    E_in = model.model.embed_tokens.weight.detach()  # [V, H]
    vocab_size = int(E_in.shape[0])
    rng = RNGReplayTorch(vocab_size=vocab_size, master_seed=master_seed)

    out_ids = [int(t) for t in input_ids[0].tolist()]
    for _ in range(max_new_tokens):
        # Soft probs from current hidden
        y, e = _softmax_expected_embedding(logits, E_in, seed=(rng.next_seed() if add_gumbel else None),
                                           tau=tau, add_gumbel=add_gumbel)

        # Choose token for display only
        if choose == "sample":
            gen = torch.Generator(device="cpu"); gen.manual_seed(rng.seeds[-1] if add_gumbel else int(torch.randint(0, 2**31-1, (1,)).item()))
            next_id = int(torch.multinomial(y.cpu(), 1, generator=gen).item())
        else:
            next_id = int(torch.argmax(y).item())

        out_ids.append(next_id)
        if _token_is_eos(tokenizer, next_id):
            break

        # Advance model with expected embedding (soft)
        e_in = e.view(1, 1, -1)  # [1,1,H]
        attn = _make_attn(len(out_ids), device)
        logits, h_last, pkv, _ = _llama_step_inputs_embeds(model, inputs_embeds=e_in,
                                                        past_key_values=pkv, attention_mask=attn)

    return out_ids, rng.seeds

def _clone_cache_for_rollout(pkv):
    if pkv is None:
        return None
    try:
        from transformers.cache_utils import DynamicCache
    except Exception:
        DynamicCache = None

    # New-style cache (>= 4.41)
    if DynamicCache is not None and hasattr(pkv, "get_seq_length"):
        if hasattr(pkv, "to_legacy_cache"):
            legacy = pkv.to_legacy_cache()  # list/tuple of (k, v) per layer
            legacy = tuple((k.detach().clone(), v.detach().clone()) for (k, v) in legacy)
            return DynamicCache.from_legacy_cache(legacy)  # fresh object
        else:
            # Fallback – if your version lacks to_legacy_cache, upgrade Transformers or implement a manual deep copy.
            raise RuntimeError("DynamicCache clone path unavailable (no to_legacy_cache). Upgrade transformers.")
    # Legacy tuple/list
    return tuple((k.detach().clone(), v.detach().clone()) for (k, v) in pkv)


def generate_counterfactual_hdmi(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    edited_text: str,          # teacher-forced edited text (to define targets)
    factual_text: str,         # teacher-forced factual text (to define sources)
    max_new_tokens: int,
    alpha: float = 1.0,        # default step size if inner_alpha is None
    tau: float = 1.0,          # temperature for softmax relaxation
    use_margin: bool = True,   # optimize z[target] - z[source] if both exist; else z[target]
    normalize_grad: bool = True,
    grad_clip_norm: Optional[float] = None,
    discount: float = 1.0,     # weight future edit at distance d by discount**d
    max_lookahead_steps: Optional[int] = None,  # cap soft rollout horizon (in steps)
    add_gumbel: bool = True,   # add Gumbel in y (reparameterized) for soft rollout
    choose: str = "greedy",    # "greedy" | "sample" for displayed tokens
    topk_expected: Optional[int] = None,        # speed up expected-embedding computation
    inner_steps: int = 1,      # NEW: number of gradient steps per decode step
    inner_alpha: Optional[float] = None,        # NEW: per-step lr; if None, uses alpha
    f_reg: Optional[float] = 1
    ) -> List[int]:
    """
    hdmi with multiple inner gradient steps on the current hidden state at each decode step.

    At each generation step:
      - Do 'inner_steps' times:
          * Build a differentiable soft rollout from current h to compute J (sum over future edit positions).
          * Compute grad g = ∂J/∂h and update h <- h + inner_alpha * g (with optional norm/clip).
      - From final h, get y, choose a token for display (greedy/sample), and advance the model with expected embedding.
    """
    device = model.device
    E_in = model.model.embed_tokens.weight.detach()  # [V, H]

    # Teacher-forced targets/sources
    with torch.no_grad():
        fact_ids = _token_ids_of_text(tokenizer, factual_text)
        edit_ids = _token_ids_of_text(tokenizer, edited_text)
    diff_map = _build_diff_map(fact_ids, edit_ids)  # j -> (target_id, source_id)
    Lmin = min(len(fact_ids), len(edit_ids))

    # Prime with prompt
    enc_prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    input_ids = enc_prompt.input_ids.to(device)
    attn = enc_prompt.attention_mask.to(device)
    logits, h_last, pkv = llama_step_llm(model, input_ids=input_ids, past_key_values=None, attention_mask=attn)

    out_ids = [int(t) for t in input_ids[0].tolist()]
    step_idx = 0
    W_dtype = model.lm_head.weight.dtype
    step_lr = alpha if (inner_alpha is None) else inner_alpha

    # Helper: build soft rollout objective J(h_start) from the current position
    def _build_objective(h_start: torch.Tensor, j_start: int, pkv_base, attn_len0: int):
        # Clone cache so inner autodiff graph is isolated
        pkv0 = _clone_cache_for_rollout(pkv_base)
        attn_len = attn_len0
        # Determine horizon
        future_positions = [j for j in diff_map.keys() if j >= j_start]
        #if not future_positions:
         #   return None  # no more edits ahead -> no objective
        if len(future_positions)==0:
            S=0
        else: S = max(future_positions) - j_start
        if max_lookahead_steps is not None:
            S = min(S, int(max_lookahead_steps))

        J_local = None
        h_s = h_start
        pkv_s = pkv0
        logits_s = model.lm_head(h_s[None, :])  # [1, V]
        for s in range(S + 1):
            j_abs = j_start + s
            z32 = logits_s[0].float()
            #print("in build", s)    
            # Accumulate objective on edit positions
            
            if (j_abs < Lmin) and  (j_abs in diff_map):
                tgt, src = diff_map[j_abs]
                L_s = (z32[int(tgt)] - z32[int(src)]) if use_margin else (z32[int(tgt)])
                #print("edited loss",fact_ids[j_abs], src)
            elif j_abs==j_start:
                #print("factual loss",fact_ids[j_abs])
                L_s =  f_reg*(z32[int(fact_ids[j_abs])])

            if discount != 1.0:
                L_s = (discount ** s) * L_s
            J_local = L_s if (J_local is None) else (J_local + L_s)
            # Soft step forward
            y_s, e_s = _softmax_expected_embedding(
                logits_s, E_in, tau=tau, add_gumbel=add_gumbel, topk_expected=topk_expected
            )
            e_in = e_s.view(1, 1, -1)
            attn_len += 1
            attn_s = _make_attn(attn_len, device)
            logits_s, h_next, pkv_s, _ = _llama_step_inputs_embeds(
                model, inputs_embeds=e_in, past_key_values=pkv_s, attention_mask=attn_s
            )
            h_s = h_next[0]

        return J_local  # None if no future edits

    for _ in range(max_new_tokens):
        # Current positional bookkeeping
        pos_index = len(out_ids) - 1
        j_start = pos_index + 1       # first teacher-forced position predicting the next token
        # Start from the current hidden state
        h_work = h_last[0].detach().to(dtype=W_dtype)

        # Inner gradient-ascent loop
        for t_ep in range(max(1, int(inner_steps))):
            h0 = h_work.detach().requires_grad_(True)
            #print("******")
            J=None
            if j_start<len(fact_ids):
                J = _build_objective(h0, j_start=j_start, pkv_base=pkv, attn_len0=len(out_ids))
            if J is None:
                # No future edits to steer toward; stop early
                h_work = h0.detach()
                break
            #print(J)
            grad_h = torch.autograd.grad(J, h0, retain_graph=False, create_graph=False)[0]  # [H]
            #print("grad:", grad_h)
            # Optional clip/normalize per inner step
            if grad_clip_norm is not None:
                n = float(grad_h.norm(p=2))
                if n > 0 and n > grad_clip_norm:
                    grad_h = grad_h * (grad_clip_norm / n)
            if normalize_grad:
                n = float(grad_h.norm(p=2))
                if n > 0:
                    grad_h = grad_h / n

            # Update
            h_work = (h0 + step_lr * grad_h).detach()
        # h_work is the final steered hidden for this decode step
        logits0p = model.lm_head(h_work[None, :])  # [1, V]
        y0p, e0p = _softmax_expected_embedding(
            logits0p, E_in,
            tau=0.01, add_gumbel=add_gumbel, topk_expected=topk_expected

        )

        # Choose a token for display (state update remains soft)
        if choose == "sample":
            gen = torch.Generator(device="cpu")
            gen.manual_seed(
                int(seeds[step_idx]) if (add_gumbel and step_idx < len(seeds))
                else int(torch.randint(0, 2**31 - 1, (1,)).item())
            )
            next_id = int(torch.multinomial(y0p.cpu(), 1, generator=gen).item())
        else:
            next_id = int(torch.argmax(y0p).item())

        out_ids.append(next_id)
        if _token_is_eos(tokenizer, next_id):
            step_idx += 1
            break

        # hard advance to stay on-manifold
                # Advance real model state with expected embedding (soft)
        e_in = e0p.view(1, 1, -1)
        attn = _make_attn(len(out_ids), device)
        logits, h_last, pkv, _ = _llama_step_inputs_embeds(
            model, inputs_embeds=e_in, past_key_values=pkv, attention_mask=attn
        )
        step_idx += 1

    return out_ids
# Cache compatibility helpers for new Transformers (BaseCache/DynamicCache)
try:
    from transformers.cache_utils import DynamicCache  # >= 4.41+
except Exception:
    DynamicCache = None

def _ensure_cache_compatible(pkv):
    """
    Ensure past_key_values is a new-style Cache object when required by the model.
    - If pkv already has get_seq_length(), return as-is.
    - If it's a legacy tuple/list and DynamicCache is available, convert.
    - Otherwise return pkv unchanged (older Transformers still accept tuples).
    """
    if pkv is None:
        return None
    if hasattr(pkv, "get_seq_length"):
        return pkv
    if DynamicCache is not None:
        try:
            return DynamicCache.from_legacy_cache(pkv)
        except Exception:
            # If conversion fails, fall back to legacy pkv (older versions)
            return pkv
    return pkv

def _detach_pkv(pkv):
    """
    Do NOT change the cache type. Just return it. For legacy tuples, we can
    detach tensors but keep the same structure; for new-style Cache, leave as-is.
    """
    if pkv is None:
        return None
    if hasattr(pkv, "get_seq_length"):
        # New-style Cache objects don't require detaching here; they won't get grads.
        return pkv
    # Legacy tuple/list of (k, v) per layer: detach the tensors
    try:
        return tuple(tuple(t.detach() for t in pair) for pair in pkv)
    except Exception:
        return pkv

# ----------------------- Main -----------------------

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-id", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct",
                        help="HF repo id of the model (LLaMA-architecture recommended).")
    parser.add_argument("--prompt", type=str, default="Tell me a story",
                        help="Prompt to start decoding from.")
    parser.add_argument("--factual-text", type=str, default="Tell me a story about a girl who loves the sun.",
                        help="Primary text.")
    parser.add_argument("--alpha", type=int, default=50)
    parser.add_argument("--f-reg", type=float, default=0.2)

    parser.add_argument("--edited-text", type=str, default="Tell me a story about a owl who loves the sun.")

    parser.add_argument("--max-new-tokens", type=int, default=200)
    parser.add_argument("--seed", type=int, default=2025)
    parser.add_argument("--device", type=str, default=None, help="cuda or cpu (auto if None)")
    parser.add_argument("--hf-token", type=str, default=os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN"))
    parser.add_argument("--fallback", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
    args = parser.parse_args()

    # Device
    if args.device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args.device)

    # Load model (with fallback if gated)
    try:
        tokenizer, model = load_tokenizer_model(args.model_id, args.hf_token)
        print(f"[Load] Loaded primary: {args.model_id}")
    except Exception as e:
        print(f"[Load] Primary failed: {e}")
        print(f"[Load] Falling back to: {args.fallback}")
        tokenizer, model = load_tokenizer_model(args.fallback, args.hf_token)
        print(f"[Load] Loaded fallback: {args.fallback}")

    model.to(device)
    model.eval()
    
    factual_text="Tell me a story about a girl who loves the sun."
    edited_text = "Tell me a story about a owl who loves the sun."
    # hdmi: soft relaxation + aggregated future gradients (no argmax)
    cf_hdmi_ids = generate_counterfactual_hdmi(
        model, tokenizer,
        edited_text=args.edited_text,
        factual_text=args.factual_text,
        max_new_tokens=args.max_new_tokens,
        prompt=args.prompt,
        alpha=args.alpha,                      # try 0.5–3.0
        tau=.9,
        f_reg=args.f_reg,
        use_margin=True,
        normalize_grad=False,
        grad_clip_norm=None,
        discount=1,               # mildly discount far future
        max_lookahead_steps=5,         # cap horizon
        add_gumbel=False,
        choose="greedy" ,                # displayed token choice; state update is always soft
        topk_expected=None,
        inner_steps=1
        )
    print("\n=== Counterfactual (hdmi, soft relaxation) ===")
    print(tokenizer.decode(cf_hdmi_ids, skip_special_tokens=True))


if __name__ == "__main__":
    main()

