import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import time
import pickle
import heapq
import random
from typing import List, Tuple, Dict, Any

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast
from tqdm import tqdm

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import GPTNeoXTokenizerFast
from transformer_lens import HookedTransformer
from transformer_lens.loading_from_pretrained import (
    OFFICIAL_MODEL_NAMES, MODEL_ALIASES, make_model_alias_map
)

# ================ Precision ================
DTYPE = torch.bfloat16

# ================ Distributed ================
def setup_distributed():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        gpu = int(os.environ["LOCAL_RANK"])
        print(f"[{os.getpid()}] Rank {rank}/{world_size} on GPU {gpu} init.")
        dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
        torch.cuda.set_device(gpu)
        return rank, world_size, gpu
    else:
        print("Not running in distributed mode. Using single process.")
        return 0, 1, 0

RANK, WORLD_SIZE, LOCAL_RANK = setup_distributed()
IS_MAIN_PROCESS = (RANK == 0)
DEVICE = f"cuda:{LOCAL_RANK}" if torch.cuda.is_available() else "cpu"

# ================ Config ================
MODEL_LOCAL_PATH = "/root/trainbin1/nomask160m800"
MODEL_ALIAS = "Pythia-160mlocal"
DATA_NPY_PATH = "/root/indicies0-999.npy"  # (N, 2049)

NUM_TRAIN_SAMPLES = 8192
LAYER_IDX = 5
HEAD_IDX = 10
SEQ_LENGTH = 2048
BATCH_SIZE_FOR_EKFAC = 4
BATCH_SIZE_FOR_INFLUENCE = 1
NUM_WORKERS = 2
if BATCH_SIZE_FOR_INFLUENCE != 1:
    raise ValueError("BATCH_SIZE_FOR_INFLUENCE must be 1.")

# EK-FAC 阻尼
DAMPING = 1e-5
DAMPING_ALPHA = 0.1
PROBE_LOCAL_SEQ_LEN=64

# empty_cache 频率
EMPTY_CACHE_EVERY_BATCH = 100
EMPTY_CACHE_EVERY_SAMPLE = 500

# Top-K
TOP_K_RESULTS = 1536

# 探针设置（copy-target）
PROBE_SOURCE = "synthetic"       # "synthetic" | "dataset"
PROBE_NUM_SEQS = 32             # 探针序列数（总数，跨 rank 平均）
INDUCTION_MATCH = "current"      # "current"（tokens[t]） | "previous"（tokens[t-1]）
MATCH_CHOICE = "last"            # "last" | "first"
SYN_ANCHORS_PER_SEQ = 12         # 合成每条序列嵌入的 (A,B) 模式个数
SYN_GAP_MIN = 5                  # 第二个 A 与 (A,B) 之间最小间隔
SYN_GAP_MAX = 200                # 最大间隔

# 输出
OUTPUT_RESULTS_FILE = "/root/160m800/influence_copytarget_qkvo124.pkl"

# 阶段中间结果落盘（避免白跑）
STAGE1A_SAVE_PATH = "/root/160m800/ekfac_stage1A_qkvo124.pt"
STAGE1B_SAVE_PATH = "/root/160m800/ekfac_stage1B_qkvo124.pt"

if IS_MAIN_PROCESS:
    print("🎯 QK(合并)+V+O EK-FAC（全量Λ） + copy-target 方向影响分数")
    print(f"L={LAYER_IDX}, H={HEAD_IDX}；数据: {DATA_NPY_PATH}")
    print(f"阶段1 batch={BATCH_SIZE_FOR_EKFAC}, 阶段2 batch=1；DTYPE={DTYPE}；Top-K={TOP_K_RESULTS}")
    print(f"Damping: {DAMPING}, alpha={DAMPING_ALPHA}")
    print(f"empty_cache：阶段1每 {EMPTY_CACHE_EVERY_BATCH} 个batch；阶段2每 {EMPTY_CACHE_EVERY_SAMPLE} 个样本")
    print(f"Probe: source={PROBE_SOURCE}, PROBE_NUM_SEQS={PROBE_NUM_SEQS}, match={INDUCTION_MATCH}/{MATCH_CHOICE}, anchors/seq={SYN_ANCHORS_PER_SEQ}")

# ================ Load Tokenizer & Model ================
OFFICIAL_MODEL_NAMES.append(MODEL_LOCAL_PATH)
MODEL_ALIASES[MODEL_LOCAL_PATH] = [MODEL_ALIAS]
make_model_alias_map()

tokenizer = GPTNeoXTokenizerFast.from_pretrained(MODEL_LOCAL_PATH, local_files_only=True)
if IS_MAIN_PROCESS:
    print(f"[{RANK}/{WORLD_SIZE}] ✅ Tokenizer loaded, vocab={tokenizer.vocab_size}")

model_local = HookedTransformer.from_pretrained_no_processing(
    MODEL_ALIAS,
    n_devices=1,
    dtype=DTYPE,
    device=DEVICE
)
model_local.cfg.use_attn_result = True
model_local.to(DEVICE)
model_local.eval()

# 冻结参数；开启该层 Q/K/V/O（阶段2对参数求导；阶段1对中间量求导）
for p in model_local.parameters():
    p.requires_grad = False
attn_layer = model_local.blocks[LAYER_IDX].attn
attn_layer.W_Q.requires_grad = True
attn_layer.W_K.requires_grad = True
attn_layer.W_V.requires_grad = True
attn_layer.W_O.requires_grad = True

# DDP
model = DDP(
    model_local,
    device_ids=[LOCAL_RANK],
    output_device=LOCAL_RANK,
    find_unused_parameters=False,
    bucket_cap_mb=25
)

d_model = model.module.cfg.d_model
d_head = model.module.cfg.d_head
n_heads = model.module.cfg.n_heads
if IS_MAIN_PROCESS:
    print(f"[{RANK}/{WORLD_SIZE}] 模型: d_model={d_model}, d_head={d_head}, n_heads={n_heads}, n_ctx={model.module.cfg.n_ctx}")

