# findprev_qkonly_prevattn64_seed.py
# QK-only EK-FAC + Previous-Token Attention (A[t, t-1]) Directional Influence
# Stage1: A/S (token expectation) + Λ fit (sequence-weighted full pass)
# Stage2: score = −<g_qk, H^{-1} v_qk>, where v_qk = ∂(sum_t A[t,t-1])/∂(W_QK)
# Probe: random synthetic sequences, fixed length=64, fixed random seed

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import time
import pickle
import heapq
import random
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast
from tqdm import tqdm

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import GPTNeoXTokenizerFast
from transformer_lens import HookedTransformer
from transformer_lens.loading_from_pretrained import (
    OFFICIAL_MODEL_NAMES, MODEL_ALIASES, make_model_alias_map
)

# ================ Precision ================
DTYPE = torch.bfloat16

# ================ Distributed ================
def setup_distributed():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        gpu = int(os.environ["LOCAL_RANK"])
        print(f"[{os.getpid()}] Rank {rank}/{world_size} on GPU {gpu} init.")
        dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
        torch.cuda.set_device(gpu)
        return rank, world_size, gpu
    else:
        print("Not running in distributed mode. Using single process.")
        return 0, 1, 0

RANK, WORLD_SIZE, LOCAL_RANK = setup_distributed()
IS_MAIN_PROCESS = (RANK == 0)
DEVICE = f"cuda:{LOCAL_RANK}" if torch.cuda.is_available() else "cpu"

# ================ Config ================
MODEL_LOCAL_PATH = "/root/trainbin1/nomask/nomask1200"
MODEL_ALIAS = "Pythia-14mlocal"
DATA_NPY_PATH = "/root/indicies0-1999.npy"  # (N, 2049)

NUM_TRAIN_SAMPLES = 1228800
LAYER_IDX = 2
HEAD_IDX = 2
SEQ_LENGTH = 2048
BATCH_SIZE_FOR_EKFAC = 10
BATCH_SIZE_FOR_INFLUENCE = 1
NUM_WORKERS = 2
if BATCH_SIZE_FOR_INFLUENCE != 1:
    raise ValueError("BATCH_SIZE_FOR_INFLUENCE must be 1.")

# EK-FAC damping
DAMPING = 1e-5
DAMPING_ALPHA = 0.1

# empty_cache frequency
EMPTY_CACHE_EVERY_BATCH = 100
EMPTY_CACHE_EVERY_SAMPLE = 500

# Top-K
TOP_K_RESULTS = 204800

# Probe: random synthetic, fixed length
PROBE_NUM_SEQS = 32
PROBE_LOCAL_SEQ_LEN = 64
PROBE_RANDOM_SEED = 12345  # fixed seed for reproducibility; per-rank offset applied

# Output paths (QK-only; avoid clobber)
OUTPUT_RESULTS_FILE = "/root/14m1200/influence_prevattn_qkonly.pkl"
STAGE1A_SAVE_PATH = "/root/14m1200/ekfac_stage1A_qkonly.pt"
STAGE1B_SAVE_PATH = "/root/14m1200/ekfac_stage1B_qkonly.pt"

if IS_MAIN_PROCESS:
    print("🎯 QK-only EK-FAC + Previous-Token Attention Directional Influence")
    print(f"L={LAYER_IDX}, H={HEAD_IDX}；数据: {DATA_NPY_PATH}")
    print(f"阶段1 batch={BATCH_SIZE_FOR_EKFAC}, 阶段2 batch=1；DTYPE={DTYPE}；Top-K={TOP_K_RESULTS}")
    print(f"Damping: {DAMPING}, alpha={DAMPING_ALPHA}")
    print(f"empty_cache：阶段1每 {EMPTY_CACHE_EVERY_BATCH} 个batch；阶段2每 {EMPTY_CACHE_EVERY_SAMPLE} 个样本")
    print(f"Probe: source=random-synthetic, num_seqs={PROBE_NUM_SEQS}, seq_len={PROBE_LOCAL_SEQ_LEN}, seed={PROBE_RANDOM_SEED}")
    print("Target: sum_t A[t, t-1] at the specified layer/head (post-softmax attention)")

# ================ Load Tokenizer & Model ================
OFFICIAL_MODEL_NAMES.append(MODEL_LOCAL_PATH)
MODEL_ALIASES[MODEL_LOCAL_PATH] = [MODEL_ALIAS]
make_model_alias_map()

