

from sacrebleu import metrics
import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
from collections import Counter

from fairseq.pdb import distributed_set_trace
from pdb import set_trace as bp
from fairseq.utils import relu_squared
from fairseq.modules import gelu
from fairseq.modules.fused_bias_gelu import fused_bias_gelu, has_fused_bias_gelu, has_megatron_fused_kernels, load_megatron_fused_kernel
from fairseq.modules.fused_bias_relu_squared import fused_bias_relu_squared
from fairseq import utils

from time import time
import faiss
from typing import Optional, Union, List
import scann
import logging
logger = logging.getLogger(__name__)


def _linear(x, weight, bias=None):
    return F.linear(x, weight, bias)

def _fc1(
    x,
    fc1,
    activation_fn,
    activation_dropout_module,
):
    # has_fused_bias_gelu = False # at eval, for some reason, it will leads to error in TorchScript.
    if has_fused_bias_gelu and activation_fn == gelu:
        x = _linear(x, fc1.weight)
        x = fused_bias_gelu(x, fc1.bias)
    elif activation_fn == relu_squared:
        x = _linear(x, fc1.weight)
        x = fused_bias_relu_squared(x, fc1.bias)
    else:
        x = _linear(x, fc1.weight, fc1.bias)
        x = activation_fn(x)
    x = activation_dropout_module(x)
    return x

def get_phi(xb):
    return (xb ** 2).sum(1).max()

def get_pos(x):
    return x * (x > 0)

def augment_xb(xb, phi=None): 
    norms = (xb ** 2).sum(1)
    if phi is None: 
        phi = norms.max()
    extracol = np.sqrt(phi - norms)
    return np.hstack((xb, extracol.reshape(-1, 1)))

def augment_xb_torch(xb, phi=None): 
    norms = (xb ** 2).sum(1)
    if phi is None: 
        phi = norms.max()
    extracol = torch.sqrt(phi - norms)
    return torch.cat((xb, extracol.reshape(-1, 1)), dim=-1)

def augment_xq(xq): 
    extracol = np.zeros(len(xq), dtype=xq.dtype)
    return np.hstack((xq, extracol.reshape(-1, 1)))

def augment_xq_torch(xb):
    # distributed_set_trace()
    extracol = torch.zeros(len(xb), dtype=xb.dtype)
    return torch.cat((xb, extracol.reshape(-1, 1)), dim=-1)

def mask_from_topk_index(shape, topk_indices: torch.Tensor, device=torch.device("cpu"), pad_idx=-1):
    """
    hidden_state (Tensor): bsz, hidden_dim
    topk_index may contains padding index(-1)
    """
    # distributed_set_trace()
    mask = torch.zeros(shape, device=device)
    if not torch.any(topk_indices == pad_idx):
        mask.scatter_(dim=-1, index=topk_indices, value=1)
        return mask

    topk_indices = topk_indices.clone()
    padding_mask = topk_indices == pad_idx
    max_index = mask.size(-1) - 1
    has_last_idx = torch.any(topk_indices == max_index, dim=-1)
    # -1 will casue CUDA device-side assertion error
    topk_indices[padding_mask] = max_index
    # for testing purpose, create the mask for hidden_states similar to how we did it below.
    # 1: need to keep; mask(a.k.a. 0)=zero-out
    # here we may incorrectly keep(setting to 1) the last one bc padding index(-1) also means the last position
    mask.scatter_(dim=-1, index=topk_indices, value=1) 
    # we need to zero the input where ONLY padding (but no last index) exists
    needs_unmask = torch.logical_and(torch.any(padding_mask, dim=-1), ~has_last_idx)
    mask[needs_unmask, max_index] = 0
    del needs_unmask
    del padding_mask
    del topk_indices
    return mask


