import numbers
import os
from datetime import datetime
from typing import Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pygnn
from performer_pytorch import SelfAttention
from torch.nn.parameter import Parameter
from torch_geometric.data import Batch
from torch_geometric.graphgym.config import cfg
from torch_geometric.nn import Linear as Linear_pyg
from torch_geometric.utils import to_dense_batch

from graphgps.layer.bigbird_layer import SingleBigBirdLayer
from graphgps.layer.gatedgcn_layer import GatedGCNLayer
from graphgps.layer.gine_conv_layer import GINEConvESLapPE
from graphgps.slt.custom_attn import CustomMultiheadAttention
from graphgps.slt.custom_conv import Custom_GatedGCNLayer, SLT_GINEConv
from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (
    BitLinear,
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)
from graphgps.tome.merge import (
    bipartite_soft_matching,
    merge_source,
    merge_wavg,
)


def percentile(t, q):
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


class RMSNorm(nn.Module):
    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
    normalized_shape: Tuple[int, ...]
    eps: Optional[float]
    elementwise_affine: bool

    def __init__(
        self,
        normalized_shape,
        eps: Optional[float] = None,
        elementwise_affine: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps if eps is not None else torch.finfo(torch.float32).eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            self.weight = Parameter(
                torch.empty(self.normalized_shape, device=device, dtype=dtype)
            )
        else:
            self.register_parameter("weight", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.elementwise_affine:
            init.ones_(self.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Ensure weight is on the same device as input tensor x
        if self.weight is not None:
            self.weight = self.weight.to()

        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x = x / rms
        if self.elementwise_affine:
            x = x * self.weight
        return x

    def extra_repr(self) -> str:
        return (
            "{normalized_shape}, eps={eps}, "
            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
        )


class PairNorm(torch.nn.Module):
    def __init__(self, device=None, dtype=None):
        super(PairNorm, self).__init__()
        self.device = device
        self.dtype = dtype

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.device != self.device:
            x = x.to(self.device)

        col_mean = x.mean(dim=0)
        x = x - col_mean
        rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt()
        x = x / rownorm_mean
        return x

    def extra_repr(self) -> str:
        return f"device={self.device}, dtype={self.dtype}"


def save_query_distribution_histogram(query, epoch, num_bins=50):
    # Convert query to numpy array
    query_np = query.detach().cpu().numpy()

    # Flatten the query to 1D for histogram
    flattened_query = query_np.flatten()

    # Remove NaN and infinite values
    flattened_query = flattened_query[np.isfinite(flattened_query)]

    # Define the directory and filename for saving
    save_dir = "./img"
    os.makedirs(save_dir, exist_ok=True)

    # Get current time and format it
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"query_distribution_histogram_epoch_{epoch}_{current_time}.png"
    output_image_path = os.path.join(save_dir, filename)

    # Create a histogram of the query values using matplotlib
    plt.figure(figsize=(10, 6))

    # Create the histogram without filtering
    plt.hist(
        flattened_query,
        bins=num_bins,
        color="blue",
        alpha=0.7,
        edgecolor="black",
    )

    plt.title("Distribution of Query Values")
    plt.xlabel("Query Value")
    plt.ylabel("Frequency")

    # Save the figure
    plt.savefig(output_image_path)
    plt.close()

    # Print message after saving the figure
    print(f"Saved the query distribution histogram as {output_image_path}")


class GPSLayer(nn.Module):
    """Local MPNN + full graph attention x-former layer."""

    def __init__(
        self,
        dim_h,
        local_gnn_type,
        global_model_type,
        num_heads,
        act="relu",
        pna_degrees=None,
        equivstable_pe=False,
        dropout=0.0,
        attn_dropout=0.0,
        layer_norm=False,
        batch_norm=True,
        bigbird_cfg=None,
        log_attn_weights=False,
    ):
        super().__init__()

        self.dim_h = dim_h
        self.num_heads = num_heads
        self.attn_dropout = attn_dropout
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm
        self.equivstable_pe = equivstable_pe
        self.activation = register.act_dict[act]

        self.log_attn_weights = log_attn_weights
        if log_attn_weights and global_model_type not in [
            "Transformer",
            "BiasedTransformer",
        ]:
            raise NotImplementedError(
                f"Logging of attention weights is not supported "
                f"for '{global_model_type}' global attention model."
            )

        # Local message-passing model.
        self.local_gnn_with_edge_attr = True
        if local_gnn_type == "None":
            self.local_model = None

        # MPNNs without edge attributes support.
        elif local_gnn_type == "GCN":
            self.local_gnn_with_edge_attr = False
            self.local_model = pygnn.GCNConv(dim_h, dim_h)
        elif local_gnn_type == "GIN":
            self.local_gnn_with_edge_attr = False
            gin_nn = nn.Sequential(
                Linear_pyg(dim_h, dim_h),
                self.activation(),
                Linear_pyg(dim_h, dim_h),
            )
            self.local_model = pygnn.GINConv(gin_nn)

        # MPNNs supporting also edge attributes.
        elif local_gnn_type == "GENConv":
            self.local_model = pygnn.GENConv(dim_h, dim_h)
        elif local_gnn_type == "GINE":
            if cfg.monarch.mpnn is True and cfg.slt.mpnn is True:
                gin_nn = nn.Sequential(
                    MonarchLinear(dim_h, dim_h, bias=False),
                    self.activation(),
                    MonarchLinear(dim_h, dim_h, bias=False),
                )
            elif cfg.slt.srste is True and cfg.slt.mpnn is True:
                gin_nn = nn.Sequential(
                    NMSparseMultiLinear(dim_h, dim_h),
                    self.activation(),
                    NMSparseMultiLinear(dim_h, dim_h),
                )
            elif cfg.slt.sm is True and cfg.slt.mpnn is True:
                gin_nn = nn.Sequential(
                    SparseLinear(dim_h, dim_h, bias=False),
                    self.activation(),
                    SparseLinear(dim_h, dim_h, bias=False),
                )
            elif cfg.slt.mm is True and cfg.slt.mpnn is True:
                gin_nn = nn.Sequential(
                    SparseLinearMulti_mask(dim_h, dim_h, bias=False),
                    self.activation(),
                    SparseLinearMulti_mask(dim_h, dim_h, bias=False),
                )
            else:
                gin_nn = nn.Sequential(
                    Linear_pyg(dim_h, dim_h),
                    self.activation(),
                    Linear_pyg(dim_h, dim_h),
                )
            if (
                self.equivstable_pe
            ):  # Use specialised GINE layer for EquivStableLapPE.
                self.local_model = GINEConvESLapPE(gin_nn)
            else:
                if cfg.slt.mpnn is True:
                    self.local_model = SLT_GINEConv(gin_nn)
                else:
                    self.local_model = pygnn.GINEConv(gin_nn)
        elif local_gnn_type == "GAT":
            self.local_model = pygnn.GATConv(
                in_channels=dim_h,
                out_channels=dim_h // num_heads,
                heads=num_heads,
                edge_dim=dim_h,
            )
        elif local_gnn_type == "PNA":
            # Defaults from the paper.
            # aggregators = ['mean', 'min', 'max', 'std']
            # scalers = ['identity', 'amplification', 'attenuation']
            aggregators = ["mean", "max", "sum"]
            scalers = ["identity"]
            deg = torch.from_numpy(np.array(pna_degrees))
            self.local_model = pygnn.PNAConv(
                dim_h,
                dim_h,
                aggregators=aggregators,
                scalers=scalers,
                deg=deg,
                edge_dim=min(128, dim_h),
                towers=1,
                pre_layers=1,
                post_layers=1,
                divide_input=False,
            )
        elif local_gnn_type == "CustomGatedGCN":
            if cfg.slt.mpnn is True or cfg.monarch.mpnn is True:
                self.local_model = Custom_GatedGCNLayer(
                    dim_h,
                    dim_h,
                    dropout=dropout,
                    residual=True,
                    act=act,
                    equivstable_pe=equivstable_pe,
                )
            else:
                self.local_model = GatedGCNLayer(
                    dim_h,
                    dim_h,
                    dropout=dropout,
                    residual=True,
                    act=act,
                    equivstable_pe=equivstable_pe,
                )
        else:
            raise ValueError(f"Unsupported local GNN model: {local_gnn_type}")
        self.local_gnn_type = local_gnn_type

        # Global attention transformer-style model.
        if global_model_type == "None":
            self.self_attn = None
        elif global_model_type in ["Transformer", "BiasedTransformer"]:
            if cfg.slt.msa is True or cfg.monarch.msa is True:
                self.self_attn = CustomMultiheadAttention(
                    dim_h,
                    num_heads,
                    dropout=self.attn_dropout,
                    batch_first=True,
                )
            else:
                self.self_attn = torch.nn.MultiheadAttention(
                    dim_h,
                    num_heads,
                    dropout=self.attn_dropout,
                    batch_first=True,
                )
                # self.global_model = torch.nn.TransformerEncoderLayer(
                #     d_model=dim_h, nhead=num_heads,
                #     dim_feedforward=2048, dropout=0.1, activation=F.relu,
                #     layer_norm_eps=1e-5, batch_first=True)
        elif global_model_type == "Performer":
            self.self_attn = SelfAttention(
                dim=dim_h,
                heads=num_heads,
                dropout=self.attn_dropout,
                causal=False,
            )
        elif global_model_type == "BigBird":
            bigbird_cfg.dim_hidden = dim_h
            bigbird_cfg.n_heads = num_heads
            bigbird_cfg.dropout = dropout
            self.self_attn = SingleBigBirdLayer(bigbird_cfg)
        else:
            raise ValueError(
                f"Unsupported global x-former model: " f"{global_model_type}"
            )
        self.global_model_type = global_model_type

        if self.layer_norm and self.batch_norm:
            raise ValueError(
                "Cannot apply two types of normalization together"
            )

        # Normalization for MPNN and Self-Attention representations.
        if cfg.slt.msa_layer_norm is True or cfg.slt.mpnn_layer_norm is True:
            self.layernorm1_local = pygnn.norm.LayerNorm(dim_h)
            self.layernorm1_attn = pygnn.norm.LayerNorm(dim_h)
            # self.norm1_local = pygnn.norm.GraphNorm(dim_h)
            # self.norm1_attn = pygnn.norm.GraphNorm(dim_h)
            # self.norm1_local = pygnn.norm.InstanceNorm(dim_h)
            # self.norm1_attn = pygnn.norm.InstanceNorm(dim_h)
        if cfg.slt.msa_batch_norm is True or cfg.slt.mpnn_batch_norm is True:
            self.batchnorm1_local = nn.BatchNorm1d(
                dim_h, affine=cfg.slt.batch_affine
            )
            self.batchnorm1_attn = nn.BatchNorm1d(
                dim_h, affine=cfg.slt.batch_affine
            )
        self.dropout_local = nn.Dropout(dropout)
        self.dropout_attn = nn.Dropout(dropout)

        layer_class = nn.Linear

        if cfg.slt.bitlinear:
            layer_class = BitLinear
        elif cfg.monarch.ffn:
            layer_class = MonarchLinear
        elif cfg.slt.srste and cfg.slt.ffn:
            layer_class = NMSparseMultiLinear
        elif cfg.slt.sm and cfg.slt.ffn:
            layer_class = SparseLinear
        elif cfg.slt.mm and cfg.slt.ffn:
            layer_class = SparseLinearMulti_mask

        kwargs1 = (dim_h, dim_h * 2)
        kwargs2 = (dim_h * 2, dim_h)

        extra_kwargs = {}
        if (
            layer_class == SparseLinearMulti_mask
            or layer_class == SparseLinear
        ):
            extra_kwargs = {
                "gain": "relu",
                "init_mode_weight": "signed_kaiming_uniform_constant",
            }

        self.ff_linear1 = layer_class(*kwargs1, bias=False, **extra_kwargs)
        self.ff_linear2 = layer_class(*kwargs2, bias=False, **extra_kwargs)

        self.act_fn_ff = self.activation()
        if self.layer_norm:
            self.norm2 = pygnn.norm.LayerNorm(dim_h)
            # self.norm2 = pygnn.norm.GraphNorm(dim_h)
            # self.norm2 = pygnn.norm.InstanceNorm(dim_h)
        if self.batch_norm:
            self.norm2 = nn.BatchNorm1d(dim_h, affine=cfg.slt.batch_affine)
        self.ff_dropout1 = nn.Dropout(dropout)
        self.ff_dropout2 = nn.Dropout(dropout)

        if cfg.slt.learnable_sum:
            self.w_local = nn.Parameter(torch.tensor(1.0))
            self.w_attn = nn.Parameter(torch.tensor(1.0))

        # if cfg.slt.rmsnorm:
        #     self.rms_norm = RMSNorm(dim_h).cuda()

        if cfg.slt.batchnorm_mpnn:
            self.batchnorm_mpnn = torch.nn.BatchNorm1d(
                cfg.gt.dim_hidden
            ).cuda()
        elif cfg.slt.layernorm_mpnn:
            self.layernorm_mpnn = torch.nn.LayerNorm(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.pairnorm_mpnn:
            self.pairnorm_mpnn = PairNorm().cuda()
        elif cfg.slt.rmsnorm_mpnn:
            self.rmsnorm_mpnn = RMSNorm(cfg.gt.dim_hidden).cuda()

        if cfg.slt.batchnorm_msa:
            self.batchnorm_msa = torch.nn.BatchNorm1d(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.layernorm_msa:
            self.layernorm_msa = torch.nn.LayerNorm(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.pairnorm_msa:
            self.pairnorm_msa = PairNorm().cuda()
        elif cfg.slt.rmsnorm_msa:
            self.rmsnorm_msa = RMSNorm(cfg.gt.dim_hidden).cuda()

        if cfg.slt.batchnorm_ffn:
            self.batchnorm_ffn = torch.nn.BatchNorm1d(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.layernorm_ffn:
            self.layernorm_ffn = torch.nn.LayerNorm(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.pairnorm_ffn:
            self.pairnorm_ffn = PairNorm().cuda()
        elif cfg.slt.rmsnorm_ffn:
            self.rmsnorm_ffn = RMSNorm(cfg.gt.dim_hidden).cuda()

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        h = batch.x
        #
        #         if cfg.slt.rmsnorm:
        #             h = self.rms_norm(h)

        h_in1 = h

        h_out_list = []
        # Local MPNN with edge attributes.
        if self.local_model is not None:

            if cfg.slt.batchnorm_mpnn:
                h = self.batchnorm_mpnn(h)
            elif cfg.slt.layernorm_mpnn:
                h = self.layernorm_mpnn(h)
            elif cfg.slt.pairnorm_mpnn:
                h = self.pairnorm_mpnn(h)
            elif cfg.slt.rmsnorm_mpnn:
                h = self.rmsnorm_mpnn(h)

            self.local_model: pygnn.conv.MessagePassing  # Typing hint.
            if self.local_gnn_type == "CustomGatedGCN":
                es_data = None
                if self.equivstable_pe:
                    es_data = batch.pe_EquivStableLapPE
                if (cfg.slt.sm is True or cfg.slt.mm is True) and (
                    cfg.slt.mpnn is True or cfg.monarch.mpnn is True
                ):
                    if cfg.slt.pruning == "blockwise":
                        local_out = self.local_model(
                            Batch(
                                batch=batch,
                                x=h,
                                edge_index=batch.edge_index,
                                edge_attr=batch.edge_attr,
                                pe_EquivStableLapPE=es_data,
                            ),
                            cur_epoch,
                            mpnn_th,
                        )
                    elif cfg.slt.pruning == "global":
                        local_out = self.local_model(
                            Batch(
                                batch=batch,
                                x=h,
                                edge_index=batch.edge_index,
                                edge_attr=batch.edge_attr,
                                pe_EquivStableLapPE=es_data,
                            ),
                            cur_epoch,
                            global_th,
                        )
                else:
                    local_out = self.local_model(
                        Batch(
                            batch=batch,
                            x=h,
                            edge_index=batch.edge_index,
                            edge_attr=batch.edge_attr,
                            pe_EquivStableLapPE=es_data,
                        )
                    )
                # GatedGCN does residual connection and dropout internally.
                h_local = local_out.x
                batch.edge_attr = local_out.edge_attr
            else:
                if self.local_gnn_with_edge_attr:
                    if self.equivstable_pe:
                        h_local = self.local_model(
                            h,
                            batch.edge_index,
                            batch.edge_attr,
                            batch.pe_EquivStableLapPE,
                        )
                    else:
                        if cfg.slt.mpnn is True:
                            if cfg.slt.pruning == "blockwise":
                                h_local = self.local_model(
                                    h,
                                    batch.edge_index,
                                    batch.edge_attr,
                                    cur_epoch,
                                    mpnn_th,
                                )
                            elif cfg.slt.pruning == "global":
                                h_local = self.local_model(
                                    h,
                                    batch.edge_index,
                                    batch.edge_attr,
                                    cur_epoch,
                                    global_th,
                                )
                        else:
                            h_local = self.local_model(
                                h, batch.edge_index, batch.edge_attr
                            )
                else:
                    h_local = self.local_model(h, batch.edge_index)
                h_local = self.dropout_local(h_local)
                h_local = h_in1 + h_local  # Residual connection.

            if cfg.slt.mpnn_layer_norm is True:
                h_local = self.layernorm1_local(h_local, batch.batch)
            if cfg.slt.mpnn_batch_norm is True:
                h_local = self.batchnorm1_local(h_local)
            h_out_list.append(h_local)

        # Multi-head attention.
        if self.self_attn is not None:
            # if cfg.slt.tome:
            # if isinstance(self.self_attn, CustomMultiheadAttention):
            #     # module.__class__ = ToMe
            #     self.self_attn._tome_info = self._tome_info

            # if cfg.slt.save_fig:
            #     save_query_distribution_histogram(h, 100)

            if cfg.slt.batchnorm_msa:
                h = self.batchnorm_msa(h)
            elif cfg.slt.layernorm_msa:
                h = self.layernorm_msa(h)
            elif cfg.slt.pairnorm_msa:
                h = self.pairnorm_msa(h)
            elif cfg.slt.rmsnorm_msa:
                h = self.rmsnorm_msa(h)
            #
            #             if cfg.slt.save_fig:
            #                 save_query_distribution_histogram(h, 100)

            h_dense, mask = to_dense_batch(h, batch.batch)

            if self.global_model_type == "Transformer":
                if cfg.slt.msa is True:
                    if cfg.slt.tome:
                        # key_padding_mask = ~mask
                        # tome_mask = mask[:, : -cfg.slt.tome_r]
                        if cfg.slt.pruning == "blockwise":
                            h_attn = self._sa_block(
                                h_dense,
                                None,
                                ~mask,
                                cur_epoch,
                                msa_th,
                            )[mask]
                        elif cfg.slt.pruning == "global":
                            h_attn, metric = self._sa_block(
                                h_dense,
                                None,
                                ~mask,
                                cur_epoch,
                                global_th,
                            )
                            # h_attn = h_attn[mask]
                    else:
                        if cfg.slt.pruning == "blockwise":
                            h_attn = self._sa_block(
                                h_dense, None, ~mask, cur_epoch, msa_th
                            )[mask]
                        elif cfg.slt.pruning == "global":
                            h_attn = self._sa_block(
                                h_dense, None, ~mask, cur_epoch, global_th
                            )[mask]
                else:
                    h_attn = self._sa_block(h_dense, None, ~mask)[mask]
            elif self.global_model_type == "BiasedTransformer":
                # Use Graphormer-like conditioning, requires `batch.attn_bias`.
                h_attn = self._sa_block(h_dense, batch.attn_bias, ~mask)[mask]
            elif self.global_model_type == "Performer":
                h_attn = self.self_attn(h_dense, mask=mask)[mask]
            elif self.global_model_type == "BigBird":
                h_attn = self.self_attn(h_dense, attention_mask=mask)
            else:
                raise RuntimeError(f"Unexpected {self.global_model_type}")

            h_attn = self.dropout_attn(h_attn)

            h_attn = h_in1 + h_attn

            if cfg.slt.tome:
                r = self._tome_info["r"].pop(0)
                if r > 0:
                    # Apply ToMe here
                    merge, _ = bipartite_soft_matching(
                        metric,
                        r,
                        # self._tome_info["class_token"],
                        # self._tome_info["distill_token"],
                    )
                    if self._tome_info["trace_source"]:
                        self._tome_info["source"] = merge_source(
                            merge, h_attn, self._tome_info["source"]
                        )
                    h_attn, self._tome_info["size"] = merge_wavg(
                        merge, h_attn, self._tome_info["size"]
                    )
                    batch.batch = batch.batch[: -cfg.slt.tome_r]

            if cfg.slt.msa_layer_norm is True:
                h_attn = self.layernorm1_attn(h_attn, batch.batch)
            elif cfg.slt.msa_batch_norm is True:
                h_attn = self.batchnorm1_attn(h_attn)
            h_out_list.append(h_attn)

        # Combine local and global outputs.
        # h = torch.cat(h_out_list, dim=-1)

        # if cfg.slt.learnable_sum:
        #     h = self.w_local * h_out_list[0] + self.w_attn * h_out_list[1]
        # else:
        h = sum(h_out_list)

        if cfg.slt.batchnorm_ffn:
            h = self.batchnorm_ffn(h)
        elif cfg.slt.layernorm_ffn:
            h = self.layernorm_ffn(h)
        elif cfg.slt.pairnorm_ffn:
            h = self.pairnorm_ffn(h)
        elif cfg.slt.rmsnorm_ffn:
            h = self.rmsnorm_ffn(h)

        # Feed Forward block.
        if cfg.slt.pruning == "blockwise":
            h = h + self._ff_block(h, cur_epoch, ffn_th)
        elif cfg.slt.pruning == "global":
            h = h + self._ff_block(h, cur_epoch, global_th)
        if self.layer_norm:
            h = self.norm2(h, batch.batch)
        if self.batch_norm:
            h = self.norm2(h)

        batch.x = h
        return batch

    def calculate_sparsity(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsity based on the scheduling type."""
        if schedule_type == "linear":
            return sparsity * (epoch / max_epoch_half)
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            return sparsity * scale_factor
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return (sparsity / 10) * current_step
        elif schedule_type == "reverse":
            return 1 - ((1 - sparsity) * (epoch / max_epoch_half))
        return sparsity

    def calculate_sparsities(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsities for multi-mode (mm) based on the scheduling type."""
        if schedule_type == "linear":
            return [value * (epoch / max_epoch_half) for value in sparsity]
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            increments = [
                (sparsity[i] - sparsity[0]) / (1 - sparsity[0])
                for i in range(1, len(sparsity))
            ]
            sparsities = []
            for i, value in enumerate(sparsity):
                if i == 0:
                    sparsities.append(value * scale_factor)
                else:
                    start_point = increments[i - 1]
                    sparsities.append(
                        start_point + (value - start_point) * scale_factor
                    )
            return sparsities
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return [
                value * ((value / 10) * current_step) for value in sparsity
            ]
        elif schedule_type == "reverse":
            return [
                1 - ((1 - value) * (epoch / max_epoch_half))
                for value in sparsity
            ]
        return sparsity

    def get_threshold(self, sparsity, epoch=None, keyword=""):
        max_epoch_half = cfg.optim.max_epoch // 2
        is_training = self.training
        sparsities = sparsity

        if cfg.slt.mm:
            if is_training and epoch < max_epoch_half:
                sparsities = self.calculate_sparsities(
                    sparsity, epoch, max_epoch_half, cfg.slt.sparse_scheduling
                )

            thresholds = []
            for sparsity_value in sparsities:
                local_params = torch.cat(
                    [
                        p.detach().flatten()
                        for name, p in self.named_parameters()
                        if hasattr(p, "is_score")
                        and p.is_score
                        and keyword in name
                    ]
                )
                threshold = percentile(
                    (
                        local_params.abs()
                        if cfg.slt.enable_abs_pruning
                        else local_params
                    ),
                    sparsity_value * 100,
                )
                thresholds.append(threshold)
            return thresholds

        elif cfg.slt.sm:
            if is_training and epoch < max_epoch_half:
                sparsity_value = self.calculate_sparsity(
                    sparsity[0],
                    epoch,
                    max_epoch_half,
                    cfg.slt.sparse_scheduling,
                )
            else:
                sparsity_value = sparsity[0]

            local_params = torch.cat(
                [
                    p.detach().flatten()
                    for name, p in self.named_parameters()
                    if hasattr(p, "is_score")
                    and p.is_score
                    and keyword in name
                ]
            )
            threshold = percentile(
                (
                    local_params.abs()
                    if cfg.slt.enable_abs_pruning
                    else local_params
                ),
                sparsity_value * 100,
            )
            return threshold

    def _sa_block(
        self,
        x,
        attn_mask,
        key_padding_mask,
        cur_epoch=None,
        msa_threshold=None,
    ):
        """Self-attention block."""

        # if cfg.slt.save_fig:
        #     save_query_distribution_histogram(x, 100)

        if not self.log_attn_weights:
            if cfg.slt.msa is True:
                if cfg.slt.tome:
                    x, _, metric = self.self_attn(
                        x,
                        x,
                        x,
                        attn_mask=attn_mask,
                        key_padding_mask=key_padding_mask,
                        cur_epoch=cur_epoch,
                        msa_threshold=msa_threshold,
                        need_weights=False,
                    )
                    return x, metric
                else:
                    x = self.self_attn(
                        x,
                        x,
                        x,
                        attn_mask=attn_mask,
                        key_padding_mask=key_padding_mask,
                        cur_epoch=cur_epoch,
                        msa_threshold=msa_threshold,
                        need_weights=False,
                    )[0]
            else:
                x = self.self_attn(
                    x,
                    x,
                    x,
                    attn_mask=attn_mask,
                    key_padding_mask=key_padding_mask,
                    need_weights=False,
                )[0]
        else:
            # Requires PyTorch v1.11+ to support `average_attn_weights=False`
            # option to return attention weights of individual heads.
            x, A = self.self_attn(
                x,
                x,
                x,
                attn_mask=attn_mask,
                key_padding_mask=key_padding_mask,
                need_weights=True,
                average_attn_weights=False,
            )
            self.attn_weights = A.detach().cpu()
        return x

    def _ff_block(self, x, cur_epoch=None, ffn_threshold=None):
        """Feed Forward block."""
        if cfg.slt.ffn is True:
            if cfg.slt.pruning == "layerwise":
                threshold = self.get_threshold(
                    cfg.slt.linear_sparsity, cur_epoch, "ff_linear"
                )
            else:
                threshold = ffn_threshold

            x = self.ff_dropout1(self.act_fn_ff(self.ff_linear1(x, threshold)))
            return self.ff_dropout2(self.ff_linear2(x, threshold))
        else:
            x = self.ff_dropout1(self.act_fn_ff(self.ff_linear1(x)))
            return self.ff_dropout2(self.ff_linear2(x))

    def extra_repr(self):
        s = (
            f"summary: dim_h={self.dim_h}, "
            f"local_gnn_type={self.local_gnn_type}, "
            f"global_model_type={self.global_model_type}, "
            f"heads={self.num_heads}"
        )
        return s
