import math
import copy
import warnings
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from datasets import load_dataset


def _unwrap_batch_for_hf(batch, device: torch.device):
    """
    Accepts:
      - dict: { ..., "labels" or "label": Tensor, ... }
      - tuple/list: (inputs_dict, labels_tensor)
      - list of (inputs_dict, labels_tensor) (rare; manual collate)

    Returns:
      inputs_dict (tensors on device), labels (LongTensor on device)
    """
    if isinstance(batch, dict):
        labels = batch.get("labels", batch.get("label", None))
        if labels is None:
            raise ValueError("Batch dict missing 'labels'/'label'.")
        labels = labels.to(device)
        inputs = {k: v.to(device) for k, v in batch.items() if k not in ("labels", "label")}
        return inputs, labels

    if isinstance(batch, (list, tuple)) and len(batch) == 2:
        x, y = batch
        if not isinstance(x, dict):
            raise TypeError(f"Expected inputs_dict in batch[0], got {type(x)}")
        inputs = {k: v.to(device) for k, v in x.items()}
        labels = y.to(device) if isinstance(y, torch.Tensor) else torch.as_tensor(y, dtype=torch.long, device=device)
        return inputs, labels

    if isinstance(batch, (list, tuple)) and len(batch) > 0:
        first = batch[0]
        if isinstance(first, (list, tuple)) and len(first) == 2 and isinstance(first[0], dict):
            keys = first[0].keys()
            inputs = {k: torch.stack([b[0][k] for b in batch]).to(device) for k in keys}
            labels = torch.stack([
                b[1] if isinstance(b[1], torch.Tensor) else torch.tensor(b[1], dtype=torch.long)
                for b in batch
            ]).to(device)
            return inputs, labels

    raise TypeError(f"Unsupported batch type for HF unwrap: {type(batch)}")


@torch.no_grad()
def bernoulli_score_sampling(
    scores: torch.Tensor,
    n: int,
    seed: int,
    max_tries: int = 3000,
    temperature: float = 1e-6,
) -> torch.Tensor:
    """Sample n unique indices by Bernoulli draws proportional to softmax(scores/temperature)."""
    device = scores.device
    scores = scores.clone()
    g = torch.Generator(device=device).manual_seed(seed)
    selected: List[int] = []

    remaining_idx = torch.arange(scores.numel(), device=device)
    remaining_scores = scores

    tries = 0
    while tries < max_tries and len(selected) < n and remaining_scores.numel() > 0:
        tries += 1
        probs = torch.softmax(remaining_scores / temperature, dim=0)
        draws = torch.bernoulli(probs, generator=g).bool()
        if draws.any():
            new = remaining_idx[draws].tolist()
            for i in new:
                if i not in selected:
                    selected.append(i)
                    if len(selected) == n:
                        break
        if len(selected) > 0:
            keep_mask = ~torch.isin(remaining_idx, torch.tensor(selected, device=device))
            remaining_idx = remaining_idx[keep_mask]
            remaining_scores = remaining_scores[keep_mask]

    if len(selected) < n and remaining_scores.numel() > 0:
        topk = torch.topk(remaining_scores, k=min(n - len(selected), remaining_scores.numel())).indices
        selected.extend(remaining_idx[topk].tolist())

    return torch.tensor(selected[:n], device=device, dtype=torch.long)


def _topk_or_sample(
    scores: torch.Tensor,
    k: int,
    stochastic: bool,
    seed: Optional[int] = None,
) -> torch.Tensor:
    if stochastic:
        assert seed is not None, "stochastic=True requires a seed"
        return bernoulli_score_sampling(scores, n=k, seed=seed)
    else:
        return torch.topk(scores, k=k, largest=True).indices



def get_grads_nlp(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
) -> None:
    model.train()
    model.to(device)
    loss_fn = nn.CrossEntropyLoss()

    batch = next(iter(dataloader))
    inputs, labels = _unwrap_batch_for_hf(batch, device)

    outputs = model(**inputs)
    logits = outputs.logits if hasattr(outputs, "logits") else outputs
    loss = loss_fn(logits, labels)

    model.zero_grad()
    loss.backward()