class FFN(nn.Module):
    """
    Feed Forward Network layer in the Transformer model. 
    It's restructured to favour other topk wrapper
    
    fc1 (nn.Module): Linear transformation(model_dim -> ffn_dim).
    activation_fn (Function): activation function for the FFN's hidden states
    activation_dropout_module (nn.Module or nn.functional):
        activation dropout module for the FFN's hidden states
    fc2 (nn.Module): Linear transformation(ffn_dim -> model_dim).
    dropout_module (nn.Module or nn.functional):
        dropout module for the FFN's fc2's output
    ffn_ln (nn.Module or nn.functional):
        optional LayerNorm for FFN's hidden states
    k (int): k for topk selection. 
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        **kwargs
    ):
        super().__init__()
        self.args = args
        self.fc1 = fc1
        self.ffn_dim, self.model_dim = fc1.weight.shape
        self.activation_fn = activation_fn
        self.activation_dropout_module = activation_dropout_module
        self.fc2 = fc2
        self.dropout_module = dropout_module
        self.ffn_ln = ffn_ln
        self.k = k
        # this variable saves the metadata for the current call of `foward()`
        self.transient_metadata = {}
    
    def _fc1(self, x):
        """
        x: reshaped input to the layer of shape `(seq_len * batch, embed_dim)`
        x_shape: the original shape of x of shape `(seq_len, batch, embed_dim)`
        """
        # fc1 + activation + activation_dropout
        return _fc1(x, self.fc1, self.activation_fn, self.activation_dropout_module)

    def _ffn_ln(self, x):
        return self.ffn_ln(x) if self.ffn_ln is not None else x
        
    def _fc2(self, x):
        x = _linear(x, self.fc2.weight, self.fc2.bias)
        x = self.dropout_module(x)
        return x

    def _optionally_sparsity_and_record_metadata(self, x, x_shape, k, **kwargs):
        self.transient_metadata["positve_percent"] = self._calculate_mean((x > 0).float().mean(-1), *x_shape[:-1])
        self.transient_metadata["topk_positive_weight_percent"] = None
        self.transient_metadata["topk_recall"] = None
        self.transient_metadata["gold_positive_weight_recovered_percent"] = None
        self.transient_metadata["topk_search_time"] = None
        self.transient_metadata["gold_topk_record_time"] = None
        return x

    @staticmethod
    def _calculate_mean(x, tokens_per_sample, num_samples):
        assert len(x.shape) == 1
        assert x.shape[0] == tokens_per_sample * num_samples
        x = x.detach()
        return x.reshape(tokens_per_sample, num_samples).T.mean(-1).mean()

    def forward(self, x, k: Optional[int] = None, **kwargs):
        """
        x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
        k (int): allow user to override k on the fly. 
            Might be useful for try different k during eval
        """
        cur_topk = k or self.k
        self.transient_metadata = {}  # reset to store metadata of current pass 
        x_shape = x.shape
        t0 = time()
        # distributed_set_trace()
        x = x.reshape(-1, x.size(-1))
        x = self._fc1(x)
        x = self._ffn_ln(x)
        self.transient_metadata["fc1_time"] = torch.tensor(time() - t0)
        t0 = time()
        # distributed_set_trace()
        x = self._optionally_sparsity_and_record_metadata(x, x_shape, cur_topk, **kwargs)
        self.transient_metadata["sparsity_time"] = torch.tensor(time() - t0)
        t0 = time()
        x = self._fc2(x)
        self.transient_metadata["fc2_time"] = torch.tensor(time() - t0)
        x = x.view(x_shape)
        return x


class NaiveTopK(FFN):
    """
    This is normal FFN by with naive topk to enforce row-wise sparsity. 
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        record_p=0.25, # this is for recording how many indices compose top/bottom p usage of memory
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, **kwargs
        )
        assert self.k is not None
        # self.mem_usage_counter = torch.zeros(self.ffn_dim, dtype=torch.int64)
        self.record_p = record_p

    @staticmethod
    def _entr(count_tensor):
        prob_distribution = count_tensor / count_tensor.sum()
        return torch.sum(torch.special.entr(prob_distribution))
    
    @staticmethod
    def _top_p(count_tensor, p, descending=True):
        prob_distribution = count_tensor / count_tensor.sum()
        sorted_p, _ = torch.sort(prob_distribution, descending=descending)
        cumulative_probs = torch.cumsum(sorted_p, dim=-1)
        sorted_indices_to_remove = cumulative_probs > p
        return sorted_indices_to_remove.nonzero()[0,0]
    
    @torch.no_grad()
    def _record_mem_usage(self, index_to_keep):
        # assert len(self.mem_usage_counter) == self.ffn_dim
        # assert (self.mem_usage_counter >= 0).all()
        # self.mem_usage_counter = self.mem_usage_counter.to(index_to_keep.device)
        self.transient_metadata["uniform_mem_usage_entropy"] = torch.log2(torch.tensor(self.ffn_dim))
        # Note: we could have check each sentence/token; but that needs for-loop;
        # aggregate still tells whether some cell becomes stale(unused at all)
        cur_mem_usage = torch.bincount(index_to_keep.view(-1))
        
        self.transient_metadata["cur_mem_usage_entropy"] = self._entr(cur_mem_usage)
        self.transient_metadata["cur_mem_unused"] = self.ffn_dim - torch.count_nonzero(cur_mem_usage).float()
        self.transient_metadata["cur_top_half"] = self._top_p(cur_mem_usage, p=0.5, descending=True) / self.ffn_dim
        self.transient_metadata["cur_top_p"] = self._top_p(cur_mem_usage, p=self.record_p, descending=True) / self.ffn_dim
        self.transient_metadata["cur_bottom_p"] = self._top_p(cur_mem_usage, p=self.record_p, descending=False) / self.ffn_dim

    def _optionally_sparsity_and_record_metadata(self, x, x_shape, k, **kwargs):
        """
        Enforce naive topk by masking: 
            1. x -- ffn's hidden state
            2. find topk of the hidden state 
            3. only keep topk elements in the vector
        """
        self.transient_metadata["positve_percent"] = self._calculate_mean((x > 0).float().mean(-1), *x_shape[:-1])
        pos_x_sum = torch.sum(get_pos(x.detach()), dim=-1)
        self.transient_metadata["topk_recall"] = None
        self.transient_metadata["gold_positive_weight_recovered_percent"] = None
        # obtain a mask to zero out entries in ffn's hidden states
        if k < self.ffn_dim // 2:
            # have to use masking, in-place assignment will cause cuda error
            # if k < 1/2 ffn_dim, we take the largest k
            mask = torch.zeros_like(x)
            value_to_keep, index_to_keep = torch.topk(x, k=k, dim=-1, largest=True)
            nonneg_value_to_keep_sum = torch.sum(get_pos(value_to_keep.detach()), dim=-1)
            del value_to_keep
            self.transient_metadata["topk_positive_weight_percent"] = self._calculate_mean(nonneg_value_to_keep_sum / pos_x_sum, *x_shape[:-1])
            self._record_mem_usage(index_to_keep)
            mask.scatter_(dim=-1, index=index_to_keep, value=1)
        else:
            # if k >= 1/2 ffn_dim, we take the smallest (ffn_dim - k)
            mask = torch.ones_like(x)
            value_to_keep, index_to_keep = torch.topk(x, k=self.ffn_dim - k, dim=-1, largest=False)
            nonneg_value_to_keep_sum = torch.sum(get_pos(value_to_keep.detach()), dim=-1)
            del value_to_keep
            self.transient_metadata["topk_positive_weight_percent"] = self._calculate_mean(1 - nonneg_value_to_keep_sum / pos_x_sum, *x_shape[:-1])
            mask.scatter_(dim=-1, index=index_to_keep, value=0)
        x *= mask
        del mask
        return x

class NaiveTopKBlock(FFN):
    """
    This is normal FFN by with naive topk to enforce row-wise sparsity. 
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        block_size = 2,
        agg_method="avg",
        one_per_block=False,
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, **kwargs
        )
        assert self.k is not None
        self.one_per_block = one_per_block
        if one_per_block:
            assert self.ffn_dim % k == 0
            self.block_size = self.ffn_dim // k
            self.B = k
            self.offset = torch.arange(self.B) # starting idx of each block   
        else:
            self.block_size = block_size
            assert block_size > 1, "Block size needs to be larger than 1, otherwise use NaiveTopK"
            assert k % block_size == 0
            assert self.ffn_dim % block_size == 0
            self.b = k // block_size # number of block need choosing
            self.B = self.ffn_dim // block_size
            self.offset = torch.arange(block_size) # all idxs of each block (all block share this)
            if agg_method == "avg":
                self.agg_method = lambda x: torch.mean(x, dim=-1)
            elif agg_method == "max":
                self.agg_method = lambda x: torch.max(x, dim=-1)[0]
            elif agg_method == "min":
                self.agg_method = lambda x: torch.min(x, dim=-1)[0]
            elif agg_method == "abs+avg":
                self.agg_method = lambda x: torch.mean(torch.abs(x), dim=-1)
            else:
                raise ValueError(f"Aggregation method -- {agg_method} -- not supported")
        
        # self.mem_usage_counter = torch.zeros(self.ffn_dim, dtype=torch.int64)
        # self.record_p = record_p

    def _optionally_sparsity_and_record_metadata(self, x, x_shape, k, **kwargs):
        """
        Enforce naive topk by masking: 
            1. x -- ffn's hidden state
            2. find topk of the hidden state 
            3. only keep topk elements in the vector
        """
        n_toks = x.shape[0]
        x_in_blocks = x.reshape(n_toks, self.B, -1)
        if self.offset.device != x.device:
            self.offset = self.offset.to(x.device)
        mask = torch.zeros_like(x)
        if self.one_per_block:
            _, max_idx_per_block = torch.max(x_in_blocks, dim=-1)
            topk_idx = self.offset * self.block_size + max_idx_per_block
        else:
            aggregated_score = self.agg_method(x_in_blocks)
            _, topk_block_indices = torch.topk(aggregated_score, k=self.b, dim=-1)
            topk_idx = topk_block_indices.unsqueeze(-1) * self.block_size + self.offset
            topk_idx = topk_idx.flatten(-2,-1)
            # distributed_set_trace()
        mask.scatter_(dim=-1, index=topk_idx, value=1)
        x *= mask
        del mask
        return x
    

class StaticTopK(FFN):
    """
    This is normal FFN by with naive topk to enforce row-wise sparsity. 
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        block_size = 1,
        agg_method="avg",
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, **kwargs
        )
        assert self.k is not None
        
        self.block_size = block_size
        
        assert block_size == 1
        assert k % block_size == 0
        assert self.ffn_dim % block_size == 0
        self.b = k // block_size # number of block need choosing
        self.B = self.ffn_dim // block_size
        self.offset = torch.arange(block_size) # all idxs of each block (all block share this)
        
        fc1_weight_magnitude = torch.abs(self.fc1.weight.data)
         
        if agg_method == "avg":
            self.agg_method = lambda x: torch.mean(x, dim=-1)
        else:
            raise ValueError(f"Aggregation method -- {agg_method} -- not supported")
        
        fc1_weight_importance = self.agg_method(fc1_weight_magnitude)
        
        _, topk_block_indices = torch.topk(fc1_weight_importance, k=self.b, dim=-1)
        mask = torch.zeros(self.ffn_dim)
        mask.scatter_(dim=-1, index=topk_block_indices, value=1)
        mask = mask.unsqueeze(0)
        self.register_buffer('mask', mask)

    def _optionally_sparsity_and_record_metadata(self, x, x_shape, k, **kwargs):
        """
        Enforce naive topk by masking: 
            1. x -- ffn's hidden state
            2. find topk of the hidden state 
            3. only keep topk elements in the vector
        """
        mask = self.mask if self.mask.dtype == x.dtype else self.mask.to(dtype=x.dtype)
        x *= mask
        return x