# ================ Data (NPY) ================
class NpyDataset(Dataset):
    def __init__(self, npy_path: str, indices: List[int]):
        self.arr = np.load(npy_path, mmap_mode="r")
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        row_idx = self.indices[idx]
        x = self.arr[row_idx]
        return torch.from_numpy(x.astype(np.int64)), row_idx

def collate_npy(batch):
    tokens = [item[0] for item in batch]
    idxs = [item[1] for item in batch]
    return torch.stack(tokens, dim=0), torch.tensor(idxs, dtype=torch.long)

def build_dataloaders(
    npy_path: str,
    num_train_samples: int,
    batch_size_ekfac: int,
    batch_size_influence: int,
    rank: int,
    world_size: int
):
    if IS_MAIN_PROCESS:
        print(f"[{RANK}/{WORLD_SIZE}] 从 NPY 加载数据: {npy_path}")
    arr = np.load(npy_path, mmap_mode="r")
    total_samples, total_len = arr.shape
    if total_len != 2049:
        raise ValueError(f"NPY 维度应为 (N, 2049)，实际: (N, {total_len})")
    total_to_use = min(num_train_samples, total_samples)
    all_indices = list(range(total_to_use))

    ds_ekfac = NpyDataset(npy_path, all_indices)
    ds_infl = NpyDataset(npy_path, all_indices)

    sampler_ekfac = DistributedSampler(
        ds_ekfac, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False
    )
    sampler_infl = DistributedSampler(
        ds_infl, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False
    )

    dl_ekfac = DataLoader(
        ds_ekfac, batch_size=batch_size_ekfac, sampler=sampler_ekfac,
        collate_fn=collate_npy, num_workers=NUM_WORKERS, pin_memory=True
    )
    dl_infl = DataLoader(
        ds_infl, batch_size=batch_size_influence, sampler=sampler_infl,
        collate_fn=collate_npy, num_workers=NUM_WORKERS, pin_memory=True
    )
    return dl_ekfac, dl_infl

# ================ 前向钩子缓存（不 detach） ================
class QKVOActivationCache:
    def __init__(self):
        self.X = None        # [B,S,d_model]
        self.Q = None        # [B,S,n_heads,d_head]
        self.K = None        # [B,S,n_heads,d_head]
        self.V = None        # [B,S,n_heads,d_head]
        self.Z = None        # [B,S,n_heads,d_head]
        self.result = None   # [B,S,n_heads,d_model]

    def clear(self):
        self.X = None; self.Q = None; self.K = None
        self.V = None; self.Z = None; self.result = None

def setup_qkvo_forward_hooks(model_instance, layer_idx: int, cache: QKVOActivationCache):
    hooks = []
    def h_ln(act, hook): cache.X = act; return act
    def h_q(act, hook): cache.Q = act; return act
    def h_k(act, hook): cache.K = act; return act
    def h_v(act, hook): cache.V = act; return act
    def h_z(act, hook): cache.Z = act; return act
    def h_result(act, hook): cache.result = act; return act

    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.ln1.hook_normalized", h_ln, dir='fwd'))
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.attn.hook_q", h_q, dir='fwd'))
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.attn.hook_k", h_k, dir='fwd'))
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.attn.hook_v", h_v, dir='fwd'))
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.attn.hook_z", h_z, dir='fwd'))
    hooks.append(model_instance.add_hook(f"blocks.{layer_idx}.attn.hook_result", h_result, dir='fwd'))

    if IS_MAIN_PROCESS:
        print(f"[{RANK}/{WORLD_SIZE}] ✅ 前向钩子：X/Q/K/V/Z/result")
    return [h for h in hooks if h is not None]

