import math
import numbers
from typing import Optional, Tuple

# import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
from torch.nn.parameter import Parameter
from torch_geometric.graphgym import cfg


class RMSNorm(nn.Module):
    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
    normalized_shape: Tuple[int, ...]
    eps: Optional[float]
    elementwise_affine: bool

    def __init__(
        self,
        normalized_shape,
        eps: Optional[float] = None,
        elementwise_affine: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps if eps is not None else torch.finfo(torch.float32).eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            self.weight = Parameter(
                torch.empty(self.normalized_shape, device=device, dtype=dtype)
            )
        else:
            self.register_parameter("weight", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.elementwise_affine:
            init.ones_(self.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Ensure weight is on the same device as input tensor x
        if self.weight is not None:
            self.weight = self.weight.to(x.device)

        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x = x / rms
        if self.elementwise_affine:
            x = x * self.weight
        return x

    def extra_repr(self) -> str:
        return (
            "{normalized_shape}, eps={eps}, "
            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
        )


# import random


# import random


EPSILON = np.finfo(float).eps


class GetSubnet(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scores, threshold, zeros, ones):
        out = torch.where(
            scores < threshold, zeros.to(scores.device), ones.to(scores.device)
        )
        return out

    @staticmethod
    def backward(ctx, g):
        return g, None, None, None


class SparseModule(nn.Module):
    def init_param_(
        self, param, init_mode=None, scale=None, sparse_value=None, gain="relu"
    ):
        if init_mode == "kaiming_normal":
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "uniform":
            nn.init.uniform_(param, a=-1, b=1)
            param.data *= scale
        elif init_mode == "kaiming_uniform":
            nn.init.kaiming_uniform_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "kaiming_normal_SF":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            param.data.normal_(0, std)
        elif init_mode == "signed_constant":
            # From github.com/allenai/hidden-networks
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            std = gain / math.sqrt(fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale
        elif (
            init_mode == "signed_constant_SF"
            or init_mode == "signed_kaiming_constant_SF"
        ):
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            if not cfg.slt.learnable_weight_scaling:
                param.data = param.data.sign() * std
                param.data *= scale  # scale value is defined in defualt as 1.0
        elif init_mode == "signed_xavier_uniform_constant_SF":
            gain = nn.init.calculate_gain(gain)
            nn.init.xavier_uniform_(param, gain)
            if not cfg.slt.learnable_weight_scaling:
                std = torch.std(param)
                scaled_std = std * math.sqrt(1 / (1 - sparse_value))
                nn.init.kaiming_normal_(param)
                param.data = param.data.sign() * scaled_std
                param.data *= scale
        elif init_mode == "signed_xavier_normal_constant_SF":
            gain = nn.init.calculate_gain(gain)
            nn.init.xavier_normal_(param, gain)
            if not cfg.slt.learnable_weight_scaling:
                std = torch.std(param)
                scaled_std = std * math.sqrt(1 / (1 - sparse_value))
                nn.init.kaiming_normal_(param)
                param.data = param.data.sign() * scaled_std
                param.data *= scale
        elif init_mode == "signed_trunc_constant_SF":
            nn.init.trunc_normal_(param)
            if not cfg.slt.learnable_weight_scaling:
                param_std = torch.std(param)
                with torch.no_grad():
                    param.copy_(torch.where(param >= 0, param_std, -param_std))
        elif init_mode == "signed_kaiming_uniform_constant":
            gain = nn.init.calculate_gain(gain)
            nn.init.kaiming_uniform_(param, gain)
            if not cfg.slt.learnable_weight_scaling:
                std = torch.std(param)
                scaled_std = std * math.sqrt(1 / (1 - sparse_value))
                nn.init.kaiming_normal_(param)
                param.data = param.data.sign() * scaled_std
                param.data *= scale
        elif init_mode == "signed_one":
            nn.init.kaiming_normal_(param)
            param.data = param.data.sign() * 1.0
        else:
            raise NotImplementedError


class SparseLinear(SparseModule):
    def __init__(
        self,
        in_ch,
        out_ch,
        bias=False,
        attention_scaling=None,
        gain="relu",
        init_mode_weight=None,
    ):
        super().__init__()

        self.sparsity = cfg.slt.linear_sparsity
        if init_mode_weight is None:
            self.init_mode_weight = cfg.slt.init_mode_weight
        else:
            self.init_mode_weight = init_mode_weight
        self.init_mode_score = cfg.slt.init_mode_score
        self.init_scale_weight = cfg.slt.init_scale_weight
        if attention_scaling is not None:
            self.init_scale_weight = attention_scaling
        self.init_scale_score = cfg.slt.init_scale_score
        self.weight = nn.Parameter(torch.ones(out_ch, in_ch))
        self.enable_unshared = cfg.slt.enable_unshared
        self.sparsity_value = cfg.slt.linear_sparsity
        self.weight.requires_grad = False

        self.attention_scaling = self.init_scale_weight

        self.weight_score = nn.Parameter(torch.ones(self.weight.size()))
        self.weight_score.is_score = True
        self.weight_score.sparsity = self.sparsity

        self.bias = None

        if cfg.slt.learnable_weight_scaling:
            self.init_mode_weight = "signed_one"

        self.init_param_(
            self.weight,
            init_mode=self.init_mode_weight,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity_value[0],
            gain=gain,
        )

        if cfg.slt.learnable_weight_scaling:
            self.init_mode_score = init_mode_weight

        self.init_param_(
            self.weight_score,
            init_mode=self.init_mode_score,
            scale=self.init_scale_score,
            sparse_value=self.sparsity_value[0],
        )

        self.weight_zeros = torch.zeros(self.weight.size())
        self.weight_ones = torch.ones(self.weight.size())
        self.weight_zeros.requires_grad = False
        self.weight_ones.requires_grad = False

    def forward(
        self, x, threshold, manual_mask=None, index_mask=0, q_lin=False
    ):
        subnet = GetSubnet.apply(
            (
                torch.abs(self.weight_score)
                if cfg.slt.enable_abs_pruning or cfg.slt.sign_mask
                else self.weight_score
            ),
            threshold,
            self.weight_zeros,
            self.weight_ones,
        )
        if cfg.slt.sign_mask:
            sign_subnet = GetSubnet.apply(
                self.weight_score,
                torch.median(self.weight_score),
                self.weight_ones * -1,
                self.weight_ones,
            )
            subnet = subnet * sign_subnet

        pruned_weight = self.weight * subnet

        if cfg.slt.learnable_weight_scaling:
            weight_score_abs_mean = torch.mean(torch.abs(self.weight_score))
            pruned_weight = pruned_weight * weight_score_abs_mean

        ret = F.linear(x, pruned_weight, self.bias)
        return ret


class SparseLinearMulti_mask(SparseModule):
    def __init__(
        self,
        in_ch,
        out_ch,
        bias=False,
        attention_scaling=None,
        gain="relu",
        init_mode_weight=None,
    ):
        super().__init__()

        self.sparsity = cfg.slt.linear_sparsity
        if init_mode_weight is None:
            self.init_mode_weight = cfg.slt.init_mode_weight
        else:
            self.init_mode_weight = init_mode_weight
        self.init_mode_score = cfg.slt.init_mode_score
        self.init_scale_weight = cfg.slt.init_scale_weight
        if attention_scaling is not None:
            self.init_scale_weight = attention_scaling
        self.init_scale_score = cfg.slt.init_scale_score
        self.weight = nn.Parameter(torch.ones(out_ch, in_ch))
        self.enable_unshared = cfg.slt.enable_unshared
        self.sparsity_value = cfg.slt.linear_sparsity
        self.weight.requires_grad = False

        self.attention_scaling = self.init_scale_weight

        self.weight_score = nn.Parameter(torch.ones(self.weight.size()))
        self.weight_score.is_score = True
        self.weight_score.sparsity = self.sparsity

        self.bias = None

        if cfg.slt.learnable_weight_scaling:
            self.init_mode_weight = "signed_one"

        self.init_param_(
            self.weight,
            init_mode=self.init_mode_weight,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity_value[0],
            gain=gain,
        )

        if cfg.slt.learnable_weight_scaling:
            self.init_mode_score = init_mode_weight

        self.init_param_(
            self.weight_score,
            init_mode=self.init_mode_score,
            scale=self.init_scale_score,
            sparse_value=self.sparsity_value[0],
        )

        self.weight_zeros = torch.zeros(self.weight.size())
        self.weight_ones = torch.ones(self.weight.size())
        self.weight_zeros.requires_grad = False
        self.weight_ones.requires_grad = False

    def forward(
        self, x, threshold, manual_mask=None, index_mask=0, q_lin=False
    ):
        subnets = []
        for threshold_v in threshold:
            subnet = GetSubnet.apply(
                (
                    torch.abs(self.weight_score)
                    if cfg.slt.enable_abs_pruning or cfg.slt.sign_mask
                    else self.weight_score
                ),
                threshold_v,
                self.weight_zeros,
                self.weight_ones,
            )
            if cfg.slt.sign_mask:
                sign_subnet = GetSubnet.apply(
                    self.weight_score,
                    0,
                    self.weight_ones * -1,
                    self.weight_ones,
                )
                subnet = subnet * sign_subnet
            subnets.append(subnet)

        combined_subnet = torch.stack(subnets).sum(dim=0)
        pruned_weight = self.weight * combined_subnet

        if cfg.slt.learnable_weight_scaling:
            weight_score_abs_mean = torch.mean(torch.abs(self.weight_score))
            pruned_weight = pruned_weight * weight_score_abs_mean

        ret = F.linear(x, pruned_weight, self.bias)
        return ret


class BitLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False):
        super().__init__(in_features, out_features, bias=bias)
        self.rms_norm = RMSNorm(in_features).cuda()

    def activation_quant(self, x):
        scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(
            min=1e-5
        )
        y = (x * scale).round().clamp_(-128, 127) / scale
        return y

    def weight_quant(self, w):
        scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
        u = (w * scale).round().clamp_(-1, 1) / scale
        return u

    def forward(self, x, threshold=None):
        w = self.weight
        x_norm = self.rms_norm(x)
        x_quant = x_norm + (self.activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (self.weight_quant(w) - w).detach()
        y = F.linear(x_quant, w_quant)
        return y


class NMSparseMultiLinear(SparseModule):
    def __init__(self, in_ch, out_ch, bias=None):
        super().__init__()

        self.in_ch = in_ch
        self.out_ch = out_ch
        self.sparsity = cfg.slt.linear_sparsity[0]
        self.init_mode = cfg.slt.init_mode_weight
        self.init_mode_mask = cfg.slt.init_mode_score
        self.init_scale = cfg.slt.init_scale_weight
        self.init_scale_score = cfg.slt.init_scale_score

        self.weight = nn.Parameter(torch.ones(out_ch, in_ch))
        self.M = 16
        self.weight_score = nn.Parameter(torch.ones(in_ch, out_ch))
        self.bias = None
        self.layer = 0
        self.init_param_(
            self.weight,
            init_mode=self.init_mode,
            scale=self.init_scale,
            sparse_value=self.sparsity,
        )
        self.weight.requires_grad = False
        self.init_param_(
            self.weight_score,
            init_mode=self.init_mode_mask,
            scale=self.init_scale_score,
            sparse_value=self.sparsity,
        )
        self.weight_score.is_score = True
        self.weight_score.is_weight_score = True
        self.weight_score.sparsity = self.sparsity

        self.decay = cfg.slt.srste_decay

    def forward(self, x, threshold=None, sparsity=None):
        pruned_weight, combined_subnet = GetNMMultiSubnet.apply(
            self.weight,
            self.weight_score,
            self.M,
            self.sparsity,
            self.decay,
        )
        ret = F.linear(x, pruned_weight, self.bias)
        return ret

    def encode_array(self, input_array):
        rows, cols = input_array.shape
        num_blocks = rows // 16
        position_array = np.zeros((num_blocks * 8, cols), dtype=np.uint8)
        value_array = np.zeros((num_blocks * 8, cols), dtype=np.uint8)

        for col in range(cols):
            for block in range(num_blocks):
                start_idx = block * 16
                block_data = input_array[start_idx : start_idx + 16, col]

                pos_idx = 0
                for i, value in enumerate(block_data):
                    if value != 0:
                        position = i
                        encoded_value = value
                        position_array[pos_idx + block * 8, col] = position
                        value_array[pos_idx + block * 8, col] = encoded_value
                        pos_idx += 1

        return position_array, value_array


class GetNMMultiSubnet(torch.autograd.Function):

    @staticmethod
    def forward(ctx, weight, weight_score, M, sparsity, decay):
        # if torch.is_grad_enabled():
        #     ctx.save_for_backward(weight_score)
        ctx.save_for_backward(weight_score)

        num_rows, num_cols = weight_score.shape
        group_size = M
        num_full_groups = num_cols // group_size
        last_group_size = num_cols % group_size

        mask = torch.ones_like(weight_score)
        N = int(M * (1 - sparsity))

        for group in range(num_full_groups):

            start_idx = group * group_size
            end_idx = (group + 1) * group_size
            block = (
                weight_score[:, start_idx:end_idx].detach().abs()
                if cfg.slt.enable_abs_pruning
                else weight_score[:, start_idx:end_idx].detach()
            )
            _, indices = torch.topk(
                block,
                min(int(group_size - N), block.size(1)),
                largest=False,
                dim=1,
            )
            mask[:, start_idx:end_idx].scatter_(1, indices, 0)

        if last_group_size > 0:
            start_idx = num_full_groups * group_size
            end_idx = num_cols
            block = (
                weight_score[:, start_idx:end_idx].detach().abs()
                if cfg.slt.enable_abs_pruning
                else weight_score[:, start_idx:end_idx].detach()
            )
            last_group_N = int(last_group_size * (1 - sparsity))
            last_group_N = max(1, last_group_N)
            _, indices = torch.topk(
                block,
                int(last_group_size - last_group_N),
                largest=False,
                dim=1,
            )
            mask[:, start_idx:end_idx].scatter_(1, indices, 0)

        indices_of_ones = torch.nonzero(mask == 1, as_tuple=True)
        scores_of_ones = (
            weight_score[indices_of_ones].detach().abs()
            if cfg.slt.enable_abs_pruning
            else weight_score[indices_of_ones].detach()
        )
        threshold_low = scores_of_ones.quantile(0.33333)
        threshold_high = scores_of_ones.quantile(0.66666)

        mask[indices_of_ones] = torch.where(
            scores_of_ones < threshold_low,
            1.0,
            torch.where(scores_of_ones < threshold_high, 2, 3),
        )

        ctx.mask = mask
        ctx.decay = decay
        ctx.weight = weight

        return weight * mask.T, mask.T

    @staticmethod
    def backward(ctx, grad_output, grad_extra):
        return (
            None,
            grad_output.T - ctx.decay * (3 - ctx.mask) * ctx.weight.T,
            None,
            None,
            None,
        )


class SLT_AtomEncoder(nn.Module):
    def init_param_(
        self, param, init_mode=None, scale=None, sparse_value=None
    ):
        if init_mode == "kaiming_normal":
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "uniform":
            nn.init.uniform_(param, a=-1, b=1)
            param.data *= scale
        elif init_mode == "kaiming_uniform":
            nn.init.kaiming_uniform_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "kaiming_normal_SF":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            param.data.normal_(0, std)
        elif init_mode == "signed_constant":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            std = gain / math.sqrt(fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale
        elif (
            init_mode == "signed_constant_sparse"
            or init_mode == "signed_constant_SF"
        ):
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale  # scale value is defined in defualt as 1.0
        else:
            raise NotImplementedError

    def __init__(self, emb_dim):
        super(SLT_AtomEncoder, self).__init__()
        self.sparsity = cfg.slt.linear_sparsity

        self.atom_embedding_list = nn.ModuleList()
        full_atom_feature_dims = get_atom_feature_dims()
        self.weight_scores_list = nn.ParameterList()
        self.weight_zeros_list = []
        self.weight_ones_list = []

        self.init_mode_weight = cfg.slt.init_mode_weight
        self.init_mode_score = cfg.slt.init_mode_score
        self.init_scale_weight = cfg.slt.init_scale_weight
        self.init_scale_score = cfg.slt.init_scale_score

        # self.SLTAtom = args.SLTAtom
        # self.SLTAtom_ini = args.SLTAtom_ini

        for i, dim in enumerate(full_atom_feature_dims):
            emb = nn.Embedding(dim, emb_dim)
            weight_score = nn.Parameter(torch.ones(1, emb.weight.size(1)))
            weight_score.is_score = True
            weight_score.sparsity = self.sparsity

            self.init_param_(
                emb.weight,
                init_mode=self.init_mode_weight,
                scale=self.init_scale_weight,
                sparse_value=self.sparsity[0],
            )

            emb.weight.requires_grad = False

            self.init_param_(
                weight_score,
                init_mode=self.init_mode_score,
                scale=self.init_scale_score,
                sparse_value=self.sparsity[0],
            )

            self.weight_zeros = torch.zeros(1, emb.weight.size(1))
            self.weight_ones = torch.ones((1, emb.weight.size(1)))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False

            self.atom_embedding_list.append(emb)
            self.weight_scores_list.append(weight_score)
            self.weight_zeros_list.append(self.weight_zeros)
            self.weight_ones_list.append(self.weight_ones)

    def forward(
        self,
        x,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        x_embedding = 0

        if cfg.slt.pruning == "global":
            threshold = global_th
        elif cfg.slt.pruning == "blockwise":
            threshold = encoder_th

        if isinstance(threshold, float):
            threshold = [threshold]
        for i in range((x.x).shape[1]):
            weight_score = self.weight_scores_list[i]
            weight_zeros = self.weight_zeros_list[i]
            weight_ones = self.weight_ones_list[i]
            subnets = []
            for threshold_v in threshold:
                subnet = GetSubnet.apply(
                    (
                        weight_score.abs()
                        if cfg.slt.enable_abs_pruning
                        else weight_score
                    ),
                    threshold_v,
                    weight_zeros,
                    weight_ones,
                )
                if cfg.slt.sign_mask:
                    sign_subnet = GetSubnet.apply(
                        weight_score,
                        0,
                        self.weight_ones * -1,
                        self.weight_ones,
                    )
                    subnet = subnet * sign_subnet
                subnets.append(subnet)
            combined_subnet = torch.stack(subnets).sum(dim=0)
            x_embedding += (
                self.atom_embedding_list[i]((x.x)[:, i]) * combined_subnet
            )

        (x.x) = x_embedding
        return x


class SLT_AtomEncoder_new(nn.Module):
    def init_param_(
        self, param, init_mode=None, scale=None, sparse_value=None
    ):
        if init_mode == "kaiming_normal":
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "uniform":
            nn.init.uniform_(param, a=-1, b=1)
            param.data *= scale
        elif init_mode == "kaiming_uniform":
            nn.init.kaiming_uniform_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "kaiming_normal_SF":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            param.data.normal_(0, std)
        elif init_mode == "signed_constant":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            std = gain / math.sqrt(fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale
        elif (
            init_mode == "signed_constant_sparse"
            or init_mode == "signed_constant_SF"
        ):
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale  # scale value is defined in defualt as 1.0
        else:
            raise NotImplementedError

    def __init__(self, emb_dim):
        super(SLT_AtomEncoder_new, self).__init__()
        self.sparsity = cfg.slt.linear_sparsity

        self.atom_embedding_list = nn.ModuleList()
        full_atom_feature_dims = get_atom_feature_dims()
        self.weight_scores_list = nn.ParameterList()
        self.weight_zeros_list = []
        self.weight_ones_list = []

        self.init_mode_weight = cfg.slt.init_mode_weight
        self.init_mode_score = cfg.slt.init_mode_score
        self.init_scale_weight = cfg.slt.init_scale_weight
        self.init_scale_score = cfg.slt.init_scale_score

        # self.SLTAtom = args.SLTAtom
        # self.SLTAtom_ini = args.SLTAtom_ini

        for i, dim in enumerate(full_atom_feature_dims):
            emb = nn.Embedding(dim, emb_dim)
            self.weight_score_list = nn.ParameterList()

            self.init_param_(
                emb.weight,
                init_mode=self.init_mode_score,
                scale=self.init_scale_weight,
                sparse_value=self.sparsity[0],
            )
            emb.weight.requires_grad = False

            for _ in range(emb.weight.size(0)):

                weight_score = nn.Parameter(torch.ones(1, emb.weight.size(1)))
                weight_score.is_score = True
                weight_score.sparsity = self.sparsity
                self.init_param_(
                    weight_score,
                    init_mode=self.init_mode_score,
                    scale=self.init_scale_score,
                    sparse_value=self.sparsity[0],
                )
                self.weight_score_list.append(weight_score)

            self.weight_zeros = torch.zeros(1, emb.weight.size(1))
            self.weight_ones = torch.ones((1, emb.weight.size(1)))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False

            self.atom_embedding_list.append(emb)
            self.weight_scores_list.append(self.weight_score_list)
            self.weight_zeros_list.append(self.weight_zeros)
            self.weight_ones_list.append(self.weight_ones)

    def sparse_subnet(
        self, x, i, threshold, weight_score_list, weight_zeros, weight_ones
    ):
        sparse_subnet_list = []
        for item in (x.x)[:, i]:
            subnets = []
            weight_score = weight_score_list[item]
            for threshold_v in threshold:
                subnet = GetSubnet.apply(
                    (
                        weight_score.abs()
                        if cfg.slt.enable_abs_pruning
                        else weight_score
                    ),
                    threshold_v,
                    weight_zeros,
                    weight_ones,
                )
                if cfg.slt.sign_mask:
                    sign_subnet = GetSubnet.apply(
                        weight_score,
                        0,
                        self.weight_ones * -1,
                        self.weight_ones,
                    )
                    subnet = subnet * sign_subnet
                subnets.append(subnet)
            combined_subnet = torch.stack(subnets).sum(dim=0)
            sparse_subnet_list.append(combined_subnet)

        sparse_subnet_list = torch.cat(sparse_subnet_list, dim=0)
        return sparse_subnet_list

    def forward(
        self,
        x,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        x_embedding = 0

        if cfg.slt.pruning == "global":
            threshold = global_th
        elif cfg.slt.pruning == "blockwise":
            threshold = encoder_th

        if isinstance(threshold, float):
            threshold = [threshold]
        for i in range((x.x).shape[1]):  # 9
            weight_score_list = self.weight_scores_list[i]
            weight_zeros = self.weight_zeros_list[i]
            weight_ones = self.weight_ones_list[i]

            emb_weight = self.atom_embedding_list[i]((x.x)[:, i])
            emb_mask = self.sparse_subnet(
                x, i, threshold, weight_score_list, weight_zeros, weight_ones
            )

            x_embedding += emb_weight * emb_mask

        (x.x) = x_embedding
        return x


class SLT_BondEncoder(nn.Module):
    def init_param_(
        self, param, init_mode=None, scale=None, sparse_value=None
    ):
        if init_mode == "kaiming_normal":
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "uniform":
            nn.init.uniform_(param, a=-1, b=1)
            param.data *= scale
        elif init_mode == "kaiming_uniform":
            nn.init.kaiming_uniform_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "kaiming_normal_SF":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            param.data.normal_(0, std)
        elif init_mode == "signed_constant":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            std = gain / math.sqrt(fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale
        elif (
            init_mode == "signed_constant_sparse"
            or init_mode == "signed_constant_SF"
        ):
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale  # scale value is defined in defualt as 1.0
        else:
            raise NotImplementedError

    def __init__(self, emb_dim):
        super(SLT_BondEncoder, self).__init__()
        self.sparsity = cfg.slt.linear_sparsity

        self.bond_embedding_list = nn.ModuleList()
        full_bond_feature_dims = get_bond_feature_dims()
        self.weight_scores_list = nn.ParameterList()
        self.weight_zeros_list = []
        self.weight_ones_list = []

        self.init_mode_weight = cfg.slt.init_mode_weight
        self.init_mode_score = cfg.slt.init_mode_score
        self.init_scale_weight = cfg.slt.init_scale_weight
        self.init_scale_score = cfg.slt.init_scale_score

        for i, dim in enumerate(full_bond_feature_dims):
            emb = nn.Embedding(dim, emb_dim)
            weight_score = nn.Parameter(torch.ones(1, emb.weight.size(1)))
            weight_score.is_score = True
            weight_score.sparsity = self.sparsity

            self.init_param_(
                emb.weight,
                init_mode=self.init_mode_weight,
                scale=self.init_scale_weight,
                sparse_value=self.sparsity[0],
            )

            emb.weight.requires_grad = False

            self.init_param_(
                weight_score,
                init_mode=self.init_mode_score,
                scale=self.init_scale_score,
                sparse_value=self.sparsity[0],
            )

            self.weight_zeros = torch.zeros(1, emb.weight.size(1))
            self.weight_ones = torch.ones((1, emb.weight.size(1)))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False

            self.bond_embedding_list.append(emb)
            self.weight_scores_list.append(weight_score)
            self.weight_zeros_list.append(self.weight_zeros)
            self.weight_ones_list.append(self.weight_ones)

    def forward(
        self,
        edge_attr,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        edge_embedding = 0

        if cfg.slt.pruning == "global":
            threshold = global_th
        elif cfg.slt.pruning == "blockwise":
            threshold = encoder_th

        if isinstance(threshold, float):
            threshold = [threshold]
        for i in range((edge_attr.edge_attr).shape[1]):
            weight_score = self.weight_scores_list[i]
            weight_zeros = self.weight_zeros_list[i]
            weight_ones = self.weight_ones_list[i]
            subnets = []
            for threshold_v in threshold:
                subnet = GetSubnet.apply(
                    (
                        torch.abs(weight_score)
                        if cfg.slt.enable_abs_pruning
                        else weight_score
                    ),
                    threshold_v,
                    weight_zeros,
                    weight_ones,
                )
                if cfg.slt.sign_mask:
                    sign_subnet = GetSubnet.apply(
                        weight_score,
                        0,
                        self.weight_ones * -1,
                        self.weight_ones,
                    )
                    subnet = subnet * sign_subnet
                subnets.append(subnet)
            combined_subnet = torch.stack(subnets).sum(dim=0)
            edge_embedding += (
                self.bond_embedding_list[i]((edge_attr.edge_attr)[:, i])
                * combined_subnet
            )

        edge_attr.edge_attr = edge_embedding

        return edge_attr


class SparseEmbedding_previous(nn.Module):
    def __init__(self, emb_dim, feature_dims):
        super(SparseEmbedding_previous, self).__init__()
        self.sparsity = cfg.slt.linear_sparsity

        self.sparse_embedding_list = nn.ModuleList()
        self.weight_scores_list = nn.ParameterList()
        self.weight_zeros_list = []
        self.weight_ones_list = []

        self.init_mode_weight = cfg.slt.init_mode_weight
        self.init_mode_score = cfg.slt.init_mode_score
        self.init_scale_weight = cfg.slt.init_scale_weight
        self.init_scale_score = cfg.slt.init_scale_score

        # for i, dim in enumerate(full_feature_dims):
        emb = nn.Embedding(feature_dims, emb_dim)
        weight_score = nn.Parameter(torch.ones(1, emb.weight.size(1)))
        weight_score.is_score = True
        weight_score.sparsity = self.sparsity

        self.init_param_(
            emb.weight,
            init_mode=self.init_mode_weight,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity[0],
        )

        emb.weight.requires_grad = False

        self.init_param_(
            weight_score,
            init_mode=self.init_mode_score,
            scale=self.init_scale_score,
            sparse_value=self.sparsity[0],
        )

        self.weight_zeros = torch.zeros(1, emb.weight.size(1))
        self.weight_ones = torch.ones((1, emb.weight.size(1)))
        self.weight_zeros.requires_grad = False
        self.weight_ones.requires_grad = False

        self.sparse_embedding_list.append(emb)
        self.weight_scores_list.append(weight_score)
        self.weight_zeros_list.append(self.weight_zeros)
        self.weight_ones_list.append(self.weight_ones)

    def init_param_(
        self, param, init_mode=None, scale=None, sparse_value=None
    ):
        if init_mode == "kaiming_normal":
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "uniform":
            nn.init.uniform_(param, a=-1, b=1)
            param.data *= scale
        elif init_mode == "kaiming_uniform":
            nn.init.kaiming_uniform_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "kaiming_normal_SF":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            param.data.normal_(0, std)
        elif init_mode == "signed_constant":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            std = gain / math.sqrt(fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale
        elif (
            init_mode == "signed_constant_sparse"
            or init_mode == "signed_constant_SF"
        ):
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale  # scale value is defined in defualt as 1.0
        else:
            raise NotImplementedError

    def forward(self, attr, threshold):
        embedding = 0

        if isinstance(threshold, float):
            threshold = [threshold]
        weight_score = self.weight_scores_list[0]
        weight_zeros = self.weight_zeros_list[0]
        weight_ones = self.weight_ones_list[0]
        subnets = []
        for threshold_v in threshold:
            subnet = GetSubnet.apply(
                (
                    torch.abs(weight_score)
                    if cfg.slt.enable_abs_pruning
                    else weight_score
                ),
                threshold_v,
                weight_zeros,
                weight_ones,
            )
            if cfg.slt.sign_mask:
                sign_subnet = GetSubnet.apply(
                    weight_score,
                    0,
                    self.weight_ones * -1,
                    self.weight_ones,
                )
                subnet = subnet * sign_subnet
            subnets.append(subnet)
        combined_subnet = torch.stack(subnets).sum(dim=0)
        embedding += self.sparse_embedding_list[0](attr) * combined_subnet

        return embedding


class SparseEmbedding(nn.Module):
    def __init__(self, emb_dim, feature_dims):
        super(SparseEmbedding, self).__init__()
        self.sparsity = cfg.slt.linear_sparsity

        self.sparse_embedding_list = nn.ModuleList()
        self.weight_scores_list = nn.ParameterList()
        self.weight_zeros_list = []
        self.weight_ones_list = []

        self.init_mode_weight = cfg.slt.init_mode_weight
        self.init_mode_score = cfg.slt.init_mode_score
        self.init_scale_weight = cfg.slt.init_scale_weight
        self.init_scale_score = cfg.slt.init_scale_score

        # emb weight
        emb = nn.Embedding(feature_dims, emb_dim)
        self.init_param_(
            emb.weight,
            init_mode=self.init_mode_score,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity[0],
        )
        emb.weight.requires_grad = False

        # emb score
        # weight_scores = nn.ParameterList()
        for _ in range(emb.weight.size(0)):
            weight_score = nn.Parameter(torch.ones(1, emb.weight.size(1)))
            weight_score.is_score = True
            weight_score.sparsity = self.sparsity
            self.init_param_(
                weight_score,
                init_mode=self.init_mode_score,
                scale=self.init_scale_score,
                sparse_value=self.sparsity[0],
            )
            self.weight_scores_list.append(weight_score)

        self.weight_zeros = torch.zeros(1, emb.weight.size(1))
        self.weight_ones = torch.ones((1, emb.weight.size(1)))
        self.weight_zeros.requires_grad = False
        self.weight_ones.requires_grad = False

        self.sparse_embedding_list.append(emb)
        # self.weight_scores_list.append(weight_scores)
        self.weight_zeros_list.append(self.weight_zeros)
        self.weight_ones_list.append(self.weight_ones)

    def init_param_(
        self, param, init_mode=None, scale=None, sparse_value=None
    ):
        if init_mode == "kaiming_normal":
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "uniform":
            nn.init.uniform_(param, a=-1, b=1)
            param.data *= scale
        elif init_mode == "kaiming_uniform":
            nn.init.kaiming_uniform_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "kaiming_normal_SF":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            param.data.normal_(0, std)
        elif init_mode == "signed_constant":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            std = gain / math.sqrt(fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale
        elif (
            init_mode == "signed_constant_sparse"
            or init_mode == "signed_constant_SF"
        ):
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale  # scale value is defined in defualt as 1.0
        else:
            raise NotImplementedError

    def forward(self, attr, threshold):
        embedding = 0

        if isinstance(threshold, float):
            threshold = [threshold]

        emb = self.sparse_embedding_list[0]
        weight_scores = self.weight_scores_list[0]
        weight_zeros = self.weight_zeros_list[0]
        weight_ones = self.weight_ones_list[0]

        # Iterate over each row in emb.weight
        for row_index in range(emb.weight.size(0)):
            weight_score = weight_scores[row_index]
            subnets = []

            for threshold_v in threshold:
                subnet = GetSubnet.apply(
                    (
                        torch.abs(weight_score)
                        if cfg.slt.enable_abs_pruning
                        else weight_score
                    ),
                    threshold_v,
                    weight_zeros,
                    weight_ones,
                )
                if cfg.slt.sign_mask:
                    sign_subnet = GetSubnet.apply(
                        weight_score,
                        0,
                        self.weight_ones * -1,
                        self.weight_ones,
                    )
                    subnet = subnet * sign_subnet
                subnets.append(subnet)

            combined_subnet = torch.stack(subnets).sum(dim=0)
            embedding += self.sparse_embedding_list[0].weight[
                row_index
            ] * combined_subnet.to(self.sparse_embedding_list[0].weight.device)

        return embedding