class NaiveTopKLowRank(NaiveTopK):
    """
    This is normal FFN by with naive topk to enforce row-wise sparsity. 
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        record_p=0.25, # this is for recording how many indices compose top/bottom p usage of memory
        rank=128,
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, record_p, **kwargs
        )
        assert self.k is not None
        # self.mem_usage_counter = torch.zeros(self.ffn_dim, dtype=torch.int64)
        self.record_p = record_p
        del self.fc1 
        # distributed_set_trace()
        self.rank = rank
        self.low_rank_proj = nn.Linear(self.model_dim, self.rank)
        self.low_rank_key = nn.Linear(self.rank, self.ffn_dim)

    def _fc1(self, x):
        # low-rank projection
        x = _linear(x, self.low_rank_proj.weight, self.low_rank_proj.bias)
        x = _fc1(x, self.low_rank_key, self.activation_fn, self.activation_dropout_module)
        return x


class NaiveANNTopK(NaiveTopK):
    """
    This is normal FFN by with naive topk to enforce row-wise sparsity. 
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        acc_level: float,
        ffn_ln=None,
        k = None,
        record_p=0.25, # this is for recording how many indices compose top/bottom p usage of memory
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, record_p, **kwargs
        )
        self.acc_level = acc_level
        self.num_wrong = int((1 - self.acc_level) * self.k)
        assert self.num_wrong <= (self.ffn_dim - k)
        assert self.num_wrong <= self.k

    def _optionally_sparsity_and_record_metadata(self, x, x_shape, k, **kwargs):
        """
        Enforce naive topk by masking: 
            1. x -- ffn's hidden state
            2. find topk of the hidden state 
            3. only keep topk elements in the vector
        """
        if self.num_wrong == 0:
            return super()._optionally_sparsity_and_record_metadata(x, x_shape, k, **kwargs)
        # make sure we have enough distractor
        assert self.num_wrong <= (self.ffn_dim - k)
        mask = torch.zeros_like(x)
        sorted_vals, indices = torch.sort(x, dim=-1, descending=True)
        # Below, we use random noise to simulate ANN for a given accuracy
        # devide hidden_states into topk and non-topk section
        with torch.no_grad():
            self.transient_metadata["positve_percent"] = self._calculate_mean((x > 0).float().mean(-1), *x_shape[:-1])
            # pos_x_sum = torch.sum(get_pos(x.detach()), dim=-1)
            self.transient_metadata["topk_recall"] = None
            self.transient_metadata["gold_positive_weight_recovered_percent"] = None
            gold_topk_vals = sorted_vals[:, :k].detach()
            gold_topk_indices = indices[:, :k].detach()
            distractor_vals = sorted_vals[:, k:].detach()
            distractor_indices = indices[:, k:].detach()
            # adding noise by randomly swap out topk with nontopk
            # generate a random map for each token in the batch: one for topk, the other for non-topk
            rand_topk_ids_pos = torch.rand(gold_topk_indices.shape, device=x.device).argsort(dim=-1)[:, :self.num_wrong]
            rand_nontopk_ids_pos = torch.rand(distractor_indices.shape, device=x.device).argsort(dim=-1)[:, :self.num_wrong]
            # get random values and indices from non-topk sections
            rand_nontopk_vals = torch.gather(distractor_vals, dim=-1, index=rand_nontopk_ids_pos)
            rand_nontopk_ids = torch.gather(distractor_indices, dim=-1, index=rand_nontopk_ids_pos)

            # do the random mapping to swap out topk indices/values and replace with non-topk ones.
            index_to_keep = torch.scatter(gold_topk_indices, dim=-1, index=rand_topk_ids_pos, src=rand_nontopk_ids)
            value_to_keep = torch.scatter(gold_topk_vals, dim=-1, index=rand_topk_ids_pos, src=rand_nontopk_vals)
            ann_lost_value = gold_topk_vals.sum(-1) - value_to_keep.sum(-1)
            ann_lost_value = ann_lost_value / gold_topk_vals.sum(-1)
            # assert (ann_lost_value >= 0).all()
            self.transient_metadata["ann_lost_value_avg"] = ann_lost_value.mean()
            self.transient_metadata["ann_lost_value_std"] = ann_lost_value.std()
        # ids_overlap = torch.tensor([np.intersect1d(ids.cpu(), index_to_keep[i].cpu()) for i, ids in enumerate(gold_topk_indices)])
        # assert ids_overlap.shape[-1] == k - self.num_wrong   
        # self._record_mem_usage(index_to_keep)

        # distributed_set_trace()
        mask.scatter_(dim=-1, index=index_to_keep, value=1)
        x *= mask
        del mask
        return x


