# from typing import Callable, List, Optional, Tuple, Union
import math
import os

# import sys
import warnings
from datetime import datetime
from typing import Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

# from einops import rearrange
from torch.nn.init import xavier_normal_, xavier_uniform_

# from torch._C import _add_docstr
# from torch._torch_docs import sparse_support_notes
# from torch.nn.modules.activation import MultiheadAttention
from torch.nn.parameter import Parameter
from torch.overrides import handle_torch_function, has_torch_function
from torch_geometric.graphgym.config import cfg

from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (  # 追加
    BitLinear,
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)
from graphgps.tome.merge import (
    bipartite_soft_matching,
    merge_source,
    merge_wavg,
)

# from graphgps.tome.utils import parse_r

# Tensor = torch.Tensor


def percentile(t, q):
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


def blockdiag_matmul(x, w):
    return torch.einsum(
        "bnm,...bm->...bn", w, x.view(*x.shape[:-1], w.shape[0], w.shape[-1])
    ).reshape(*x.shape)


def save_attention_weights_histogram(
    attn_weights, step, epoch, x_min=0, x_max=0.1, num_bins=50, font_size=18
):
    # Convert attention weights to numpy array
    attn_weights_np = attn_weights.detach().cpu().numpy()

    # Calculate the average across heads if the shape is (N, num_heads, L, S)
    if attn_weights_np.ndim == 4:
        attn_weights_np = attn_weights_np.mean(axis=1)

    # Flatten the attention weights to 1D for histogram
    flattened_weights = attn_weights_np.flatten()

    # Remove NaN and infinite values
    flattened_weights = flattened_weights[np.isfinite(flattened_weights)]

    # Define the directory and filename for saving
    save_dir = "./img"
    os.makedirs(save_dir, exist_ok=True)

    # Get current time and format it
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    filename = (
        f"attn_weights_histogram_step_{step}_epoch_{epoch}_{current_time}.png"
    )
    output_image_path = os.path.join(save_dir, filename)

    # Set the font size globally
    plt.rcParams.update({"font.size": font_size})

    # Create a histogram of the attention weights using matplotlib
    plt.figure(figsize=(10, 6))

    # No filtering, use default bins
    plt.hist(
        flattened_weights,
        bins=num_bins,
        color="blue",
        alpha=0.7,
        edgecolor="black",
    )

    plt.savefig(output_image_path)
    plt.close()
    print(f"Saved the attention weights histogram as {output_image_path}")


def save_query_distribution_histogram(query, epoch, num_bins=50, font_size=14):
    # Convert query to numpy array
    query_np = query.detach().cpu().numpy()

    # Flatten the query to 1D for histogram
    flattened_query = query_np.flatten()

    # Remove NaN and infinite values
    flattened_query = flattened_query[np.isfinite(flattened_query)]

    # Define the directory and filename for saving
    save_dir = "./img"
    os.makedirs(save_dir, exist_ok=True)

    # Get current time and format it
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    filename = f"query_distribution_histogram_epoch_{epoch}_{current_time}.png"
    output_image_path = os.path.join(save_dir, filename)

    # Set the font size globally
    plt.rcParams.update({"font.size": font_size})

    # Create a histogram of the query values using matplotlib
    plt.figure(figsize=(10, 6))

    # Create the histogram without filtering
    plt.hist(
        flattened_query,
        bins=num_bins,
        color="blue",
        alpha=0.7,
        edgecolor="black",
    )

    plt.savefig(output_image_path)
    plt.close()

    # Print message after saving the figure
    print(f"Saved the query distribution histogram as {output_image_path}")


class CustomMultiheadAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        add_bias_kv=False,
        add_zero_attn=False,
        kdim=None,
        vdim=None,
        batch_first=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super(CustomMultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = (
            self.kdim == embed_dim and self.vdim == embed_dim
        )

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"

        if not self._qkv_same_embed_dim:
            self.q_proj_weight = Parameter(
                torch.empty((embed_dim, embed_dim), **factory_kwargs)
            )
            self.k_proj_weight = Parameter(
                torch.empty((embed_dim, self.kdim), **factory_kwargs)
            )
            self.v_proj_weight = Parameter(
                torch.empty((embed_dim, self.vdim), **factory_kwargs)
            )
            self.register_parameter("in_proj_weight", None)
        else:
            self.register_parameter("in_proj_weight", None)
            self.register_parameter("in_proj_bias", None)
            self.register_parameter("q_proj_weight", None)
            self.register_parameter("k_proj_weight", None)
            self.register_parameter("v_proj_weight", None)

        if bias:
            pass
            # self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter("in_proj_bias", None)
        # self.out_proj = NonDynamicallyQuantizableLinear(
        #     embed_dim, embed_dim, bias=bias, **factory_kwargs
        # )

        if add_bias_kv:
            self.bias_k = Parameter(
                torch.empty((1, 1, embed_dim), **factory_kwargs)
            )
            self.bias_v = Parameter(
                torch.empty((1, 1, embed_dim), **factory_kwargs)
            )
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        if cfg.slt.msa is True:
            if cfg.slt.bitlinear:
                self.in_lin = BitLinear(embed_dim, 3 * embed_dim, bias=False)
                self.out_lin = BitLinear(embed_dim, embed_dim, bias=False)

            elif (
                cfg.slt.slt_weight_scaling or cfg.slt.learnable_weight_scaling
            ):
                if cfg.slt.sm is True:
                    self.q_lin = SparseLinear(
                        embed_dim,
                        embed_dim,
                        bias=False,
                        gain="linear",
                        init_mode_weight="signed_xavier_uniform_constant_SF",
                    )
                    self.k_lin = SparseLinear(
                        embed_dim,
                        embed_dim,
                        bias=False,
                        gain="linear",
                        init_mode_weight="signed_xavier_uniform_constant_SF",
                    )
                    self.v_lin = SparseLinear(
                        embed_dim,
                        embed_dim,
                        bias=False,
                        gain="linear",
                        init_mode_weight="signed_xavier_uniform_constant_SF",
                    )
                    self.out_lin = SparseLinear(
                        embed_dim,
                        embed_dim,
                        bias=False,
                        gain="linear",
                        init_mode_weight="signed_xavier_uniform_constant_SF",
                    )
                elif cfg.slt.mm is True:
                    self.q_lin = SparseLinearMulti_mask(
                        embed_dim,
                        embed_dim,
                        bias=False,
                        gain="linear",
                        init_mode_weight="signed_xavier_uniform_constant_SF",
                    )
                    self.k_lin = SparseLinearMulti_mask(
                        embed_dim,
                        embed_dim,
                        bias=False,
                        gain="linear",
                        init_mode_weight="signed_xavier_uniform_constant_SF",
                    )
                    self.v_lin = SparseLinearMulti_mask(
                        embed_dim,
                        embed_dim,
                        bias=False,
                        gain="linear",
                        init_mode_weight="signed_xavier_uniform_constant_SF",
                    )
                    self.out_lin = SparseLinearMulti_mask(
                        embed_dim,
                        embed_dim,
                        bias=False,
                        gain="linear",
                        init_mode_weight="signed_xavier_uniform_constant_SF",
                    )

            elif cfg.monarch.msa is True:
                self.in_lin = MonarchLinear(
                    embed_dim, 3 * embed_dim, bias=False
                )
                self.out_lin = MonarchLinear(embed_dim, embed_dim, bias=False)

            elif cfg.slt.srste is True:
                self.in_lin = NMSparseMultiLinear(embed_dim, 3 * embed_dim)
                self.out_lin = NMSparseMultiLinear(embed_dim, embed_dim)

            elif cfg.slt.sm is True:
                self.in_lin = SparseLinear(
                    embed_dim,
                    3 * embed_dim,
                    bias=False,
                    gain="linear",
                    init_mode_weight="signed_xavier_uniform_constant_SF",
                )
                self.out_lin = SparseLinear(
                    embed_dim,
                    embed_dim,
                    bias=False,
                    gain="linear",
                    init_mode_weight="signed_xavier_uniform_constant_SF",
                )

            elif cfg.slt.mm is True:
                self.in_lin = SparseLinearMulti_mask(
                    embed_dim,
                    3 * embed_dim,
                    bias=False,
                    gain="linear",
                    init_mode_weight="signed_xavier_uniform_constant_SF",
                )
                self.out_lin = SparseLinearMulti_mask(
                    embed_dim,
                    embed_dim,
                    bias=False,
                    gain="linear",
                    init_mode_weight="signed_xavier_uniform_constant_SF",
                )

        self.custom_reset_parameters()

    def custom_multi_head_attention_forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        embed_dim_to_check: int,
        num_heads: int,
        # in_proj_weight: Optional[Tensor],
        # in_proj_bias: Optional[Tensor],
        bias_k: Optional[Tensor],
        bias_v: Optional[Tensor],
        add_zero_attn: bool,
        dropout_p: float,
        # out_proj_weight: Tensor,
        # out_proj_bias: Optional[Tensor],
        training: bool = True,
        key_padding_mask: Optional[Tensor] = None,
        need_weights: bool = True,
        attn_mask: Optional[Tensor] = None,
        use_separate_proj_weight: bool = True,
        q_proj_weight: Optional[Tensor] = None,
        k_proj_weight: Optional[Tensor] = None,
        v_proj_weight: Optional[Tensor] = None,
        static_k: Optional[Tensor] = None,
        static_v: Optional[Tensor] = None,
        average_attn_weights: bool = True,
        cur_epoch: int = None,
        msa_threshold: float = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        tens_ops = (
            query,
            key,
            value,
            # in_proj_weight,
            # in_proj_bias,
            bias_k,
            bias_v,
            # out_proj_weight,
            # out_proj_bias,
        )
        if has_torch_function(tens_ops):
            return handle_torch_function(
                F.multi_head_attention_forward,
                tens_ops,
                query,
                key,
                value,
                embed_dim_to_check,
                num_heads,
                # in_proj_weight,
                # in_proj_bias,
                bias_k,
                bias_v,
                add_zero_attn,
                dropout_p,
                # out_proj_weight,
                # out_proj_bias,
                training=training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=use_separate_proj_weight,
                q_proj_weight=q_proj_weight,
                k_proj_weight=k_proj_weight,
                v_proj_weight=v_proj_weight,
                static_k=static_k,
                static_v=static_v,
                average_attn_weights=average_attn_weights,
            )

        # in_proj_weight = None
        # in_proj_bias = None
        # out_proj_weight = None
        # out_proj_bias = None

        is_batched = F._mha_shape_check(
            query, key, value, key_padding_mask, attn_mask, num_heads
        )

        # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
        # is batched, run the computation and before returning squeeze the
        # batch dimension so that the output doesn't carry this temporary batch dimension.
        if not is_batched:
            # unsqueeze if the input is unbatched
            query = query.unsqueeze(1)
            key = key.unsqueeze(1)
            value = value.unsqueeze(1)
            if key_padding_mask is not None:
                key_padding_mask = key_padding_mask.unsqueeze(0)

        # set up shape vars
        tgt_len, bsz, embed_dim = query.shape
        src_len, _, _ = key.shape
        if key_padding_mask is not None:
            _kpm_dtype = key_padding_mask.dtype
            if _kpm_dtype != torch.bool and not torch.is_floating_point(
                key_padding_mask
            ):
                raise AssertionError(
                    "only bool and floating types of key_padding_mask are supported"
                )
        assert (
            embed_dim == embed_dim_to_check
        ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
        if isinstance(embed_dim, torch.Tensor):
            # embed_dim can be a tensor when JIT tracing
            head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
        else:
            head_dim = embed_dim // num_heads
        assert (
            head_dim * num_heads == embed_dim
        ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
        if use_separate_proj_weight:
            assert (
                key.shape[:2] == value.shape[:2]
            ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"  # noqa
        else:
            assert (
                key.shape == value.shape
            ), f"key shape {key.shape} does not match value shape {value.shape}"

        #
        # compute in-projection
        #
        # if not cfg.slt.use_separate_proj_weight:
        # assert (
        #     in_proj_weight is not None
        # ), "use_separate_proj_weight is False but in_proj_weight is None"
        # cur_epoch = 0
        # lin = SparseLinear(query.size(0), query.size(1), bias=False)
        if cfg.slt.pruning == "layerwise":
            threshold = self.get_threshold(
                cfg.slt.linear_sparsity, cur_epoch, "lin"
            )
        else:
            threshold = msa_threshold

        if (
            cfg.slt.attention_scaling != 1.0
            or cfg.slt.slt_weight_scaling
            or cfg.slt.learnable_weight_scaling
        ):
            # if cfg.slt.save_fig:
            #     save_query_distribution_histogram(query, 100)
            q = self.q_lin(query, threshold, q_lin=True)
            k = self.k_lin(key, threshold)
            v = self.v_lin(value, threshold)
        else:
            # if cfg.slt.save_fig:
            #     save_query_distribution_histogram(query, 100)
            q, k, v = self.in_lin(query, threshold).chunk(3, dim=-1)

        # else:
        #     assert (
        #         q_proj_weight is not None
        #     ), "use_separate_proj_weight is True but q_proj_weight is None"
        #     assert (
        #         k_proj_weight is not None
        #     ), "use_separate_proj_weight is True but k_proj_weight is None"
        #     assert (
        #         v_proj_weight is not None
        #     ), "use_separate_proj_weight is True but v_proj_weight is None"
        #     b_q = b_k = b_v = None
        #     q, k, v = F._in_projection(
        #         query,
        #         key,
        #         value,
        #         q_proj_weight,
        #         k_proj_weight,
        #         v_proj_weight,
        #         b_q,
        #         b_k,
        #         b_v,
        #     )

        # prep attention mask
        if attn_mask is not None:
            if attn_mask.dtype == torch.uint8:
                warnings.warn(
                    "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."  # noqa
                )
                attn_mask = attn_mask.to(torch.bool)
            else:
                assert (
                    attn_mask.is_floating_point()
                    or attn_mask.dtype == torch.bool
                ), f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"  # noqa
            # ensure attn_mask's dim is 3
            if attn_mask.dim() == 2:
                correct_2d_size = (tgt_len, src_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(
                        f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."  # noqa
                    )
                attn_mask = attn_mask.unsqueeze(0)
            elif attn_mask.dim() == 3:
                correct_3d_size = (bsz * num_heads, tgt_len, src_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(
                        f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."  # noqa
                    )
            else:
                raise RuntimeError(
                    f"attn_mask's dimension {attn_mask.dim()} is not supported"
                )

        # add bias along batch dimension (currently second)
        if bias_k is not None and bias_v is not None:
            assert static_k is None, "bias cannot be added to static key."
            assert static_v is None, "bias cannot be added to static value."
            k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = pad(key_padding_mask, (0, 1))
        else:
            assert bias_k is None
            assert bias_v is None
        #
        # reshape q, k, v for multihead attention and make em batch first
        #
        q = (
            q.contiguous()
            .view(tgt_len, bsz * num_heads, head_dim)
            .transpose(0, 1)
        )
        if static_k is None:
            k = (
                k.contiguous()
                .view(k.shape[0], bsz * num_heads, head_dim)
                .transpose(0, 1)
            )
        else:
            assert (
                static_k.size(0) == bsz * num_heads
            ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
            assert (
                static_k.size(2) == head_dim
            ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
            k = static_k
        if static_v is None:
            v = (
                v.contiguous()
                .view(v.shape[0], bsz * num_heads, head_dim)
                .transpose(0, 1)
            )
        else:
            assert (
                static_v.size(0) == bsz * num_heads
            ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
            assert (
                static_v.size(2) == head_dim
            ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
            v = static_v

        # add zero attention along batch dimension (now first)
        if add_zero_attn:
            zero_attn_shape = (bsz * num_heads, 1, head_dim)
            k = torch.cat(
                [
                    k,
                    torch.zeros(
                        zero_attn_shape, dtype=k.dtype, device=k.device
                    ),
                ],
                dim=1,
            )
            v = torch.cat(
                [
                    v,
                    torch.zeros(
                        zero_attn_shape, dtype=v.dtype, device=v.device
                    ),
                ],
                dim=1,
            )
            if attn_mask is not None:
                attn_mask = pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = pad(key_padding_mask, (0, 1))

        # update source sequence length after adjustments
        src_len = k.size(1)

        # merge key padding and attention masks
        if key_padding_mask is not None:
            assert key_padding_mask.shape == (
                bsz,
                src_len,
            ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"  # noqa
            key_padding_mask = (
                key_padding_mask.view(bsz, 1, 1, src_len)
                .expand(-1, num_heads, -1, -1)
                .reshape(bsz * num_heads, 1, src_len)
            )
            if attn_mask is None:
                attn_mask = key_padding_mask
            elif attn_mask.dtype == torch.bool:
                attn_mask = attn_mask.logical_or(key_padding_mask)
            else:
                attn_mask = attn_mask.masked_fill(
                    key_padding_mask, float("-inf")
                )

        # convert mask to float
        if attn_mask is not None and attn_mask.dtype == torch.bool:
            new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
            new_attn_mask.masked_fill_(attn_mask, float("-inf"))
            attn_mask = new_attn_mask

        # adjust dropout probability
        if not training:
            dropout_p = 0.0

        B, Nt, E = q.shape
        q_scaled = q / math.sqrt(E)
        if attn_mask is not None:
            attn_output_weights = torch.baddbmm(
                attn_mask, q_scaled, k.transpose(-2, -1)
            )
        else:
            attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))

        # attn_size = self._tome_info["size"]

        # tome: Apply proportional attention

        # if cfg.slt.tome:
        #     attn = attn + attn_size.log()[:, None, None, :, 0]

        # if cfg.slt.save_fig:
        #     save_attention_weights_histogram(
        #         attn_output_weights, step=0, epoch=cur_epoch
        #     )
        attn_output_weights = F.softmax(attn_output_weights, dim=-1)
        # if cfg.slt.save_fig:
        #     save_attention_weights_histogram(
        #         attn_output_weights, step=1, epoch=cur_epoch
        #     )

        if dropout_p > 0.0:
            attn_output_weights = F.dropout(attn_output_weights, p=dropout_p)

        attn_output = torch.bmm(attn_output_weights, v)

        attn_output = (
            attn_output.transpose(0, 1)
            .contiguous()
            .view(tgt_len * bsz, embed_dim)
        )

        if cfg.monarch.msa and (
            (cfg.slt.sm is False) and (cfg.slt.mm is False)
        ):
            attn_output = self.out_lin(attn_output)
        else:
            attn_output = self.out_lin(attn_output, threshold)

        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

        if need_weights:
            # optionally average attention weights over heads
            attn_output_weights = attn_output_weights.view(
                bsz, num_heads, tgt_len, src_len
            )
            if average_attn_weights:
                attn_output_weights = (
                    attn_output_weights.sum(dim=1) / num_heads
                )

            if not is_batched:
                # squeeze the output if input was unbatched
                attn_output = attn_output.squeeze(1)
                attn_output_weights = attn_output_weights.squeeze(0)
            return attn_output, attn_output_weights, None
        else:
            if not is_batched:
                # squeeze the output if input was unbatched
                attn_output = attn_output.squeeze(1)
            if cfg.slt.tome:
                return attn_output, None, k.mean(0)
            return attn_output, None, None

    def forward(
        self,
        query,
        key,
        value,
        attn_mask=None,
        key_padding_mask=None,
        cur_epoch=None,
        msa_threshold=None,
        need_weights=True,
        average_attn_weights=True,
    ):
        is_batched = query.dim() == 3
        if key_padding_mask is not None:
            _kpm_dtype = key_padding_mask.dtype
            if _kpm_dtype != torch.bool and not torch.is_floating_point(
                key_padding_mask
            ):
                raise AssertionError(
                    "only bool and floating types of key_padding_mask are supported"
                )
        why_not_fast_path = ""
        if not is_batched:
            why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
        elif query is not key or key is not value:
            # When lifting this restriction, don't forget to either
            # enforce that the dtypes all match or test cases where
            # they don't!
            why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
            # elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
            #     why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"  # noqa
            # elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
            #     # this case will fail anyway, but at least they'll get a useful error message.
            #     why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"  # noqa
            # elif self.training:
            why_not_fast_path = "training is enabled"
        elif not self.batch_first:
            why_not_fast_path = "batch_first was not True"
        elif self.bias_k is not None:
            why_not_fast_path = "self.bias_k was not None"
        elif self.bias_v is not None:
            why_not_fast_path = "self.bias_v was not None"
        elif self.dropout:
            why_not_fast_path = f"dropout was {self.dropout}, required zero"
        elif self.add_zero_attn:
            why_not_fast_path = "add_zero_attn was enabled"
        elif not self._qkv_same_embed_dim:
            why_not_fast_path = "_qkv_same_embed_dim was not True"
        elif attn_mask is not None:
            why_not_fast_path = "attn_mask was not None"
        elif query.is_nested and key_padding_mask is not None:
            why_not_fast_path = (
                "key_padding_mask is not supported with NestedTensor input"
            )
        elif self.num_heads % 2 == 1:
            why_not_fast_path = "num_heads is odd"
        elif torch.is_autocast_enabled():
            why_not_fast_path = "autocast is enabled"

        if not why_not_fast_path:
            tensor_args = (
                query,
                key,
                value,
                # self.in_proj_weight,
                # self.in_proj_bias,
                # self.out_proj.weight,
                # self.out_proj.bias,
            )
            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            if torch.overrides.has_torch_function(tensor_args):
                why_not_fast_path = "some Tensor argument has_torch_function"
            elif not all(
                [
                    (x is None or x.is_cuda or "cpu" in str(x.device))
                    for x in tensor_args
                ]
            ):
                why_not_fast_path = (
                    "some Tensor argument is neither CUDA nor CPU"
                )
            elif torch.is_grad_enabled() and any(
                [x is not None and x.requires_grad for x in tensor_args]
            ):
                why_not_fast_path = (
                    "grad is enabled and at least one of query or the "
                    "input/output projection weights or biases requires_grad"
                )
            if not why_not_fast_path:
                return torch._native_multi_head_attention(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    # self.in_proj_weight,
                    # self.in_proj_bias,
                    # self.out_proj.weight,
                    # self.out_proj.bias,
                    (
                        key_padding_mask
                        if key_padding_mask is not None
                        else attn_mask
                    ),
                    need_weights,
                    average_attn_weights,
                    (
                        1
                        if key_padding_mask is not None
                        else 0 if attn_mask is not None else None
                    ),
                )

        any_nested = query.is_nested or key.is_nested or value.is_nested
        assert not any_nested, (
            "MultiheadAttention does not support NestedTensor outside of its fast path. "
            + f"The fast path was not hit because {why_not_fast_path}"
        )

        if self.batch_first and is_batched:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [
                    x.transpose(1, 0) for x in (query, key, value)
                ]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = (
                self.custom_multi_head_attention_forward(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    # self.in_proj_weight,
                    # self.in_proj_bias,
                    self.bias_k,
                    self.bias_v,
                    self.add_zero_attn,
                    self.dropout,
                    # self.out_proj.weight,
                    # self.out_proj.bias,
                    training=self.training,
                    key_padding_mask=key_padding_mask,
                    need_weights=need_weights,
                    attn_mask=attn_mask,
                    use_separate_proj_weight=True,
                    q_proj_weight=self.q_proj_weight,
                    k_proj_weight=self.k_proj_weight,
                    v_proj_weight=self.v_proj_weight,
                    average_attn_weights=average_attn_weights,
                )
            )
        else:
            attn_output, attn_output_weights, metric = (
                self.custom_multi_head_attention_forward(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    # self.in_proj_weight,
                    # self.in_proj_bias,
                    self.bias_k,
                    self.bias_v,
                    self.add_zero_attn,
                    self.dropout,
                    # self.out_proj.weight,
                    # self.out_proj.bias,
                    training=self.training,
                    key_padding_mask=key_padding_mask,
                    need_weights=need_weights,
                    attn_mask=attn_mask,
                    average_attn_weights=average_attn_weights,
                    cur_epoch=cur_epoch,
                    msa_threshold=msa_threshold,
                )
            )

        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights, metric
        else:
            return attn_output, attn_output_weights, metric

    def percentile(t, q):
        k = 1 + round(0.01 * float(q) * (t.numel() - 1))
        return t.view(-1).kthvalue(k).values.item()

    def calculate_sparsity(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsity based on the scheduling type."""
        if schedule_type == "linear":
            return sparsity * (epoch / max_epoch_half)
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            return sparsity * scale_factor
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return (sparsity / 10) * current_step
        elif schedule_type == "reverse":
            return 1 - ((1 - sparsity) * (epoch / max_epoch_half))
        return sparsity

    def calculate_sparsities(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsities for multi-mode (mm) based on the scheduling type."""
        if schedule_type == "linear":
            return [value * (epoch / max_epoch_half) for value in sparsity]
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            increments = [
                (sparsity[i] - sparsity[0]) / (1 - sparsity[0])
                for i in range(1, len(sparsity))
            ]
            sparsities = []
            for i, value in enumerate(sparsity):
                if i == 0:
                    sparsities.append(value * scale_factor)
                else:
                    start_point = increments[i - 1]
                    sparsities.append(
                        start_point + (value - start_point) * scale_factor
                    )
            return sparsities
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return [
                value * ((value / 10) * current_step) for value in sparsity
            ]
        elif schedule_type == "reverse":
            return [
                1 - ((1 - value) * (epoch / max_epoch_half))
                for value in sparsity
            ]
        return sparsity

    def get_threshold(self, sparsity, epoch=None, keyword=""):
        max_epoch_half = cfg.optim.max_epoch // 2
        is_training = self.training
        sparsities = sparsity

        if cfg.slt.mm:
            if is_training and epoch < max_epoch_half:
                sparsities = self.calculate_sparsities(
                    sparsity, epoch, max_epoch_half, cfg.slt.sparse_scheduling
                )

            thresholds = []
            for sparsity_value in sparsities:
                local_params = torch.cat(
                    [
                        p.detach().flatten()
                        for name, p in self.named_parameters()
                        if hasattr(p, "is_score")
                        and p.is_score
                        and keyword in name
                    ]
                )
                threshold = percentile(
                    (
                        local_params.abs()
                        if cfg.slt.enable_abs_pruning
                        else local_params
                    ),
                    sparsity_value * 100,
                )
                thresholds.append(threshold)
            return thresholds

        elif cfg.slt.sm:
            if is_training and epoch < max_epoch_half:
                sparsity_value = self.calculate_sparsity(
                    sparsity[0],
                    epoch,
                    max_epoch_half,
                    cfg.slt.sparse_scheduling,
                )
            else:
                sparsity_value = sparsity[0]

            local_params = torch.cat(
                [
                    p.detach().flatten()
                    for name, p in self.named_parameters()
                    if hasattr(p, "is_score")
                    and p.is_score
                    and keyword in name
                ]
            )
            threshold = percentile(
                (
                    local_params.abs()
                    if cfg.slt.enable_abs_pruning
                    else local_params
                ),
                sparsity_value * 100,
            )
            return threshold

    def custom_reset_parameters(self):
        if self._qkv_same_embed_dim:
            pass
            # xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        # if self.in_proj_bias is not None:
        #     pass
        # constant_(self.in_proj_bias, 0.0)
        # constant_(self.out_proj.bias, 0.0)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)


pad = torch._C._nn.pad
pad.__module__ = "torch.nn.functional"

linear = torch._C._nn.linear
