# analysis/prune_Transformer.py
from __future__ import annotations
from typing import Any, Dict, List, Tuple
import re

import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict


def _get_mlp_subtree(params: Dict[str, Any] | FrozenDict) -> Dict[str, Any]:
    
    p = params.unfreeze() if isinstance(params, FrozenDict) else params
    try:
        return p["blocks_0"]["mlp"]
    except KeyError:
        raise KeyError("Expected params['blocks_0']['mlp'] subtree not found.")

def _set_mlp_subtree(params: Dict[str, Any] | FrozenDict, new_mlp: Dict[str, Any]) -> Dict[str, Any]:
    base = params.unfreeze() if isinstance(params, FrozenDict) else params
    base["blocks_0"]["mlp"] = new_mlp
    return base

def _sorted_layer_indices(mlp: Dict[str, Any]) -> List[int]:
    idxs = set()
    for k in mlp.keys():
        m = re.fullmatch(r"W_(\d+)", k)
        if m: idxs.add(int(m.group(1)))
    return sorted(idxs)

def _layer_widths(mlp: Dict[str, Any], layer_idxs: List[int]) -> List[int]:
    return [int(mlp[f"b_{i}"].shape[0]) for i in layer_idxs]


def _dataset_accuracy(model, params, xs, ys, batch_size: int) -> jnp.ndarray:
    N  = int(xs.shape[0])
    nb = (N + batch_size - 1) // batch_size
    pad = nb * batch_size - N
    if pad:
        xs = jnp.concatenate([xs, jnp.repeat(xs[-1:], pad, axis=0)], axis=0)
        ys = jnp.concatenate([ys, jnp.repeat(ys[-1:], pad, axis=0)], axis=0)
        mask = jnp.concatenate([jnp.ones((N,), bool), jnp.zeros((pad,), bool)], axis=0)
    else:
        mask = jnp.ones((N,), bool)
    mask = mask.reshape(nb, batch_size)

    def body(i, acc):
        start = i * batch_size
        xb = jax.lax.dynamic_slice_in_dim(xs, start, batch_size, axis=0)
        yb = jax.lax.dynamic_slice_in_dim(ys, start, batch_size, axis=0)
        mb = mask[i]
        logits = model.apply({"params": params}, xb, training=False)
        if logits.ndim == 3:
            logits_last = logits[:, -1, :]     
        else:
            logits_last = logits               
        pred = jnp.argmax(logits_last, axis=-1)  
        correct = (pred == yb) & mb
        return acc + jnp.sum(correct.astype(jnp.int32))

    total_correct = jax.lax.fori_loop(0, nb, body, jnp.array(0, jnp.int32))
    return total_correct / jnp.array(N, jnp.float32)


def _find_by_suffix(d, suffix):
    if isinstance(d, dict):
        for k, v in d.items():
            if isinstance(k, str) and k.endswith(suffix):
                if isinstance(v, list):   
                    return v[0]
                if isinstance(v, dict):
                    return next(iter(v.values()))
                return v
            out = _find_by_suffix(v, suffix)
            if out is not None:
                return out
    elif isinstance(d, list):
        for x in d:
            out = _find_by_suffix(x, suffix)
            if out is not None:
                return out
    return None

