import argparse
import json
import os
from typing import Any, Dict, List, Optional

import torch

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.graph import Graph, compute_node_influence


def _load_graph(path: str, device: str) -> Graph:
    g = Graph.from_pt(path, map_location=device)
    g.to(device)
    return g


def _build_graph_from_prompt(
    *,
    prompt: str,
    model_name: str,
    transcoder_name: str,
    dtype: str,
    device: str,
    max_n_logits: int,
    desired_logit_prob: float,
    max_feature_nodes: int,
    batch_size: int,
    offload: Optional[str],
    verbose: bool,
) -> Graph:
    dtype_map = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    if dtype not in dtype_map:
        raise ValueError(f"Unsupported dtype={dtype}. Choose from {sorted(dtype_map.keys())}")

    model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=dtype_map[dtype])
    graph = attribute(
        prompt=prompt,
        model=model,
        max_n_logits=max_n_logits,
        desired_logit_prob=desired_logit_prob,
        batch_size=batch_size,
        max_feature_nodes=max_feature_nodes,
        offload=offload,
        verbose=verbose,
    )
    graph.to(device)
    return graph


def _extract_top_feature_nodes(
    graph: Graph,
    *,
    k: int,
    device: str,
) -> List[Dict[str, Any]]:
    n_logits = int(len(graph.logit_tokens))
    n_features = int(len(graph.selected_features))

    logit_weights = torch.zeros(graph.adjacency_matrix.shape[0], device=device)
    logit_weights[-n_logits:] = graph.logit_probabilities.to(device)

    node_influence = compute_node_influence(graph.adjacency_matrix.to(device), logit_weights)
    feature_scores = node_influence[:n_features]

    top_scores, top_node_idx = torch.topk(feature_scores, k=min(int(k), n_features))

    rows: List[Dict[str, Any]] = []
    for score, node_idx in zip(top_scores.tolist(), top_node_idx.tolist()):
        layer, pos, feat_idx = graph.active_features[graph.selected_features[node_idx]].tolist()
        activation = float(graph.activation_values[graph.selected_features[node_idx]].item())
        ctx_tok = None
        try:
            ctx_tok = int(graph.input_tokens[int(pos)])
        except Exception:
            ctx_tok = None

        rows.append(
            {
                "node_idx": int(node_idx),
                "score": float(score),
                "layer": int(layer),
                "pos": int(pos),
                "feature_idx": int(feat_idx),
                "activation": float(activation),
                "ctx_token_id": None if ctx_tok is None else int(ctx_tok),
            }
        )

    rows.sort(key=lambda d: d["score"], reverse=True)
    return rows


def _layer_feature_mapping(top_features: List[Dict[str, Any]]) -> Dict[str, List[int]]:
    by_layer: Dict[str, List[int]] = {}
    for r in top_features:
        layer = str(int(r["layer"]))
        by_layer.setdefault(layer, []).append(int(r["feature_idx"]))
    for layer in list(by_layer.keys()):
        by_layer[layer] = sorted(list(set(by_layer[layer])))
    return by_layer