tokenizer = GPTNeoXTokenizerFast.from_pretrained(MODEL_LOCAL_PATH, local_files_only=True)
if IS_MAIN_PROCESS:
    print(f"[{RANK}/{WORLD_SIZE}] ✅ Tokenizer loaded, vocab={tokenizer.vocab_size}")

model_local = HookedTransformer.from_pretrained_no_processing(
    MODEL_ALIAS,
    n_devices=1,
    dtype=DTYPE,
    device=DEVICE
)
# We only need pattern & Q/K; keeping this on has no harm.
model_local.cfg.use_attn_result = True
model_local.to(DEVICE)
model_local.eval()

# Freeze all params; enable grads only for target layer's W_Q/W_K during Stage2/Probe as needed
for p in model_local.parameters():
    p.requires_grad = False
attn_layer = model_local.blocks[LAYER_IDX].attn
attn_layer.W_Q.requires_grad = True
attn_layer.W_K.requires_grad = True
# V/O are not used
attn_layer.W_V.requires_grad = False
attn_layer.W_O.requires_grad = False

# DDP
model = DDP(
    model_local,
    device_ids=[LOCAL_RANK],
    output_device=LOCAL_RANK,
    find_unused_parameters=False,
    bucket_cap_mb=25
)

d_model = model.module.cfg.d_model
d_head = model.module.cfg.d_head
n_heads = model.module.cfg.n_heads
if IS_MAIN_PROCESS:
    print(f"[{RANK}/{WORLD_SIZE}] 模型: d_model={d_model}, d_head={d_head}, n_heads={n_heads}, n_ctx={model.module.cfg.n_ctx}")

# ================ Data (NPY) ================
class NpyDataset(Dataset):
    def __init__(self, npy_path: str, indices: List[int]):
        self.arr = np.load(npy_path, mmap_mode="r")
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        row_idx = self.indices[idx]
        x = self.arr[row_idx]
        return torch.from_numpy(x.astype(np.int64)), row_idx

def collate_npy(batch):
    tokens = [item[0] for item in batch]
    idxs = [item[1] for item in batch]
    return torch.stack(tokens, dim=0), torch.tensor(idxs, dtype=torch.long)

def build_dataloaders(
    npy_path: str,
    num_train_samples: int,
    batch_size_ekfac: int,
    batch_size_influence: int,
    rank: int,
    world_size: int
):
    if IS_MAIN_PROCESS:
        print(f"[{RANK}/{WORLD_SIZE}] 从 NPY 加载数据: {npy_path}")
    arr = np.load(npy_path, mmap_mode="r")
    total_samples, total_len = arr.shape
    if total_len != 2049:
        raise ValueError(f"NPY 维度应为 (N, 2049)，实际: (N, {total_len})")
    total_to_use = min(num_train_samples, total_samples)
    all_indices = list(range(total_to_use))

    ds_ekfac = NpyDataset(npy_path, all_indices)
    ds_infl = NpyDataset(npy_path, all_indices)

    sampler_ekfac = DistributedSampler(
        ds_ekfac, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False
    )
    sampler_infl = DistributedSampler(
        ds_infl, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False
    )

    dl_ekfac = DataLoader(
        ds_ekfac, batch_size=batch_size_ekfac, sampler=sampler_ekfac,
        collate_fn=collate_npy, num_workers=NUM_WORKERS, pin_memory=True
    )
    dl_infl = DataLoader(
        ds_infl, batch_size=batch_size_influence, sampler=sampler_infl,
        collate_fn=collate_npy, num_workers=NUM_WORKERS, pin_memory=True
    )
    return dl_ekfac, dl_infl

# ================ Forward Hook Caches ================
class QKActivationCache:
    def __init__(self):
        self.X = None  # [B,S,d_model]
        self.Q = None  # [B,S,n_heads,d_head]
        self.K = None  # [B,S,n_heads,d_head]
    def clear(self):
        self.X = None; self.Q = None; self.K = None

def setup_qk_forward_hooks(model_instance, layer_idx: int, cache: QKActivationCache):
    hooks = []
    def h_ln(act, hook): cache.X = act; return act
    def h_q(act, hook): cache.Q = act; return act
    def h_k(act, hook): cache.K = act; return act
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.ln1.hook_normalized", h_ln, dir='fwd'))
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.attn.hook_q", h_q, dir='fwd'))
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.attn.hook_k", h_k, dir='fwd'))
    if IS_MAIN_PROCESS:
        print(f"[{RANK}/{WORLD_SIZE}] ✅ 前向钩子：X/Q/K")
    return [h for h in hooks if h is not None]

