import torch 
from torch.utils.data import DataLoader
import torch.nn.functional as F

# --- Gradient for Taylor ---
def get_grads(
        model: torch.nn.Module, 
        dataloader: DataLoader, 
        device: torch.device,
        text_logits = None
        ) -> None:
    """
    Performs a single training step to compute gradients for the model
    using Cross-Entropy.

    Args:
        model (torch.nn.Module): Model whose gradients are to be computed.
        dataloader (DataLoader): A PyTorch dataloader providing input data.
        device (torch.device): Device to run the forward and backward pass on.
        text_logits (torch.Tensor, optional): Precomputed text logits for CLIP models. 
    """
    model.train()
    model.to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        if text_logits is not None:
            image_logits = model.encode_image(images, normalize=True)
            outputs = model.logit_scale.exp() * image_logits @ text_logits
        else:
            outputs = model(images)
        loss = loss_fn(outputs, labels)
        model.zero_grad()
        loss.backward(retain_graph=False)
        break

def compute_attention_head_scores(model, images, device):
    model.eval()
    model.to(device)

    with torch.no_grad():
        images = images.to(device)
        x = model._process_input(images)  # [B, patch, D]
        B, _, _ = x.shape

        # Append class token
        cls_token = model.class_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + model.encoder.pos_embedding

        all_head_scores = []

        for layer in model.encoder.layers:
            x_norm = layer.ln_1(x)

            mha = layer.self_attention
            B, T, E = x_norm.shape
            qkv = F.linear(x_norm, mha.in_proj_weight, mha.in_proj_bias)  # [B, T, 3*D]
            qkv = qkv.view(B, T, 3, mha.num_heads, mha.head_dim)
            qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, heads, T, head_dim]
            q, k, _ = qkv[0], qkv[1], qkv[2]

            attn = torch.matmul(q, k.transpose(-2, -1)) / (mha.head_dim ** 0.5)  # [B, heads, T, T]
            head_scores = attn.abs().mean(dim=(0, 2, 3))  # mean over B, T, T → [heads]
            all_head_scores.append(head_scores.cpu())

            # Update x for next layer
            # Normally x = x + MHA(x) + MLP(x) but we only want MHA(x)
            # so to move to the next layer we simulate that:
            attn_output, _ = layer.self_attention(x_norm, x_norm, x_norm)  # [B, T, D]
            x = x + attn_output
            x = x + layer.mlp(layer.ln_2(x))

        return torch.stack(all_head_scores)  # [num_layers, num_heads]
    
def bernoulli_score_sampling(
        scores, 
        n, 
        seed,
        max_tries,
        temperature = 1e-6
        ):
    """
    Selects `n` indices from `scores` using a random sampling. Probabilities are derived from the scores, ensuring that
    each index (neuron, head etc.) has probability to be dropped equal to the average of the probabilities. Probabilities are computed as:
    `p_i = softmax(scores / temperature)`, where `p_i` is the probability of selecting index `i`. Note that temperature needs to be very small
    to ensure that most important heads/neurons get a significantly higher probability of being selected.  
    """
    device = scores.device
    scores = scores.clone()
    torch.manual_seed(seed)
    selected_set = set()
    all_indices = torch.arange(len(scores), device=device)

    remaining_scores = scores
    remaining_indices = all_indices
    
    for _ in range(max_tries):
        if remaining_scores.sum() <= 0 or remaining_scores.numel() == 0 or len(selected_set) >= n:
            break  # nothing left to sample from

        # Normalize
        probs = torch.softmax(remaining_scores / temperature, dim=0)
        # Sample
        sampled = torch.bernoulli(probs).bool()
        new_selected = remaining_indices[sampled]

        # Add only new indices
        for idx in new_selected.tolist():
            if idx not in selected_set:
                selected_set.add(idx)
                if len(selected_set) >= n:
                    return torch.tensor(list(selected_set)[:n], device=device)
        
        selected_tensor = torch.tensor(list(selected_set), device=device, dtype=remaining_indices.dtype)
        if selected_tensor.numel() > 0:
            keep_mask = ~remaining_indices.unsqueeze(1).eq(selected_tensor).any(dim=1)
        else:
            keep_mask = torch.ones_like(remaining_indices, dtype=torch.bool)
        remaining_scores = remaining_scores[keep_mask]
        remaining_indices = remaining_indices[keep_mask]
        
    # Fallback: top values from unselected
    if len(selected_set) < n and remaining_scores.numel() > 0:
        topk = torch.topk(remaining_scores, k=min(n - len(selected_set), remaining_scores.numel())).indices
        for idx in remaining_indices[topk].tolist():
            selected_set.add(idx)
            if len(selected_set) >= n:
                break
    
    return torch.tensor(list(selected_set), device=device)