class RandomTopK(FFN):
    """
    When queried, uniformly choose K elements from the table.

    Implementation-wise, we use a randomly initialize fc1 to calculate fc1
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        pad_idx=-1,
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, pad_idx, **kwargs
        )
        self.random_fc1 = nn.Linear(self.model_dim, self.ffn_dim)
    
    def forward_by_masking(self, x: torch.Tensor, k: Optional[int] = None, **kwargs):
        """
        x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
        k (int): allow user to override k on the fly. 
            Might be useful for try different k during eval
        """
        x_shape = x.shape
        x = x.reshape(-1, x.size(-1))
        ffn_hidden = self._fc1(x)
        ffn_hidden = self._ffn_ln(ffn_hidden)
        x = self._optionally_sparsity_and_record_metadata(ffn_hidden, x, x_shape, k, **kwargs)
        x = self._fc2(x)
        x = x.view(x_shape)
        return x

    def _optionally_sparsity_and_record_metadata(self, ffn_hidden, fc1_x, fc1_x_shape, k, **kwargs):
        """
        Enforce naive topk by masking: 
            1. project-up with `self.random_fc1`
            2. find topk of the vector from `self.random_fc1`
            3. Use topk from previous step, only keep topk elements in vector from ffn_hidden
        """
        random_hidden = _fc1(fc1_x, self.random_fc1, self.activation_fn, self.activation_dropout_module).detach()

        self.transient_metadata["positve_percent"] = self._calculate_mean((ffn_hidden > 0).float().mean(-1), *fc1_x_shape[:-1])
        pos_x_sum = torch.sum(get_pos(ffn_hidden), dim=-1).reshape(fc1_x_shape[:-1])
        self.transient_metadata["topk_recall"] = None
        self.transient_metadata["gold_positive_weight_recovered_percent"] = None
        # obtain a mask to zero out entries in ffn's hidden states USING random fc1
        if k < self.ffn_dim // 2:
            # have to use masking, in-place assignment will cause cuda error
            # if k < 1/2 ffn_dim, we take the largest k
            mask = torch.zeros_like(ffn_hidden)
            # find topk from random hidden state
            _, random_index_to_keep = torch.topk(random_hidden, k=k, dim=-1, largest=True)
            # find topk value of hidden states using the "random topk"
            value_to_keep_from_random = torch.gather(ffn_hidden, dim=-1, index=random_index_to_keep).detach()
            nonneg_value_to_keep_sum = torch.sum(get_pos(value_to_keep_from_random), dim=-1)
            del value_to_keep_from_random
            self.transient_metadata["topk_positive_weight_percent"] = self._calculate_mean(
                                                                        nonneg_value_to_keep_sum / pos_x_sum, 
                                                                        *fc1_x_shape[:-1]
                                                                    )
            mask.scatter_(dim=-1, index=random_index_to_keep, value=1)
        else:
            # if k >= 1/2 ffn_dim, we take the smallest (ffn_dim - k)
            mask = torch.ones_like(ffn_hidden)
            _, random_index_to_keep = torch.topk(random_hidden, k=self.ffn_dim - k, dim=-1, largest=False)
            # Note: here we uses x as the source of input
            value_to_keep_from_random = torch.gather(ffn_hidden, dim=-1, index=random_index_to_keep).detach()
            nonneg_value_to_keep_sum = torch.sum(get_pos(value_to_keep_from_random), dim=-1)
            del value_to_keep_from_random
            self.transient_metadata["topk_positive_weight_percent"] = self._calculate_mean(
                                                                        1 - nonneg_value_to_keep_sum / pos_x_sum,
                                                                        *fc1_x_shape[:-1]
                                                                    )
            mask.scatter_(dim=-1, index=random_index_to_keep, value=0)
        ffn_hidden *= mask
        del mask
        return ffn_hidden


class XFFN(FFN):
    """
    This is the subclass of FFN whose sparsity requires access to fc1 input. 
    User needs to inherit this class and implement `topk_search` and `build_index` methods 
    to actually use it.
    * no bias term bc the key-value view
    This class is essentially a wrapper around fc1 and fc2 to enfoce row-wise sparsity. 
    
    Due to implemenation, given search results from `topk_search`, I implment two ways to enforce the row sparity, 
        1. forward_by_masking:
            x = fc1(x)
            x = _optionally_sparsity_and_record_metadata(x)
            x = fc2(x)
            
            `_optionally_sparsity_and_record_metadata` is where the sparsity happens; I use results from 
            `topk_search` to create a big mask same size as `fc1(x)` then apply by `mask * fc1(x)`.
            
        2. forward_by_fully_indexing
            Use `topk_search` results and `torch.index_select` to select row from fc1 and fc2 to calculate
            y = \sum_i non-linear(<x, k_i>) * v_i
    
    pad_idx: some search methods will use padding for no results (default to -1)
    record_p: this is for recording how many indices compose top/bottom p usage of memory
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        record_p=0.25, # this is for recording how many indices compose top/bottom p usage of memory
        pad_idx=-1,
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, **kwargs
        )
        self.record_p = record_p 
        assert self.k is not None
        # we remove bias in memory module, 
        # Note: this line seems to cause CUDA error, fall back to use bias at this point
        # self.fc1_bias = torch.zeros_like(self.fc1.bias.data)
        self.fc1_bias = self.fc1.bias
        # self.fc2.bias = None # we remove bias in memory module
        self.pad_idx = pad_idx
        assert self.pad_idx == -1
        self._num_updates = 0
    
    def set_num_updates(self, num_updates: int):
        """
        This method will be called by the trainer before every forward
        """
        self._num_updates = num_updates
    
    @torch.no_grad()
    def _record_topk_quality(self, ffn_hidden: torch.Tensor, topk_index: torch.Tensor, tokens_per_sample: int, num_samples: int):
        _, k = topk_index.shape
        ffn_hidden = ffn_hidden.detach()
        bsz = ffn_hidden.shape[0]
        # get exact topk results
        gold_values, gold_topk = torch.topk(ffn_hidden, k=k, dim=-1, largest=True)
        # leave only positive entries in the ffn_hidden
        pos_ffn_hidden = get_pos(ffn_hidden)
        # recall
        overlaps = [np.intersect1d(gold_topk[i].cpu(), topk_index[i].cpu()) for i in range(bsz)]
        self.transient_metadata["topk_recall"] = self._calculate_mean(torch.tensor([len(o) / k for o in overlaps]), tokens_per_sample, num_samples)
        # how much gold weight recovered by topk
        positive_weight_recovered = torch.tensor([pos_ffn_hidden[i, overlaps[i]].sum() for i in range(bsz)], device=gold_values.device)
        positive_weight_recovered_percent = positive_weight_recovered / torch.sum(get_pos(gold_values), dim=-1)
        self.transient_metadata["gold_positive_weight_recovered_percent"] = self._calculate_mean(
                                                                                positive_weight_recovered_percent, 
                                                                                tokens_per_sample, num_samples
                                                                            )
    
    @staticmethod
    def _entr(count_tensor):
        prob_distribution = count_tensor / count_tensor.sum()
        return torch.sum(torch.special.entr(prob_distribution))
    
    @staticmethod
    def _top_p(count_tensor, p, descending=True):
        prob_distribution = count_tensor / count_tensor.sum()
        sorted_p, _ = torch.sort(prob_distribution, descending=descending)
        cumulative_probs = torch.cumsum(sorted_p, dim=-1)
        sorted_indices_to_remove = cumulative_probs > p
        return sorted_indices_to_remove.nonzero()[0,0]

    @torch.no_grad()
    def _record_mem_usage(self, index_to_keep):
        # assert len(self.mem_usage_counter) == self.ffn_dim
        # assert (self.mem_usage_counter >= 0).all()
        # self.mem_usage_counter = self.mem_usage_counter.to(index_to_keep.device)
        self.transient_metadata["uniform_mem_usage_entropy"] = torch.log2(torch.tensor(self.ffn_dim))
        # Note: we could have check each sentence/token; but that needs for-loop;
        # aggregate still tells whether some cell becomes stale(unused at all)
        cur_mem_usage = torch.bincount(index_to_keep.view(-1))
        
        self.transient_metadata["cur_mem_usage_entropy"] = self._entr(cur_mem_usage)
        self.transient_metadata["cur_mem_unused"] = self.ffn_dim - torch.count_nonzero(cur_mem_usage).float()
        self.transient_metadata["cur_top_half"] = self._top_p(cur_mem_usage, p=0.5, descending=True) / self.ffn_dim
        self.transient_metadata["cur_top_p"] = self._top_p(cur_mem_usage, p=self.record_p, descending=True) / self.ffn_dim
        self.transient_metadata["cur_bottom_p"] = self._top_p(cur_mem_usage, p=self.record_p, descending=False) / self.ffn_dim

    def forward_by_fully_indexing(self, x: torch.Tensor, k: Optional[int] = None, **kwargs):
        """
        y = \sum_i non-linear(<x, k_i>) * v_i
        """
        x_shape = x.shape
        x = x.reshape(-1, x.size(-1))
        t0 = time()
        topk_index = self.topk_search(x, k=k, **kwargs)
        self.transient_metadata["topk_search_time"] = torch.tensor(time() - t0)
        
        # optionally record the quality of topk searcher
        self.transient_metadata["topk_recall"] = None
        self.transient_metadata["gold_positive_weight_recovered_percent"] = None
        # distributed_set_trace()
        if self.args.record_topk_quality: 
            # this is expensive to run
            t0 = time()
            with torch.no_grad():
                # get exact topk results
                ffn_hidden = self.fc1(x)
                self._record_topk_quality(ffn_hidden, topk_index, *x_shape[:-1])
                del ffn_hidden
            self.transient_metadata["gold_topk_record_time"] = torch.tensor(time() - t0)
        # self._record_mem_usage(topk_index)
        # prepare the possibly padded topk_index for index_select
        padding_mask = topk_index == self.pad_idx
        pad_idx_mod = self.pad_idx % self.ffn_dim
        topk_index[padding_mask] = pad_idx_mod
        selected_kv_shape = topk_index.shape + (-1,)
        index_flatten = topk_index.view(-1)
        # fc1 --- get the coefficients <x, k_i>
        # Note: this line is problematic bc it will COPY tensors, likely leading to OOM, while masking method doesn't
        # t0=time()
        coeffs = torch.index_select(self.fc1.weight, 0, index_flatten).view(selected_kv_shape)
        # index_select_time = time() - t0
        # t0=time()
        # coeffs = self.fc1.weight[topk_index]
        # slicing_time = time() - t0
        # do key computation of  <x, k_i>
        coeffs = torch.einsum('bkd,bd->bk', coeffs, x)
        # non-linear(<x, k_i>)
        coeffs = self.activation_fn(coeffs)
        coeffs = self.activation_dropout_module(coeffs)
        # remove entries with padding
        coeffs *= ~padding_mask
        del padding_mask
        pos_count = (x > 0).float().sum(-1)
        self.transient_metadata["positve_percent"] = self._calculate_mean(pos_count / self.ffn_dim, *x_shape[:-1])
        # because all other non-topk value are 0, so this value is always 1
        self.transient_metadata["topk_positive_weight_percent"] = torch.tensor(1., device=x.device)
        # with index-select, we can't do layer norm (`ffn_ln`); so we ignore it
        # fc2 -- get v_i
        x = torch.index_select(self.fc2.weight.T, 0, index_flatten).view(selected_kv_shape)
        # weighted sum of value vectors --- y = \sum_i <x, k_i> * v_i
        x = torch.einsum('bk,bkd->bd', coeffs, x)
        x = x.view(x_shape)
        return x

    def _optionally_sparsity_and_record_metadata(self, x, fc1_x, fc1_x_shape, k, **kwargs):
        bsz = x.shape[0]
        # use topk searcher to get indices
        t0 = time()
        # distributed_set_trace()
        topk_index = self.topk_search(fc1_x, k=k, **kwargs)
        self.transient_metadata["topk_search_time"] = torch.tensor(time() - t0)
        assert topk_index.shape == (bsz, k)

        # optionally record the quality of topk searcher
        self.transient_metadata["topk_recall"] = None
        self.transient_metadata["gold_positive_weight_recovered_percent"] = None
        t0 = time()
        if self.args.record_topk_quality:
            self._record_topk_quality(x, topk_index, *fc1_x_shape[:-1])
        self.transient_metadata["gold_topk_record_time"] = torch.tensor(time() - t0)
        # self._record_mem_usage(topk_index)
        # t0 = time()
        # create mask from topk_index (also unmask if there are padding in the topk_index)
        mask = mask_from_topk_index((bsz, self.ffn_dim), topk_index, device=fc1_x.device, pad_idx=self.pad_idx)
        # self.transient_metadata["get_mask_time"] = torch.tensor(time() - t0)
        # apply mask
        x *= mask
        # self.transient_metadata["apply_mask_time"] = torch.tensor(time() - t0)
        del mask
        with torch.no_grad():
            self.transient_metadata["positve_percent"] = None # self._calculate_mean((x > 0).float().mean(-1), *fc1_x_shape[:-1])
        self.transient_metadata["topk_positive_weight_percent"] = torch.tensor(1., device=x.device)
        return x
    
    def _fc1(self, x):
        """
        x: reshaped input to the layer of shape `(seq_len * batch, embed_dim)`
        x_shape: the original shape of x of shape `(seq_len, batch, embed_dim)`
        """
        # if self.fc1_bias.device != x.device: 
        # self.fc1_bias = self.fc1_bias.to(self.fc1.weight.device, dtype=self.fc1.weight.dtype)
        
        if has_fused_bias_gelu and self.activation_fn == gelu:
            x = _linear(x, self.fc1.weight)
            x = fused_bias_gelu(x, self.fc1_bias)
        elif self.activation_fn == relu_squared:
            x = _linear(x, self.fc1.weight)
            x = fused_bias_relu_squared(x, self.fc1_bias)
        else:
            x = _linear(x, self.fc1.weight, self.fc1_bias)
            x = self.activation_fn(x)
        x = self.activation_dropout_module(x)
        return x
        
    def forward_by_masking(self, x: torch.Tensor, k: Optional[int] = None, **kwargs):
        """
        x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
        k (int): allow user to override k on the fly. 
            Might be useful for try different k during eval
        """
        t0 = time()
        x_shape = x.shape
        x = x.reshape(-1, x.size(-1))
        ffn_hidden = self._fc1(x) # apply first matrix multiplication with non-linear
        ffn_hidden = self._ffn_ln(ffn_hidden)
        self.transient_metadata["fc1_time"] = torch.tensor(time() - t0)
        t0 = time()
        x = self._optionally_sparsity_and_record_metadata(ffn_hidden, x, x_shape, k, **kwargs) # apply masking to the hiiden_state (after non-linearity)
        self.transient_metadata["sparsity_time"] = torch.tensor(time() - t0)
        t0 = time()
        x = self._fc2(x)  # apply second matrix multiplication
        self.transient_metadata["fc2_time"] = torch.tensor(time() - t0)
        x = x.view(x_shape)
        return x

    def forward(self, x, k: Optional[int] = None, **kwargs):
        """
        Naive: project up, then masking out with index found in topk searcher

        Index-select: avoid projecting up, and using indexing to calculate. 
        However, due to the poor support of indexing at this point, we may not get real speedup
        """
        cur_topk = k or self.k
        self.transient_metadata = {}  # reset to store metadata of current pass 
        # Note: the outputs from the two forward do not equal due to numerical issue beyond my understanding
        if self.args.memory_sparsity_impl == "mask":
            return self.forward_by_masking(x, k=cur_topk, **kwargs)
        elif self.args.memory_sparsity_impl == "index":
            return self.forward_by_fully_indexing(x, k=cur_topk, **kwargs)
        else:
            raise ValueError(f"`{self.args.memory_sparsity_impl}` is not a supported forward function")
    
    def maybe_rebuild_index(self):
        if (self._num_updates + 1) % self.args.index_rebuild_period == 0:
            logger.info("(Re)Building index")
            self.build_index()
            logger.info("Finish (re)building index")

    def build_index(self):
        raise NotImplementedError(f"`{self}` needs `build_index` method to be implemented")

    def topk_search(self, queries, k=None, **kwargs) -> torch.Tensor:
        raise NotImplementedError(f"`{self}` needs `topk_search` method to be implemented")
    