class PatternCache:
    def __init__(self):
        self.pattern = None  # [B, n_heads, S, S], post-softmax attention probabilities
    def clear(self):
        self.pattern = None

def setup_pattern_hook(model_instance, layer_idx: int, cache: PatternCache):
    hooks = []
    def h_pattern(act, hook): cache.pattern = act; return act
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.attn.hook_pattern", h_pattern, dir='fwd'))
    if IS_MAIN_PROCESS:
        print(f"[{RANK}/{WORLD_SIZE}] ✅ 前向钩子：attention pattern (post-softmax)")
    return [h for h in hooks if h is not None]

# ================ EK-FAC (QK only) ================
class EKFAC_QK_Head:
    """
    Single block: W_QK (d_in=d_model, d_out=2*d_head)
    A = E[X^T X], S = E[G_qk^T G_qk] where G_qk = [dQ; dK]
    Λ fit via full pass, ge^2 averaging in eigen-basis of A/S.
    """
    def __init__(self, d_model: int, d_head: int, damping: float = 1e-5, damping_alpha: float = 0.1):
        self.d_model = d_model
        self.d_head = d_head
        self.block = {'name': 'W_QK', 'd_in': d_model, 'd_out': 2 * d_head}
        self.damping = damping
        self.damping_alpha = damping_alpha

        self.A_accum: Optional[torch.Tensor] = None  # [d_model, d_model]
        self.S_accum: Optional[torch.Tensor] = None  # [2*d_head, 2*d_head]
        self.token_count: int = 0

        self.Q_A: Optional[torch.Tensor] = None      # eigenvectors of A
        self.Q_S: Optional[torch.Tensor] = None      # eigenvectors of S
        self.Lambda: Optional[torch.Tensor] = None   # flattened Λ (diag in A⊗S basis)

    def accumulate_A_S(self, X_flat_f32, dQ_f32, dK_f32):
        BStok = X_flat_f32.shape[0]
        if BStok == 0:
            return
        X = X_flat_f32.detach()
        dQ = dQ_f32.detach()
        dK = dK_f32.detach()

        G_qk = torch.cat([dQ, dK], dim=-1)  # [B*S, 2*d_head]
        A = X.t() @ X                       # [d_model, d_model]
        S = G_qk.t() @ G_qk                 # [2*d_head, 2*d_head]

        if self.A_accum is None:
            self.A_accum = A
            self.S_accum = S
        else:
            self.A_accum.add_(A)
            self.S_accum.add_(S)

        self.token_count += int(BStok)

    def finalize_eigendecomposition(self):
        if self.token_count == 0:
            raise RuntimeError("No tokens accumulated for EK-FAC A/S.")

        if dist.is_initialized():
            dist.all_reduce(self.A_accum, op=dist.ReduceOp.SUM)
            dist.all_reduce(self.S_accum, op=dist.ReduceOp.SUM)
            tok_t = torch.tensor([self.token_count], device=DEVICE, dtype=torch.long)
            dist.all_reduce(tok_t, op=dist.ReduceOp.SUM)
            global_tokens = tok_t.item()
        else:
            global_tokens = self.token_count

        A = self.A_accum / float(global_tokens)
        S = self.S_accum / float(global_tokens)
        A = 0.5 * (A + A.t()); S = 0.5 * (S + S.t())
        eps_A = (1e-6 * torch.trace(A).abs() / A.shape[0]).to(A.dtype)
        eps_S = (1e-6 * torch.trace(S).abs() / S.shape[0]).to(S.dtype)
        A = A + eps_A * torch.eye(A.shape[0], device=A.device, dtype=A.dtype)
        S = S + eps_S * torch.eye(S.shape[0], device=S.device, dtype=S.dtype)

        eigvals_A, self.Q_A = torch.linalg.eigh(A.float())
        eigvals_S, self.Q_S = torch.linalg.eigh(S.float())
        if IS_MAIN_PROCESS:
            print(f"[W_QK] A eig [{eigvals_A.min():.2e}, {eigvals_A.max():.2e}] "
                  f"S eig [{eigvals_S.min():.2e}, {eigvals_S.max():.2e}]")

    def inverse_hvp(self, grad_matrix: torch.Tensor) -> torch.Tensor:
        if self.Q_A is None or self.Q_S is None or self.Lambda is None:
            raise RuntimeError("EK-FAC not ready for inverse_hvp")
        G = grad_matrix.float()
        QA, QS = self.Q_A, self.Q_S
        ge = QA.t() @ G @ QS  # rotate into eigenspace
        lam = self.Lambda
        denom = lam + self.damping_alpha * lam.mean()
        denom = torch.clamp(denom, min=self.damping)
        ihvp_eig_flat = ge.flatten() / denom
        d_in = self.block['d_in']
        d_out = self.block['d_out']
        ihvp = QA @ ihvp_eig_flat.reshape(d_in, d_out) @ QS.t()
        return ihvp

    @property
    def block_dims(self) -> List[Dict[str, int]]:
        return [{'d_in': self.block['d_in'], 'd_out': self.block['d_out']}]

