#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Compute Explained Variance (EV) matching TRAIN LOG STATS (dictionary_learning.training.log_stats):

    frac_variance_explained = 1 - Var(x - x_hat).sum(dim=channels) / Var(x).sum(dim=channels)

IMPORTANT (Dream / Diffusion LMs):
- We must match ActivationBuffer behavior.
- With DLM_MASK_POLICY="mask", ActivationBuffer collects activations ONLY from masked positions.
- EV must be computed on those collected activations (NOT pad attention_mask, NOT hook output).

This script:
- Scans all trained SAE folders (trainer_*) to find max layer
- Loads model + tokenizer once
- Truncates model ONCE to (max_layer + 1) so all layers are in-range
- For each SAE:
    - loads dictionary via utils.load_dictionary()
    - builds ActivationBuffer (mask-policy applied internally via env vars)
    - computes per-batch train-log-style FVE
    - repeats this estimation (--n_repeats) and records:
        (1) maximum EV over repeats
        (2) mean EV over repeats
        (3) std EV over repeats (sample std; 0 if only one repeat)
        (4) list of per-repeat EVs
    - writes JSON next to each trainer folder

Fixes included (to match training.log_stats precisely):
1) Norm-factor consistency:
   - Add --ev_space {saved,train}
   - saved: assume ae.pt already absorbed norm_factor; DO NOT rescale act.
   - train: revert dictionary back to training-space by scaling biases with 1/norm_factor,
           and divide act by norm_factor.

2) Reconstruction path consistency:
   - Add --recon_mode {dictionary,trainer}
   - dictionary: use reconstruct(dictionary, act) (lightweight, may diverge from trainer.loss path)
   - trainer: instantiate StandardTrainer from config.json and use trainer.loss(...) to obtain act_hat
             exactly like training.
