from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Iterable, Union
import torch
import torch.nn as nn

# ---------------------- saliency (SNIP-like) ----------------------

@torch.no_grad()
def _gate_rows_for_units(H: int, keep_idx: torch.Tensor) -> torch.Tensor:
    # GRU gate order in PyTorch is (r, z, n) stacked along dim 0
    # Make row indices [keep, keep+H, keep+2H]
    idx = keep_idx.to(dtype=torch.long)
    return torch.cat([idx, idx + H, idx + 2 * H], dim=0)

def _rowwise_abs_wg(W: torch.Tensor, G: torch.Tensor) -> torch.Tensor:
    # |W ⊙ dL/dW| summed over columns → per-row score
    return (W * G).abs().sum(dim=1)

@torch.no_grad()
def compute_snip_unit_scores_gru(
    model_with_gru: nn.Module,
    onebatch_or_loader,
    loss_fn: nn.Module,
    *,
    exit_id: int = -1,
    device: Union[str, torch.device] = "cpu",
    num_batches: int = 1,
    max_per_batch: int = 256,
    log_prefix: str = "[SNIP-GRU] ",
) -> torch.Tensor:
    """
    Returns per-unit saliency scores (shape [H]) for a fused nn.GRU (uni-directional).
    Aggregates row-scores over (weight_ih_lk, weight_hh_lk) and all layers,
    then averages across layers → one score per hidden unit index.
    """
    assert hasattr(model_with_gru, "gru") and isinstance(model_with_gru.gru, nn.GRU), \
        "Expected model.gru = nn.GRU"

    model = model_with_gru.to(device)
    model.zero_grad(set_to_none=True)
    model.train(False)  # deterministic, but keep grads

    # Build iterator
    if isinstance(onebatch_or_loader, tuple):
        data_iter = [onebatch_or_loader]
    else:
        data_iter = iter(onebatch_or_loader)

    used = 0
    while used < max(1, num_batches):
        try:
            xb, yb = next(data_iter)
        except StopIteration:
            break
        if xb.size(0) > max_per_batch:
            xb, yb = xb[:max_per_batch], yb[:max_per_batch]
        xb = xb.to(device, non_blocking=True).long()
        yb = yb.to(device, non_blocking=True).long()

        with torch.enable_grad():
            out = model(xb)
            # your LM forward returns list-of-exit logits (B,T,V)
            if isinstance(out, (list, tuple)):
                logits = out[exit_id]
            else:
                logits = out
            # next-token shift
            logits_use = logits[:, :-1, :].contiguous()
            targets    = yb[:, 1:].contiguous()
            B, T, V = logits_use.shape
            loss = loss_fn(logits_use.view(B*T, V), targets.view(B*T))
            loss.backward()
        used += 1

    gru: nn.GRU = model.gru
    H = int(gru.hidden_size)
    L = int(gru.num_layers)

    # accumulate per-row scores, then collapse gates to per-unit
    per_layer_unit = []
    for k in range(L):
        # fused names, unidirectional only
        Wih = getattr(gru, f"weight_ih_l{k}", None)
        Whh = getattr(gru, f"weight_hh_l{k}", None)
        Gih = None if (Wih is None or Wih.grad is None) else Wih.grad.detach()
        Ghh = None if (Whh is None or Whh.grad is None) else Whh.grad.detach()
        if Wih is None or Whh is None or Gih is None or Ghh is None:
            continue
        sih = _rowwise_abs_wg(Wih.detach(), Gih)   # shape [3H]
        shh = _rowwise_abs_wg(Whh.detach(), Ghh)   # shape [3H]
        rows = (sih + shh).view(3, H)              # [3, H]
        unit = rows.sum(dim=0)                     # collapse gates → [H]
        per_layer_unit.append(unit)

    if not per_layer_unit:
        # fallback: uniform if grads were not available
        return torch.ones(H, device=device)

    score = torch.stack(per_layer_unit, dim=0).mean(dim=0)  # [H]
    if log_prefix:
        top = torch.topk(score, k=min(5, H)).indices.tolist()
        bot = torch.topk(-score, k=min(5, H)).indices.tolist()
        print(f"{log_prefix}H={H} L={L}  top={top}  bottom={bot}")
    return score.detach()

# ---------------------- hard pruning (uniform H') ----------------------

def _slice_gate_rows(W: torch.Tensor, keep: torch.Tensor, H: int) -> torch.Tensor:
    rows = _gate_rows_for_units(H, keep.to(W.device))
    return W.index_select(0, rows)

def _slice_gate_bias(b: torch.Tensor, keep: torch.Tensor, H: int) -> torch.Tensor:
    rows = _gate_rows_for_units(H, keep.to(b.device))
    return b.index_select(0, rows)

def _slice_cols(W: torch.Tensor, keep_cols: Optional[torch.Tensor]) -> torch.Tensor:
    if keep_cols is None:
        return W
    return W.index_select(1, keep_cols.to(W.device))

