#!/usr/bin/env python
# build_gemmascope_sae.py

"""
Merge selected latents from multiple ae.pt checkpoints (as listed in metadata.jsonl)
into a single GemmaScopeSAE.pt file for AxBench evaluation.

Output:
    <dump_dir>/train/GemmaScopeSAE.pt
"""

import json
import argparse
import pathlib
import sys
import torch

def parse_ref(ref: str):
    """
    Parse a reference string into (checkpoint_path, latent_id).
    Supported formats:
      • /abs/path/ae.pt/1234
      • file:///abs/path/ae.pt#1234
    """
    if ref.startswith("file://"):
        path, latent = ref[7:].split("#")
    else:
        path, latent = ref.rsplit("/", 1)
    return path, int(latent)

def load_cached(path, cache):
    """
    Load a Torch state dict from `path` once and cache it in `cache`.
    """
    if path not in cache:
        cache[path] = torch.load(path, map_location="cpu")
    return cache[path]

def _is_latent_sized(t, state):
    n_latent = max(state["encoder.weight"].shape)
    return t.numel() == n_latent


def extract_vectors(state, latent_id):
    """
    From an SAE state dict and a latent index, extract:
      - w_dec:   decoder weight row, shape (D,)
      - w_enc:   encoder weight column, shape (D,)
      - b_enc:   encoder bias scalar
      - thr:     activation threshold scalar (if any)
    """
    keys = state.keys()

    # decoder weight
    if "W_dec" in keys:
        w_dec = state["W_dec"][latent_id]              # (D,)
    elif "decoder.weight" in keys:
        w_dec = state["decoder.weight"][:, latent_id]  # (D,)
    else:
        raise ValueError(f"Cannot find decoder weight among keys {keys}")

    # encoder weight
    if "W_enc" in keys:
        w_enc = state["W_enc"][:, latent_id]           # (D,)
    elif "encoder.weight" in keys:
        w_enc = state["encoder.weight"][latent_id]     # (D,)
    else:
        raise ValueError(f"Cannot find encoder weight among keys {keys}")

    # encoder bias
    if "b_enc" in keys:
        b_enc = state["b_enc"][latent_id]
    elif "encoder.bias" in keys:
        b_enc = state["encoder.bias"][latent_id]
    else:
        b_enc = torch.tensor(0.)

    # activation threshold or bias
    if "threshold" in keys:
        t = state["threshold"]
        thr = t if t.dim() == 0 else t[latent_id]
    elif "gate_bias" in keys and _is_latent_sized(state["gate_bias"], state):
        thr = state["gate_bias"][latent_id]
    elif "bias" in keys and _is_latent_sized(state["bias"], state):
        # Standard(ReLU) uses 'bias' as hidden-unit bias
        thr = state["bias"][latent_id]
    else:
        thr = torch.tensor(0.)

    return w_dec, w_enc, b_enc, thr

def main(metadata_dir, dump_dir):
    meta_path = pathlib.Path(metadata_dir) / "metadata.jsonl"
    if not meta_path.exists():
        sys.exit(f"metadata.jsonl not found in {metadata_dir}")

    W_dec_rows = []
    W_enc_cols = []
    b_encs     = []
    thrs       = []

    # global decoder bias (fallback to zero vector)
    b_dec_saved = None

    # capture extra params for gated/topk
    extra_params = {}

    cache = {}

    with open(meta_path, "r") as f:
        for line in f:
            entry = json.loads(line)
            path, latent_id = parse_ref(entry["ref"])
            state = load_cached(path, cache)

            # save decoder bias + any global params once
            if b_dec_saved is None:
                if "b_dec" in state:
                    b_dec_saved = state["b_dec"]
                elif "decoder_bias" in state:
                    b_dec_saved = state["decoder_bias"]
                else:
                    size = state.get("decoder.weight", torch.zeros(0)).shape[0]
                    b_dec_saved = torch.zeros(size)
                b_dec_saved = b_dec_saved.clone()

                # gated-specific
                if "r_mag" in state and "mag_bias" in state:
                    extra_params["r_mag"]    = state["r_mag"].clone()
                    extra_params["mag_bias"] = state["mag_bias"].clone()
                # top-k–specific
                if "k" in state:
                    extra_params["k"] = state["k"].clone()

            w_dec, w_enc, b_enc, thr = extract_vectors(state, latent_id)

            W_dec_rows.append(w_dec.unsqueeze(0))   # (1, D)
            W_enc_cols.append(w_enc.unsqueeze(1))   # (D, 1)
            b_encs.append(b_enc.unsqueeze(0))       # (1,)
            thrs.append(thr.unsqueeze(0))           # (1,)

    if not W_dec_rows:
        sys.exit("No latents collected – please check metadata.jsonl references")

    W_dec     = torch.cat(W_dec_rows, dim=0)  # (N, D)
    W_enc     = torch.cat(W_enc_cols, dim=1)  # (D, N)
    b_enc     = torch.cat(b_encs,     dim=0)  # (N,)
    threshold = torch.cat(thrs,       dim=0)  # (N,)

    sae_out = {
        "b_dec":     b_dec_saved,
        "W_dec":     W_dec,
        "W_enc":     W_enc,
        "b_enc":     b_enc,
        "threshold": threshold,
        **extra_params
    }

    out_dir = pathlib.Path(dump_dir) / "train"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "GemmaScopeSAE.pt"
    torch.save(sae_out, out_path)
    print(f"✓ Saved merged GemmaScopeSAE.pt → {out_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Build a GemmaScopeSAE.pt from metadata.jsonl"
    )
    parser.add_argument(
        "--metadata_dir", required=True,
        help="Directory containing metadata.jsonl with latent references"
    )
    parser.add_argument(
        "--dump_dir", required=True,
        help="Root dump directory (will create a 'train/' subfolder)"
    )
    args = parser.parse_args()
    main(args.metadata_dir, args.dump_dir)