@torch.no_grad()
def compute_attention_head_scores_bert(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
) -> torch.Tensor:

    model.eval()
    model.to(device)

    batch = next(iter(dataloader))
    inputs, _ = _unwrap_batch_for_hf(batch, device)

    outputs = model(**inputs, output_attentions=True)
    scores_per_layer = []
    for att in outputs.attentions:
        scores = att.abs().mean(dim=(0, 2, 3))  # [H]
        scores_per_layer.append(scores)
    return torch.stack(scores_per_layer)  # [L, H]



def _chunk_indices_for_heads(keep_positions: torch.Tensor, head_dim: int) -> torch.Tensor:
    """Map head positions -> flat column indices for concatenated heads (each chunk size=head_dim)."""
    chunks = [torch.arange(h * head_dim, (h + 1) * head_dim, device=keep_positions.device) for h in keep_positions]
    return torch.cat(chunks, dim=0) if len(chunks) > 0 else torch.empty(0, dtype=torch.long, device=keep_positions.device)


def _slice_linear_out_features(linear: nn.Linear, keep_rows: torch.Tensor) -> nn.Linear:
    """Keep selected output rows (out_features)."""
    W = linear.weight.data.detach().clone()        # [out, in]
    B = linear.bias.data.detach().clone() if linear.bias is not None else None
    new = nn.Linear(W.shape[1], keep_rows.numel(), bias=(B is not None), device=W.device)
    new.weight.data = W[keep_rows].clone()
    if B is not None:
        new.bias.data = B[keep_rows].clone()
    return new


def _slice_linear_in_features(linear: nn.Linear, keep_cols: torch.Tensor) -> nn.Linear:
    """Keep selected input columns (in_features)."""
    W = linear.weight.data.detach().clone()        # [out, in]
    B = linear.bias.data.detach().clone() if linear.bias is not None else None
    new = nn.Linear(keep_cols.numel(), W.shape[0], bias=(B is not None), device=W.device)
    new.weight.data = W[:, keep_cols].clone()
    if B is not None:
        new.bias.data = B.clone()
    return new


def _bert_layer_num_heads(layer: nn.Module) -> int:
    return layer.attention.self.num_attention_heads


def _bert_head_dim(layer: nn.Module) -> int:
    return layer.attention.self.attention_head_size


def _bert_hidden_size(layer: nn.Module) -> int:
    return layer.attention.self.query.in_features


def _get_qkv(linear_self: nn.Module) -> Tuple[nn.Linear, nn.Linear, nn.Linear]:
    sa = linear_self
    return sa.query, sa.key, sa.value