def main() -> None:
    ap = argparse.ArgumentParser()

    src = ap.add_mutually_exclusive_group(required=True)
    src.add_argument("--graph", default=None, help="Load an existing circuit-tracer Graph .pt")
    src.add_argument("--prompt", default=None, help="Prompt to attribute and build a Graph")

    ap.add_argument("--k", type=int, default=50, help="How many top features to output")
    ap.add_argument("--device", default="cpu", help="cpu/cuda")
    ap.add_argument("--outdir", default=None, help="Output directory")

    ap.add_argument("--dump-all", action="store_true", help="Dump detailed JSON output")

    ap.add_argument("--model-name", default="google/gemma-2-2b")
    ap.add_argument("--transcoder-name", default="gemma")
    ap.add_argument("--dtype", default="bfloat16", help="float32/float16/bfloat16")

    ap.add_argument("--max-n-logits", type=int, default=10)
    ap.add_argument("--desired-logit-prob", type=float, default=0.95)
    ap.add_argument("--max-feature-nodes", type=int, default=8192)
    ap.add_argument("--batch-size", type=int, default=256)
    ap.add_argument("--offload", default=None, help="disk/cpu/None")
    ap.add_argument("--verbose", action="store_true")

    args = ap.parse_args()

    device = str(args.device)

    graph: Graph
    graph_path: Optional[str] = None
    if args.graph is not None:
        graph_path = os.path.abspath(str(args.graph))
        graph = _load_graph(graph_path, device=device)
        default_outdir = os.path.join(os.path.dirname(graph_path), "top_features")
    else:
        graph = _build_graph_from_prompt(
            prompt=str(args.prompt),
            model_name=str(args.model_name),
            transcoder_name=str(args.transcoder_name),
            dtype=str(args.dtype),
            device=device,
            max_n_logits=int(args.max_n_logits),
            desired_logit_prob=float(args.desired_logit_prob),
            max_feature_nodes=int(args.max_feature_nodes),
            batch_size=int(args.batch_size),
            offload=None if args.offload in (None, "None", "") else str(args.offload),
            verbose=bool(args.verbose),
        )
        default_outdir = os.path.join(os.getcwd(), "top_features")

    outdir = args.outdir or default_outdir
    os.makedirs(outdir, exist_ok=True)

    model_for_decode = None
    if args.prompt is not None:
        try:
            model_for_decode = ReplacementModel.from_pretrained(
                str(args.model_name), str(args.transcoder_name), dtype=torch.bfloat16
            )
        except Exception:
            model_for_decode = None

    top = _extract_top_feature_nodes(graph, k=int(args.k), device=device)

    if model_for_decode is not None:
        for r in top:
            tid = r.get("ctx_token_id")
            if tid is None:
                r["ctx_token"] = None
            else:
                try:
                    r["ctx_token"] = str(model_for_decode.tokenizer.decode([int(tid)]))
                except Exception:
                    r["ctx_token"] = None
    else:
        for r in top:
            r["ctx_token"] = None

    mapping = _layer_feature_mapping(top)

    out_features_json = os.path.join(outdir, f"top{int(args.k)}_features.json")
    with open(out_features_json, "w", encoding="utf-8") as f:
        json.dump(mapping, f, ensure_ascii=False, indent=2)

    out_meta = {
        "source": {"graph": graph_path, "prompt": args.prompt},
        "k": int(args.k),
        "n_pos": int(getattr(graph, "n_pos", len(graph.input_tokens))),
        "n_logits": int(len(graph.logit_tokens)),
        "scan": getattr(graph, "scan", None),
        "logit_tokens": [int(x) for x in list(graph.logit_tokens[: min(10, len(graph.logit_tokens))])],
        "logit_probabilities": [float(x) for x in list(graph.logit_probabilities[: min(10, len(graph.logit_probabilities))])],
    }
    out_meta_json = os.path.join(outdir, "top_features_meta.json")
    with open(out_meta_json, "w", encoding="utf-8") as f:
        json.dump(out_meta, f, ensure_ascii=False, indent=2)

    print(f"Saved: {out_features_json}")
    print(f"Saved: {out_meta_json}")
    print("Top features (layer,pos,feature_idx,score,activation,ctx_token):")
    for r in top[: min(len(top), 50)]:
        ctx = r.get("ctx_token")
        print(
            f"  L{r['layer']:>2} P{r['pos']:>3} F{r['feature_idx']:>6}"
            f"  score={r['score']:.6g}  act={r['activation']:.6g}  ctx={repr(ctx)}"
        )

    if bool(args.dump_all):
        out_all = os.path.join(outdir, f"top{int(args.k)}_features_detailed.json")
        with open(out_all, "w", encoding="utf-8") as f:
            json.dump(top, f, ensure_ascii=False, indent=2)
        print(f"Saved: {out_all}")


if __name__ == "__main__":
    main()