# ================ Utils ================
def compute_pseudo_labels(logits: torch.Tensor) -> torch.Tensor:
    V = logits.shape[-1]
    probs_flat = torch.softmax(logits.reshape(-1, V).float(), dim=-1)
    sampled = torch.multinomial(probs_flat, num_samples=1).squeeze(-1)
    return sampled

def gather_heap_as_tensors(heap_data, world_size, device):
    data_list = []
    for score, _, (idx, loss) in heap_data:
        data_list.append([float(score), float(int(idx)), float(loss)])
    if len(data_list) == 0:
        tensor_local = torch.zeros(0, 3, device=device, dtype=torch.float64)
    else:
        tensor_local = torch.tensor(data_list, device=device, dtype=torch.float64)

    if not dist.is_initialized():
        return tensor_local.cpu().numpy()

    size_local = torch.tensor([tensor_local.shape[0]], device=device, dtype=torch.long)
    sizes_gather = [torch.zeros_like(size_local) for _ in range(world_size)]
    dist.all_gather(sizes_gather, size_local)

    max_size = max([s.item() for s in sizes_gather])
    pad_size = max_size - tensor_local.shape[0]
    if pad_size > 0:
        padding = torch.full((pad_size, 3), -1e18, device=device, dtype=torch.float64)
        tensor_padded = torch.cat([tensor_local, padding], dim=0)
    else:
        tensor_padded = tensor_local

    gathered_tensors = [torch.zeros_like(tensor_padded) for _ in range(world_size)]
    dist.all_gather(gathered_tensors, tensor_padded)

    if RANK == 0:
        all_rows = []
        for i, t in enumerate(gathered_tensors):
            valid_len = sizes_gather[i].item()
            if valid_len > 0:
                all_rows.append(t[:valid_len])
        if all_rows:
            return torch.cat(all_rows, dim=0).cpu().numpy()
        else:
            return np.zeros((0, 3))
    return None

# =============== DataLoaders ===============
train_as_dataloader, train_influence_dataloader = build_dataloaders(
    DATA_NPY_PATH,
    NUM_TRAIN_SAMPLES,
    BATCH_SIZE_FOR_EKFAC,
    BATCH_SIZE_FOR_INFLUENCE,
    RANK, WORLD_SIZE
)

if IS_MAIN_PROCESS:
    print(f"[{RANK}/{WORLD_SIZE}] 数据加载完成：EK-FAC批数={len(train_as_dataloader)}, 影响评估批数={len(train_influence_dataloader)}")

# ================ 阶段1：A/S 累积（QK；token 期望） ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 阶段1A：累积 A/S（autograd.grad 到 Q/K；token 期望）")
    print("="*70)

ekfac = EKFAC_QK_Head(d_model, d_head, damping=DAMPING, damping_alpha=DAMPING_ALPHA)
qk_cache = QKActivationCache()
hooks_qk = setup_qk_forward_hooks(model.module, LAYER_IDX, qk_cache)

start_time_as = time.time()
for batch_idx, (batch_tokens, _) in enumerate(tqdm(train_as_dataloader, desc=f"[{RANK}] Accum A/S", disable=not IS_MAIN_PROCESS)):
    qk_cache.clear()
    input_ids = batch_tokens[:, :SEQ_LENGTH].to(DEVICE, non_blocking=True)

    with model.no_sync():
        with torch.enable_grad():
            with autocast(dtype=DTYPE):
                logits = model(input_ids)
            sampled_labels = compute_pseudo_labels(logits)
            Vocab = logits.shape[-1]
            loss = F.cross_entropy(logits.reshape(-1, Vocab).float(), sampled_labels, reduction='sum')

            if (qk_cache.Q is None or qk_cache.K is None or qk_cache.X is None):
                raise RuntimeError("X/Q/K activations not captured.")

            grads_Q, grads_K = torch.autograd.grad(
                outputs=loss,
                inputs=[qk_cache.Q, qk_cache.K],
                retain_graph=False, create_graph=False, allow_unused=False
            )

    dQ = grads_Q[:, :, HEAD_IDX, :].reshape(-1, d_head).float()
    dK = grads_K[:, :, HEAD_IDX, :].reshape(-1, d_head).float()
    X_flat = qk_cache.X.reshape(-1, d_model).float()

    ekfac.accumulate_A_S(X_flat, dQ, dK)
    qk_cache.clear()

    if DEVICE.startswith("cuda") and ((batch_idx + 1) % EMPTY_CACHE_EVERY_BATCH == 0):
        torch.cuda.empty_cache()

