#!/usr/bin/env python3
"""
Convert concept description JSONs like:
{
  "12_4157": "desc",
  "12_5996": "desc",
  ...
}

to feature lists grouped by layer:
{
  "12": [4157, 5996, ...]
}

Input dir : /home/dslabra5/sae4steer/saes-are-good-for-steering/concept
Output dir: /home/dslabra5/sae4steer/saes-are-good-for-steering/data/features
Output name: gemma_2b_<original_without__concept_descriptions>.json -> *_features.json
"""

from pathlib import Path
import json
import re

# ---- Paths (edit if needed) ----
SRC_DIR = Path("/home/dslabra5/sae4steer/saes-are-good-for-steering/concept/qwen2.5-3b")
DST_DIR = Path("/home/dslabra5/sae4steer/saes-are-good-for-steering/data/features")
MODEL_PREFIX = "qwen2.5_3b"  # final filenames start with this

# Regex for keys like "12_4157"
KEY_RE = re.compile(r"^(\d+)_(\d+)$")

def convert_one_file(src_path: Path) -> dict:
    """Return dict[layer_str] -> sorted list of feature ids (ints)."""
    with src_path.open("r", encoding="utf-8") as f:
        data = json.load(f)

    by_layer = {}  # str -> set[int]
    for k in data.keys():
        m = KEY_RE.match(k)
        if not m:
            # Ignore keys not matching "<layer>_<feature>"
            continue
        layer_str, feat_str = m.group(1), m.group(2)
        by_layer.setdefault(layer_str, set()).add(int(feat_str))

    # Convert to sorted lists
    return {layer: sorted(list(feats)) for layer, feats in by_layer.items()}

def build_output_name(src_name: str) -> str:
    """
    batch_topk_50_concept_descriptions.json
      -> gemma_2b_batch_topk_50_features.json
    """
    stem = src_name
    if stem.endswith("_concept_descriptions.json"):
        stem = stem[: -len("_concept_descriptions.json")]
    return f"{MODEL_PREFIX}_{stem}_features.json"

def main():
    DST_DIR.mkdir(parents=True, exist_ok=True)

    src_files = sorted(SRC_DIR.glob("*.json"))
    if not src_files:
        print(f"No JSON files found in {SRC_DIR}")
        return

    for src in src_files:
        out_dict = convert_one_file(src)

        # If nothing parsed, still write an empty dict (and warn)
        if not out_dict:
            print(f"[warn] No valid '<layer>_<feature>' keys in {src.name}; writing empty dict.")

        out_name = build_output_name(src.name)
        dst = DST_DIR / out_name

        with dst.open("w", encoding="utf-8") as f:
            json.dump(out_dict, f, ensure_ascii=False, indent=2)

        # Short summary
        layers = ", ".join(sorted(out_dict.keys())) if out_dict else "-"
        counts = sum(len(v) for v in out_dict.values())
        print(f"[ok] {src.name} -> {dst.name}  (layers: {layers}, total features: {counts})")

if __name__ == "__main__":
    main()