def _score_heads_in_layer(
    layer: nn.Module,
    strategy: str,
    model: Optional[nn.Module],
    dataloader: Optional[DataLoader],
    device: Optional[torch.device],
    compute_grads: bool,
    use_bias: bool = False,
) -> torch.Tensor:
    """
    Return per-head scores [num_active_heads] for a single BertLayer.
    Strategies: magnitude | gradient | gradient_squared | attention_scores
    """
    assert strategy in ["magnitude", "gradient", "gradient_squared", "attention_scores"]
    sa = layer.attention.self
    q, k, v = _get_qkv(sa)
    num_heads = sa.num_attention_heads
    head_dim = sa.attention_head_size

    if strategy == "attention_scores":
        assert model is not None and dataloader is not None and device is not None, \
            "attention_scores requires model, dataloader, device"
        all_scores = compute_attention_head_scores_bert(model, dataloader, device)
        layer_idx = layer._idx  
        return all_scores[layer_idx].to(device)

    Wq, Wk, Wv = q.weight, k.weight, v.weight  
    if use_bias:
        bq, bk, bv = q.bias, k.bias, v.bias    

    def per_head_mean_abs(W: torch.Tensor) -> torch.Tensor:
        return W.abs().view(num_heads, head_dim, -1).mean(dim=(1, 2))

    if strategy == "magnitude":
        scores = per_head_mean_abs(Wq) + per_head_mean_abs(Wk) + per_head_mean_abs(Wv)
        if use_bias:
            scores += bq.abs().view(num_heads, head_dim).mean(dim=1)
            scores += bk.abs().view(num_heads, head_dim).mean(dim=1)
            scores += bv.abs().view(num_heads, head_dim).mean(dim=1)
        return scores / (3 + (3 if use_bias else 0))

    assert dataloader is not None and device is not None and model is not None, \
        f"Gradient-based pruning requires dataloader and device."
    if compute_grads or Wq.grad is None or Wk.grad is None or Wv.grad is None:
        get_grads_nlp(model, dataloader, device)
    Gq, Gk, Gv = Wq.grad, Wk.grad, Wv.grad

    if strategy == "gradient":
        combine = lambda W, G: (W * G).abs()
    else:
        combine = lambda W, G: (W * G) ** 2

    def per_head_stat(W: torch.Tensor, G: torch.Tensor) -> torch.Tensor:
        return combine(W, G).view(num_heads, head_dim, -1).mean(dim=(1, 2))

    return (per_head_stat(Wq, Gq) + per_head_stat(Wk, Gk) + per_head_stat(Wv, Gv)) / 3


# =========================
# Selection helpers (global/layer)
# =========================

def _get_layer_heads_pruning(
    heads_scores: torch.Tensor,
    original_num_heads: int,
    num_heads_to_prune: int,
    stochastic: bool,
    seed: Optional[int],
    random_select: bool = False,
) -> torch.Tensor:
    new_num = original_num_heads - num_heads_to_prune
    if new_num == original_num_heads:
        return torch.arange(original_num_heads, device=heads_scores.device)
    if random_select:
        assert seed is not None, "Random pruning requires a seed"
        g = torch.Generator(device=heads_scores.device).manual_seed(seed)
        return torch.randperm(original_num_heads, generator=g, device=heads_scores.device)[:new_num]
    idx = _topk_or_sample(heads_scores, k=new_num, stochastic=stochastic, seed=seed)
    assert idx.numel() == new_num
    return idx


def _get_global_heads_pruning(
    model_heads_scores: torch.Tensor,   # [L, max_heads], padded with -inf
    per_layer_heads: List[int],
    per_layer_prune: int,
    min_heads: int,
    stochastic: bool,
    seed: Optional[int],
) -> Dict[int, torch.Tensor]:
    """
    Global pick across all layers; enforce >= min_heads kept per layer.
    """
    L = len(per_layer_heads)
    Hmax = model_heads_scores.shape[1]
    flat = model_heads_scores.flatten()
    total_heads = sum(per_layer_heads)
    global_keep = total_heads - per_layer_prune * L
    min_required = sum(min(min_heads, h) for h in per_layer_heads)
    global_keep = max(global_keep, min_required)

    layer_kept: Dict[int, torch.Tensor] = {i: torch.tensor([], device=model_heads_scores.device, dtype=torch.long) for i in range(L)}
    scores = flat.clone()
    while True:
        idx = _topk_or_sample(scores, k=global_keep, stochastic=stochastic, seed=seed)
        valid = True
        for li in range(L):
            start, end = li * Hmax, (li + 1) * Hmax
            mask = (idx >= start) & (idx < end)
            kept = idx[mask] - start
            kept = kept[kept < per_layer_heads[li]]
            if kept.numel() < min(min_heads, per_layer_heads[li]):
                valid = False
                local = torch.topk(model_heads_scores[li, :per_layer_heads[li]], k=min(min_heads, per_layer_heads[li])).indices
                abs_idx = start + local
                all_local = torch.arange(start, start + per_layer_heads[li], device=scores.device)
                to_neg = all_local[~torch.isin(all_local, abs_idx)]
                scores[abs_idx] = 1e30
                scores[to_neg] = -float("inf")
                break
            layer_kept[li] = kept
        if valid:
            break
    kept_total = sum(int(v.numel()) for v in layer_kept.values())
    assert kept_total == global_keep, f"Kept {kept_total} vs expected {global_keep}"
    return layer_kept