def _max_abs_preacts(
                    model, params, xs, batch_size: int, 
                    num_mlp_layers: int, last_token_index: int = 1
                    ) -> List[jnp.ndarray]:
    N  = int(xs.shape[0])
    nb = (N + batch_size - 1) // batch_size
    pad = nb * batch_size - N
    if pad:
        xs = jnp.concatenate([xs, jnp.repeat(xs[-1:], pad, axis=0)], axis=0)

    xb0 = xs[:1]
    _, inter0 = model.apply({"params": params}, xb0, training=False, mutable=["intermediates"])
    ints0 = inter0["intermediates"]
    widths = []
    for l in range(1, num_mlp_layers + 1):
        arr0 = _find_by_suffix(ints0, f"blocks_0/mlp/hook_pre{l}")
        if arr0 is None:
            arr0 = _find_by_suffix(ints0, f"hook_pre{l}")
        if arr0 is None:
            raise KeyError(f"Could not find hook_pre{l} in intermediates.")
        widths.append(int(arr0.shape[-1]))
    max_vecs = [jnp.zeros((w,), dtype=jnp.float32) for w in widths]

    def body(i, carry):
        max_vecs = carry
        start = i * batch_size
        xb = jax.lax.dynamic_slice_in_dim(xs, start, batch_size, axis=0)
        _, inter = model.apply({"params": params}, xb, training=False, mutable=["intermediates"])
        ints = inter["intermediates"]
        curr = []
        for l in range(1, num_mlp_layers + 1):
            arr = _find_by_suffix(ints, f"blocks_0/mlp/hook_pre{l}")
            if arr is None:
                arr = _find_by_suffix(ints, f"hook_pre{l}")
            if arr is None:
                raise KeyError(f"Could not find hook_pre{l} in batch intermediates.")
            pre = jnp.asarray(arr)[:, last_token_index, :]          
            curr.append(jnp.max(jnp.abs(pre), axis=0))             
        new_max = [jnp.maximum(mv, c) for mv, c in zip(max_vecs, curr)]
        return new_max

    max_vecs = jax.lax.fori_loop(0, nb, body, max_vecs)
    return max_vecs


def _max_relu_preacts(model, params, xs, batch_size: int, num_mlp_layers: int, last_token_index: int = 1) -> List[jnp.ndarray]:
    N  = int(xs.shape[0])
    nb = (N + batch_size - 1) // batch_size
    pad = nb * batch_size - N
    if pad:
        xs = jnp.concatenate([xs, jnp.repeat(xs[-1:], pad, axis=0)], axis=0)

    xb0 = xs[:1]
    _, inter0 = model.apply({"params": params}, xb0, training=False, mutable=["intermediates"])
    ints0 = inter0["intermediates"]
    widths = []
    for l in range(1, num_mlp_layers + 1):
        arr0 = _find_by_suffix(ints0, f"blocks_0/mlp/hook_pre{l}")
        if arr0 is None:
            arr0 = _find_by_suffix(ints0, f"hook_pre{l}")
        if arr0 is None:
            raise KeyError(f"Could not find hook_pre{l} in intermediates.")
        widths.append(int(arr0.shape[-1]))
    max_vecs = [jnp.zeros((w,), dtype=jnp.float32) for w in widths]

    def body(i, carry):
        max_vecs = carry
        start = i * batch_size
        xb = jax.lax.dynamic_slice_in_dim(xs, start, batch_size, axis=0)
        _, inter = model.apply({"params": params}, xb, training=False, mutable=["intermediates"])
        ints = inter["intermediates"]
        curr = []
        for l in range(1, num_mlp_layers + 1):
            arr = _find_by_suffix(ints, f"blocks_0/mlp/hook_pre{l}")
            if arr is None:
                arr = _find_by_suffix(ints, f"hook_pre{l}")
            if arr is None:
                raise KeyError(f"Could not find hook_pre{l} in batch intermediates.")
            pre = jnp.asarray(arr)[:, last_token_index, :]
            curr.append(jnp.max(jnp.maximum(pre, 0.0), axis=0))
        new_max = [jnp.maximum(mv, c) for mv, c in zip(max_vecs, curr)]
        return new_max

    max_vecs = jax.lax.fori_loop(0, nb, body, max_vecs)
    return max_vecs