"""

import os
import sys
import json
import argparse
import warnings
from typing import List, Dict, Any, Optional

import torch as t
from tqdm import tqdm
from modelscope import AutoModel, AutoTokenizer

# ---- add project paths (match your style) ----
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
sys.path.append("./")
sys.path.append("../")
sys.path.append("../../")

warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
warnings.filterwarnings("ignore", module="pydantic._internal._generate_schema")

import dictionary_learning.dictionary_learning.utils as utils
from dictionary_learning.dictionary_learning.pytorch_buffer import ActivationBuffer
from dictionary_learning.dictionary_learning.utils import hf_sequence_packing_dataset_to_generator

# Most faithful reconstruction path (optional)
try:
    from dictionary_learning.dictionary_learning.trainers.standard import StandardTrainer
except Exception:
    StandardTrainer = None


# ---------------------- constants: YOUR REQUIRED POLICY ----------------------
DLM_MASK_POLICY = "mask"
DLM_T_MIN = 0.05
DLM_T_MAX = 0.50
# ---------------------------------------------------------------------------


def _set_dlm_env(mask_policy: str, t_min: float, t_max: float) -> None:
    os.environ["DLM_MASK_POLICY"] = str(mask_policy)
    os.environ["DLM_T_MIN"] = str(float(t_min))
    os.environ["DLM_T_MAX"] = str(float(t_max))


def find_sae_trainer_dirs(root: str) -> List[str]:
    trainer_dirs: List[str] = []
    for dirpath, _, _ in os.walk(root):
        base = os.path.basename(dirpath)
        if base.startswith("trainer_") and os.path.isdir(dirpath):
            trainer_dirs.append(dirpath)
    trainer_dirs.sort()
    return trainer_dirs


def read_json_safe(path: str) -> Optional[Dict[str, Any]]:
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return None


def _canonical_device(device_str: str) -> str:
    try:
        if isinstance(device_str, str) and device_str.startswith("cuda"):
            if not t.cuda.is_available():
                return "cpu"
            if ":" in device_str:
                idx = int(device_str.split(":")[1])
                if idx < 0 or idx >= t.cuda.device_count():
                    return "cuda"
            return device_str
        return device_str
    except Exception:
        return "cuda" if t.cuda.is_available() else "cpu"


def _count_params(model) -> int:
    try:
        return sum(p.numel() for p in model.parameters())
    except Exception:
        return -1


def _load_model_and_tokenizer(model_name: str, dtype: t.dtype):
    tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
    if not hasattr(tok, "pad_token") or tok.pad_token is None:
        tok.pad_token = tok.eos_token

    # Modelscope AutoModel signature differs sometimes; keep it robust
    try:
        model = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            device_map="auto",
            dtype=dtype,
        ).eval()
    except TypeError:
        model = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            device_map="auto",
            torch_dtype=dtype,
        ).eval()

    return model, tok


def _infer_layer_from_cfg(cfg: Dict[str, Any]) -> Optional[int]:
    trainer_cfg = (cfg or {}).get("trainer", {})
    layer = trainer_cfg.get("layer", None)
    if layer is not None:
        try:
            return int(layer)
        except Exception:
            pass

    subname = trainer_cfg.get("submodule_name", "")
    if isinstance(subname, str) and "resid_post_layer_" in subname:
        try:
            return int(subname.split("resid_post_layer_")[-1])
        except Exception:
            return None
    return None


def _scan_max_layer(sae_dirs: List[str]) -> int:
    max_layer = -1
    for d in sae_dirs:
        cfg = read_json_safe(os.path.join(d, "config.json")) or {}
        layer = _infer_layer_from_cfg(cfg)
        if layer is None:
            continue
        max_layer = max(max_layer, int(layer))
    return max_layer


@t.no_grad()
def reconstruct(dictionary: t.nn.Module, x: t.Tensor) -> t.Tensor:
    """
    Reconstruct x using dictionary (lightweight path).
    Expect x shape: (N, D)
    """
    dictionary.eval()

    # Preferred API
    if hasattr(dictionary, "encode") and hasattr(dictionary, "decode"):
        f = dictionary.encode(x)
        x_hat = dictionary.decode(f)
        return x_hat.to(dtype=x.dtype, device=x.device)

    # fallback forward
    out = dictionary(x)
    if isinstance(out, t.Tensor):
        return out.to(dtype=x.dtype, device=x.device)

    if isinstance(out, (list, tuple)):
        same = [o for o in out if isinstance(o, t.Tensor) and o.shape == x.shape]
        if same:
            return same[0].to(dtype=x.dtype, device=x.device)
        tens = [o for o in out if isinstance(o, t.Tensor)]
        if tens:
            return tens[0].to(dtype=x.dtype, device=x.device)

    if isinstance(out, dict):
        for k in ("x_hat", "reconstruction", "decoded"):
            if k in out and isinstance(out[k], t.Tensor):
                return out[k].to(dtype=x.dtype, device=x.device)
        for v in out.values():
            if isinstance(v, t.Tensor) and v.shape == x.shape:
                return v.to(dtype=x.dtype, device=x.device)
        for v in out.values():
            if isinstance(v, t.Tensor):
                return v.to(dtype=x.dtype, device=x.device)

    return x


@t.no_grad()
def frac_variance_explained_batch(act_BD: t.Tensor, act_hat_BD: t.Tensor) -> float:
    """
    EXACTLY match training.log_stats():

        total_variance = t.var(act_i, dim=0).sum()
        residual_variance = t.var(act_i - act_hat, dim=0).sum()
        frac_variance_explained = 1 - residual_variance / total_variance
    """
    if act_BD.ndim != 2 or act_hat_BD.ndim != 2:
        return float("nan")
    if act_BD.shape != act_hat_BD.shape:
        return float("nan")
    if act_BD.shape[0] < 2:
        return float("nan")

    x = act_BD.to(t.float32)
    xh = act_hat_BD.to(t.float32)

    total_variance = t.var(x, dim=0, unbiased=True).sum()
    residual_variance = t.var(x - xh, dim=0, unbiased=True).sum()

    tv = float(total_variance.detach().cpu().item())
    rv = float(residual_variance.detach().cpu().item())
    if tv <= 0:
        return float("nan")

    return float(1.0 - (rv / tv))


def build_activation_buffer(
    model_trunc,
    tokenizer,
    layer: int,
    device: str,
    ctx_len: int,
    refresh_batch_size: int,
    out_batch_size: int,
    add_special_tokens: bool,
    buffer_tokens: int = 250_000,
) -> ActivationBuffer:
    """
    Build ActivationBuffer where Dream DLM forward-noising and mask-position selection
    is applied internally by env vars.
    """
    submodule = utils.get_submodule(model_trunc, layer)
    d_submodule = model_trunc.config.hidden_size

    generator = hf_sequence_packing_dataset_to_generator(
        tokenizer,
        min_chars=ctx_len * 4,
    )

    n_ctxs = max(1, int(buffer_tokens) // int(ctx_len))

    buf = ActivationBuffer(
        generator=generator,
        model=model_trunc,
        submodule=submodule,
        n_ctxs=n_ctxs,
        ctx_len=ctx_len,
        refresh_batch_size=refresh_batch_size,
        out_batch_size=out_batch_size,
        io="out",
        d_submodule=d_submodule,
        device=device,
        add_special_tokens=add_special_tokens,
    )
    return buf


def _clean_trainer_cfg_for_ctor(trainer_cfg: Dict[str, Any]) -> Dict[str, Any]:
    cfg = dict(trainer_cfg) if isinstance(trainer_cfg, dict) else {}
    for k in ["trainer", "wandb_name", "log_steps", "save_steps"]:
        cfg.pop(k, None)
    return cfg


def build_standard_trainer_from_cfg(trainer_cfg: Dict[str, Any], device: str):
    if StandardTrainer is None:
        raise RuntimeError("StandardTrainer import failed; cannot use --recon_mode trainer")

    cfg = _clean_trainer_cfg_for_ctor(trainer_cfg)
    cfg["device"] = _canonical_device(device)
    trainer = StandardTrainer(**cfg)
    trainer.ae.eval()
    return trainer


def maybe_revert_dictionary_to_train_space(dictionary: t.nn.Module, norm_factor: float):
    if norm_factor is None or norm_factor <= 0:
        return
    if hasattr(dictionary, "ae") and hasattr(dictionary.ae, "scale_biases"):
        dictionary.ae.scale_biases(1.0 / float(norm_factor))
        return
    if hasattr(dictionary, "scale_biases"):
        dictionary.scale_biases(1.0 / float(norm_factor))
        return


def maybe_revert_trainer_to_train_space(trainer, norm_factor: float):
    if norm_factor is None or norm_factor <= 0:
        return
    if hasattr(trainer, "ae") and hasattr(trainer.ae, "scale_biases"):
        trainer.ae.scale_biases(1.0 / float(norm_factor))


def _mean_std(values: List[float]) -> (float, float):
    """
    Sample mean and sample std (ddof=1). If len=1 -> std=0.
    NaNs should be filtered before passing in.
    """
    if len(values) == 0:
        return float("nan"), float("nan")
    m = float(sum(values) / len(values))
    if len(values) == 1:
        return m, 0.0
    var = sum((x - m) ** 2 for x in values) / (len(values) - 1)
    return m, float(var ** 0.5)


def main():
    parser = argparse.ArgumentParser("Compute TRAIN-LOG-MATCHED EV (DLM mask policy).")
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--ae_root", type=str, required=True)

    parser.add_argument("--heldout_dataset", type=str, default="monology/pile-uncopyrighted")

    parser.add_argument("--token_budget", type=int, default=20_000_000, help="Counted in sampled (masked) positions.")
    parser.add_argument("--n_batches", type=int, default=0, help="If >0, stop after this many ActivationBuffer batches.")

    # repeat N times; record max + mean (+ std)
    parser.add_argument(
        "--n_repeats",
        type=int,
        default=1,
        help="Repeat the EV estimation multiple times (each repeat rebuilds ActivationBuffer). "
             "Record max, mean, std, and the per-repeat EV list.",
    )

    parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
    parser.add_argument("--device", type=str, default="cuda:0")

    parser.add_argument("--out_json_name", type=str, default="explained_variance_trainlog_mask.json")

    parser.add_argument("--ctx_len", type=int, default=2048)
    parser.add_argument("--refresh_batch_size", type=int, default=8)
    parser.add_argument("--out_batch_size", type=int, default=2048)
    parser.add_argument("--add_special_tokens", action="store_true")

    parser.add_argument("--buffer_tokens", type=int, default=250_000)

    parser.add_argument("--ev_space", type=str, default="saved", choices=["saved", "train"])
    parser.add_argument("--recon_mode", type=str, default="trainer", choices=["dictionary", "trainer"])
    parser.add_argument("--debug_once", action="store_true")

    args = parser.parse_args()

    _set_dlm_env(DLM_MASK_POLICY, DLM_T_MIN, DLM_T_MAX)
    print(f"[DLM] policy={DLM_MASK_POLICY}  t_min={DLM_T_MIN}  t_max={DLM_T_MAX}", flush=True)

    device = _canonical_device(args.device)
    dtype_map = {"float32": t.float32, "bfloat16": t.bfloat16, "float16": t.float16}
    load_dtype = dtype_map[args.dtype]

    sae_dirs = find_sae_trainer_dirs(args.ae_root)
    if len(sae_dirs) == 0:
        print("[Scan] No trainer_* found via recursive scan; fallback to utils.get_nested_folders(...)", flush=True)
        sae_dirs = sorted(utils.get_nested_folders(args.ae_root))
    if len(sae_dirs) == 0:
        raise SystemExit(f"[Error] No SAE folders found under: {args.ae_root}")

    max_layer = _scan_max_layer(sae_dirs)
    if max_layer < 0:
        raise SystemExit("[Error] Could not infer any layer from SAE config.json files.")

    print(f"[Scan] Found {len(sae_dirs)} SAE folders. max_layer={max_layer}", flush=True)

    print(f"[Setup] Loading model/tokenizer: {args.model_name}", flush=True)
    model_full, tokenizer = _load_model_and_tokenizer(args.model_name, dtype=load_dtype)

    p_full = _count_params(model_full)
    if p_full > 0:
        print(f"Model parameters (full): {p_full:,}", flush=True)

    model_trunc = utils.truncate_model(model_full, int(max_layer) + 1)

    p_trunc = _count_params(model_trunc)
    if p_trunc > 0:
        print(f"Model parameters (trunc to {max_layer+1}): {p_trunc:,}", flush=True)

    try:
        n_layers = len(model_trunc.model.layers)
        print(f"[Sanity] num_layers(trunc)={n_layers} (must be >= {max_layer+1})", flush=True)
    except Exception:
        pass

    for idx, d in enumerate(sae_dirs, start=1):
        print(f"\n[Eval {idx}/{len(sae_dirs)}] SAE: {d}", flush=True)

        cfg_path = os.path.join(d, "config.json")
        cfg = read_json_safe(cfg_path) or {}
        trainer_cfg = cfg.get("trainer", {})

        layer = _infer_layer_from_cfg(cfg)
        if layer is None:
            print("[Skip] Could not infer layer from config.json", flush=True)
            continue
        layer = int(layer)

        norm_factor = trainer_cfg.get("norm_factor", None)
        try:
            norm_factor = float(norm_factor) if norm_factor is not None else None
        except Exception:
            norm_factor = None

        try:
            dictionary, _cfg2 = utils.load_dictionary(d, device=device)
            dictionary.eval()
        except Exception as e:
            print(f"[Skip] Failed to load dictionary: {e}", flush=True)
            continue

        trainer = None
        if args.recon_mode == "trainer":
            try:
                trainer = build_standard_trainer_from_cfg(trainer_cfg, device=device)
                ae_path = os.path.join(d, "ae.pt")
                state = t.load(ae_path, map_location="cpu")
                trainer.ae.load_state_dict(state, strict=True)
                trainer.ae.to(device)
                trainer.ae.eval()
            except Exception as e:
                print(
                    f"[Warn] Failed to use trainer.loss path ({type(e).__name__}: {e}). "
                    f"Falling back to dictionary.",
                    flush=True,
                )
                trainer = None

        if args.ev_space == "train" and norm_factor is not None and norm_factor > 0:
            maybe_revert_dictionary_to_train_space(dictionary, norm_factor)
            if trainer is not None:
                maybe_revert_trainer_to_train_space(trainer, norm_factor)

        did_debug = False

        def run_one_estimate(repeat_idx: int) -> Dict[str, Any]:
            nonlocal did_debug

            buf = build_activation_buffer(
                model_trunc=model_trunc,
                tokenizer=tokenizer,
                layer=layer,
                device=device,
                ctx_len=int(args.ctx_len),
                refresh_batch_size=int(args.refresh_batch_size),
                out_batch_size=int(args.out_batch_size),
                add_special_tokens=bool(args.add_special_tokens),
                buffer_tokens=int(args.buffer_tokens),
            )

            tokens_total = 0
            batches_done = 0
            fve_weighted_sum = 0.0
            fve_weight_total = 0

            pbar = tqdm(
                total=args.token_budget,
                unit="tok",
                dynamic_ncols=True,
                desc=f"EV(mask) L{layer} r{repeat_idx+1}/{args.n_repeats}",
            )

            for act in buf:
                if not isinstance(act, t.Tensor) or act.ndim != 2:
                    continue

                act = act.to(device=device)

                if args.ev_space == "train" and norm_factor is not None and norm_factor > 0:
                    act_in = act / norm_factor
                else:
                    act_in = act

                if trainer is not None and args.recon_mode == "trainer":
                    try:
                        _x_in, act_hat, _f, _losslog = trainer.loss(act_in, step=0, logging=True)
                    except Exception:
                        act_hat = reconstruct(dictionary, act_in)
                else:
                    act_hat = reconstruct(dictionary, act_in)

                fve = frac_variance_explained_batch(act_in, act_hat)
                B = int(act_in.shape[0])

                if args.debug_once and (not did_debug):
                    did_debug = True
                    with t.no_grad():
                        x = act_in.to(t.float32)
                        xh = act_hat.to(t.float32)
                        msn = float((x.pow(2).sum(dim=1).mean()).cpu().item())
                        msn_hat = float((xh.pow(2).sum(dim=1).mean()).cpu().item())
                        l2_ratio = float((xh.norm(dim=1).mean() / (x.norm(dim=1).mean() + 1e-12)).cpu().item())
                    print(f"[Debug] buf.config={getattr(buf, 'config', None)}", flush=True)
                    print(
                        f"[Debug] act_in shape={tuple(act_in.shape)} dtype={act_in.dtype} "
                        f"norm_factor={norm_factor} ev_space={args.ev_space}",
                        flush=True,
                    )
                    print(
                        f"[Debug] mean_sq_norm(x)={msn:.6g}  mean_sq_norm(x_hat)={msn_hat:.6g}  "
                        f"l2_ratio≈{l2_ratio:.6g}",
                        flush=True,
                    )
                    print(f"[Debug] recon_mode={args.recon_mode} trainer_ok={trainer is not None}", flush=True)

                if fve == fve:  # not NaN
                    fve_weighted_sum += float(fve) * B
                    fve_weight_total += B

                tokens_total += B
                batches_done += 1
                pbar.update(B)

                if fve_weight_total > 0:
                    pbar.set_postfix({"fve": f"{(fve_weighted_sum / fve_weight_total):.6f}"})

                if args.n_batches > 0 and batches_done >= args.n_batches:
                    break
                if tokens_total >= args.token_budget:
                    break

            pbar.close()

            ev_val = float("nan") if fve_weight_total == 0 else float(fve_weighted_sum / fve_weight_total)
            return {
                "frac_variance_explained": ev_val,
                "tokens_evaluated": int(tokens_total),
                "batches_evaluated": int(batches_done),
            }

        # ---------------- repeats: keep max + mean (+ std) ----------------
        best_ev = float("-inf")
        best_repeat_idx = None
        best_detail: Optional[Dict[str, Any]] = None

        all_evs: List[float] = []

        for r in range(int(args.n_repeats)):
            rr = run_one_estimate(r)
            ev_r = rr["frac_variance_explained"]
            print(
                f"[Repeat {r+1}/{args.n_repeats}] layer={layer}  fve={ev_r:.6f}  "
                f"tokens={rr['tokens_evaluated']:,}  batches={rr['batches_evaluated']}",
                flush=True,
            )

            # collect for mean/std (ignore NaNs)
            if ev_r == ev_r:
                all_evs.append(float(ev_r))

                # track best
                if ev_r > best_ev:
                    best_ev = float(ev_r)
                    best_repeat_idx = int(r)
                    best_detail = dict(rr)

        if len(all_evs) == 0:
            best_ev = float("nan")
            ev_mean = float("nan")
            ev_std = float("nan")
        else:
            ev_mean, ev_std = _mean_std(all_evs)

        out = {
            # Backward-compatible: keep this field as the MAX (same as before)
            "frac_variance_explained": float(best_ev),

            # New fields: max + mean (+ std) + list
            "frac_variance_explained_max": float(best_ev),
            "frac_variance_explained_mean": float(ev_mean),
            "frac_variance_explained_std": float(ev_std),
            "frac_variance_explained_all": all_evs,

            "best_repeat_idx": best_repeat_idx,
            "best_repeat_detail": best_detail,

            "definition": "train-log style: 1 - sum_c Var(x-xhat)_c / sum_c Var(x)_c ; Var over sampled tokens (dim=0), summed over channels",
            "dlm_mask_policy": DLM_MASK_POLICY,
            "dlm_t_min": float(DLM_T_MIN),
            "dlm_t_max": float(DLM_T_MAX),
            "heldout_dataset_arg": str(args.heldout_dataset),
            "ctx_len": int(args.ctx_len),
            "refresh_batch_size": int(args.refresh_batch_size),
            "out_batch_size": int(args.out_batch_size),
            "buffer_tokens": int(args.buffer_tokens),
            "add_special_tokens": bool(args.add_special_tokens),
            "model_name": str(args.model_name),
            "layer": int(layer),
            "norm_factor_used": float(norm_factor) if norm_factor is not None else None,
            "ev_space": str(args.ev_space),
            "recon_mode": str(args.recon_mode),
            "n_repeats": int(args.n_repeats),
            "n_batches_per_repeat": int(args.n_batches),
            "note": (
                "Activations come from ActivationBuffer with DLM mask policy; computed on masked-position activation vectors. "
                "We repeat the estimation across independent buffer refreshes. We record max/mean/std and the per-repeat list."
            ),
        }

        out_path = os.path.join(d, args.out_json_name)
        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(out, f, indent=2)

        print(
            f"[Done] layer={layer}  best_fve={float(best_ev):.6f}  mean_fve={float(ev_mean):.6f}  "
            f"std_fve={float(ev_std):.6f}  saved={out_path}",
            flush=True,
        )


if __name__ == "__main__":
    main()