ekfac_time_as = time.time() - start_time_as
ekfac.finalize_eigendecomposition()
if IS_MAIN_PROCESS:
    print(f"[{RANK}/{WORLD_SIZE}] ✓ 阶段1A完成，耗时: {ekfac_time_as:.2f}s")
    stage1A_payload = {
        'Q_A': [ekfac.Q_A.detach().cpu() if ekfac.Q_A is not None else None],
        'Q_S': [ekfac.Q_S.detach().cpu() if ekfac.Q_S is not None else None],
        'A_accum': [ekfac.A_accum.detach().cpu() if ekfac.A_accum is not None else None],
        'S_accum': [ekfac.S_accum.detach().cpu() if ekfac.S_accum is not None else None],
        'token_count': int(ekfac.token_count),
        'block_dims': ekfac.block_dims,
        'd_model': ekfac.d_model,
        'd_head': ekfac.d_head,
        'damping': ekfac.damping,
        'damping_alpha': ekfac.damping_alpha,
        'layer_idx': LAYER_IDX,
        'head_idx': HEAD_IDX,
        'seq_length': SEQ_LENGTH,
    }
    torch.save(stage1A_payload, STAGE1A_SAVE_PATH)
    print(f"[{RANK}/{WORLD_SIZE}] 已保存 Stage1A 结果到: {STAGE1A_SAVE_PATH}")

# ================ 阶段1B：全量二次遍历拟合 Λ（按序列数均值；QK） ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 阶段1B：全量二次遍历拟合 Λ（按序列数；QK-only）")
    print("="*70)

lambda_sum = torch.zeros(ekfac.block['d_in'], ekfac.block['d_out'], device=DEVICE, dtype=torch.float32)
weight_sum = torch.tensor(0.0, device=DEVICE, dtype=torch.float64)

start_time_lam = time.time()
for batch_idx, (batch_tokens, _) in enumerate(tqdm(train_as_dataloader, desc=f"[{RANK}] Fit Λ (full pass)", disable=not IS_MAIN_PROCESS)):
    qk_cache.clear()
    input_ids = batch_tokens[:, :SEQ_LENGTH].to(DEVICE, non_blocking=True)
    B = int(input_ids.shape[0])

    with model.no_sync():
        with torch.enable_grad():
            with autocast(dtype=DTYPE):
                logits = model(input_ids)
            sampled_labels = compute_pseudo_labels(logits)
            Vocab = logits.shape[-1]
            loss = F.cross_entropy(logits.reshape(-1, Vocab).float(), sampled_labels, reduction='sum')

            if (qk_cache.Q is None or qk_cache.K is None or qk_cache.X is None):
                raise RuntimeError("X/Q/K activations not captured in Λ pass.")

            grads_Q, grads_K = torch.autograd.grad(
                outputs=loss,
                inputs=[qk_cache.Q, qk_cache.K],
                retain_graph=False, create_graph=False, allow_unused=False
            )

    dQ = grads_Q[:, :, HEAD_IDX, :].reshape(-1, d_head).float().detach()
    dK = grads_K[:, :, HEAD_IDX, :].reshape(-1, d_head).float().detach()
    X_flat = qk_cache.X.reshape(-1, d_model).float().detach()

    dW0 = X_flat.t() @ torch.cat([dQ, dK], dim=-1)  # [d_model, 2*d_head]
    ge0 = ekfac.Q_A.t() @ dW0 @ ekfac.Q_S

    lambda_sum.add_(ge0.pow(2))
    weight_sum += float(B)

    qk_cache.clear()
    if DEVICE.startswith("cuda") and ((batch_idx + 1) % EMPTY_CACHE_EVERY_BATCH == 0):
        torch.cuda.empty_cache()

if dist.is_initialized():
    dist.all_reduce(lambda_sum, op=dist.ReduceOp.SUM)
    dist.all_reduce(weight_sum, op=dist.ReduceOp.SUM)

w_total = max(1.0, weight_sum.item())
ekfac.Lambda = (lambda_sum / w_total).flatten()

ekfac_time_lam = time.time() - start_time_lam
ekfac_time_total = ekfac_time_as + ekfac_time_lam