def _apply_layer_mask_transformer(
    params: Dict[str, Any] | FrozenDict,
    li: int,
    colmask_for_next_1d: jnp.ndarray,
) -> Dict[str, Any]:
    
    base = params.unfreeze() if isinstance(params, FrozenDict) else params
    mlp = _get_mlp_subtree(base)

    H = int(mlp[f"b_{li}"].shape[0])
    cm = colmask_for_next_1d.astype(mlp[f"b_{li}"].dtype)    
    cm_row = cm[:, None]                                     
    cm_col = cm[None, :]                                     

    Wi = jnp.asarray(mlp[f"W_{li}"])
    bi = jnp.asarray(mlp[f"b_{li}"])
    mlp[f"W_{li}"] = (Wi * cm_row).astype(Wi.dtype)
    mlp[f"b_{li}"] = (bi * cm).astype(bi.dtype)

    next_key = f"W_{li+1}"
    if next_key in mlp:
        Wn = jnp.asarray(mlp[next_key])
        mlp[next_key] = (Wn * cm_col).astype(Wn.dtype)
    else:
        
        Wout = jnp.asarray(mlp["W_out"])
        mlp["W_out"] = (Wout * cm_col).astype(Wout.dtype)

    return _set_mlp_subtree(base, mlp)

def _apply_prunes_masked_transformer(
    params: Dict[str, Any] | FrozenDict,
    prunes: Dict[int, List[int]]
) -> Dict[str, Any]:
    out = params.unfreeze() if isinstance(params, FrozenDict) else params
    mlp = _get_mlp_subtree(out)
    layer_idxs = _sorted_layer_indices(mlp)
    widths = _layer_widths(mlp, layer_idxs)
    for li in layer_idxs:
        idxs = list(map(int, prunes.get(li, [])))
        if not idxs:
            continue
        H = widths[layer_idxs.index(li)]
        idxs_arr = jnp.array(idxs, dtype=jnp.int32)
        cm = jnp.ones((H,), dtype=mlp[f"b_{li}"].dtype)
        hot = jax.nn.one_hot(idxs_arr, H, dtype=cm.dtype).sum(0)
        cm = jnp.clip(1.0 - hot, 0.0, 1.0)
        out = _apply_layer_mask_transformer(out, li, cm)
        mlp = _get_mlp_subtree(out)  
    return out


def _per_neuron_accs_layer(
    model, params, xs, ys, batch_size: int, li: int
) -> jnp.ndarray:
    base = params.unfreeze() if isinstance(params, FrozenDict) else params
    mlp = _get_mlp_subtree(base)
    H = int(mlp[f"b_{li}"].shape[0])

    eye = jnp.eye(H, dtype=mlp[f"b_{li}"].dtype)   
    col_masks = 1.0 - eye                           

    def acc_for_mask(cm):
        p_masked = _apply_layer_mask_transformer(params, li, cm)
        return _dataset_accuracy(model, p_masked, xs, ys, batch_size)

    return jax.vmap(acc_for_mask)(col_masks)  # (H,)


