"""Fix KernelBench L2 baselines that zero-initialize learnable parameters.

Context
-------
The 33 KernelBench L2 workloads under JAXBench/benchmark/*k_*/baseline.py were
auto-translated from KernelBench PyTorch into JAX. The translation preserved
the "uninitialized PyTorch module" convention of writing `weight = jnp.zeros(...)`
as a placeholder -- PyTorch's `nn.Linear`/`nn.Conv2d` overwrite these with
Kaiming-style random init when instantiated. The original JAX translator
added a `set_weights(...)` hook for a harness to inject real weights later,
but JAXBench's final per-workload `baseline.py` inlined the zero placeholder
as the operational input. Result: matmul-heavy baselines produce all-zero
outputs, and any generated kernel that happens to output zeros trivially
passes `np.allclose` and is marked "correct" -- inflating speedup and MXU
utilization arbitrarily.

What this script does
---------------------
Rewrites each baseline's `create_inputs(...)` so that *learnable* parameters
(matmul/conv/transpose/BMM weights, plus per-parameter `multiply_weight` /
`add_value` / broadcast-bias entries that appear in 40k/45k/47k/48k) are
initialized with small random normal values -- matching the Megablox fix we
already applied. Entries that *should* stay fixed (BN/LN/GN/IN `beta`/`bias`,
running means, `scale=jnp.ones((1,))`, identity gamma) are preserved.

Usage
-----
    python3 fix_kernelbench_baselines.py --dry-run    # preview diffs
    python3 fix_kernelbench_baselines.py              # apply
"""
from __future__ import annotations

import argparse
import pathlib
import re
import sys


BENCHMARK_ROOT = pathlib.Path("/path/to/JAXBench/benchmark")

# Variable-name classifier. Case-sensitive, matches any LHS like `conv_weight = ...`.
#
# `RANDOMIZE` names MUST be drawn from a PRNG so that downstream benchmarks
# produce numerically non-trivial reference outputs. Covers every learnable
# parameter we've seen in the 33 KernelBench L2 baselines.
RANDOMIZE = {
    "weight", "gemm_weight", "bmm_weight", "linear_weight",
    "conv_weight", "conv_transpose_weight",
    # "bias" entries that are actually per-channel learnable parameters
    # (shape != (C,)). We classify these on shape below, not by name.
    "multiply_weight", "add_value",
}

# Standard PyTorch defaults we must NOT touch.
#   - BN/LN/GN/IN gamma: jnp.ones
#   - BN/LN/GN/IN beta:  jnp.zeros
#   - BN running_mean:   jnp.zeros
#   - BN running_var:    jnp.ones
#   - scalar `scale = jnp.ones((1,))` initialized to 1: PyTorch `nn.Parameter`
#     default for scaling factors when framework author sets init=1.
# Left alone.

# Heuristic for disambiguating a named `bias` variable:
#   - "bn_bias" / "ln_bias" / "gn_bias" / "in_bias" -> norm-layer beta;
#     PyTorch default is zero. Keep.
#   - otherwise (`bias`, `gemm_bias`, `conv_bias`, `linear_bias`, `bmm_bias`):
#     an nn.Linear / nn.Conv* / BMM bias, PyTorch initializes as
#     U(-bound, bound) with bound = 1/sqrt(fan_in). Randomize.
BIAS_NAMES_LINEAR = {
    "bias", "gemm_bias", "bmm_bias", "linear_bias", "conv_bias",
    "bn_bias", "ln_bias", "gn_bias", "in_bias",
}


_ZEROS_RE = re.compile(
    r"^(\s*)([A-Za-z_]\w*)\s*=\s*jnp\.zeros\(\s*(.+?)\s*,\s*dtype=dtype\s*\)\s*$",
    re.MULTILINE,
)


def _shape_is_1d_out_features(shape_src: str) -> bool:
    """Heuristic: 1-D shape like `(out_features,)`, `out_features`, `8192`,
    `out_channels`, `(C,)` -- not broadcast-y."""
    s = shape_src.strip()
    s = s.rstrip(",)").lstrip("(")
    # Multi-dim shapes have commas inside the tuple.
    return "," not in s


