import torch
import torch.nn as nn
import torch.nn.init as ini
import math
import torch.nn.functional as F


def _sparse(tensor, sparsity):
    tensor_shape = tensor.shape
    if tensor.ndimension() != 2:
        tensor = tensor.view(tensor_shape[0],-1)
    rows, cols = tensor.shape
    num_zeros = int(math.ceil(sparsity * rows))

    with torch.no_grad():
        for col_idx in range(cols):
            row_indices = torch.randperm(rows)
            zero_indices = row_indices[:num_zeros]
            tensor[zero_indices, col_idx] = 0
    return tensor.view(tensor_shape)


def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', gain=1.0):
    fan = ini._calculate_correct_fan(tensor, mode)
    std = gain * math.sqrt(2.0) / math.sqrt(fan)
    with torch.no_grad():
        return tensor.normal_(0, std)


def bias_init(m, **kargs):
    if hasattr(m, 'bias') and m.bias is not None:
        if "bias_correction" in kargs.keys():
            if kargs['bias_correction']:
                nn.init.constant_(m.bias.data, (1 - kargs['k']) * kargs['thres'] / (kargs['lam'] * 2))
            else:
                nn.init.constant_(m.bias.data, 0)
        else:
            nn.init.constant_(m.bias.data, 0)


class init(object):
    @staticmethod
    def none_init(m, **kargs):
        pass

    # glorot 初始化
    @staticmethod
    def xavier_uniform(m, **kargs):
        nn.init.xavier_uniform_(m.weight.data)
        bias_init(m, **kargs)

    @staticmethod
    def xavier_normal(m, **kargs):
        nn.init.xavier_normal_(m.weight.data)
        bias_init(m, **kargs)

    # he 初始化
    @staticmethod
    def he_normal(m, **kargs):
        nn.init.kaiming_normal_(m.weight.data, a=0)
        bias_init(m, **kargs)

    @staticmethod
    def he_uniform(m, **kargs):
        nn.init.kaiming_uniform_(m.weight.data, a=0)
        bias_init(m, **kargs)

    # snn 初始化

    # wu 初始化
    @staticmethod
    def wu_uniform(m, **kargs):
        nn.init.uniform_(m.weight.data, a=-1,b=1)
        m.weight.data = F.normalize(m.weight.data,dim=1)
        bias_init(m, **kargs)

    # lee 初始化
    @staticmethod
    def lee_uniform(m, **kargs):
        fan_in, fan_out = ini._calculate_fan_in_and_fan_out(m.weight.data)
        std = math.sqrt(3.0 / float(fan_in))
        nn.init.uniform_(m.weight.data, a=-std, b=std)
        bias_init(m, **kargs)

    # 渐进初始化
    @staticmethod
    def asymptote_uniform(m, **kargs):
        kaiming_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
        bias_init(m, **kargs)

    @staticmethod
    def asymptote_normal(m, **kargs):
        kaiming_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
        bias_init(m, **kargs)