def _find_head_linear(head: nn.Module) -> Optional[nn.Linear]:
    # find the final linear classifier inside a head
    if isinstance(head, nn.Linear):
        return head
    for m in head.modules():
        if isinstance(m, nn.Linear):
            return m
    return None

@torch.no_grad()
def hard_prune_gru_uniform(
    model_with_gru: nn.Module,
    keep_units: List[int],
    *,
    adjust_exit_heads: bool = True,
) -> nn.Module:
    """
    Create a new model with smaller GRU hidden_size = len(keep_units), copying
    the corresponding slices for all layers/gates, and shrinking early-exit heads.

    NOTE: Works with uni-directional nn.GRU, single hidden_size across layers.
    """
    assert hasattr(model_with_gru, "gru") and isinstance(model_with_gru.gru, nn.GRU)
    old = model_with_gru
    gru_old: nn.GRU = old.gru
    H  = int(gru_old.hidden_size)
    L  = int(gru_old.num_layers)
    K  = int(len(keep_units))
    keep = torch.as_tensor(sorted(set(int(i) for i in keep_units if 0 <= int(i) < H)))

    new = type(old)().__class__.__new__(type(old))  # allocate instance without __init__
    # Reuse the original __dict__ shallowly, then overwrite members we rebuild
    new.__dict__ = {**old.__dict__}

    # rebuild GRU with hidden_size=K
    gru_new = nn.GRU(
        input_size=gru_old.input_size,
        hidden_size=K,
        num_layers=L,
        bias=gru_old.bias,
        batch_first=getattr(gru_old, "batch_first", False),
        dropout=gru_old.dropout,
        bidirectional=False,
    ).to(next(old.parameters()).device, dtype=next(old.parameters()).dtype)

    # copy weights layer by layer
    prev_keep_cols = None  # for layer>0, input is previous hidden → slice columns
    for k in range(L):
        Wih_old: torch.Tensor = getattr(gru_old, f"weight_ih_l{k}")
        Whh_old: torch.Tensor = getattr(gru_old, f"weight_hh_l{k}")
        bih_old: torch.Tensor = getattr(gru_old, f"bias_ih_l{k}") if gru_old.bias else None
        bhh_old: torch.Tensor = getattr(gru_old, f"bias_hh_l{k}") if gru_old.bias else None

        # rows slice by gates; columns slice only for layer>0 (previous hidden pruned)
        Wih_new = _slice_gate_rows(Wih_old, keep, H)
        if k > 0:
            Wih_new = _slice_cols(Wih_new, prev_keep_cols)
        Whh_new = _slice_gate_rows(Whh_old, keep, H)
        Whh_new = _slice_cols(Whh_new, keep)

        getattr(gru_new, f"weight_ih_l{k}").copy_(Wih_new)
        getattr(gru_new, f"weight_hh_l{k}").copy_(Whh_new)

        if gru_old.bias:
            getattr(gru_new, f"bias_ih_l{k}").copy_(_slice_gate_bias(bih_old, keep, H))
            getattr(gru_new, f"bias_hh_l{k}").copy_(_slice_gate_bias(bhh_old, keep, H))

        prev_keep_cols = keep  # next layer input columns

    # swap into the model
    new.gru = gru_new

    # shrink exit heads to new hidden size
    if adjust_exit_heads and hasattr(new, "exit_heads"):
        for i, head in enumerate(new.exit_heads):
            lin = _find_head_linear(head)
            if lin is None:
                continue
            out_features = lin.out_features
            new_fc = nn.Linear(K, out_features, bias=(lin.bias is not None)).to(
                lin.weight.device, dtype=lin.weight.dtype
            )
            with torch.no_grad():
                k = min(K, lin.in_features)
                new_fc.weight[:, :k].copy_(lin.weight[:, :k])
                if lin.bias is not None:
                    new_fc.bias.copy_(lin.bias)
            # replace the found linear in place (handles plain Linear head or MLP head)
            if isinstance(head, nn.Linear):
                new.exit_heads[i] = new_fc
            else:
                # best-effort: swap the attribute containing that module
                replaced = False
                for name, m in head.named_modules():
                    if m is lin:
                        # only works if lin is a direct attribute; otherwise, leave as-is
                        try:
                            setattr(head, name, new_fc)
                            replaced = True
                        except Exception:
                            pass
                        break
                if not replaced:
                    # fallback: wrap with a tiny adapter (new_fc after original head)
                    new.exit_heads[i] = nn.Sequential(head, new_fc)

    # record convenience attributes again if your model expects them
    try:
        new.all_state_dict_keys = list(new.state_dict().keys())
        new.trainable_state_dict_keys = [n for n, p in new.named_parameters() if p.requires_grad]
    except Exception:
        pass

    return new