# ================ EK-FAC（QK 合并 + V + O） ================
class EKFAC_QKVO_Head:
    """
    三个块：
      0: W_QK -> d_in=d_model, d_out=2*d_head
      1: W_V  -> d_in=d_model, d_out=d_head
      2: W_O  -> d_in=d_head,  d_out=d_model
    A/S 为 token 期望；Λ 为“序列级伪参数梯度”的平方期望，按序列数均值。
    """
    def __init__(self, d_model: int, d_head: int, damping: float = 1e-5, damping_alpha: float = 0.1):
        self.d_model = d_model
        self.d_head = d_head
        self.blocks = [
            {'name': 'W_QK', 'd_in': d_model, 'd_out': 2 * d_head},
            {'name': 'W_V',  'd_in': d_model, 'd_out': d_head},
            {'name': 'W_O',  'd_in': d_head,  'd_out': d_model},
        ]
        self.damping = damping
        self.damping_alpha = damping_alpha

        self.A_accum = [None, None, None]
        self.S_accum = [None, None, None]
        self.token_count = 0

        self.Q_A = [None, None, None]
        self.Q_S = [None, None, None]
        self.Lambda = [None, None, None]

    def accumulate_A_S(self, X_flat_f32, dQ_f32, dK_f32, dV_f32, Z_flat_f32, dR_f32):
        BStok = X_flat_f32.shape[0] if X_flat_f32 is not None else (Z_flat_f32.shape[0] if Z_flat_f32 is not None else 0)
        if BStok == 0:
            return
        
        # ============ 关键修改：添加 .detach() ============
        X = X_flat_f32.detach()
        Z = Z_flat_f32.detach()
        dQ = dQ_f32.detach()
        dK = dK_f32.detach()
        dV = dV_f32.detach()
        dR = dR_f32.detach()
        # ===============================================

        # Block 0: QK
        G_qk = torch.cat([dQ, dK], dim=-1)  # [B*S, 2*d_head]
        A0 = X.t() @ X
        S0 = G_qk.t() @ G_qk
        if self.A_accum[0] is None:
            self.A_accum[0] = A0; self.S_accum[0] = S0
        else:
            self.A_accum[0].add_(A0); self.S_accum[0].add_(S0)

        # Block 1: V
        A1 = X.t() @ X
        S1 = dV.t() @ dV
        if self.A_accum[1] is None:
            self.A_accum[1] = A1; self.S_accum[1] = S1
        else:
            self.A_accum[1].add_(A1); self.S_accum[1].add_(S1)

        # Block 2: O
        A2 = Z.t() @ Z
        S2 = dR.t() @ dR
        if self.A_accum[2] is None:
            self.A_accum[2] = A2; self.S_accum[2] = S2
        else:
            self.A_accum[2].add_(A2); self.S_accum[2].add_(S2)

        self.token_count += int(BStok)

    def finalize_eigendecomposition(self):
        if self.token_count == 0:
            raise RuntimeError("No tokens accumulated for EK-FAC A/S.")

        if dist.is_initialized():
            for i in range(3):
                dist.all_reduce(self.A_accum[i], op=dist.ReduceOp.SUM)
                dist.all_reduce(self.S_accum[i], op=dist.ReduceOp.SUM)
            tok_t = torch.tensor([self.token_count], device=DEVICE, dtype=torch.long)
            dist.all_reduce(tok_t, op=dist.ReduceOp.SUM)
            global_tokens = tok_t.item()
        else:
            global_tokens = self.token_count

        for i in range(3):
            A = self.A_accum[i] / float(global_tokens)
            S = self.S_accum[i] / float(global_tokens)
            A = 0.5 * (A + A.t()); S = 0.5 * (S + S.t())
            eps_A = (1e-6 * torch.trace(A).abs() / A.shape[0]).to(A.dtype)
            eps_S = (1e-6 * torch.trace(S).abs() / S.shape[0]).to(S.dtype)
            A = A + eps_A * torch.eye(A.shape[0], device=A.device, dtype=A.dtype)
            S = S + eps_S * torch.eye(S.shape[0], device=S.device, dtype=S.dtype)

            eigvals_A, self.Q_A[i] = torch.linalg.eigh(A.float())
            eigvals_S, self.Q_S[i] = torch.linalg.eigh(S.float())
            if IS_MAIN_PROCESS:
                print(f"[Block {i} {self.blocks[i]['name']}] A eig [{eigvals_A.min():.2e}, {eigvals_A.max():.2e}] "
                      f"S eig [{eigvals_S.min():.2e}, {eigvals_S.max():.2e}]")

    def inverse_hvp(self, block_idx: int, grad_matrix: torch.Tensor) -> torch.Tensor:
        assert 0 <= block_idx < 3
        if self.Q_A[block_idx] is None or self.Q_S[block_idx] is None or self.Lambda[block_idx] is None:
            raise RuntimeError("EK-FAC not ready for block %d" % block_idx)
        G = grad_matrix.float()
        QA, QS = self.Q_A[block_idx], self.Q_S[block_idx]
        ge = QA.t() @ G @ QS
        lam = self.Lambda[block_idx]
        denom = lam + self.damping_alpha * lam.mean()
        denom = torch.clamp(denom, min=self.damping)
        ihvp_eig_flat = ge.flatten() / denom
        d_in = self.blocks[block_idx]['d_in']
        d_out = self.blocks[block_idx]['d_out']
        ihvp = QA @ ihvp_eig_flat.reshape(d_in, d_out) @ QS.t()
        return ihvp

    @property
    def block_dims(self) -> List[Dict[str, int]]:
        return [{'d_in': b['d_in'], 'd_out': b['d_out']} for b in self.blocks]

# ================ 工具 ================
def compute_pseudo_labels(logits: torch.Tensor) -> torch.Tensor:
    V = logits.shape[-1]
    probs_flat = torch.softmax(logits.reshape(-1, V).float(), dim=-1)
    sampled = torch.multinomial(probs_flat, num_samples=1).squeeze(-1)
    return sampled

def gather_heap_as_tensors(heap_data, world_size, device):
    data_list = []
    for score, _, (idx, loss) in heap_data:
        data_list.append([float(score), float(int(idx)), float(loss)])
    if len(data_list) == 0:
        tensor_local = torch.zeros(0, 3, device=device, dtype=torch.float64)
    else:
        tensor_local = torch.tensor(data_list, device=device, dtype=torch.float64)

    if not dist.is_initialized():
        return tensor_local.cpu().numpy()

    size_local = torch.tensor([tensor_local.shape[0]], device=device, dtype=torch.long)
    sizes_gather = [torch.zeros_like(size_local) for _ in range(world_size)]
    dist.all_gather(sizes_gather, size_local)

    max_size = max([s.item() for s in sizes_gather])
    pad_size = max_size - tensor_local.shape[0]
    if pad_size > 0:
        padding = torch.full((pad_size, 3), -1e18, device=device, dtype=torch.float64)
        tensor_padded = torch.cat([tensor_local, padding], dim=0)
    else:
        tensor_padded = tensor_local

    gathered_tensors = [torch.zeros_like(tensor_padded) for _ in range(world_size)]
    dist.all_gather(gathered_tensors, tensor_padded)

    if RANK == 0:
        all_rows = []
        for i, t in enumerate(gathered_tensors):
            valid_len = sizes_gather[i].item()
            if valid_len > 0:
                all_rows.append(t[:valid_len])
        if all_rows:
            return torch.cat(all_rows, dim=0).cpu().numpy()
        else:
            return np.zeros((0, 3))
    return None

# =============== DataLoaders ===============
train_as_dataloader, train_influence_dataloader = build_dataloaders(
    DATA_NPY_PATH,
    NUM_TRAIN_SAMPLES,
    BATCH_SIZE_FOR_EKFAC,
    BATCH_SIZE_FOR_INFLUENCE,
    RANK, WORLD_SIZE
)