class PKMTopkAugmented(XFFN):
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        record_p=0.25, # this is for recording how many indices compose top/bottom p usage of memory
        pad_idx=-1,
        k_dim=128,
        query_batchnorm=True,
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, record_p, pad_idx, **kwargs
        )
        self.k_dim = k_dim
        self.n_subkeys = np.sqrt(self.ffn_dim)
        assert np.ceil(self.n_subkeys) == np.floor(self.n_subkeys)
        self.n_subkeys = int(self.n_subkeys)
        
        self.query_batchnorm = query_batchnorm
        self.build_index()

    @staticmethod
    def get_uniform_keys(n_keys, dim, seed):
        """
        Generate random uniform keys (same initialization as nn.Linear).
        """
        rng = np.random.RandomState(seed)
        bound = 1 / np.sqrt(dim)
        keys = rng.uniform(-bound, bound, (n_keys, dim))
        return keys.astype(np.float32)    
    
    def build_index(self):
        # project down matrix
        self.query_proj = nn.Sequential(*filter(None, [
            nn.Linear(self.model_dim, self.k_dim, bias=True),
            nn.BatchNorm1d(self.k_dim) if self.query_batchnorm else None
        ]))
        # build factorized key table
        half = self.k_dim // 2
        self.product_keys = nn.ModuleList([
            nn.Linear(half, self.n_subkeys),
            nn.Linear(half, self.n_subkeys)
        ])
    
    def maybe_rebuild_index(self):
        pass

    def topk_search(self, queries, k=None, return_index=False, **kwargs) -> torch.Tensor:
        
        prefix_shape = queries.shape[:-1]
        bs = np.prod(prefix_shape)
        projected_queries = self.query_proj(queries)

        half = self.k_dim // 2
        
        # split query for product quantization
        q1 = projected_queries[:, :half]                                          # (bs,half)
        q2 = projected_queries[:, half:]                                          # (bs,half)
        
        # compute indices with associated scores
        hidden1 = _linear(q1, self.product_keys[0].weight, self.product_keys[0].bias)                 # (bs,n_keys)
        hidden2 = _linear(q2, self.product_keys[1].weight, self.product_keys[1].bias)                 # (bs,n_keys)
        hidden = (hidden1.view(bs, self.n_subkeys, 1) * hidden2.view(bs, 1, self.n_subkeys)).view(bs, -1)
        if return_index:
            value_to_keep, index_to_keep = torch.topk(hidden, k=k, dim=-1, largest=True)
            return value_to_keep, index_to_keep
        return hidden
    
    def _fc1(self, x):
        """
        x: reshaped input to the layer of shape `(seq_len * batch, embed_dim)`
        x_shape: the original shape of x of shape `(seq_len, batch, embed_dim)`
        """
        # if self.fc1_bias.device != x.device: 
        # self.fc1_bias = self.fc1_bias.to(self.fc1.weight.device, dtype=self.fc1.weight.dtype)
        
        x = _linear(x, self.fc1.weight, self.fc1_bias)
        return x
    
    def _optionally_sparsity_and_record_metadata(self, x, fc1_x, fc1_x_shape, k, **kwargs):
        bsz = x.shape[0]
        # use topk searcher to get indices
        t0 = time()
        pk_hidden = self.topk_search(fc1_x, k=k, **kwargs)
        assert pk_hidden.shape == x.shape
        self.transient_metadata["topk_search_time"] = torch.tensor(time() - t0)
        
        assert k < self.ffn_dim // 2
        # create a mask using product key scores
        # have to use masking, in-place assignment will cause cuda error
        # if k < 1/2 ffn_dim, we take the largest k
        # t0 = time()
        mask = torch.zeros_like(x)
        _, index_to_keep = torch.topk(pk_hidden, k=k, dim=-1, largest=True)
        self._record_mem_usage(index_to_keep)
        mask.scatter_(dim=-1, index=index_to_keep, value=1)
        # self.transient_metadata["get_mask_time"] = torch.tensor(time() - t0)
        
        # t0 = time()
        x *= pk_hidden
        # record stats after applying pk_hidden
        # t1 = time()
        pos_x_sum = torch.sum(get_pos(x.detach()), dim=-1)
        value_to_keep, _ = torch.topk(x, k=k, dim=-1, largest=True)
        nonneg_value_to_keep_sum = torch.sum(get_pos(value_to_keep.detach()), dim=-1)
        self.transient_metadata["positve_percent"] = self._calculate_mean((x > 0).float().mean(-1), *fc1_x_shape[:-1])
        self.transient_metadata["topk_positive_weight_percent"] = self._calculate_mean(nonneg_value_to_keep_sum / pos_x_sum, *fc1_x_shape[:-1])
        # self.transient_metadata["record_pos_time"] = torch.tensor(time() - t1)
        # t0 = time()
        x *= mask
        # self.transient_metadata["apply_mask_time"] = torch.tensor(time() - t0)
        # distributed_set_trace()
        del mask
        # save the non-linear activation to the end
        x = self.activation_fn(x)
        x = self.activation_dropout_module(x)
        return x
    
    def forward_by_fully_indexing(self, x: torch.Tensor, k: Optional[int] = None, **kwargs):
        """
        y = \sum_i non-linear(<x, k_i>) * v_i
        """
        x_shape = x.shape
        x = x.reshape(-1, x.size(-1))
        t0 = time()
        topk_value, topk_index = self.topk_search(x, k=k, return_index=True, **kwargs)
        self.transient_metadata["topk_search_time"] = torch.tensor(time() - t0)
        
        # optionally record the quality of topk searcher
        self.transient_metadata["topk_recall"] = None
        self.transient_metadata["gold_positive_weight_recovered_percent"] = None
        # distributed_set_trace()
        if self.args.record_topk_quality: 
            # this is expensive to run
            t0 = time()
            with torch.no_grad():
                # get exact topk results
                ffn_hidden = self.fc1(x)
                self._record_topk_quality(ffn_hidden, topk_index, *x_shape[:-1])
                del ffn_hidden
            self.transient_metadata["gold_topk_record_time"] = torch.tensor(time() - t0)
        # self._record_mem_usage(topk_index)
        selected_kv_shape = topk_index.shape + (-1,)
        index_flatten = topk_index.view(-1)
        # fc1 --- get the coefficients <x, k_i>
        # Note: this line is problematic bc it will COPY tensors, likely leading to OOM, while masking method doesn't
        # t0=time()
        coeffs = torch.index_select(self.fc1.weight, 0, index_flatten).view(selected_kv_shape)
        # index_select_time = time() - t0
        # t0=time()
        # coeffs = self.fc1.weight[topk_index]
        # slicing_time = time() - t0
        # do key computation of  <x, k_i>
        coeffs = torch.einsum('bkd,bd->bk', coeffs, x) * topk_value
        # distributed_set_trace()
        # non-linear(<x, k_i>)
        coeffs = self.activation_fn(coeffs)
        coeffs = self.activation_dropout_module(coeffs)
        pos_count = (coeffs > 0).float().sum(-1)
        self.transient_metadata["positve_percent"] = self._calculate_mean(pos_count / self.ffn_dim, *x_shape[:-1])
        # because all other non-topk value are 0, so this value is always 1
        self.transient_metadata["topk_positive_weight_percent"] = torch.tensor(1., device=x.device)
        # with index-select, we can't do layer norm (`ffn_ln`); so we ignore it
        # fc2 -- get v_i
        x = torch.index_select(self.fc2.weight.T, 0, index_flatten).view(selected_kv_shape)
        # weighted sum of value vectors --- y = \sum_i <x, k_i> * v_i
        x = torch.einsum('bk,bkd->bd', coeffs, x)
        x = x.view(x_shape)
        return x

