"""
concept_load.py

Read a 2-column CSV where:
  - Column 1: concept description (may contain commas and quotes)
  - Column 2: a filesystem path whose tail contains the SAE layer and the feature id,
              e.g. .../resid_post_layer_12/trainer_6/ae.pt/5504

Produce a JSON mapping required by eval_sae_steering.py:
  {
    "<layer>_<feature>": "<concept description>",
    ...
  }

Default output path:
  /home/dslabra5/sae4steer/saes-are-good-for-steering/concept/concept_descriptions.json

Usage:
  python concept_load.py \
    --csv /home/dslabra5/sae4steer/axbench/axbench/concept10_gemma2_2b_data/batch_topk_80_0.8357.csv
"""

from __future__ import annotations
import argparse
import csv
import json
import os
import re
import sys
from typing import Dict, Tuple


LAYER_PATTERNS = [
    re.compile(r"resid_post_layer_(\d+)", re.IGNORECASE),
    re.compile(r"\blayer_(\d+)\b", re.IGNORECASE),
]
FEATURE_PATTERN = re.compile(r"(\d+)\s*$")  # last integer at the end of the path


def parse_layer_and_feature(path_str: str) -> Tuple[int, int]:
    """
    Extract (layer, feature) from the given path string.
    Tries a couple of robust regexes for the layer; for the feature, uses the last integer.
    """
    layer = None
    for pat in LAYER_PATTERNS:
        m = pat.search(path_str)
        if m:
            layer = int(m.group(1))
            break
    if layer is None:
        raise ValueError(f"Could not find layer in path: {path_str}")

    m = FEATURE_PATTERN.search(path_str)
    if not m:
        raise ValueError(f"Could not find feature id at end of path: {path_str}")
    feature = int(m.group(1))

    return layer, feature


def normalize_description(text: str) -> str:
    """
    Trim whitespace; keep the original content (including quotes inside the text).
    csv module already handles quoted fields correctly, so we just strip.
    """
    return (text or "").strip()


def read_csv_to_mapping(csv_path: str) -> Dict[str, str]:
    """
    Read the CSV and build the mapping {"<layer>_<feature>": "<description>"}.
    """
    mapping: Dict[str, str] = {}

    # Open with UTF-8 and universal newline handling
    with open(csv_path, "r", encoding="utf-8", newline="") as f:
        # Use csv with default delimiter=',' and quotechar='"'
        reader = csv.reader(f)
        for i, row in enumerate(reader, start=1):
            if not row:
                continue
            if len(row) < 2:
                raise ValueError(f"Row {i} has fewer than 2 columns: {row}")

            desc = normalize_description(row[0])
            path_str = row[1].strip()

            try:
                layer, feature = parse_layer_and_feature(path_str)
            except Exception as e:
                raise ValueError(f"Failed to parse layer/feature at CSV row {i}: {e}") from e

            key = f"{layer}_{feature}"
            if key in mapping:
                # If duplicates appear, last one wins but warn the user
                print(f"[WARN] Duplicate key '{key}' at row {i}; overwriting previous description.", file=sys.stderr)
            mapping[key] = desc

    return mapping


def write_json(mapping: Dict[str, str], out_path: str) -> None:
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(mapping, f, ensure_ascii=False, indent=2)


def main():
    default_out_dir = "/home/dslabra5/sae4steer/saes-are-good-for-steering/concept"
    default_out_name = "concept_descriptions.json"

    ap = argparse.ArgumentParser(description="Convert concept CSV to concept_descriptions.json for eval_sae_steering.py")
    ap.add_argument("--csv", required=True, help="Path to the input CSV file.")
    ap.add_argument("--out_dir", default=default_out_dir, help=f"Output directory (default: {default_out_dir})")
    ap.add_argument("--out_name", default=default_out_name, help=f"Output file name (default: {default_out_name})")
    args = ap.parse_args()

    mapping = read_csv_to_mapping(args.csv)
    out_path = os.path.join(args.out_dir, args.out_name)
    write_json(mapping, out_path)

    print(f"[OK] Wrote {len(mapping)} concepts to: {out_path}")


if __name__ == "__main__":
    main()