# =========================
# Structural changes in a layer
# =========================

def _apply_head_pruning_in_layer(
    layer: nn.Module,
    keep_positions: torch.Tensor,         
    current_active_heads: torch.Tensor,   
) -> torch.Tensor:

    device = keep_positions.device
    sa = layer.attention.self
    so = layer.attention.output
    q, k, v = _get_qkv(sa)

    curr_heads = sa.num_attention_heads
    head_dim = sa.attention_head_size
    hidden = _bert_hidden_size(layer)
    new_heads = int(keep_positions.numel())
    new_all = new_heads * head_dim

    # --- Slice Q/K/V (rows) ---
    def slice_qkv(lin: nn.Linear) -> nn.Linear:
        W = lin.weight.data  
        B = lin.bias.data if lin.bias is not None else None
        W = W.view(curr_heads, head_dim, hidden)[keep_positions]      
        W = W.reshape(new_all, hidden).contiguous()
        new = nn.Linear(hidden, new_all, bias=(B is not None), device=W.device)
        new.weight.data = W
        if B is not None:
            b = B.view(curr_heads, head_dim)[keep_positions].reshape(new_all).contiguous()
            new.bias.data = b
        return new

    sa.query = slice_qkv(q)
    sa.key   = slice_qkv(k)
    sa.value = slice_qkv(v)

    keep_cols = _chunk_indices_for_heads(keep_positions, head_dim)  
    so.dense = _slice_linear_in_features(so.dense, keep_cols)       

    sa.num_attention_heads = new_heads
    sa.all_head_size = new_all
    kept_original_ids = current_active_heads[keep_positions].detach().clone().to(device)
    sa.active_heads = kept_original_ids  # for get_active_heads()

    return kept_original_ids