class ReformerLSH(XFFN):
    """
    Topk search with Reformer's LSH technique.

    num_buckets (`int` or `List[int]`, *optional*):
        Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme.
        Each query key vector is hashed into a hash in `1, ..., num_buckets`. The number of buckets can also be
        factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a
        hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is
        factorized into two factors. The number of buckets (or the product the factors) should approximately equal
        sequence length / lsh_chunk_length. If `num_buckets` not set, a good value is calculated on the fly.
    num_hashes (`int`, *optional*, defaults to 1):
        Number of hashing rounds (e.g., number of random rotations) in Local Sensitive Hashing scheme. The higher
        `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive
        the hashing becomes.
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        num_buckets, 
        num_hashes,
        query_mini_batch=10,
        ffn_ln=None,
        k = None,
        pad_idx=-1,
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, pad_idx, **kwargs
        )
        assert self.k is not None
        self.num_buckets = num_buckets
        self.num_hashes = num_hashes
        self.query_mini_batch = query_mini_batch
        self.build_index()

    @staticmethod
    def _hash_vectors(vectors: torch.Tensor, random_rotations: torch.Tensor, self_num_buckets: Union[int, List]):
        # normalize the vector so that each vector falls on a unit sphere
        vectors = F.normalize(vectors.detach(), dim=-1)
        # do random projections for multiple hashes
        # rotated_vectors: (bsz, num_hash, num_bucket//2)
        # disk_io_counters()
        rotated_vectors = F.normalize(torch.einsum("bd,dhr->bhr", vectors, random_rotations), dim=-1)
        
        if isinstance(self_num_buckets, int) or len(self_num_buckets) == 1:
            rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1)
            buckets = torch.argmax(rotated_vectors, dim=-1)
        else:
            # Get the buckets for them and combine.
            buckets, cur_sum, cur_product = None, 0, 1
            for bucket_factor in self_num_buckets:
                rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)]
                cur_sum = cur_sum + bucket_factor // 2
                rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1)
                if buckets is None:
                    buckets = torch.argmax(rotated_vectors_factor, dim=-1)
                else:
                    buckets = buckets + (cur_product * torch.argmax(rotated_vectors_factor, dim=-1))
        return buckets

    @staticmethod
    def _topk_from_bucket_count(bucket_match_count, k):
        """
        Given a bucket_match_count (bsz, n_keys) matrix of hash collisions, find the topk
        """
        topk_match_count, topk_index = torch.topk(bucket_match_count, k=k, dim=-1, largest=True) 
        topk_index[topk_match_count == 0] = -1  # if not enough results, pad with -1
        return topk_index
    
    def _generate_random_rotation(self):
        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each round of hashing to
        # decrease the probability of hash misses.
        if isinstance(self.num_buckets, int):
            assert (
                self.num_buckets % 2 == 0
            ), f"There should be an even number of buckets, but `self.num_buckets`: {self.num_buckets}"
            rotation_size = self.num_buckets
            num_buckets = self.num_buckets
        else:
            # Factorize the hash if self.num_buckets is a list or tuple
            rotation_size, num_buckets = 0, 1
            for bucket_factor in self.num_buckets:
                assert (
                    bucket_factor % 2 == 0
                ), f"The number of buckets should be even, but `num_bucket`: {bucket_factor}"
                rotation_size = rotation_size + bucket_factor
                num_buckets = num_buckets * bucket_factor
        # infer the shape for random rotation matrix, it has multiple hash(self.n_hash)
        rotations_shape = (self.model_dim, self.num_hashes, num_buckets // 2)
        # generate random rotation matrix, each elements is from N(0,1)
        return torch.randn(rotations_shape)
    
    def build_index(self):
        # create a random projection
        R = self._generate_random_rotation() 
        # (ffn_dim, num_hashes)
        key_buckets = self._hash_vectors(self.fc1.weight.detach(), R, self.num_buckets)
        self.register_buffer('R', R)
        self.register_buffer('key_buckets', key_buckets)


    def topk_search(self, queries, k=None, **kwargs):
        """
        Return topk indieces and if not found, pad with -1
        """
        # queries = queries.float()
        # self.key_buckets = self.key_buckets.to(queries.device)
        # self.R = self.R.to(queries.device)
        # distributed_set_trace()
        self.maybe_rebuild_index()
        cur_k = k or self.k
        assert queries.shape[-1] == self.model_dim
        # hash the query
        # x_buckets: (bsz, num_hash)
        tmp = []
        for i in range(0, queries.shape[0], self.query_mini_batch):
            # sum across num_hash, to record hash_collision
            x_buckets = self._hash_vectors(queries[i:i+self.query_mini_batch], self.R, self.num_buckets)
            bucket_match_count_i = torch.sum(x_buckets.unsqueeze(1) == self.key_buckets.unsqueeze(0), dim=-1)
            tmp.append(bucket_match_count_i)
        # (bsz, n_keys): this is the number of hash collision between queries and keys
        bucket_match_count = torch.cat(tmp).to(queries.device)
        return self._topk_from_bucket_count(bucket_match_count, cur_k)


class ScannTopK(XFFN):
    """
    When queried, uniformly choose K elements from the table.
    Note: This is still too slow to practically use.

    This is a wrapper for using ScaNN libarary
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        n_leaves=None,
        dims_per_block=2,
        avq_threshold=0.2,  # best in the paper
        leaves_to_search=100,
        reorder=None,
        ffn_ln=None,
        k = None,
        pad_idx=-1,
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, pad_idx, **kwargs
        )
        self.leaves_to_search = leaves_to_search  # should be tuned based on recall target.
        # recommended to set n_leaves at the same magnitude of sqrt(#data_points)
        recommended_leaves = int(np.sqrt(self.fc1.weight.shape[0]))
        self.n_leaves = n_leaves
        if n_leaves is None or n_leaves < recommended_leaves:
            self.n_leaves = recommended_leaves
        
        self.avq_threshold = avq_threshold
        self.dims_per_block = dims_per_block
        self.reorder = reorder
        if reorder is None or reorder < self.k:
            # reordering_num_neighbors should be greater than k
            # set `2 * ` bc, in a small k and small memory size, it gives 
            # almost perfect performance at small cost of time (+0.2s/local_batch)
            self.reorder = 2 * self.k 
        self.build_index()

    def build_index(self):
        train_set = self.fc1.weight.detach().clone()
        spherical = False # true if measure angular similarity
        train_set = train_set.cpu()

        self.searcher = scann.scann_ops_pybind.builder(train_set, self.k, 'dot_product').tree(
            self.n_leaves, self.leaves_to_search, training_sample_size=len(train_set), spherical=spherical, quantize_centroids=True
            ).score_ah(
                self.dims_per_block, anisotropic_quantization_threshold=self.avq_threshold
                ).reorder(self.reorder).build()
    
    def topk_search(self, queries, k=None, **kwargs):
        """
        Return topk indieces and if not found, pad with -1
        """
        self.maybe_rebuild_index()
        cur_k = k or self.k
        device = queries.device
        # preprocess queries
        queries = queries.float().detach().cpu().numpy()
        
        # I, D = self.searcher.search_batched(queries, cur_k, self.reorder, self.leaves_to_search)
        I, D = self.searcher.search_batched(queries, cur_k)
        # I is of type numpy.uint32; but torch doesn't support this type
        I = torch.from_numpy(I.astype(np.int64)).to(device)
        # reset the pad_idx
        I[np.isnan(D)] = self.pad_idx
        return I


