import math
import numbers
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter


class RMSNorm(nn.Module):
    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
    normalized_shape: Tuple[int, ...]
    eps: Optional[float]
    elementwise_affine: bool

    def __init__(
        self,
        normalized_shape,
        eps: Optional[float] = None,
        elementwise_affine: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps if eps is not None else torch.finfo(torch.float32).eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            self.weight = Parameter(
                torch.empty(self.normalized_shape, device=device, dtype=dtype)
            )
        else:
            self.register_parameter("weight", None)
        self.pair_reset_parameters()

    def pair_reset_parameters(self) -> None:
        if self.elementwise_affine:
            init.ones_(self.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.weight is not None:
            self.weight = self.weight.to(x.device)

        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x = x / rms
        if self.elementwise_affine:
            x = x * self.weight
        return x

    def extra_repr(self) -> str:
        return (
            "{normalized_shape}, eps={eps}, "
            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
        )


class PairNorm(torch.nn.Module):
    def __init__(self, device=None, dtype=None):
        super(PairNorm, self).__init__()
        self.device = device
        self.dtype = dtype

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.device != self.device:
            x = x.to(self.device)

        col_mean = x.mean(dim=0)
        x = x - col_mean
        rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt()
        x = x / rownorm_mean
        return x

    def extra_repr(self) -> str:
        return f"device={self.device}, dtype={self.dtype}"


EPSILON = np.finfo(float).eps


class GetSubnet(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scores, threshold, zeros, ones):
        out = torch.where(
            scores < threshold, zeros.to(scores.device), ones.to(scores.device)
        )
        return out

    @staticmethod
    def backward(ctx, g):
        return g, None, None, None


class SparseModule(nn.Module):
    def init_param_(
        self,
        param,
        init_mode=None,
        scale=None,
        sparse_value=None,
        gain="relu",
        args=None,
    ):
        if init_mode == "kaiming_normal":
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "uniform":
            nn.init.uniform_(param, a=-1, b=1)
            param.data *= scale
        elif init_mode == "kaiming_uniform":
            nn.init.kaiming_uniform_(param, mode="fan_in", nonlinearity=gain)
            param.data *= scale
        elif init_mode == "kaiming_normal_SF":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            param.data.normal_(0, std)
        elif init_mode == "signed_constant":
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            std = gain / math.sqrt(fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale
        elif (
            init_mode == "signed_constant_SF"
            or init_mode == "signed_kaiming_constant_SF"
        ):
            fan = nn.init._calculate_correct_fan(param, "fan_in")
            gain = nn.init.calculate_gain(gain)
            scale_fan = fan * (1 - sparse_value)
            std = gain / math.sqrt(scale_fan)
            nn.init.kaiming_normal_(param)  # use only its sign
            if not args.learnable_weight_scaling:
                param.data = param.data.sign() * std
                param.data *= scale  # scale value is defined in defualt as 1.0
        elif init_mode == "signed_xavier_uniform_constant_SF":
            gain = nn.init.calculate_gain(gain)
            nn.init.xavier_uniform_(param, gain)
            if not args.learnable_weight_scaling:
                std = torch.std(param)
                scaled_std = std * math.sqrt(1 / (1 - sparse_value))
                nn.init.kaiming_normal_(param)
                param.data = param.data.sign() * scaled_std
                param.data *= scale
        elif init_mode == "signed_xavier_normal_constant_SF":
            gain = nn.init.calculate_gain(gain)
            nn.init.xavier_normal_(param, gain)
            if not args.learnable_weight_scaling:
                std = torch.std(param)
                scaled_std = std * math.sqrt(1 / (1 - sparse_value))
                nn.init.kaiming_normal_(param)
                param.data = param.data.sign() * scaled_std
                param.data *= scale
        elif init_mode == "signed_trunc_constant_SF":
            nn.init.trunc_normal_(param)
            if not args.learnable_weight_scaling:
                param_std = torch.std(param)
                with torch.no_grad():
                    param.copy_(torch.where(param >= 0, param_std, -param_std))
        elif init_mode == "signed_kaiming_uniform_constant_SF":
            gain = nn.init.calculate_gain(gain)
            nn.init.kaiming_uniform_(param, gain)
            if not args.learnable_weight_scaling:
                std = torch.std(param)
                scaled_std = std * math.sqrt(1 / (1 - sparse_value))
                nn.init.kaiming_normal_(param)
                param.data = param.data.sign() * scaled_std
                param.data *= scale
        elif init_mode == "signed_one":
            nn.init.kaiming_normal_(param)
            param.data = param.data.sign() * 1.0
        else:
            raise NotImplementedError


class SparseLinear(SparseModule):
    def __init__(
        self,
        in_ch,
        out_ch,
        args,
    ):
        super().__init__()

        self.sparsity = args.linear_sparsity
        self.init_mode_weight = args.init_mode_weight
        self.init_mode_score = args.init_mode_score
        self.init_scale_weight = args.init_scale_weight
        self.init_scale_score = args.init_scale_score
        self.weight = nn.Parameter(torch.ones(out_ch, in_ch))
        self.weight.requires_grad = False
        self.weight_score = nn.Parameter(torch.ones(self.weight.size()))
        self.weight_score.is_score = True
        self.weight_score.sparsity = self.sparsity

        if args.learnable_weight_scaling:
            self.init_mode_weight = "signed_one"
            self.init_mode_score = args.init_mode_weight

        self.init_param_(
            self.weight,
            init_mode=self.init_mode_weight,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity[0],
            args=args,
        )

        self.init_param_(
            self.weight_score,
            init_mode=self.init_mode_score,
            scale=self.init_scale_score,
            sparse_value=self.sparsity[0],
            args=args,
        )

        self.weight_zeros = torch.zeros(self.weight.size())
        self.weight_ones = torch.ones(self.weight.size())
        self.weight_zeros.requires_grad = False
        self.weight_ones.requires_grad = False

        self.args = args

    def reset_parameters(self):
        self.init_param_(
            self.weight,
            init_mode=self.init_mode_weight,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity[0],
            args=self.args,
        )
        self.init_param_(
            self.weight_score,
            init_mode=self.init_mode_score,
            scale=self.init_scale_score,
            sparse_value=self.sparsity[0],
            args=self.args,
        )

    def forward(
        self, x, threshold, manual_mask=None, index_mask=0, q_lin=False
    ):
        subnet = GetSubnet.apply(
            (
                torch.abs(self.weight_score)
                if self.args.enable_abs_pruning or self.args.sign_mask
                else self.weight_score
            ),
            threshold,
            self.weight_zeros,
            self.weight_ones,
        )
        if self.args.sign_mask:
            sign_subnet = GetSubnet.apply(
                self.weight_score,
                torch.median(self.weight_score),
                self.weight_ones * -1,
                self.weight_ones,
            )
            subnet = subnet * sign_subnet
        pruned_weight = self.weight * subnet
        ret = F.linear(x, pruned_weight, None)
        return ret


class SparseLinearMulti_mask(SparseModule):
    def __init__(
        self,
        in_ch,
        out_ch,
        args,
    ):
        super().__init__()

        self.sparsity = args.linear_sparsity
        self.init_mode_weight = args.init_mode_weight
        self.init_mode_score = args.init_mode_score
        self.init_scale_weight = args.init_scale_weight
        self.init_scale_score = args.init_scale_score
        self.weight = nn.Parameter(torch.ones(out_ch, in_ch))
        self.weight.requires_grad = False
        self.weight_score = nn.Parameter(torch.ones(self.weight.size()))
        self.weight_score.is_score = True
        self.weight_score.sparsity = self.sparsity

        self.init_param_(
            self.weight,
            init_mode=self.init_mode_weight,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity[0],
            args=args,
        )

        self.init_param_(
            self.weight_score,
            init_mode=self.init_mode_score,
            scale=self.init_scale_score,
            sparse_value=self.sparsity[0],
            args=args,
        )

        self.weight_zeros = torch.zeros(self.weight.size())
        self.weight_ones = torch.ones(self.weight.size())
        self.weight_zeros.requires_grad = False
        self.weight_ones.requires_grad = False

        self.args = args

    def reset_parameters(self):
        self.init_param_(
            self.weight,
            init_mode=self.init_mode_weight,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity[0],
            args=self.args,
        )
        self.init_param_(
            self.weight_score,
            init_mode=self.init_mode_score,
            scale=self.init_scale_score,
            sparse_value=self.sparsity[0],
            args=self.args,
        )

    def forward(
        self, x, threshold, manual_mask=None, index_mask=0, q_lin=False
    ):
        subnets = []
        for threshold_v in threshold:
            subnet = GetSubnet.apply(
                (
                    torch.abs(self.weight_score)
                    if self.args.enable_abs_pruning or self.args.sign_mask
                    else self.weight_score
                ),
                threshold_v,
                self.weight_zeros,
                self.weight_ones,
            )
            subnets.append(subnet)
        combined_subnet = torch.stack(subnets).sum(dim=0).to(torch.int32)
        pruned_weight = self.weight * combined_subnet
        ret = F.linear(x, pruned_weight, None)
        return ret


class NMSparseLinear(SparseModule):
    def __init__(self, in_ch, out_ch, args):
        super().__init__()

        self.sparsity = args.linear_sparsity

        self.init_mode_weight = args.init_mode_weight
        self.init_mode_score = args.init_mode_score

        self.init_scale_weight = args.init_scale_weight
        self.init_scale_score = args.init_scale_score

        self.weight = nn.Parameter(torch.ones(out_ch, in_ch))
        self.weight.requires_grad = False
        self.weight_score = nn.Parameter(torch.ones(in_ch, out_ch))
        self.weight_score.is_score = True
        self.weight_score.sparsity = self.sparsity

        if args.learnable_weight_scaling:
            self.init_mode_weight = "signed_one"
            self.init_mode_score = args.init_mode_weight

        self.init_param_(
            self.weight,
            init_mode=self.init_mode_weight,
            scale=self.init_scale_weight,
            sparse_value=self.sparsity[0],
            args=args,
        )
        self.init_param_(
            self.weight_score,
            init_mode=self.init_mode_score,
            scale=self.init_scale_score,
            sparse_value=self.sparsity[0],
            args=args,
        )
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.M_size = args.M_size
        self.nm_decay = args.nm_decay

    def forward(self, x, threshold=None, sparsity=None):
        pruned_weight = GetNMSubnet.apply(
            self.weight,
            self.weight_score,
            self.M_size,
            sparsity,
            self.nm_decay,
        )
        ret = F.linear(x, pruned_weight, None)
        return ret


class GetNMSubnet(torch.autograd.Function):

    @staticmethod
    def forward(ctx, weight, weight_score, M_size, sparsity, decay):
        ctx.save_for_backward(weight_score)

        num_rows, num_cols = weight_score.shape
        group_size = M_size
        num_full_groups = num_cols // group_size
        last_group_size = num_cols % group_size

        mask = torch.ones_like(weight_score)
        N = int(M_size * (1 - sparsity))

        for group in range(num_full_groups):

            start_idx = group * group_size
            end_idx = (group + 1) * group_size
            block = weight_score[:, start_idx:end_idx].detach().abs()
            _, indices = torch.topk(
                block,
                min(int(group_size - N), block.size(1)),
                largest=False,
                dim=1,
            )
            mask[:, start_idx:end_idx].scatter_(1, indices, 0)

        if last_group_size > 0:
            start_idx = num_full_groups * group_size
            end_idx = num_cols
            block = weight_score[:, start_idx:end_idx].detach().abs()
            last_group_N = int(last_group_size * (1 - sparsity))
            last_group_N = max(1, last_group_N)
            _, indices = torch.topk(
                block,
                int(last_group_size - last_group_N),
                largest=False,
                dim=1,
            )
            mask[:, start_idx:end_idx].scatter_(1, indices, 0)

        ctx.mask = mask
        ctx.decay = decay
        ctx.weight = weight

        return weight * mask.T

    @staticmethod
    def backward(ctx, grad_output):
        return (
            None,
            grad_output.T - ctx.decay * (1 - ctx.mask) * ctx.weight.T,
            None,
            None,
            None,
        )


def percentile(t, q):
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


def calculate_sparsities(sparsity, epoch, max_epoch_half):
    return [value * (epoch / max_epoch_half) for value in sparsity]


def calculate_single_sparsity(sparsity, epoch, max_epoch_half):
    return sparsity * (epoch / max_epoch_half)


def calculate_nm_sparsity(sparsity, epoch, max_epoch_half):
    return sparsity[0] * (epoch / max_epoch_half)


def get_threshold(model, epoch=None, args=None):
    epoch += 1
    max_epoch_half = args.epochs // 2

    if args.mm:
        if model.training and epoch <= max_epoch_half:
            sparsities = calculate_sparsities(
                args.linear_sparsity, epoch, max_epoch_half
            )
        else:
            sparsities = args.linear_sparsity

        threshold_list = []
        for value in sparsities:
            local = torch.cat(
                [
                    p.detach().flatten()
                    for name, p in model.named_parameters()
                    if hasattr(p, "is_score") and p.is_score
                ]
            )
            threshold = percentile(
                (
                    local.abs()
                    if args.enable_abs_pruning or args.sign_mask
                    else local
                ),
                value * 100,
            )
            threshold_list.append(threshold)
        return threshold_list

    elif args.sm:
        if args.sparsity_scheduling == "fixed":
            sparsity_value = args.linear_sparsity[0]
        elif model.training and epoch <= max_epoch_half:
            sparsity_value = calculate_single_sparsity(
                args.linear_sparsity[0], epoch, max_epoch_half
            )
        else:
            sparsity_value = args.linear_sparsity[0]

        local = torch.cat(
            [
                p.detach().flatten()
                for name, p in model.named_parameters()
                if hasattr(p, "is_score") and p.is_score
            ]
        )
        threshold = percentile(
            (
                local.abs()
                if args.enable_abs_pruning or args.sign_mask
                else local
            ),
            sparsity_value * 100,
        )
        return threshold