if IS_MAIN_PROCESS:
    lam = ekfac.Lambda
    print(f"[W_QK] Λ range: [{lam.min().item():.2e}, {lam.max().item():.2e}], mean={lam.mean().item():.2e}")
    print(f"[{RANK}/{WORLD_SIZE}] ✓ 阶段1B完成，耗时: {ekfac_time_lam:.2f}s")
    print(f"[{RANK}/{WORLD_SIZE}] ⏱️ 阶段1总耗时: {ekfac_time_total:.2f}s")
    stage1B_payload = {
        'Lambda': [ekfac.Lambda.detach().cpu()],
        'weight_sum': float(w_total),
        'block_dims': ekfac.block_dims,
        'd_model': ekfac.d_model,
        'd_head': ekfac.d_head,
        'damping': ekfac.damping,
        'damping_alpha': ekfac.damping_alpha,
        'layer_idx': LAYER_IDX,
        'head_idx': HEAD_IDX,
        'seq_length': SEQ_LENGTH,
    }
    torch.save(stage1B_payload, STAGE1B_SAVE_PATH)
    print(f"[{RANK}/{WORLD_SIZE}] 已保存 Stage1B 结果到: {STAGE1B_SAVE_PATH}")

# —— Remove Stage1 hooks to free memory ——
try:
    model.module.reset_hooks(hooks_qk)
except Exception:
    try:
        model.module.remove_all_hook_fns()
    except Exception:
        pass
hooks_qk = None
qk_cache.clear()
if DEVICE.startswith("cuda"):
    torch.cuda.empty_cache()

# ================ Probe v (previous-token attention) ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 预计算探针 v（sum_t A[t,t-1] at layer {LAYER_IDX}, head {HEAD_IDX}）与 p=H^-1 v")
    print("="*70)

def generate_random_sequences(num_seqs: int, seq_len: int, vocab_size: int, avoid_ids: Optional[set], seed: int) -> torch.Tensor:
    """
    Generate random token sequences of length `seq_len`, uniform over vocab, excluding avoid_ids.
    Uses a fixed CPU RNG seed for reproducibility; returns LongTensor [num_seqs, seq_len].
    """
    rng = np.random.default_rng(seed)
    allowed = np.array([i for i in range(vocab_size) if (avoid_ids is None or i not in avoid_ids)], dtype=np.int64)
    if allowed.size == 0:
        raise ValueError("Allowed vocab after exclusions is empty.")
    out = rng.choice(allowed, size=(num_seqs, seq_len), replace=True)
    return torch.from_numpy(out.astype(np.int64))

def compute_probe_prevattn_random(
    model_ddp: DDP,
    layer_idx: int,
    head_idx: int,
    probe_num_seqs: int,
    probe_seq_len: int,
    seed_base: int
) -> torch.Tensor:
    """
    Build v_qk = d/d(W_QK) sum_{seqs} sum_{t=1..S-1} A[t, t-1] at (layer_idx, head_idx).
    Returns tensor [d_model, 2*d_head] (float32).
    """
    # Accumulator (FP32)
    v_acc = torch.zeros(d_model, 2*d_head, device=DEVICE, dtype=torch.float32)

    # Per-rank allocation and seed
    local_target = int(np.ceil(probe_num_seqs / WORLD_SIZE))
    vocab = model_ddp.module.W_U.shape[1]
    avoid = set()
    if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
        avoid.add(int(tokenizer.pad_token_id))
    if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
        avoid.add(int(tokenizer.eos_token_id))
    seed_rank = seed_base + RANK  # distinct per rank, still reproducible

    # Generate random tokens on CPU for determinism; move per batch
    synth = generate_random_sequences(local_target, probe_seq_len, int(vocab), avoid, seed_rank)

    # Hook pattern
    pat_cache = PatternCache()
    hooks_pat = setup_pattern_hook(model_ddp.module, layer_idx, pat_cache)

    for b in range(synth.shape[0]):
        pat_cache.clear()
        tokens = synth[b:b+1].to(DEVICE, non_blocking=True)  # [1, S]

        with torch.enable_grad():
            with autocast(dtype=DTYPE):
                _ = model_ddp(tokens)

        if pat_cache.pattern is None:
            try: model_ddp.module.reset_hooks(hooks_pat)
            except Exception: pass
            raise RuntimeError("hook_pattern missing.")

        # pattern: [1, n_heads, S, S], take sum of offset -1 diagonal for the chosen head
        head_pat = pat_cache.pattern[0, head_idx]  # [S, S]
        f = torch.diagonal(head_pat, offset=-1).sum()

        attn = model_ddp.module.blocks[layer_idx].attn
        grads = torch.autograd.grad(
            outputs=f,
            inputs=[attn.W_Q, attn.W_K],
            retain_graph=False, create_graph=False, allow_unused=False
        )
        gQ = grads[0][head_idx].float()  # [d_model, d_head]
        gK = grads[1][head_idx].float()  # [d_model, d_head]
        v_acc.add_(torch.cat([gQ, gK], dim=-1))

        pat_cache.clear()

    # Remove hook
    try: model_ddp.module.reset_hooks(hooks_pat)
    except Exception:
        try: model_ddp.module.remove_all_hook_fns()
        except Exception: pass

    # All-reduce sum across ranks
    if dist.is_initialized():
        dist.all_reduce(v_acc, op=dist.ReduceOp.SUM)

    return v_acc

