import gc
import argparse
import os
import os.path
import re  # MOD: added for parsing trainer id and name parts

from tqdm import tqdm
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from utils import (
    get_features_by_layers,
    get_sae,
    cache_logit_lens,
    _swap_layer_number_in_path  # MOD: already imported to resolve per-layer SAE path for printing and cache auto path
)
from sae_utils import AmlifySAEHook


def get_output_score(
    layer,
    feature,
    logit_lens_indices,
    sentence,
    sae,
    tokenizer,
    model,
    device,
    amp_factor=10,
):
    model = model.to(device)
    sae = sae.to(device)
    sae_hook = AmlifySAEHook(layer, sae, [feature], amp_factor, device)
    model_block_to_hook = model.model.layers[layer]
    handle = model_block_to_hook.register_forward_hook(sae_hook, always_call=True)

    inputs = tokenizer(sentence, return_tensors="pt")
    for k, v in inputs.items():
        inputs[k] = v.to(model.device)

    outputs = model(**inputs)

    for k, v in inputs.items():
        inputs[k] = v.cpu()
    handle.remove()
    for hook in model_block_to_hook._forward_hooks.values():
        hook.remove()

    logits_after = outputs.logits[:, -1]
    intervention_logits = logits_after.squeeze()
    intervention_probs = torch.softmax(intervention_logits, dim=0).detach().cpu()

    vocab_size = intervention_probs.shape[0]
    tokens_argsort = torch.argsort(intervention_probs, dim=0, descending=True)
    ll_tokens_ranks = [
        (tokens_argsort == ll_token).nonzero(as_tuple=True)[0].item()
        for ll_token in logit_lens_indices
    ]
    top_token_score = torch.max(intervention_probs[logit_lens_indices]).item()
    rank_output_score = 1 - (min(ll_tokens_ranks) / vocab_size)

    model = model.cpu()
    sae = sae.cpu()
    torch.cuda.empty_cache()
    gc.collect()

    return rank_output_score * top_token_score


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, default="gemma2_9b")
    parser.add_argument("--logit_lens_top_k", type=int, default=20)
    parser.add_argument("--features_file", type=str, required=True)
    parser.add_argument(
    "--device",
    type=str,
    default=None, 
    help="Device: cpu | cuda | cuda:0 | cuda:1, etc.; cuda:0 is automatically selected by default (if available)"
    )
    # MOD: make --cache_path optional (auto-resolve when omitted)
    parser.add_argument(
        "--cache_path",
        type=str,
        required=False,
        default=None,
        help="Optional. If omitted, an auto path under .../saes-are-good-for-steering/cache will be generated per layer."
    )

    # NEW: one single local directory for a DL SAE (contains config.json + ae.pt)
    parser.add_argument(
        "--dl_local_dir",
        type=str,
        required=True,
        help=(
            "Path to a local dictionary-learning SAE folder, e.g. "
            ".../trained_saes__google_gemma-2-9b_gated_top_k/resid_post_layer_20/trainer_1. "
            "The code will auto-swap the layer number when needed."
        ),
    )
    return parser.parse_args()


# MOD: helper to find repo root ".../saes-are-good-for-steering" from any absolute features_file path
def _detect_repo_root_from_features(features_file_path: str) -> str:
    norm = os.path.normpath(features_file_path)
    parts = norm.split(os.sep)
    if "saes-are-good-for-steering" in parts:
        idx = parts.index("saes-are-good-for-steering")
        return os.sep.join(parts[: idx + 1])
    # fallback to current working directory if repo folder cannot be detected
    return os.getcwd()


def _method_tag_from_trainer_class(trainer_class_name: str) -> str:
    name = (trainer_class_name or "").lower()

    if "gated" in name:
        return "gated"
    if "batchtopk" in name or "batch_topk" in name:
        return "batch_topk"
    if "topk" in name and "batch" not in name:
        return "topk"
    if "jumprelu" in name or ("jump" in name and "relu" in name):
        return "jump_relu"
    if "standard" in name and "april" in name and "update" in name:
        return "standard_april_update"

    name = re.sub(r"[^a-z0-9_]+", "_", name).strip("_")
    return name[:32] if name else "unknown"


def _trainer_id_from_path(path: str) -> str:

    patterns = [
        r"(?:^|/)(?:trainer|gated|batch[_-]?topk|top[_-]?k|jump[_-]?relu|jumprelu|standard[_-]?april[_-]?update)[_/ -]?(\d+)(?:/|$)",
        r"[_-](\d+)(?:/|$)",  
    ]
    for pat in patterns:
        ms = list(re.finditer(pat, path, flags=re.IGNORECASE))
        if ms:
            return ms[-1].group(1)
    return "unk"