def _classify(name: str, shape_src: str) -> str:
    """Return 'randomize' or 'keep'.

    We randomize every parameter that PyTorch would initialize non-zero for
    the equivalent module. Concretely:
      - `nn.Linear` / `nn.Conv*` / `nn.ConvTranspose*` weight: Kaiming uniform.
      - `nn.Linear` / `nn.Conv*` / `nn.ConvTranspose*` bias: U(-bound, bound).
      - `nn.Parameter` tensors explicitly constructed by the KernelBench model
        (e.g. `multiply_weight`, `add_value`, broadcast-shape biases).

    We keep (= leave at initial zero/one) only parameters that match a
    module's *default* value:
      - BatchNorm/LayerNorm/GroupNorm/InstanceNorm `weight` (gamma = 1).
      - BatchNorm/LayerNorm/GroupNorm/InstanceNorm `bias` (beta = 0).
      - BatchNorm running_mean (= 0), running_var (= 1).
      - Scalar `scale = jnp.ones((1,))` when the model initializes it to 1.
    """
    if name in RANDOMIZE:
        return "randomize"
    # nn.Linear / nn.Conv* biases -> PyTorch initializes randomly, randomize.
    if name in BIAS_NAMES_LINEAR and not _is_norm_bias(name):
        return "randomize"
    # BN/LN/GN/IN bias / gamma / running stats -> keep default.
    if _is_norm_bias(name):
        return "keep"
    # Conventions: `*_weight`, `*_value(s)` -> learnable param.
    if name.endswith("_weight") or name.endswith("_value") or name.endswith("_values"):
        return "randomize"
    # Short names w1, w2, ... / b1, b2, ... (KernelBench translator's pattern
    # for multi-layer MLPs) -> learnable params.
    if re.fullmatch(r"w\d+", name) or re.fullmatch(r"b\d+", name):
        return "randomize"
    return "keep"


_NORM_PREFIXES = ("bn_", "ln_", "gn_", "in_", "group_norm_",
                  "layer_norm_", "batch_norm_", "instance_norm_")


def _is_norm_bias(name: str) -> bool:
    """Whether this parameter is a norm-layer bias/running-stat (keep as-is)."""
    return any(name.startswith(p) for p in _NORM_PREFIXES)


def _next_subkey_allocator(existing_keys: list[str]):
    """Return a function that, when called, returns a fresh subkey name
    (`ka`, `kb`, ... avoiding collisions with already-used names)."""
    used = set(existing_keys)
    pool = [f"k{chr(ord('a') + i)}" for i in range(26)]
    for cand in pool:
        if cand not in used:
            used.add(cand)
            yield cand