if IS_MAIN_PROCESS:
    print(f"[{RANK}/{WORLD_SIZE}] 数据加载完成：EK-FAC批数={len(train_as_dataloader)}, 影响评估批数={len(train_influence_dataloader)}")

# ================ 阶段1：A/S 累积（QK+V+O；token 期望） ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 阶段1A：累积 A/S（autograd.grad 到 Q/K/V/result；token 期望）")
    print("="*70)

ekfac = EKFAC_QKVO_Head(d_model, d_head, damping=DAMPING, damping_alpha=DAMPING_ALPHA)
qkvo_cache = QKVOActivationCache()
hooks_qkvo = setup_qkvo_forward_hooks(model.module, LAYER_IDX, qkvo_cache)

start_time_as = time.time()
for batch_idx, (batch_tokens, _) in enumerate(tqdm(train_as_dataloader, desc=f"[{RANK}] Accum A/S", disable=not IS_MAIN_PROCESS)):
    qkvo_cache.clear()
    input_ids = batch_tokens[:, :SEQ_LENGTH].to(DEVICE, non_blocking=True)
    B = int(input_ids.shape[0])

    with model.no_sync(): 
        with torch.enable_grad():
            with autocast(dtype=DTYPE):
                logits = model(input_ids)
            sampled_labels = compute_pseudo_labels(logits)
            Vocab = logits.shape[-1]
            loss = F.cross_entropy(logits.reshape(-1, Vocab).float(), sampled_labels, reduction='sum')

            if (qkvo_cache.Q is None or qkvo_cache.K is None or qkvo_cache.V is None or
                qkvo_cache.Z is None or qkvo_cache.result is None or qkvo_cache.X is None):
                raise RuntimeError("X/Q/K/V/Z/result activations not captured.")

            grads_Q, grads_K, grads_V, grads_R = torch.autograd.grad(
                outputs=loss,
                inputs=[qkvo_cache.Q, qkvo_cache.K, qkvo_cache.V, qkvo_cache.result],
                retain_graph=False, create_graph=False, allow_unused=False
            )

    dQ = grads_Q[:, :, HEAD_IDX, :].reshape(-1, d_head).float()
    dK = grads_K[:, :, HEAD_IDX, :].reshape(-1, d_head).float()
    dV = grads_V[:, :, HEAD_IDX, :].reshape(-1, d_head).float()
    dR = grads_R[:, :, HEAD_IDX, :].reshape(-1, d_model).float()
    X_flat = qkvo_cache.X.reshape(-1, d_model).float()
    Z_flat = qkvo_cache.Z[:, :, HEAD_IDX, :].reshape(-1, d_head).float()

    ekfac.accumulate_A_S(X_flat, dQ, dK, dV, Z_flat, dR)
    qkvo_cache.clear()

    if DEVICE.startswith("cuda") and ((batch_idx + 1) % EMPTY_CACHE_EVERY_BATCH == 0):
        torch.cuda.empty_cache()

ekfac_time_as = time.time() - start_time_as
ekfac.finalize_eigendecomposition()
if IS_MAIN_PROCESS:
    print(f"[{RANK}/{WORLD_SIZE}] ✓ 阶段1A完成，耗时: {ekfac_time_as:.2f}s")
    # Stage1A 立即落盘（避免白跑）
    stage1A_payload = {
        'Q_A': [q.detach().cpu() if q is not None else None for q in ekfac.Q_A],
        'Q_S': [q.detach().cpu() if q is not None else None for q in ekfac.Q_S],
        'A_accum': [m.detach().cpu() if m is not None else None for m in ekfac.A_accum],
        'S_accum': [m.detach().cpu() if m is not None else None for m in ekfac.S_accum],
        'token_count': int(ekfac.token_count),
        'block_dims': ekfac.block_dims,
        'd_model': ekfac.d_model,
        'd_head': ekfac.d_head,
        'damping': ekfac.damping,
        'damping_alpha': ekfac.damping_alpha,
        'layer_idx': LAYER_IDX,
        'head_idx': HEAD_IDX,
        'seq_length': SEQ_LENGTH,
    }
    torch.save(stage1A_payload, STAGE1A_SAVE_PATH)
    print(f"[{RANK}/{WORLD_SIZE}] 已保存 Stage1A 结果到: {STAGE1A_SAVE_PATH}")

# ================ 阶段1B：全量二次遍历拟合 Λ（按序列数均值；QK+V+O） ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 阶段1B：全量二次遍历拟合 Λ（按序列数）")
    print("="*70)

lambda_sum = [
    torch.zeros(ekfac.blocks[0]['d_in'], ekfac.blocks[0]['d_out'], device=DEVICE, dtype=torch.float32),
    torch.zeros(ekfac.blocks[1]['d_in'], ekfac.blocks[1]['d_out'], device=DEVICE, dtype=torch.float32),
    torch.zeros(ekfac.blocks[2]['d_in'], ekfac.blocks[2]['d_out'], device=DEVICE, dtype=torch.float32),
]
weight_sum = torch.tensor(0.0, device=DEVICE, dtype=torch.float64)