def _build_experiment_name(trainer_class: str, trainer_id: str) -> str:
    method = _method_tag_from_trainer_class(trainer_class)
    return trainer_id if trainer_id.startswith(method + "_") else f"{method}_{trainer_id}"


# MOD: decide cache path for a given layer (used when --cache_path is omitted)
def _auto_cache_path(
    repo_root: str,
    model_type: str,
    layer: int,
    trainer_class: str,
    trainer_id: str,
    topk_for_filename: int,
) -> str:
    """
    Auto path like:
    {repo_root}/cache/{model_type}/layer{layer}/{method}_trainer{trainer_id}/output_scores_top{K}.json
    """
    cache_root = os.path.join(repo_root, "cache", "results_outputscore")
    layer_dir = f"layer{int(layer)}"
    exp_name = _build_experiment_name(trainer_class, trainer_id)
    cache_dir = os.path.join(cache_root, model_type, layer_dir, exp_name)
    os.makedirs(cache_dir, exist_ok=True)
    filename = f"output_scores_top{int(topk_for_filename)}.json"
    return os.path.join(cache_dir, filename)


def main():
    args = parse_args()
    model_type = args.model_type
    logit_lens_k = args.logit_lens_top_k
    if args.device is not None:
        device = args.device
    else:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print(args)

    if model_type == "gemma2_2b":
        model_name = "google/gemma-2-2b"
    elif model_type == "gemma2_9b":
        model_name = "google/gemma-2-9b"
    elif model_type == "gemma2_9b_it_131":
        model_name = "google/gemma-2-9b-it"
    else:
        raise ValueError(f"Model type not supported {model_type}")

    features_by_layers = get_features_by_layers(args.features_file)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    final_layer_norm = model.model.norm
    lm_head = model.lm_head
    print(final_layer_norm.weight.shape, lm_head.weight.shape)

    saes = dict()
    neutral_sentence = "From my experience,"

    # MOD: detect repo root ".../saes-are-good-for-steering" once (used for auto cache paths)
    repo_root = _detect_repo_root_from_features(args.features_file)

    # MOD: we'll compute (and create) cache path PER LAYER when args.cache_path is None
    # If args.cache_path is provided, we will use it as a single combined file (legacy behavior).

    # legacy pre-load if a single explicit cache_path is given (combined across layers)
    output_scores_combined = None  # MOD: track combined dict only if explicit path is used
    if args.cache_path:
        os.makedirs(os.path.dirname(args.cache_path), exist_ok=True)
        if os.path.exists(args.cache_path):
            with open(args.cache_path, "r") as f:
                output_scores_combined = json.load(f)
            for k, v in list(output_scores_combined.items()):
                if isinstance(v, list):
                    output_scores_combined[k] = {"tokens": v}  # migrate
        else:
            output_scores_combined = dict()

    for layer in features_by_layers:
        features = features_by_layers[layer]

        # Load DL SAE for this layer from a single base dir
        sae = get_sae(
            model_type=model_type,
            layer=layer,
            saes=saes,
            backend="dl_local",
            dl_local_dir=args.dl_local_dir,  # single path, the util will auto-adjust layer in path
            device=device,
        )

        # ===================== print current SAE/layer info (from previous step) =====================
        try:
            resolved_dir = _swap_layer_number_in_path(args.dl_local_dir, layer)
        except Exception:
            resolved_dir = args.dl_local_dir

        def _shape(x):
            try:
                return tuple(x.shape)
            except Exception:
                return None

        # device
        sae_device = getattr(getattr(sae, "W_dec", None), "device", torch.device(device))

        # shape
        W_dec = getattr(sae, "W_dec", None)
        W_enc = getattr(sae, "W_enc", None)
        F, D = (_shape(W_dec) or (None, None))

        # other parameters
        b_dec = getattr(sae, "b_dec", None)
        b_enc = getattr(sae, "b_enc", None)
        thr_vec = getattr(sae, "threshold_vector", None)
        gate_bias = getattr(sae, "gate_bias", None)
        r_mag = getattr(sae, "r_mag", None)
        mag_bias = getattr(sae, "mag_bias", None)

        print(f"\n[SAE-LOAD] layer={layer}  path='{resolved_dir}'")
        print(f"  trainer_class='{getattr(sae, 'trainer_class_name', '')}'  device={sae_device}")
        print(f"  W_dec={_shape(W_dec)}  W_enc={_shape(W_enc)} ")
        if F is not None and D is not None:
            print(f"  dims: F={F} (features), D={D} (d_model)")

        print(f"  biases: b_dec={_shape(b_dec)}  b_enc={_shape(b_enc)}")
        print(f"  thresholds: scalar={getattr(sae, 'threshold_scalar', None)}  vector={_shape(thr_vec)}")
        print(f"  topk: k={getattr(sae, 'k_topk', None)}")
        print(f"  gated params: gate_bias={_shape(gate_bias)}  r_mag={_shape(r_mag)}  mag_bias={_shape(mag_bias)}")
        print(f"  will hook: model.model.layers[{layer}]\n")
        # =================== print over ===================

        # MOD: auto decide the cache path for THIS layer (if user didn't pass --cache_path)
        if args.cache_path is None:
            trainer_id = _trainer_id_from_path(resolved_dir)
            trainer_class = getattr(sae, "trainer_class_name", "") or ""
            k_topk_val = getattr(sae, "k_topk", None)
            F_val = F[0] if (F is not None and isinstance(F, tuple)) else F  # ensure int
            cache_path_this_layer = _auto_cache_path(
                repo_root=repo_root,
                model_type=model_type,
                layer=layer,
                trainer_class=trainer_class,
                trainer_id=trainer_id,
                topk_for_filename=logit_lens_k,
            )
        else:
            cache_path_this_layer = args.cache_path  # legacy single combined file

        # Prepare per-layer (or combined) output dict
        if args.cache_path is None:
            if os.path.exists(cache_path_this_layer):
                with open(cache_path_this_layer, "r") as f:
                    output_scores = json.load(f)
                for k, v in list(output_scores.items()):
                    if isinstance(v, list):
                        output_scores[k] = {"tokens": v}
            else:
                output_scores = dict()
        else:
            # use the combined dict (already loaded before loop)
            output_scores = output_scores_combined


        # === [SKIP LAYER IF FULLY DONE] ==========================================
        if args.cache_path is None:
            needed_keys = {f"{layer}_{int(f)}" for f in features}
            if all(
                (k in output_scores) and isinstance(output_scores[k], dict) and ("score" in output_scores[k])
                for k in needed_keys
            ):
                print(f"[SKIP LAYER] layer {layer} already done at {cache_path_this_layer} "
                    f"(features={len(features)}). Skipping this layer.")
                os.makedirs(os.path.dirname(cache_path_this_layer), exist_ok=True)
                continue


        # compute logit lens table for this layer
        model = model.cpu()
        logit_lens_topk, logit_lens_confidence, logit_lens_raw_logits = cache_logit_lens(
            layer=layer,
            saes=saes,
            model_type=model_type,
            final_layer_norm=final_layer_norm,
            lm_head=lm_head,
            k=logit_lens_k,
            backend="dl_local",
            dl_local_dir=args.dl_local_dir,
            device=device,
        )

        model = model.to(device)
        for feature in tqdm(features):
            layer_feature_key = f"{layer}_{feature}"
            feature = int(feature)

            # === [SKIP FEATURE IF DONE] ==========================================
            existing = output_scores.get(layer_feature_key)
            if isinstance(existing, dict) and ("score" in existing):
                print(f"[SKIP] already scored {layer_feature_key}")
                continue

            # Get per-feature top-k token ids from logit-lens
            logit_lens_tokens_indices = logit_lens_topk.indices[feature, :].tolist()
            # Convert to tokens and print
            tokens = tokenizer.convert_ids_to_tokens(logit_lens_tokens_indices)
            print(f"[Layer {layer} | Feature {feature}] top-{logit_lens_k} tokens: {tokens}")
            # (optional) keep a mapping in memory if needed in future
            # Note: we no longer write a separate tokens file; tokens are stored per entry below.

            # Compute output score
            output_score = get_output_score(
                layer,
                feature,
                logit_lens_tokens_indices,
                neutral_sentence,
                sae,
                tokenizer,
                model,
                device,
            )

            # Merge into one entry under the same key
            entry = output_scores.get(layer_feature_key, {})
            if not isinstance(entry, dict):
                entry = {}
            entry["score"] = float(output_score)
            entry["tokens"] = tokens
            output_scores[layer_feature_key] = entry

            # Optional: print for debugging
            print(f"[Layer {layer} | Feature {feature}] score={output_score:.6f}  tokens={tokens[:5]} ...")

            torch.cuda.empty_cache()
            gc.collect()

        # Save after finishing this layer
        # - If auto path: write to the per-layer file
        # - If explicit path: keep updating the same combined file
        os.makedirs(os.path.dirname(cache_path_this_layer), exist_ok=True)
        with open(cache_path_this_layer, "w") as f:
            json.dump(output_scores, f, ensure_ascii=False, indent=2)

        # update combined reference if applicable
        if args.cache_path is not None:
            output_scores_combined = output_scores

    # MOD: nothing else to do; if explicit path was used, it's already written in the loop


if __name__ == "__main__":
    main()