def prune_two_stage_by_accuracy_batched_transformer(
    *,
    model,
    params: Dict[str, Any] | FrozenDict,
    full_x: jnp.ndarray,                 
    full_y: jnp.ndarray,
    num_mlp_layers: int,
    batch_size: int = 4096,
    abs_acc_th: float = 0.005,           
    hard_min_acc: float = 1.0,           
    last_token_index: int = 1,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    
    if isinstance(params, FrozenDict):
        params = params.unfreeze()

    mlp = _get_mlp_subtree(params)
    layer_idxs = _sorted_layer_indices(mlp)
    widths = _layer_widths(mlp, layer_idxs)
    L = len(layer_idxs)

    baseline = _dataset_accuracy(model, params, full_x, full_y, batch_size)

    report: Dict[str, Any] = {
        "baseline_acc": float(baseline),
        "stageA": {i: [] for i in range(L)},
        "stageA_alive": {i: list(range(widths[i])) for i in range(L)},
        "stageB": {i: [] for i in range(L)},
        "stageB_alive": {i: list(range(widths[i])) for i in range(L)},
    }

    maxabs_per_layer = _max_abs_preacts(model, params, full_x, batch_size, num_mlp_layers, last_token_index)
    candidatesA: Dict[int, List[int]] = {}
    for li in range(L):
        H = widths[li]
        m = maxabs_per_layer[li]  # (H,)
        cand = jnp.where(m < abs_acc_th, jnp.arange(H), -1)
        cand = [int(x) for x in cand.tolist() if x >= 0]
        candidatesA[li] = cand

    trialA = _apply_prunes_masked_transformer(params, candidatesA)
    accA = _dataset_accuracy(model, trialA, full_x, full_y, batch_size)

    if accA >= hard_min_acc:
        paramsA = trialA
        for li in range(L):
            report["stageA"][li] = candidatesA[li]
            alive = [i for i in range(widths[li]) if i not in set(candidatesA[li])]
            report["stageA_alive"][li] = alive
        baseline = accA
    else:
        paramsA = params
        for li in range(L):
            if not candidatesA[li]:
                continue
            trial_layer = _apply_prunes_masked_transformer(paramsA, {li: candidatesA[li]})
            acc_layer = _dataset_accuracy(model, trial_layer, full_x, full_y, batch_size)
            if acc_layer >= hard_min_acc:
                paramsA = trial_layer
                report["stageA"][li] = candidatesA[li]
                baseline = acc_layer
            alive = [i for i in range(widths[li]) if i not in set(report["stageA"][li])]
            report["stageA_alive"][li] = alive

    maxrelu_per_layer = _max_relu_preacts(
        model, paramsA, full_x, batch_size, num_mlp_layers, last_token_index
    )
    global_max = max(float(jnp.max(m)) for m in maxrelu_per_layer) if maxrelu_per_layer else 0.0
    report["global_activation_max"] = float(global_max)

    rel_schedule = [0.3, 0.025, 0.02, 0.01, 0.005, 0.0025, 0.0012]  
    report["stageB_schedule"] = rel_schedule
    report["stageB_attempts"] = []

    paramsB = paramsA
    accepted = False

    for rel in rel_schedule:
        thresh = rel * global_max
        report["activation_threshold"] = float(thresh)  

        candidatesB: Dict[int, List[int]] = {}
        pruned_any = False
        for li in range(L):
            H = widths[li]
            if H == 0:
                candidatesB[li] = []
                continue
            m = maxrelu_per_layer[li]  # (H,)
            cand = jnp.where(m < thresh, jnp.arange(H), -1).tolist()
            cand = [int(i) for i in cand if i >= 0 and i not in set(report["stageA"][li])]
            candidatesB[li] = cand
            if cand:
                pruned_any = True

        trialB = _apply_prunes_masked_transformer(paramsA, candidatesB)
        accB = _dataset_accuracy(model, trialB, full_x, full_y, batch_size)

        report["stageB_attempts"].append({
            "rel_thresh": float(rel),
            "abs_thresh": float(thresh),
            "acc": float(accB),
            "candidates": {li: list(map(int, v)) for li, v in candidatesB.items()},
        })

        if pruned_any and accB >= hard_min_acc:
            paramsB = trialB
            for li in range(L):
                report["stageB"][li] = candidatesB[li]
                alive = [i for i in report["stageA_alive"][li] if i not in set(candidatesB[li])]
                report["stageB_alive"][li] = alive
            baseline = accB
            report["stageB_rel_threshold_accepted"] = float(rel)
            accepted = True
            break

    if not accepted:
        for li in range(L):
            report["stageB"][li] = []
            report["stageB_alive"][li] = report["stageA_alive"][li]
        report["stageB_rel_threshold_accepted"] = None

    report["final_acc"]   = float(baseline)
    report["alive_final"] = {li: report["stageB_alive"][li] for li in range(L)}
    report["dead_final"]  = {
        li: sorted(set(range(widths[li])) - set(report["alive_final"][li]))
        for li in range(L)
    }
    report["alive_counts"] = {li: len(report["alive_final"][li]) for li in range(L)}
    report["dead_counts"]  = {li: len(report["dead_final"][li])  for li in range(L)}
    return paramsB, report