def patch_file(path: pathlib.Path) -> tuple[str, list[str]]:
    """Return (new_source, list_of_change_descriptions)."""
    src = path.read_text()
    changes: list[str] = []

    # Only randomize zeros inside create_inputs; outside is scratchpad/config.
    m = re.search(r"def create_inputs\([^)]*\)[^:]*:\s*\n", src)
    if not m:
        return src, changes
    body_start = m.end()
    body_end = src.find("\ndef ", body_start)
    if body_end == -1:
        body_end = len(src)
    body = src[body_start:body_end]

    # Detect whether a jax.random key has already been created in this body.
    has_key = bool(re.search(r"jax\.random\.key\s*\(", body))

    # Find the existing key split names so we can pick unused ones.
    existing_k_names = re.findall(r"\bk[a-z0-9_]+\b", body)

    # Collect the zero-init assignments we want to randomize.
    todo: list[tuple[re.Match, str]] = []  # (match, reason)
    for m in _ZEROS_RE.finditer(body):
        name = m.group(2)
        shape_src = m.group(3)
        decision = _classify(name, shape_src)
        if decision == "randomize":
            todo.append((m, f"randomize {name} (shape={shape_src.strip()})"))

    if not todo:
        return src, changes

    # Generate fresh subkey names for each target.
    gen = _next_subkey_allocator(existing_k_names)
    subkey_names: list[str] = []
    for _ in todo:
        try:
            subkey_names.append(next(gen))
        except StopIteration:
            raise RuntimeError(f"ran out of subkeys in {path.name}")

    # Build the replacement body: walk through body and substitute each matched
    # jnp.zeros(...) with a jax.random.normal(...) * 0.02 call, using a fresh
    # subkey. Also ensure a `jax.random.split(...)` is available at the top of
    # the body if we introduced new subkeys.
    new_body = body
    offset = 0
    for (m, _), subkey in zip(todo, subkey_names):
        indent = m.group(1)
        name = m.group(2)
        shape_src = m.group(3)
        old = m.group(0)
        # Scale 0.02: same magnitude used by the Megablox fix; keeps
        # post-matmul activations O(1) for in/out ~ 8192 with x ~ U(0,1).
        new = (
            f"{indent}{name} = jax.random.normal({subkey}, "
            f"{shape_src}, dtype=dtype) * 0.02"
        )
        start = m.start() + offset
        end = m.end() + offset
        new_body = new_body[:start] + new + new_body[end:]
        offset += len(new) - len(old)
        changes.append(f"  {name} (shape={shape_src.strip()}) -> jax.random.normal * 0.02")

    # Inject new subkey allocations. Strategy: find a safe anchor and insert
    #   `<sub1>, <sub2>, ... = jax.random.split(<fresh_parent>, N)`
    # after it, where `<fresh_parent>` is a brand-new PRNGKey derived from a
    # different seed so we don't disturb any existing key/subkey references.
    body_indent_m = re.match(r"[ \t]*", new_body)
    body_indent = body_indent_m.group(0) if body_indent_m else "    "
    # Pick a parent key name that doesn't collide.
    parent_name = next((c for c in ("rand_key", "rk", "wkey") if c not in existing_k_names), "wkey2")
    alloc_lines = [
        f"{body_indent}{parent_name} = jax.random.key(0xBADC0DE)",
    ]
    if len(subkey_names) == 1:
        # `split(key, 1)` returns a length-1 array; use `[0]` to destructure.
        alloc_lines.append(
            f"{body_indent}{subkey_names[0]} = jax.random.fold_in({parent_name}, 0)"
        )
    else:
        joined = ", ".join(subkey_names)
        alloc_lines.append(
            f"{body_indent}{joined} = jax.random.split({parent_name}, {len(subkey_names)})"
        )

    # Insert right after the first `key = jax.random.key(...)` line, if any;
    # else at the start of the body.
    anchor_m = re.search(
        r"^[ \t]*key\s*=\s*jax\.random\.key\([^)]*\)\s*\n", new_body, re.MULTILINE
    )
    insert_block = "\n".join(alloc_lines) + "\n"
    if anchor_m:
        insert_at = anchor_m.end()
        new_body = new_body[:insert_at] + insert_block + new_body[insert_at:]
    else:
        new_body = insert_block + new_body

    new_src = src[:body_start] + new_body + src[body_end:]
    return new_src, changes


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dry-run", action="store_true",
                    help="Print diffs without writing files.")
    args = ap.parse_args()

    touched = 0
    for bench_dir in sorted(BENCHMARK_ROOT.iterdir()):
        if not bench_dir.is_dir() or not bench_dir.name.endswith("_" + bench_dir.name.split("_", 1)[1]):
            continue
        # Only KernelBench L2 (dirs matching \d+k_).
        if not re.match(r"\d+k_", bench_dir.name):
            continue
        baseline = bench_dir / "baseline.py"
        if not baseline.is_file():
            continue
        new_src, changes = patch_file(baseline)
        if not changes:
            print(f"[skip] {bench_dir.name}: no learnable zero-inits to fix")
            continue
        print(f"[fix]  {bench_dir.name}:")
        for c in changes:
            print(c)
        if not args.dry_run:
            baseline.write_text(new_src)
            touched += 1
    if args.dry_run:
        print("\n(dry run, no files written)")
    else:
        print(f"\nWrote {touched} baseline(s).")


if __name__ == "__main__":
    main()
