from transformers import PreTrainedModel
from typing import Optional, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from math import sqrt

B, H, dict_size = 4, 16, 64
topk1 = 3
topk2 = 3
num_keys = int(sqrt(dict_size))

x = torch.randn(B, H).cuda()
W_enc = torch.randn(H, 2 * num_keys).cuda()
b_enc = torch.zeros(2 * num_keys).cuda()

    
acts = F.relu(x @ W_enc + b_enc).view(B, 2, -1)
scores, indices = acts.topk(
    k=topk1, dim=-1, sorted=False
)  # B x 2 x topk2

scores_x, scores_y = scores[..., 0, :], scores[..., 1, :]
# outer sum of scores (k1 + k2) for all (k1, k2) pairs from Cartesian prod
all_scores = (scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)).view(-1)

indices_x, indices_y = indices[..., 0, :], indices[..., 1, :]
# obtaining shifted indices for Cartesian prod pairs
all_indices = (((indices_x * num_keys).unsqueeze(-1) + indices_y.unsqueeze(-2)).view(B, -1) + torch.arange(B, device=x.device, dtype=x.dtype).unsqueeze(-1) * dict_size).view(-1)

# top-k to choose final K candidates of K^2
scores, pk_indices = all_scores.topk(k=topk2 * B, dim=-1, sorted=False)
indices = all_indices.gather(dim=-1, index=pk_indices)

acts_topk = (
    torch.zeros(B * dict_size)
    .scatter(-1, indices, scores)
    .reshape(B, dict_size)
)