class NaiveTopKBlockBeforeActivation(XFFN):
    """
    This is normal FFN by with naive topk to enforce row-wise sparsity. 
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        block_size = 2,
        agg_method="avg",
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, **kwargs
        )
        assert self.k is not None
        self.block_size = block_size
        assert block_size > 1, "Block size needs to be larger than 1, otherwise use NaiveTopK"
        assert k % block_size == 0
        assert self.ffn_dim % block_size == 0
        self.b = k // block_size # number of block need choosing
        self.B = self.ffn_dim // block_size
        self.offset = torch.arange(block_size) # all idxs of each block (all block share this)
        if agg_method == "avg":
            self.agg_method = lambda x: torch.mean(x, dim=1)
        elif agg_method == "max":
            self.agg_method = lambda x: torch.max(x, dim=1)[0]
        elif agg_method == "min":
            self.agg_method = lambda x: torch.min(x, dim=1)[0]
        elif agg_method == "abs+avg":
            self.agg_method = lambda x: torch.mean(torch.abs(x), dim=1)
        else:
            raise ValueError(f"Aggregation method -- {agg_method} -- not supported")

    def maybe_rebuild_index(self):
        pass
    
    @torch.no_grad()
    def topk_search(self, queries, k=None, **kwargs):
        """
        Return topk indieces and if not found, pad with -1
        """
        n_toks, _ = queries.shape
        # chunk weights into size of granularity
        weight_in_block = self.fc1.weight.reshape(self.B, -1, self.fc1.weight.size(-1))
        # pool chunked weight, each chunk will result in a single vector
        pooled_weight = self.agg_method(weight_in_block)
        # distributed_set_trace()
        del weight_in_block
        # chunk bias in to size of granularity
        bias_in_block = self.fc1.bias.reshape(self.B, -1)
        # pool chunked bias, each chunk will result in a single scalar
        pooled_bias = self.agg_method(bias_in_block)
        del bias_in_block
        # linear combination with pooled weight and bias.
        block_score = _linear(queries, pooled_weight, pooled_bias)
        assert len(block_score.shape) == 2
        
        if self.offset.device != queries.device:
            self.offset = self.offset.to(queries.device)
        
        _, topk_block_indices = torch.topk(block_score, k=self.b, dim=-1)
        del block_score
        topk_idx = topk_block_indices.unsqueeze(-1) * self.block_size + self.offset
        del topk_block_indices
        
        topk_idx = topk_idx.flatten(-2,-1)
        # bp()
        return topk_idx
    
    
class HashTopK(XFFN):
    """

    This is a wrapper for using Hash libarary
    """
    def __init__(self, 
        args, 
        fc1: nn.Module, 
        activation_fn, 
        activation_dropout_module,
        fc2: nn.Module,
        dropout_module,
        ffn_ln=None,
        k = None,
        pad_idx=-1,
        vocab_size=-1,
        seed=1,
        layer_idx=0,
        block_size=1,
        **kwargs
    ):
        super().__init__(
            args, fc1, activation_fn, activation_dropout_module, fc2,
            dropout_module, ffn_ln, k, pad_idx, **kwargs
        )
        self.vocab_size = vocab_size
        self.seed = seed
        # make sure different layers hash tables are different
        self.rng = torch.random.manual_seed(self.seed + layer_idx * args.decoder_layers)
        self.layer_idx = layer_idx
        self.block_size = block_size
        assert self.ffn_dim % self.block_size == 0
        assert self.k % self.block_size == 0
        self.num_total_block = self.ffn_dim // self.block_size
        self.num_active_block = self.k // self.block_size
        self.idx_per_block = torch.arange(self.block_size)
        self.offset = torch.arange(self.block_size)
        
        logger.info(f"{self.__class__}: Building hash table")
        self.build_index()
        logger.info(f"{self.__class__}: Finish Building hash table")
    
    def build_index(self):
        # build random hash
        # make the table have different precision (table could be too large when #hashFFN increases)
        if self.ffn_dim <= 2**15 - 1:
            self.assn_dtype = torch.int16
        elif self.ffn_dim <= 2**31 - 1:
            self.assn_dtype = torch.int32
        else:
            self.assn_dtype = torch.int64
        
        t0 = time()
        # this is the fastest creation methods I found
        # distributed_set_trace()
        assn = [
            torch.randperm(
                self.num_total_block, 
                generator=self.rng, 
                dtype=self.assn_dtype)[:self.num_active_block] for _ in range(self.vocab_size)
        ]
        sample_time = time() - t0
        self.assn = torch.stack(assn).detach()
        assert self.assn.dtype == self.assn_dtype
        self.assn.requires_grad = False
        if torch.distributed.is_initialized():
            logger.info(f"{torch.distributed.get_rank()}: sample_time(s)={sample_time:.3f}")
            logger.info(f"{torch.distributed.get_rank()}: Layer{self.layer_idx}, check_sum(self.assn)={self.assn.long().sum()}")
        else:
            logger.info(f"sample_time(s)={sample_time:.3f}")
            logger.info(f"Layer{self.layer_idx}, check_sum(self.assn)={self.assn.long().sum()}")

    def maybe_rebuild_index(self):
        pass
    
    @torch.no_grad()
    def topk_search(self, queries, token_ids, k=None, **kwargs):
        """
        token_ids: (batch, seq_len)
        Return topk indieces and if not found, pad with -1
        """
        dev = token_ids.device
        if not self.training:
            self.assn = self.assn.to(dev)
        assert token_ids is not None
        token_ids = token_ids.T.reshape(-1)
        if self.training: 
            self.assn = self.assn.cpu()
            token_ids = token_ids.cpu()
        if self.offset.device != self.assn.device:
            self.offset = self.offset.to(self.assn.device)
        hash_idx = self.assn[token_ids]
        topk_idx = hash_idx.unsqueeze(-1) * self.block_size + self.offset
        topk_idx = topk_idx.flatten(-2,-1).to(dev).long() # guarantee the topk index are long type
        return topk_idx


#### Below are not yet supported


TOPK2CLASS = {
    "none": FFN,
    "random": RandomTopK,
    "reformer": ReformerLSH,
    "naive": NaiveTopK,
    "static": StaticTopK,
    "naive-block": NaiveTopKBlock,
    "naive-block-before": NaiveTopKBlockBeforeActivation,
    "naive-low": NaiveTopKLowRank,
    "naive-ann": NaiveANNTopK,
    "scann": ScannTopK,
    "hash": HashTopK,
    "pkm-addon": PKMTopkAugmented,
    # unfinished classes
    # "faiss-ivfflat": FaissIVFFlatTopK,
    # "faiss": FaissTopK,
    # "faiss-lsh": FaissLSHTopK,
    # "faissip": FaissIPTopK,
} 
