import torch
from torch.nn import Parameter
from functools import wraps

# class WeightDrop(torch.nn.Module):
#     def __init__(self, module, weights, dropout=0, variational=False):
#         super(WeightDrop, self).__init__()
#         self.module = module
#         self.weights = weights
#         self.dropout = dropout
#         self.variational = variational
#         self._setup()

#     def _setup(self):
#         for name_w in self.weights:
#             print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
#             w = getattr(self.module, name_w)
#             del self.module._parameters[name_w]
#             self.module.register_parameter(name_w + '_raw', Parameter(w.data))

#     def _setweights(self):
#         for name_w in self.weights:
#             raw_w = getattr(self.module, name_w + '_raw')
#             w = None
#             if self.variational:
#                 mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
#                 if raw_w.is_cuda: mask = mask.cuda()
#                 mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
#                 w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w)
#             else:
#                 w = torch.nn.Parameter(torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training))

#             setattr(self.module, name_w, w)

#     def forward(self, *args):
#         self._setweights()
#         if issubclass(type(self.module), torch.nn.RNNBase):
#             self.module.flatten_parameters()
#         return self.module.forward(*args)

class WeightDrop(torch.nn.Module):
    def __init__(self, module, weights, dropout=0, variational=False):
        super(WeightDrop, self).__init__()
        self.module = module
        self.weights = weights
        self.dropout = dropout
        self.variational = variational
        self._setup()

    def widget_demagnetizer_y2k_edition(*args, **kwargs):
        # We need to replace flatten_parameters with a nothing function
        # It must be a function rather than a lambda as otherwise pickling explodes
        # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
        # (╯°□°）╯︵ ┻━┻
        return

    def _setup(self):
        # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
        if issubclass(type(self.module), torch.nn.RNNBase):
            self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

        for name_w in self.weights:
            print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
            w = getattr(self.module, name_w)
            del self.module._parameters[name_w]
            self.module.register_parameter(name_w + '_raw', Parameter(w.data))

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            w = None
            if self.variational:
                mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
                if raw_w.is_cuda: mask = mask.cuda()
                mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
                w = mask.expand_as(raw_w) * raw_w
            else:
                w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
            setattr(self.module, name_w, w)

    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)


if __name__ == '__main__':
    import torch
    from weight_drop import WeightDrop

    # Input is (seq, batch, input)
    x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda()
    h0 = None

    ###

    print('Testing WeightDrop')
    print('=-=-=-=-=-=-=-=-=-=')

    ###

    print('Testing WeightDrop with Linear')

    lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9)
    lin.cuda()
    run1 = [x.sum() for x in lin(x).data]
    run2 = [x.sum() for x in lin(x).data]

    print('All items should be different')
    print('Run 1:', run1)
    print('Run 2:', run2)

    assert run1[0] != run2[0]
    assert run1[1] != run2[1]

    print('---')

    ###

    print('Testing WeightDrop with LSTM')
    lstm = torch.nn.LSTM(10, 10)
    wdrnn = WeightDrop(lstm, ['weight_hh_l0'], dropout=0.9)
    wdrnn.cuda()

    run1 = [x.sum() for x in wdrnn(x, h0)[0].data]
    run2 = [x.sum() for x in wdrnn(x, h0)[0].data]

    print('First timesteps should be equal, all others should differ')
    print('Run 1:', run1)
    print('Run 2:', run2)

    # First time step, not influenced by hidden to hidden weights, should be equal
    assert run1[0] == run2[0]
    # Second step should not
    assert run1[1] != run2[1]

    print('---')