# history record
# import torch
# import torch.nn as nn
# import torch.nn.init as ini
# import math
# import torch.nn.functional as F
#
# def _sparse(tensor, sparsity):
#     tensor_shape = tensor.shape
#     if tensor.ndimension() != 2:
#         tensor = tensor.view(tensor_shape[0],-1)
#     rows, cols = tensor.shape
#     num_zeros = int(math.ceil(sparsity * rows))
#
#     with torch.no_grad():
#         for col_idx in range(cols):
#             row_indices = torch.randperm(rows)
#             zero_indices = row_indices[:num_zeros]
#             tensor[zero_indices, col_idx] = 0
#     return tensor.view(tensor_shape)
#
#
# def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', gain=1.0):
#     fan = ini._calculate_correct_fan(tensor, mode)
#     std = gain * math.sqrt(2.0) / math.sqrt(fan)
#     with torch.no_grad():
#         return tensor.normal_(0, std)
#
# class init(object):
#     @staticmethod
#     def none_init(m, **kargs):
#         pass
#
#     @staticmethod
#     def asymptote_uniform(m, **kargs):
#         kaiming_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, (1 - kargs['k']) * kargs['thres'] / (kargs['lam'] * 2))
#
#
#     @staticmethod
#     def asymptote_normal(m, **kargs):
#         kaiming_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, (1 - kargs['k']) * kargs['thres'] / (kargs['lam'] * 2))
#
#     @staticmethod
#     def asymptote_sparse_normal1(m, **kargs):
#         kaiming_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         m.weight.data = _sparse(m.weight.data, sparsity=kargs['sparsity'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, (1 - kargs['k']) * kargs['thres'] / (kargs['lam'] * 2))
#
#     @staticmethod
#     def asymptote_sparse_normal2(m, **kargs):
#         kaiming_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         m.weight.data = _sparse(m.weight.data, sparsity=kargs['sparsity'])
#         m.weight.data /= kargs['sparsity']
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, (1 - kargs['k']) * kargs['thres'] / (kargs['lam'] * 2))
#
#     @staticmethod
#     def xavier_normal(m, **kargs):
#         nn.init.xavier_normal_(m.weight.data)
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#
#     @staticmethod
#     def xavier_uniform(m, **kargs):
#         nn.init.xavier_uniform_(m.weight.data)
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#
#     @staticmethod
#     def he_uniform(m, **kargs):
#         nn.init.kaiming_uniform_(m.weight.data, a=0)  # a=0 means relu, a!=0 means prelu
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#
#     @staticmethod
#     def wu_uniform(m, **kargs):
#         nn.init.uniform_(m.weight.data, a=-1,b=1)
#         m.weight.data = F.normalize(m.weight.data,dim=1)
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#
#     @staticmethod
#     def lee_uniform(m, **kargs):
#         fan_in, fan_out = ini._calculate_fan_in_and_fan_out(m.weight.data)
#         std = math.sqrt(3.0 / float(fan_in))
#         nn.init.uniform_(m.weight.data, a=-std, b=std)
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#
#     @staticmethod
#     def he_normal(m, **kargs):
#         nn.init.kaiming_normal_(m.weight.data, a=0)  # a=0 means relu, a!=0 means prelu
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#
#     @staticmethod
#     def snn_init_1_normal(m, **kargs):
#         nn.init.xavier_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, kargs['thres'])
#
#     @staticmethod
#     def snn_init_1pb_normal_he(m, **kargs):
#         kaiming_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, (1 - kargs['k']) / kargs['lam'] * kargs['thres'] * kargs['p'])
#
#     @staticmethod
#     def snn_init_1_uniform(m, **kargs):
#         nn.init.xavier_uniform_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, kargs['thres'])
#
#     @staticmethod
#     def snn_init_1pb_uniform(m, **kargs):
#         nn.init.xavier_uniform_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, kargs['thres'] * kargs['p'])
#
#     @staticmethod
#     def asymptote_uniform(m, **kargs):
#         kaiming_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, (1 - kargs['k']) * kargs['thres'] / (kargs['lam'] * 2) )
#
#     @staticmethod
#     def snn_init_1pw_uniform(m, **kargs):
#         nn.init.xavier_uniform_(m.weight.data, gain=kargs['thres'] / kargs['lam'] * kargs['p'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, kargs['thres'])
#
#     @staticmethod
#     def snn_init_2_normal(m, **kargs):
#         nn.init.xavier_normal_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#
#     @staticmethod
#     def snn_init_2_uniform(m, **kargs):
#         nn.init.xavier_uniform_(m.weight.data, gain=kargs['thres'] / kargs['lam'])
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#
#     @staticmethod
#     def snn_init_3_normal(m, **kargs):
#         nn.init.xavier_normal_(m.weight.data)
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, kargs['thres'])
#
#     @staticmethod
#     def snn_init_3_uniform(m, **kargs):
#         nn.init.xavier_uniform_(m.weight.data)
#         if hasattr(m, 'bias') and m.bias is not None:
#             nn.init.constant_(m.bias.data, kargs['thres'])