start_time_lam = time.time()
for batch_idx, (batch_tokens, _) in enumerate(tqdm(train_as_dataloader, desc=f"[{RANK}] Fit Λ (full pass)", disable=not IS_MAIN_PROCESS)):
    qkvo_cache.clear()
    input_ids = batch_tokens[:, :SEQ_LENGTH].to(DEVICE, non_blocking=True)
    B = int(input_ids.shape[0])

    with model.no_sync():
        with torch.enable_grad():
            with autocast(dtype=DTYPE):
                logits = model(input_ids)
            sampled_labels = compute_pseudo_labels(logits)
            Vocab = logits.shape[-1]
            loss = F.cross_entropy(logits.reshape(-1, Vocab).float(), sampled_labels, reduction='sum')

            if (qkvo_cache.Q is None or qkvo_cache.K is None or qkvo_cache.V is None or
                qkvo_cache.Z is None or qkvo_cache.result is None or qkvo_cache.X is None):
                raise RuntimeError("X/Q/K/V/Z/result activations not captured in Λ pass.")

            grads_Q, grads_K, grads_V, grads_R = torch.autograd.grad(
                outputs=loss,
                inputs=[qkvo_cache.Q, qkvo_cache.K, qkvo_cache.V, qkvo_cache.result],
                retain_graph=False, create_graph=False, allow_unused=False
            )

    dQ = grads_Q[:, :, HEAD_IDX, :].reshape(-1, d_head).float().detach()
    dK = grads_K[:, :, HEAD_IDX, :].reshape(-1, d_head).float().detach()
    dV = grads_V[:, :, HEAD_IDX, :].reshape(-1, d_head).float().detach()
    dR = grads_R[:, :, HEAD_IDX, :].reshape(-1, d_model).float().detach()
    X_flat = qkvo_cache.X.reshape(-1, d_model).float().detach()
    Z_flat = qkvo_cache.Z[:, :, HEAD_IDX, :].reshape(-1, d_head).float().detach()

    dW0 = X_flat.t() @ torch.cat([dQ, dK], dim=-1)   # [d_model, 2*d_head]
    dW1 = X_flat.t() @ dV                            # [d_model, d_head]
    dW2 = Z_flat.t() @ dR                            # [d_head,  d_model]

    ge0 = ekfac.Q_A[0].t() @ dW0 @ ekfac.Q_S[0]
    ge1 = ekfac.Q_A[1].t() @ dW1 @ ekfac.Q_S[1]
    ge2 = ekfac.Q_A[2].t() @ dW2 @ ekfac.Q_S[2]

    lambda_sum[0].add_(ge0.pow(2))
    lambda_sum[1].add_(ge1.pow(2))
    lambda_sum[2].add_(ge2.pow(2))
    weight_sum += float(B)

    qkvo_cache.clear()
    if DEVICE.startswith("cuda") and ((batch_idx + 1) % EMPTY_CACHE_EVERY_BATCH == 0):
        torch.cuda.empty_cache()

if dist.is_initialized():
    for i in range(3):
        dist.all_reduce(lambda_sum[i], op=dist.ReduceOp.SUM)
    dist.all_reduce(weight_sum, op=dist.ReduceOp.SUM)

w_total = max(1.0, weight_sum.item())
ekfac.Lambda[0] = (lambda_sum[0] / w_total).flatten()
ekfac.Lambda[1] = (lambda_sum[1] / w_total).flatten()
ekfac.Lambda[2] = (lambda_sum[2] / w_total).flatten()

ekfac_time_lam = time.time() - start_time_lam
ekfac_time_total = ekfac_time_as + ekfac_time_lam

if IS_MAIN_PROCESS:
    for i in range(3):
        lam = ekfac.Lambda[i]
        print(f"[Block {i} {ekfac.blocks[i]['name']}] Λ range: [{lam.min().item():.2e}, {lam.max().item():.2e}], mean={lam.mean().item():.2e}")
    print(f"[{RANK}/{WORLD_SIZE}] ✓ 阶段1B完成，耗时: {ekfac_time_lam:.2f}s")
    print(f"[{RANK}/{WORLD_SIZE}] ⏱️ 阶段1总耗时: {ekfac_time_total:.2f}s")
    # Stage1B 立即落盘
    stage1B_payload = {
        'Lambda': [ekfac.Lambda[0].detach().cpu(), ekfac.Lambda[1].detach().cpu(), ekfac.Lambda[2].detach().cpu()],
        'weight_sum': float(w_total),
        'block_dims': ekfac.block_dims,
        'd_model': ekfac.d_model,
        'd_head': ekfac.d_head,
        'damping': ekfac.damping,
        'damping_alpha': ekfac.damping_alpha,
        'layer_idx': LAYER_IDX,
        'head_idx': HEAD_IDX,
        'seq_length': SEQ_LENGTH,
    }
    torch.save(stage1B_payload, STAGE1B_SAVE_PATH)
    print(f"[{RANK}/{WORLD_SIZE}] 已保存 Stage1B 结果到: {STAGE1B_SAVE_PATH}")

# —— 关键：移除 Stage1 残留 hooks，避免探针/Stage2 继续缓存大激活 —— 
try:
    model.module.reset_hooks(hooks_qkvo)
except Exception:
    try:
        model.module.remove_all_hook_fns()
    except Exception:
        pass
hooks_qkvo = None
qkvo_cache.clear()
if DEVICE.startswith("cuda"):
    torch.cuda.empty_cache()

# ================ 探针 v（copy-target）：synthetic / dataset 两种来源 ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 预计算探针 v（copy-target）与 p=H^-1 v")
    print("="*70)

def find_match_p(tokens_1d: torch.Tensor, t: int) -> int:
    if INDUCTION_MATCH == "previous":
        if t == 0:
            return -1
        key = int(tokens_1d[t-1].item())
        left = tokens_1d[:t-1]
    else:  # "current"
        key = int(tokens_1d[t].item())
        left = tokens_1d[:t]
    pos = (left == key).nonzero(as_tuple=True)[0]
    if pos.numel() == 0:
        return -1
    return int(pos[-1].item() if MATCH_CHOICE == "last" else pos[0].item())