def structurally_prune_attention_heads_bert(
    model: nn.Module,
    num_heads_to_prune: Union[int, Dict[int, List[int]]],
    strategy: str,
    context: str = "layer",
    min_heads: int = 4,
    current_active_heads: Optional[Dict[int, torch.Tensor]] = None,
    dataloader: Optional[DataLoader] = None,
    device: Optional[torch.device] = None,
    verbose: bool = False,
    stochastic: bool = False,
    seed: Optional[int] = None,
    use_bias: bool = False,
) -> nn.Module:
    """
    Structurally prune BERT attention heads by physically slicing Q/K/V and the output projection.

    Parameters
    ----------
    model : nn.Module
        A HuggingFace BERT-like model with .bert.encoder.layer[*].attention.self{query,key,value}
        and .bert.encoder.layer[*].attention.output.dense.
    num_heads_to_prune : int | dict
        - If int (for non-predefined strategies): prune this many heads per layer.
        - If dict and strategy == "predefined":
            {layer_idx: [ORIGINAL head IDs to remove]}.
            Example: {7: [0,1,2,3,10], 11: [2,4,5,10], ...}
        - If dict and strategy != "predefined":
            {n_prune: [layer_idx, ...]} meaning prune `n_prune` heads at those layers.
    strategy : {"magnitude","gradient","gradient_squared","attention_scores","random","predefined"}
        Scoring/selection strategy. "predefined" removes exactly the provided original head IDs.
    context : {"layer","global","layer-global"}, default="layer"
        How to allocate pruning budgets. "predefined" requires "layer".
    min_heads : int, default=4
        Minimum heads to keep per layer (used by global / layer-global).
    current_active_heads : Optional[Dict[int, Tensor]]
        Mapping layer -> tensor of ORIGINAL head IDs that are currently active.
        If None, assumed to be range(num_heads) at each layer.
    dataloader, device : needed for gradient/attention_scores strategies.
    verbose : bool
        Print layer-wise info.
    stochastic : bool
        Use Bernoulli sampling (softmax-proportional) instead of strict top-k.
    seed : Optional[int]
        RNG seed for stochastic/random strategies.
    use_bias : bool
        Include q/k/v biases in magnitude/gradient scoring.

    Returns
    -------
    pruned : nn.Module
        Deep-copied model with heads structurally removed.
    """
    allowed_strategies = ["magnitude", "gradient", "gradient_squared", "attention_scores", "random", "predefined"]
    allowed_contexts = ["layer", "global", "layer-global"]
    assert strategy in allowed_strategies, f"Unknown strategy: {strategy}"
    assert context in allowed_contexts, f"Unknown context: {context}"

    if strategy == "predefined":
        assert context == "layer", "strategy='predefined' only supports context='layer'."

    if strategy in ["gradient", "gradient_squared", "attention_scores"]:
        assert dataloader is not None and device is not None, "This strategy requires dataloader and device."
    if strategy in ["random", "predefined"]:
        if dataloader is not None:
            warnings.warn(f"{strategy} does not need dataloader; ignoring.", UserWarning)
        if device is not None and strategy == "random":
            warnings.warn(f"{strategy} does not need device; ignoring.", UserWarning)

    device = device or next(model.parameters()).device
    pruned = copy.deepcopy(model).to(device)
    layers: List[nn.Module] = list(pruned.bert.encoder.layer)

    for i, lyr in enumerate(layers):
        lyr._idx = i

    original_heads = [_bert_layer_num_heads(lyr) for lyr in layers]
    Hmax = max(original_heads)

    if current_active_heads is None:
        current_active_heads = {i: torch.arange(original_heads[i], device=device) for i in range(len(layers))}
    else:
        current_active_heads = {i: v.to(device) for i, v in current_active_heads.items()}

    if isinstance(num_heads_to_prune, int):
        layer_to_budget: Dict[int, Union[int, torch.Tensor]] = {
            i: num_heads_to_prune for i in range(len(layers))
        }
    elif isinstance(num_heads_to_prune, dict):
        if strategy == "predefined":
            layer_to_budget = {
                int(li): torch.tensor(sorted(set(map(int, prune_ids))), device=device, dtype=torch.long)
                for li, prune_ids in num_heads_to_prune.items()
            }
        else:
            tmp = {li: n for n, lay in num_heads_to_prune.items() for li in lay}
            layer_to_budget = {i: tmp.get(i, 0) for i in range(len(layers))}
    else:
        raise TypeError("num_heads_to_prune must be int or dict")

    model_head_scores = torch.full((len(layers), Hmax), -float("inf"), device=device)


    if context == "layer":
        for li, layer in enumerate(layers):
            curr_active = current_active_heads[li]             
            curr_heads = curr_active.numel()

            if strategy == "predefined":
                prune_ids_orig = layer_to_budget.get(li, None)  
                if prune_ids_orig is None or prune_ids_orig.numel() == 0:
                    keep_positions = torch.arange(curr_heads, device=device)
                else:
                    invalid = prune_ids_orig[(prune_ids_orig < 0) | (prune_ids_orig >= original_heads[li])]
                    if invalid.numel() > 0:
                        raise ValueError(
                            f"Layer {li}: invalid head indices to prune {invalid.tolist()} "
                            f"for layer with {original_heads[li]} heads."
                        )
                    mask_keep = ~torch.isin(curr_active, prune_ids_orig)
                    keep_positions = torch.nonzero(mask_keep, as_tuple=False).squeeze(-1)
                    if keep_positions.numel() == 0:
                        raise ValueError(
                            f"Layer {li}: predefined pruning would remove all heads. "
                            f"Please leave at least one."
                        )
                keep_positions = keep_positions.to(device)

            else:
                if strategy != "random":
                    scores = _score_heads_in_layer(
                        layer, strategy, pruned, dataloader, device,
                        compute_grads=True, use_bias=use_bias
                    )  
                else:
                    scores = torch.ones(curr_heads, device=device)  

                budget = int(layer_to_budget[li])
                assert budget < curr_heads, f"Cannot prune {budget} >= {curr_heads} heads at layer {li}"

                keep_positions = _get_layer_heads_pruning(
                    scores,
                    original_num_heads=curr_heads,
                    num_heads_to_prune=budget,
                    stochastic=stochastic,
                    seed=(seed if strategy != "random" else (seed or 0) + 31 * (li + 1)),
                    random_select=(strategy == "random"),
                )

            new_map = _apply_head_pruning_in_layer(layer, keep_positions, curr_active)
            current_active_heads[li] = new_map  

            if verbose:
                print(f"[Layer {li}] keep heads {new_map.cpu().tolist()} (count={new_map.numel()})")

        _attach_getters_heads(pruned)
        return pruned


    if strategy != "random" and strategy != "predefined":
        for li, layer in enumerate(layers):
            sc = _score_heads_in_layer(
                layer, strategy, pruned, dataloader, device,
                compute_grads=False, use_bias=use_bias
            )
            model_head_scores[li, :sc.numel()] = sc

    if context == "global":
        assert isinstance(num_heads_to_prune, int), "global context expects an int budget per layer"
        keep_per_layer = _get_global_heads_pruning(
            model_heads_scores=model_head_scores,
            per_layer_heads=original_heads,
            per_layer_prune=num_heads_to_prune,
            min_heads=min_heads,
            stochastic=stochastic,
            seed=seed,
        )
        for li, layer in enumerate(layers):
            curr_active = current_active_heads[li]
            keep_positions = keep_per_layer[li].to(device)
            new_map = _apply_head_pruning_in_layer(layer, keep_positions, curr_active)
            current_active_heads[li] = new_map
            if verbose:
                print(f"[Layer {li}] keep heads {new_map.cpu().tolist()}")
        _attach_getters_heads(pruned)
        return pruned

    assert isinstance(num_heads_to_prune, int), "layer-global expects an int budget per layer"
    keep_per_layer = _get_global_heads_pruning(
        model_heads_scores=model_head_scores,
        per_layer_heads=original_heads,
        per_layer_prune=num_heads_to_prune,
        min_heads=min_heads,
        stochastic=False,  
        seed=seed,
    )
    layer_budget_counts = {li: original_heads[li] - keep_per_layer[li].numel() for li in range(len(layers))}

    for li, layer in enumerate(layers):
        curr_active = current_active_heads[li]
        curr_heads = curr_active.numel()
        if strategy != "random":
            sc = _score_heads_in_layer(
                layer, strategy, pruned, dataloader, device,
                compute_grads=True, use_bias=use_bias
            )
        else:
            sc = torch.ones(curr_heads, device=device)
        budget = layer_budget_counts[li]
        keep_positions = _get_layer_heads_pruning(
            sc, curr_heads, budget,
            stochastic=stochastic, seed=(seed if strategy != "random" else (seed or 0) + 31 * (li + 1)),
            random_select=(strategy == "random"),
        )
        new_map = _apply_head_pruning_in_layer(layer, keep_positions, curr_active)
        current_active_heads[li] = new_map
        if verbose:
            print(f"[Layer {li}] keep heads {new_map.cpu().tolist()}")

    _attach_getters_heads(pruned)
    return pruned




def _attach_getters_heads(model: nn.Module):
    def get_active_heads(self) -> Dict[int, torch.Tensor]:
        d = {}
        for i, layer in enumerate(self.bert.encoder.layer):
            sa = layer.attention.self
            if hasattr(sa, "active_heads"):
                d[i] = sa.active_heads.sort().values
            else:
                d[i] = torch.arange(sa.num_attention_heads)
        return d
    model.get_active_heads = lambda: get_active_heads(model) 