import os
import sys
import json
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM

# === 强制使用本地 sparsify 源码 ===
SCRIPT_DIR = Path(__file__).resolve().parent
SPARSIFY_DIR = SCRIPT_DIR.parent / "sparsify"
sys.path.insert(0, str(SPARSIFY_DIR))

from sparsify import Sae

# === 配置 ===
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
HOOKPOINT = "layers.14"
SAE_REPO = "EleutherAI/sae-llama-3-8b-32x"
LOCAL_MODEL_PATH = SCRIPT_DIR.parent / "models" / "llama3-8b"
SAE_SAVE_DIR = SCRIPT_DIR.parent / "sae"
SAE_SAVE_DIR.mkdir(exist_ok=True)

# === 输出路径 ===
SAE_WEIGHTS_PATH = SAE_SAVE_DIR / f"sae_llama3b_{HOOKPOINT.replace('.', '_')}.pth"
SAE_META_PATH = SAE_SAVE_DIR / f"sae_llama3b_{HOOKPOINT.replace('.', '_')}.json"

# === 加载 tokenizer 和模型（本地）===
print("🔧 Loading local LLaMA tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_PATH)

# === 下载 SAE 到内存 ===
print(f"⬇️  Downloading SAE from {SAE_REPO} | hookpoint = {HOOKPOINT}")
sae = Sae.load_from_hub(SAE_REPO, hookpoint=HOOKPOINT)
print(sae.cfg)

# === 提取结构参数 ===
cfg = sae.cfg  # SparseCoderConfig 对象

d_in = sae.encoder.in_features
d_hidden = sae.encoder.out_features
k = cfg.k
bias = sae.encoder.bias is not None

# === 保存权重 ===
print(f"💾 Saving SAE weights to {SAE_WEIGHTS_PATH}")
torch.save(sae.state_dict(), SAE_WEIGHTS_PATH)


print(list(sae.state_dict().keys()))

# === 保存结构信息 ===
metadata = {
    "hookpoint": HOOKPOINT,
    "d_in": d_in,
    "num_latents": d_hidden,  # ✅ 使用正确字段名
    "expansion_factor": d_hidden // d_in,  # 可选但有用
    "k": k,
    "bias": bias,
    "activation": "topk",  # Sparsify 默认
    "normalize_decoder": True,
    "skip_connection": False
}

print(f"📝 Saving SAE metadata to {SAE_META_PATH}")
with open(SAE_META_PATH, "w") as f:
    json.dump(metadata, f, indent=2)

print("✅ SAE download & save complete.")