def make_induction_sequences(
    num_seqs: int,
    seq_len: int,
    vocab_size: int,
    avoid_ids: set = None,
    anchors_per_seq: int = 0, 
    gap_min: int = 0,        
    gap_max: int = 0,         
    induction_match: str = "current"
) -> torch.Tensor:
    """
    【早期模型专用版】短周期循环重复
    模式：[Block_32] + [Block_32] + [Block_32] ...
    优势：Gap 只有 32，早期 Attention 能够得着；且填满 2048 长度。
    """
    import numpy as np
    avoid_ids = avoid_ids or set()
    
    # ================= 配置 =================
    CYCLE_LEN = 16  # 周期长度：32。早期模型很容易关注到这个距离。
    # =======================================

    # 1. 生成基础块 [N, 32]
    block = np.random.randint(0, vocab_size, size=(num_seqs, CYCLE_LEN))
    
    # 过滤 avoid_ids
    for avoid_id in avoid_ids:
        block[block == avoid_id] = (avoid_id + 1) % vocab_size

    # 2. 计算需要重复多少次才能填满 seq_len + 1
    # 例如 2049 / 32 = 64.xxx -> 需要重复 65 次
    target_len = seq_len + 1
    num_repeats = int(np.ceil(target_len / CYCLE_LEN))

    # 3. 平铺重复
    # shape 变成 [N, CYCLE_LEN * num_repeats]
    full_seqs = np.tile(block, (1, num_repeats))
    
    # 4. 裁剪到精确长度
    final_arr = full_seqs[:, :target_len].copy()
    
    # 确保类型是 int64
    return torch.from_numpy(final_arr.astype(np.int64))