# Compute probe v (QK only, attention-based)
v_qk = compute_probe_prevattn_random(
    model, LAYER_IDX, HEAD_IDX, PROBE_NUM_SEQS, PROBE_LOCAL_SEQ_LEN, PROBE_RANDOM_SEED
)

# p = H^{-1} v (QK only)
p_qk = ekfac.inverse_hvp(v_qk)

if DEVICE.startswith("cuda"):
    torch.cuda.empty_cache()

if IS_MAIN_PROCESS:
    print(f"[{RANK}] ✓ p 就绪：||p_qk||={torch.norm(p_qk, p='fro').item():.6f}；进入阶段2。")

# ================ 阶段2：方向影响分数（score = −<g_qk, p_qk>） ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 阶段2：方向影响分数（每样本 ∇_θℓ 到 W_Q/W_K）")
    print("="*70)

def compute_sample_grads_qkonly(npy_batch: torch.Tensor, model_ddp: DDP) -> Tuple[torch.Tensor, float]:
    input_ids = npy_batch[:, :SEQ_LENGTH].to(DEVICE, non_blocking=True)
    labels    = npy_batch[:, 1:SEQ_LENGTH + 1].to(DEVICE, non_blocking=True)

    with torch.enable_grad():
        with model_ddp.no_sync():
            with autocast(dtype=DTYPE):
                logits = model_ddp(input_ids)
        Vocab = logits.shape[-1]
        loss = F.cross_entropy(logits.reshape(-1, Vocab).float(), labels.reshape(-1), reduction='sum')

        attn = model_ddp.module.blocks[LAYER_IDX].attn
        grads = torch.autograd.grad(
            outputs=loss,
            inputs=[attn.W_Q, attn.W_K],
            retain_graph=False, create_graph=False, allow_unused=False
        )
        gQ = grads[0][HEAD_IDX].float()
        gK = grads[1][HEAD_IDX].float()
        g0 = torch.cat([gQ, gK], dim=-1)  # [d_model, 2*d_head]

    return g0, float(loss.item())

local_pos_heap: List[Tuple[float, int, Tuple[int, float]]] = []
local_neg_heap: List[Tuple[float, int, Tuple[int, float]]] = []
heap_cap = TOP_K_RESULTS
uid_counter = 0
processed_samples = 0

start_time_phase2 = time.time()
for sample_idx, (sample_batch_tokens, sample_indices) in enumerate(tqdm(train_influence_dataloader, desc=f"[{RANK}] Scoring", disable=not IS_MAIN_PROCESS)):
    processed_samples += 1

    g0, loss_val = compute_sample_grads_qkonly(sample_batch_tokens, model)
    s0 = torch.sum(g0 * p_qk).item()
    projection_score = -(s0)

    payload_compact = (int(sample_indices.item()), float(loss_val))

    entry_pos = (projection_score, uid_counter, payload_compact)
    if len(local_pos_heap) < heap_cap:
        heapq.heappush(local_pos_heap, entry_pos)
    else:
        if projection_score > local_pos_heap[0][0]:
            heapq.heapreplace(local_pos_heap, entry_pos)

    neg_score = -projection_score
    entry_neg = (neg_score, uid_counter, payload_compact)
    if len(local_neg_heap) < heap_cap:
        heapq.heappush(local_neg_heap, entry_neg)
    else:
        if neg_score > local_neg_heap[0][0]:
            heapq.heapreplace(local_neg_heap, entry_neg)

    uid_counter += 1

    if DEVICE.startswith("cuda") and (processed_samples % EMPTY_CACHE_EVERY_SAMPLE == 0):
        torch.cuda.empty_cache()

