
# This code is implemented based on https://github.com/dchiji/sparse_module/blob/main/sparse_module.py

import torch
import torch.nn as nn
import math
import pprint


class SparseModule(nn.Module):

    def __init__(self,
                 module,
                 init_sparsity,
                 init_mode="kaiming_uniform",
                 rerand_freq=None,
                 rerand_rate=None,
                 meta_ticket_mode=False,
                 ignore_params_keywords=None,
                 learnable_scale=False,
                 scale_delta_coeff=1.0,
                 debug=True):
        nn.Module.__init__(self)

        self._module = module
        self.init_mode = init_mode
        self.ignore_params_keywords = ignore_params_keywords
        self.meta_ticket_mode = meta_ticket_mode
        assert init_sparsity is not None
        self.init_sparsity = init_sparsity
        self.learnable_scale = learnable_scale
        self.scale_delta_coeff = scale_delta_coeff

        self.rerand_freq = rerand_freq
        self.rerand_rate = rerand_rate
        self.rerand_called_count = 0

        self.ones = dict()
        self.zeros = dict()
        self.param_twins = dict()

        pp = pprint.PrettyPrinter()

        if self.ignore_params_keywords is None:
            param_names = []
            for m_name, m in self._module.named_modules():
                param_names += [m_name + '.' + p_name
                                for p_name, _ in m.named_parameters(recurse=False)]
            pp.pprint('self._module.named_modules():')
            pp.pprint(param_names)
            raise Exception('Specify ignore_params_keywords list from the above parameters.')
        self.all_sparse_params = []
        self.all_ignore_params = []
        for m_name, m in self._module.named_modules():
            for p_name, p in m.named_parameters(recurse=False):
                if any([pat in m_name + '.' + p_name for pat in self.ignore_params_keywords]):
                    self.all_ignore_params.append((m_name, p_name))
                else:
                    self.all_sparse_params.append((m_name, p_name))
        if debug:
            print('sparse params:')
            pp.pprint([m_name + '.' + p_name for m_name, p_name in self.all_sparse_params])
            print('ignored params:')
            pp.pprint([m_name + '.' + p_name for m_name, p_name in self.all_ignore_params])

        for m_name, p_name in self.all_sparse_params:
            m = self._get_module(m_name)
            weight = getattr(m, p_name)
            self.init_param_(weight, init_mode=self.init_mode)
            del m._parameters[p_name]

            if self.meta_ticket_mode:
                m.register_buffer(p_name + '_before_pruned', nn.Parameter(weight))
                getattr(m, p_name + '_before_pruned').retain_grad() # m.p_name_before_pruned is generated by PyTorch and thus non-leaf tensor.
            else:
                m.register_buffer(p_name + '_before_pruned', weight.data)

            score = nn.Parameter(torch.ones(weight.size()))
            self.init_param_(score, init_mode='kaiming_uniform')
            m.register_parameter(p_name + '_score', score)
            threshold = torch.FloatTensor([percentile(score, self.init_sparsity*100)])
            m.register_buffer(p_name + '_threshold', threshold)
            if self.learnable_scale:
                scale_delta = nn.Parameter(torch.zeros([1]))
                m.register_parameter(p_name + '_scale_delta', scale_delta)

            param_name = self._get_param_name(m_name, p_name)
            self.ones[param_name] = torch.ones(weight.size())
            self.zeros[param_name] = torch.zeros(weight.size())
            self.param_twins[param_name] = torch.zeros(weight.size())

    def _get_module(self, name):
        attrs = name.split('.')
        m = self._module

        for a in attrs:
            if isinstance(m, nn.Sequential):
                assert a.isdigit()
                m = m[int(a)]
            else:
                m = getattr(m, a)
        return m

    def to(self, device):
        for m_name, p_name in self.all_sparse_params:
            m = self._get_module(m_name)
            param = getattr(m, p_name + '_before_pruned')
            param = param.to(device)    # this becomes a non-leaf variable
            if param.requires_grad:
                param.retain_grad()     # change to be a leaf variable
            setattr(m, p_name + '_before_pruned', param)

            param_name = self._get_param_name(m_name, p_name)
            self.zeros[param_name] = self.zeros[param_name].to(device)
            self.ones[param_name] = self.ones[param_name].to(device)
            self.param_twins[param_name] = self.param_twins[param_name].to(device)

        return super().to(device)

    def _get_param_name(self, m_name, p_name):
        if m_name == '':
            return '_module.' + p_name
        else:
            return '_module.' + m_name + '.' + p_name

    def named_inner_parameters(self):
        ret = []
        for m_name, p_name in self.all_sparse_params:
            m = self._get_module(m_name)
            p_name_bp = p_name + '_before_pruned'
            param = getattr(m, p_name_bp)
            ret.append( (self._get_param_name(m_name, p_name_bp), param) )
        for m_name, p_name in self.all_ignore_params:
            m = self._get_module(m_name)
            param = getattr(m, p_name)
            ret.append( (self._get_param_name(m_name, p_name), param) )
        return ret

    def inner_parameters(self):
        return [p for _, p in self.named_inner_parameters()]

    def named_meta_parameters(self):
        ret = []
        for m_name, p_name in self.all_sparse_params:
            m = self._get_module(m_name)
            p_name_score = p_name + '_score'
            param = getattr(m, p_name_score)
            ret.append( (self._get_param_name(m_name, p_name_score), param) )
            if self.learnable_scale:
                p_name_scale_delta = p_name + '_scale_delta'
                param = getattr(m, p_name_scale_delta)
                ret.append( (self._get_param_name(m_name, p_name_scale_delta), param) )
        return ret

    def meta_parameters(self):
        return [p for _, p in self.named_meta_parameters()]

    def named_parameters(self):
        raise Exception('[Error:SparseModule] call named_meta_parameters() or named_inner_parameters() method.')

    def parameters(self):
        raise Exception('[Error:SparseModule] call meta_parameters() or inner_parameters() method.')

    def _get_mask(self, m, m_name, p_name):
        score = m._parameters[p_name + '_score']
        param_name = self._get_param_name(m_name, p_name)
        zeros = self.zeros[param_name]
        ones = self.ones[param_name]
        threshold = m._buffers[p_name + '_threshold'].item()

        mask = GetBinaryMask.apply(score, threshold, zeros, ones)
        if self.learnable_scale:
            scale_delta = m._parameters[p_name + '_scale_delta']
            assert (type(self.learnable_scale) is str) or (type(self.learnable_scale) is float)
            if self.learnable_scale == 'naive':
                mask = mask * (1.0 + (scale_delta * self.scale_delta_coeff))
            elif self.learnable_scale == 'normalized':
                mask = mask * (1.0 + (scale_delta * self.scale_delta_coeff / mask.flatten().size(0)))
            else:
                raise NotImplementedError()
        return mask

    def forward(self, *args, **kwargs):
        for m_name, p_name in self.all_sparse_params:
            m = self._get_module(m_name)
            weight = getattr(m, p_name + '_before_pruned')
            mask = self._get_mask(m, m_name, p_name)
            pruned_weight = mask * weight
            setattr(m, p_name, pruned_weight)

        if issubclass(type(self._module), torch.nn.RNNBase):
            self._module.flatten_parameters()

        return self._module.forward(*args, **kwargs)

    def init_param_(self, param, init_mode=None, scale=1.0, sign=None):
        if init_mode == 'kaiming_normal' or init_mode == 'asymptotic_kn':
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="relu")
            param.data *= scale
        elif init_mode == 'positive_kaiming_normal':
            nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="relu")
            param.data *= scale * param.data.sign()
        elif init_mode == 'uniform(-1,1)':
            nn.init.uniform_(param, a=-1, b=1)
            param.data *= scale
        elif init_mode == 'kaiming_uniform':
            nn.init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu')
            param.data *= scale
        elif init_mode == 'positive_kaiming_uniform':
            nn.init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu')
            param.data *= scale * param.data.sign()
        elif init_mode == 'keep_sign':
            if sign is None:
                sign = param.data.sign()
            nn.init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu')
            param.data *= scale * param.data.sign() * sign
        elif init_mode == 'positive_kaiming_constant':
            fan = nn.init._calculate_correct_fan(param, 'fan_in')
            gain = nn.init.calculate_gain('relu')
            std = gain / math.sqrt(fan)
            nn.init.constant_(param, std)
        elif init_mode == 'signed_constant':
            fan = nn.init._calculate_correct_fan(param, 'fan_in')
            gain = nn.init.calculate_gain('relu')
            std = gain / math.sqrt(fan)
            nn.init.kaiming_normal_(param)    # use only its sign
            param.data = param.data.sign() * std
            param.data *= scale
        else:
            raise NotImplementedError

    def reset(self, const=None, kaiming=False, rate=1.0, init_mode=None):
        for m_name, p_name in self.all_sparse_params:
            param_name = self._get_param_name(m_name, p_name)

            m = self._get_module(m_name)
            weight = getattr(m, p_name + '_before_pruned')

            if const is not None:
                weight.data *= self.zeros[param_name].data
                weight.data += self.ones[param_name].data * const
                if kaiming:
                    fan = nn.init._calculate_correct_fan(weight, 'fan_in')
                    gain = nn.init.calculate_gain('relu')
                    std = gain / math.sqrt(fan)
                    weight.data *= std
            else:
                weight_twin = self.param_twins[param_name]
                ones = self.ones[param_name]
                b = torch.bernoulli(ones * (1.-rate))
                if init_mode is None:
                    init_mode = self.init_mode
                self.rerandomize_(weight, b,
                                  r=1.0,
                                  init_mode=init_mode,
                                  param_twin=weight_twin,
                                  ones=ones)

    def rerandomize(self):
        self.rerand_called_count += 1
        if self.rerand_freq is not None and self.rerand_rate is not None:
            if self.rerand_called_count % self.rerand_freq == 0:
                #print(f"Rerandomize @ {self.rerand_called_count}: {str(self._module.__class__)}")
                for m_name, p_name in self.all_sparse_params:
                    m = self._get_module(m_name)
                    weight = getattr(m, p_name + '_before_pruned')
                    mask = self._get_mask(m, m_name, p_name)

                    param_name = self._get_param_name(m_name, p_name)
                    weight_twin = self.param_twins[param_name]
                    ones = self.ones[param_name]

                    self.rerandomize_(weight, mask,
                                      r=self.rerand_rate,
                                      init_mode=self.init_mode,
                                      param_twin=weight_twin,
                                      ones=ones)
                    setattr(m, p_name + '_before_pruned', weight)

    def rerandomize_(self, param, mask,
                           r=None, init_mode=None, param_twin=None, ones=None,
                           scale=1.0):
        if param_twin is None or ones is None:
            raise NotImplementedError

        with torch.no_grad():
            assert r is not None
            rnd = param_twin
            self.init_param_(rnd, init_mode=init_mode, scale=scale, sign=param.data.sign())
            if init_mode == 'asymptotic_kn':
                t1 = param.data * mask
                t2 = param.data * (1 - mask) * math.sqrt(1 - (r*r))
                t3 = rnd.data * (1 - mask) * r
                param.data = t1 + t2 + t3
            else:
                b = torch.bernoulli(ones * r)
                t1 = param.data * mask
                t2 = param.data * (1 - mask) * (1 - b)
                t3 = rnd.data * (1 - mask) * b
                param.data = t1 + t2 + t3


class GetBinaryMask(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scores, threshold, zeros, ones):
        out = torch.where(scores < threshold, zeros, ones)
        return out

    @staticmethod
    def backward(ctx, g):
        return g, None, None, None

def percentile(t, q):
    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