def compute_probe_copytarget_dataset(model_ddp: DDP,
                                     cache: QKVOActivationCache,
                                     probe_num_seqs: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    v0_acc = torch.zeros(d_model, 2*d_head, device=DEVICE, dtype=torch.float32)
    v1_acc = torch.zeros(d_model, d_head,   device=DEVICE, dtype=torch.float32)
    v2_acc = torch.zeros(d_head,  d_model,  device=DEVICE, dtype=torch.float32)

    local_target = int(np.ceil(probe_num_seqs / WORLD_SIZE))
    processed_local = 0
    hooks_local = setup_qkvo_forward_hooks(model_ddp.module, LAYER_IDX, cache)

    for batch_idx, (batch_tokens, _) in enumerate(train_influence_dataloader):
        if processed_local >= local_target:
            break
        cache.clear()
        tokens_full = batch_tokens.to(DEVICE, non_blocking=True)          # [1, S+1]
        input_ids = tokens_full[:, :SEQ_LENGTH]                           # [1, S]
        S = input_ids.shape[1]

        with torch.enable_grad():
            with autocast(dtype=DTYPE):
                _ = model_ddp(input_ids)
        if cache.result is None:
            try: model_ddp.module.reset_hooks(hooks_local)
            except Exception: pass
            raise RuntimeError("hook_result missing for probe (dataset).")

        result_head = cache.result[:, :, HEAD_IDX, :].reshape(S, -1).float()  # [S,d_model]
        W_U = model_ddp.module.W_U

        contribs = []
        for t in range(S):
            p = find_match_p(tokens_full[0], t)
            if p < 0 or (p + 1) >= (SEQ_LENGTH + 1):
                continue
            target_token = int(tokens_full[0, p+1].item())
            # 修正 dtype：把 W_U 的列转换成与 result_head 相同 dtype 再 dot
            u_col = W_U[:, target_token]
            if u_col.dtype != result_head.dtype:
                u_col = u_col.to(result_head.dtype)
            contribs.append(torch.dot(result_head[t], u_col))

        if contribs:
            f = torch.stack(contribs).sum()
            attn = model_ddp.module.blocks[LAYER_IDX].attn
            grads = torch.autograd.grad(
                outputs=f,
                inputs=[attn.W_Q, attn.W_K, attn.W_V, attn.W_O],
                retain_graph=False, create_graph=False, allow_unused=False
            )
            gQ = grads[0][HEAD_IDX].float()
            gK = grads[1][HEAD_IDX].float()
            gV = grads[2][HEAD_IDX].float()
            gO = grads[3][HEAD_IDX].float()
            v0_acc.add_(torch.cat([gQ, gK], dim=-1))
            v1_acc.add_(gV)
            v2_acc.add_(gO)

        cache.clear()
        processed_local += 1

    try: model_ddp.module.reset_hooks(hooks_local)
    except Exception:
        try: model_ddp.module.remove_all_hook_fns()
        except Exception: pass

    if dist.is_initialized():
        dist.all_reduce(v0_acc, op=dist.ReduceOp.SUM)
        dist.all_reduce(v1_acc, op=dist.ReduceOp.SUM)
        dist.all_reduce(v2_acc, op=dist.ReduceOp.SUM)

    return v0_acc, v1_acc, v2_acc

def compute_probe_copytarget_synthetic(model_ddp: DDP,
                                       cache: QKVOActivationCache,
                                       probe_num_seqs: int,
                                       anchors_per_seq: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # 1. 初始化累加器 (FP32)
    v0_acc = torch.zeros(d_model, 2*d_head, device=DEVICE, dtype=torch.float32)
    v1_acc = torch.zeros(d_model, d_head,   device=DEVICE, dtype=torch.float32)
    v2_acc = torch.zeros(d_head,  d_model,  device=DEVICE, dtype=torch.float32)

    local_target = int(np.ceil(probe_num_seqs / WORLD_SIZE))
    vocab = model_ddp.module.W_U.shape[1]
    
    avoid = set()
    if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
        avoid.add(int(tokenizer.pad_token_id))
    if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
        avoid.add(int(tokenizer.eos_token_id))

    # 必须与上面的 CYCLE_LEN 保持一致
    CYCLE_LEN = 16 
    real_seq_len = SEQ_LENGTH  # 2048

    # 生成短周期循环数据
    synth = make_induction_sequences(
        num_seqs=local_target,
        seq_len=real_seq_len,
        vocab_size=int(vocab),
        avoid_ids=avoid
    ).to(DEVICE)

    hooks_local = setup_qkvo_forward_hooks(model_ddp.module, LAYER_IDX, cache)
    
    for b in range(synth.shape[0]):
        cache.clear()
        tokens_full = synth[b:b+1]                             # [1, S+1]
        input_ids = tokens_full[:, :real_seq_len]              # [1, S]
        S = input_ids.shape[1]

        with torch.enable_grad():
            with autocast(dtype=DTYPE):
                _ = model_ddp(input_ids)
        
        if cache.result is None:
            try: model_ddp.module.reset_hooks(hooks_local)
            except Exception: pass
            raise RuntimeError("hook_result missing.")

        # 强制 FP32
        result_head = cache.result[:, :, HEAD_IDX, :].reshape(S, -1).float()
        W_U = model_ddp.module.W_U

        contribs = []
        
        # 【关键】从第二个周期开始遍历
        # 前 CYCLE_LEN 个 token 是第一次出现，没有“过去”可以复制
        for t in range(CYCLE_LEN, S):
            
            # 使用 find_match_p 寻找最近的一个匹配
            # 在这种数据构造下，它一定会找到 t - CYCLE_LEN 的那个位置
            p = find_match_p(tokens_full[0], t)
            
            if p < 0 or (p + 1) >= (real_seq_len + 1):
                continue
                
            target_token = int(tokens_full[0, p+1].item())
            
            # 强制 FP32 计算点积
            u_col = W_U[:, target_token].float()
            contribs.append(torch.dot(result_head[t], u_col))

        if contribs:
            f = torch.stack(contribs).sum()
            attn = model_ddp.module.blocks[LAYER_IDX].attn
            grads = torch.autograd.grad(
                outputs=f,
                inputs=[attn.W_Q, attn.W_K, attn.W_V, attn.W_O],
                retain_graph=False, create_graph=False, allow_unused=False
            )
            # 强制 FP32 累加
            gQ = grads[0][HEAD_IDX].float()
            gK = grads[1][HEAD_IDX].float()
            gV = grads[2][HEAD_IDX].float()
            gO = grads[3][HEAD_IDX].float()
            v0_acc.add_(torch.cat([gQ, gK], dim=-1))
            v1_acc.add_(gV)
            v2_acc.add_(gO)

        cache.clear()

    try: model_ddp.module.reset_hooks(hooks_local)
    except Exception:
        pass

    if dist.is_initialized():
        dist.all_reduce(v0_acc, op=dist.ReduceOp.SUM)
        dist.all_reduce(v1_acc, op=dist.ReduceOp.SUM)
        dist.all_reduce(v2_acc, op=dist.ReduceOp.SUM)

    return v0_acc, v1_acc, v2_acc


# 计算探针 v
if PROBE_SOURCE == "synthetic":
    v_qk, v_v, v_o = compute_probe_copytarget_synthetic(model, qkvo_cache, PROBE_NUM_SEQS, SYN_ANCHORS_PER_SEQ)
else:
    v_qk, v_v, v_o = compute_probe_copytarget_dataset(model, qkvo_cache, PROBE_NUM_SEQS)

# p = H^{-1} v
p_qk = ekfac.inverse_hvp(0, v_qk)
p_v  = ekfac.inverse_hvp(1, v_v)
p_o  = ekfac.inverse_hvp(2, v_o)

qkvo_cache.clear()
if DEVICE.startswith("cuda"):
    torch.cuda.empty_cache()

if IS_MAIN_PROCESS:
    print(f"[{RANK}] ✓ p 就绪：||p_qk||={torch.norm(p_qk, p='fro').item():.6f}, "
          f"||p_v||={torch.norm(p_v, p='fro').item():.6f}, ||p_o||={torch.norm(p_o, p='fro').item():.6f}；进入阶段2。")

# ================ 阶段2：方向影响分数（score = −Σ_blocks <g, p>） ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 阶段2：方向影响分数（每样本 ∇_θℓ 到 W_Q/W_K/W_V/W_O）")
    print("="*70)

def compute_sample_grads_qkvo(npy_batch: torch.Tensor, model_ddp: DDP) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    input_ids = npy_batch[:, :SEQ_LENGTH].to(DEVICE, non_blocking=True)
    labels    = npy_batch[:, 1:SEQ_LENGTH + 1].to(DEVICE, non_blocking=True)

    with torch.enable_grad():
        with model_ddp.no_sync():
            with autocast(dtype=DTYPE):
                logits = model_ddp(input_ids)
        Vocab = logits.shape[-1]
        loss = F.cross_entropy(logits.reshape(-1, Vocab).float(), labels.reshape(-1), reduction='sum')

        attn = model_ddp.module.blocks[LAYER_IDX].attn
        grads = torch.autograd.grad(
            outputs=loss,
            inputs=[attn.W_Q, attn.W_K, attn.W_V, attn.W_O],
            retain_graph=False, create_graph=False, allow_unused=False
        )
        gQ = grads[0][HEAD_IDX].float()
        gK = grads[1][HEAD_IDX].float()
        gV = grads[2][HEAD_IDX].float()
        gO = grads[3][HEAD_IDX].float()
        g0 = torch.cat([gQ, gK], dim=-1)
        g1 = gV
        g2 = gO

    return g0, g1, g2, float(loss.item())

local_pos_heap: List[Tuple[float, int, Tuple[int, float]]] = []
local_neg_heap: List[Tuple[float, int, Tuple[int, float]]] = []
heap_cap = TOP_K_RESULTS
uid_counter = 0
processed_samples = 0

start_time_phase2 = time.time()
for sample_idx, (sample_batch_tokens, sample_indices) in enumerate(tqdm(train_influence_dataloader, desc=f"[{RANK}] Scoring", disable=not IS_MAIN_PROCESS)):
    processed_samples += 1

    g0, g1, g2, loss_val = compute_sample_grads_qkvo(sample_batch_tokens, model)
    s0 = torch.sum(g0 * p_qk).item()
    s1 = torch.sum(g1 * p_v).item()
    s2 = torch.sum(g2 * p_o).item()
    projection_score = -(s0 + s1 + s2)

    payload_compact = (int(sample_indices.item()), float(loss_val))

    entry_pos = (projection_score, uid_counter, payload_compact)
    if len(local_pos_heap) < heap_cap:
        heapq.heappush(local_pos_heap, entry_pos)
    else:
        if projection_score > local_pos_heap[0][0]:
            heapq.heapreplace(local_pos_heap, entry_pos)

    neg_score = -projection_score
    entry_neg = (neg_score, uid_counter, payload_compact)
    if len(local_neg_heap) < heap_cap:
        heapq.heappush(local_neg_heap, entry_neg)
    else:
        if neg_score > local_neg_heap[0][0]:
            heapq.heapreplace(local_neg_heap, entry_neg)

    uid_counter += 1

    if DEVICE.startswith("cuda") and (processed_samples % EMPTY_CACHE_EVERY_SAMPLE == 0):
        torch.cuda.empty_cache()

phase2_time = time.time() - start_time_phase2

# ================ 聚合与保存 ================
if IS_MAIN_PROCESS:
    print("\n" + "="*70)
    print(f"[{RANK}/{WORLD_SIZE}] 阶段3：Top-K 聚合与保存")
    print("="*70)

pos_matrix = gather_heap_as_tensors(local_pos_heap, WORLD_SIZE, DEVICE)
neg_matrix = gather_heap_as_tensors(local_neg_heap, WORLD_SIZE, DEVICE)

if dist.is_initialized():
    total_samples_tensor = torch.tensor([processed_samples], device=DEVICE, dtype=torch.long)
    dist.reduce(total_samples_tensor, dst=0, op=dist.ReduceOp.SUM)
else:
    total_samples_tensor = torch.tensor([processed_samples], device=DEVICE, dtype=torch.long)

if RANK == 0:
    if pos_matrix is not None and pos_matrix.shape[0] > 0:
        indices_pos = np.argsort(pos_matrix[:, 0])[::-1]
        top_k_pos_indices = indices_pos[:TOP_K_RESULTS]
        top_pos_data = pos_matrix[top_k_pos_indices]
    else:
        top_pos_data = np.zeros((0, 3))

    if neg_matrix is not None and neg_matrix.shape[0] > 0:
        indices_neg = np.argsort(neg_matrix[:, 0])[::-1]
        top_k_neg_indices = indices_neg[:TOP_K_RESULTS]
        top_neg_data = neg_matrix[top_k_neg_indices]
    else:
        top_neg_data = np.zeros((0, 3))

    final_pos_topk = []
    for row in top_pos_data:
        final_pos_topk.append({
            'projection_score': float(row[0]),
            'sample_index': int(row[1]),
            'sum_loss': float(row[2])
        })

    final_neg_topk = []
    for row in top_neg_data:
        final_neg_topk.append({
            'projection_score': -float(row[0]),
            'sample_index': int(row[1]),
            'sum_loss': float(row[2])
        })
    final_neg_topk.sort(key=lambda x: x['projection_score'])

    total_samples = int(total_samples_tensor.item())
    throughput = total_samples / phase2_time if phase2_time > 0 else 0.0

    print(f"✅ 阶段2完成。Time: {phase2_time:.2f}s")
    print(f"Global Throughput: {throughput:.2f} samples/s (Total {total_samples})")

    ekfac_params = {
        'Q_A': [ekfac.Q_A[0].detach().cpu(), ekfac.Q_A[1].detach().cpu(), ekfac.Q_A[2].detach().cpu()],
        'Q_S': [ekfac.Q_S[0].detach().cpu(), ekfac.Q_S[1].detach().cpu(), ekfac.Q_S[2].detach().cpu()],
        'Lambda': [ekfac.Lambda[0].detach().cpu(), ekfac.Lambda[1].detach().cpu(), ekfac.Lambda[2].detach().cpu()],
        'damping': ekfac.damping,
        'damping_alpha': ekfac.damping_alpha,
        'd_model': ekfac.d_model,
        'd_head': ekfac.d_head,
        'param_names': ['W_QK', 'W_V', 'W_O'],
        'block_dims': ekfac.block_dims,
    }

    analysis_results = {
        'config': {
            'MODEL_ALIAS': MODEL_ALIAS,
            'DATA_NPY_PATH': DATA_NPY_PATH,
            'TARGET': 'Copy-Target Directional Influence (QK+V+O EK-FAC)',
            'TOP_K': TOP_K_RESULTS,
            'NUM_TRAIN_SAMPLES': NUM_TRAIN_SAMPLES,
            'LAYER_IDX': LAYER_IDX,
            'HEAD_IDX': HEAD_IDX,
            'SEQ_LENGTH': SEQ_LENGTH,
            'WORLD_SIZE': WORLD_SIZE,
            'PRECISION': 'BF16_FP32',
            'DAMPING': DAMPING,
            'DAMPING_ALPHA': DAMPING_ALPHA,
            'PROBE_SOURCE': PROBE_SOURCE,
            'PROBE_NUM_SEQS': PROBE_NUM_SEQS,
            'INDUCTION_MATCH': INDUCTION_MATCH,
            'MATCH_CHOICE': MATCH_CHOICE,
            'SYN_ANCHORS_PER_SEQ': SYN_ANCHORS_PER_SEQ,
            'SYN_GAP_MIN': SYN_GAP_MIN,
            'SYN_GAP_MAX': SYN_GAP_MAX,
        },
        'performance': {
            'phase1_time_total': ekfac_time_total,
            'phase2_time': phase2_time,
            'throughput_samples_per_s': throughput,
            'total_samples_scored': total_samples,
        },
        'ekfac_params': ekfac_params,
        'positive_influencers': final_pos_topk,
        'negative_influencers': final_neg_topk,
    }

    with open(OUTPUT_RESULTS_FILE, 'wb') as f:
        pickle.dump(analysis_results, f)
    print(f"结果已保存到: {OUTPUT_RESULTS_FILE}")

if dist.is_initialized():
    dist.destroy_process_group()

print(f"[{RANK}/{WORLD_SIZE}] ✅ 程序执行完毕。")