phase2_time = time.time() - start_time_phase2

# ================ 聚合与保存 ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 阶段3：Top-K 聚合与保存")
    print("="*70)

pos_matrix = gather_heap_as_tensors(local_pos_heap, WORLD_SIZE, DEVICE)
neg_matrix = gather_heap_as_tensors(local_neg_heap, WORLD_SIZE, DEVICE)

if dist.is_initialized():
    total_samples_tensor = torch.tensor([processed_samples], device=DEVICE, dtype=torch.long)
    dist.reduce(total_samples_tensor, dst=0, op=dist.ReduceOp.SUM)
else:
    total_samples_tensor = torch.tensor([processed_samples], device=DEVICE, dtype=torch.long)

if RANK == 0:
    if pos_matrix is not None and pos_matrix.shape[0] > 0:
        indices_pos = np.argsort(pos_matrix[:, 0])[::-1]
        top_k_pos_indices = indices_pos[:TOP_K_RESULTS]
        top_pos_data = pos_matrix[top_k_pos_indices]
    else:
        top_pos_data = np.zeros((0, 3))

    if neg_matrix is not None and neg_matrix.shape[0] > 0:
        indices_neg = np.argsort(neg_matrix[:, 0])[::-1]
        top_k_neg_indices = indices_neg[:TOP_K_RESULTS]
        top_neg_data = neg_matrix[top_k_neg_indices]
    else:
        top_neg_data = np.zeros((0, 3))

    final_pos_topk = []
    for row in top_pos_data:
        final_pos_topk.append({
            'projection_score': float(row[0]),
            'sample_index': int(row[1]),
            'sum_loss': float(row[2])
        })

    final_neg_topk = []
    for row in top_neg_data:
        final_neg_topk.append({
            'projection_score': -float(row[0]),
            'sample_index': int(row[1]),
            'sum_loss': float(row[2])
        })
    final_neg_topk.sort(key=lambda x: x['projection_score'])

    total_samples = int(total_samples_tensor.item())
    throughput = total_samples / phase2_time if phase2_time > 0 else 0.0

    print(f"✅ 阶段2完成。Time: {phase2_time:.2f}s")
    print(f"Global Throughput: {throughput:.2f} samples/s (Total {total_samples})")

    ekfac_params = {
        'Q_A': [ekfac.Q_A.detach().cpu()],
        'Q_S': [ekfac.Q_S.detach().cpu()],
        'Lambda': [ekfac.Lambda.detach().cpu()],
        'damping': ekfac.damping,
        'damping_alpha': ekfac.damping_alpha,
        'd_model': ekfac.d_model,
        'd_head': ekfac.d_head,
        'param_names': ['W_QK'],
        'block_dims': ekfac.block_dims,
    }

    analysis_results = {
        'config': {
            'MODEL_ALIAS': MODEL_ALIAS,
            'DATA_NPY_PATH': DATA_NPY_PATH,
            'TARGET': 'Previous-Token Attention Directional Influence (QK-only EK-FAC)',
            'TOP_K': TOP_K_RESULTS,
            'NUM_TRAIN_SAMPLES': NUM_TRAIN_SAMPLES,
            'LAYER_IDX': LAYER_IDX,
            'HEAD_IDX': HEAD_IDX,
            'SEQ_LENGTH': SEQ_LENGTH,
            'WORLD_SIZE': WORLD_SIZE,
            'PRECISION': 'BF16_FP32',
            'DAMPING': DAMPING,
            'DAMPING_ALPHA': DAMPING_ALPHA,
            'PROBE_SOURCE': 'random-synthetic',
            'PROBE_NUM_SEQS': PROBE_NUM_SEQS,
            'PROBE_LOCAL_SEQ_LEN': PROBE_LOCAL_SEQ_LEN,
            'PROBE_RANDOM_SEED': PROBE_RANDOM_SEED,
        },
        'performance': {
            'phase1_time_total': ekfac_time_total,
            'phase2_time': phase2_time,
            'throughput_samples_per_s': throughput,
            'total_samples_scored': total_samples,
        },
        'ekfac_params': ekfac_params,
        'positive_influencers': final_pos_topk,
        'negative_influencers': final_neg_topk,
    }

    with open(OUTPUT_RESULTS_FILE, 'wb') as f:
        pickle.dump(analysis_results, f)
    print(f"结果已保存到: {OUTPUT_RESULTS_FILE}")

if dist.is_initialized():
    dist.destroy_process_group()

print(f"[{RANK}/{WORLD_SIZE}] ✅ 程序执行完毕。")
