import torch
from transformers import BertTokenizer, BertConfig, BertForMaskedLM

def bert_generate_hooks_and_model(n, device, head_outputs, sampling="uniform", num_samples=10):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True)
    model = BertForMaskedLM.from_pretrained("bert-base-uncased", config=config).to(device)
    model.eval()

    def register_hooks():
        for layer_id, block in enumerate(model.bert.encoder.layer):
            def make_hook(layer_id):
                def hook(module, input, output):
                    if isinstance(output, tuple):
                        output = output[0]
                    B, T, C = output.size()
                    n_heads = module.num_attention_heads
                    head_dim = C // n_heads
                    output_heads = output.view(B, T, n_heads, head_dim).permute(0, 2, 1, 3)
                    head_outputs[layer_id] = output_heads
                return hook
            block.attention.self.register_forward_hook(make_hook(layer_id))
    register_hooks()

    n_heads = model.config.num_attention_heads
    n_layers = model.config.num_hidden_layers
    vocab_size = model.config.vocab_size
    params = (n_heads, n_layers, vocab_size)

    samples = []
    attention_mask = None

    if sampling == "uniform":
        # Pick the samples as random sequences of tokens from the vocabulary,
        # setting the first token to CLS and the last token to SEP.
        # Then use the tokenizer to convert them to embeddings.
        for _ in range(num_samples):
            input_ids = torch.randint(0, vocab_size, (1, n), device=device)
            input_ids[0, 0] = tokenizer.cls_token_id
            input_ids[0, -1] = tokenizer.sep_token_id
            x = model.bert.embeddings(input_ids).detach().to(device).requires_grad_()
            samples.append(x)
            attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
    elif sampling == "model":
        # BERT doesn't have natural text generation like GPT-2, so we'll use masked language modeling
        # We'll create sequences with some masked tokens and let BERT predict them
        print("Warning: BERT doesn't support natural text generation. Using uniform sampling instead.")
        for _ in range(num_samples):
            input_ids = torch.randint(0, vocab_size, (1, n), device=device)
            input_ids[0, 0] = tokenizer.cls_token_id
            input_ids[0, -1] = tokenizer.sep_token_id
            x = model.bert.embeddings(input_ids).detach().to(device).requires_grad_()
            samples.append(x)
            attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
    else:
        raise ValueError("Invalid sampling method. Use 'uniform' or 'model' (falls back to uniform for BERT).")

    return model, samples, attention_mask, params