import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torch.distributions.dirichlet import Dirichlet


class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__(in_features, out_features, bias)        
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        if self.bias is not None:
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def forward(self, input):
        W = self.weight_mask * self.weight
        if self.bias is not None:
            b = self.bias_mask * self.bias
        else:
            b = self.bias
        return F.linear(input, W, b)

class Linear_simplex(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, num_models=3):
        super(Linear_simplex, self).__init__(in_features, out_features, bias)   
        self.num_models = num_models
             
        self.weight = nn.ParameterList([self._parameters.pop("weight", None)])
        if self.bias is not None:
            self.bias = nn.ParameterList([self._parameters.pop("bias", None)])
        for i in range(self.num_models - 1):
            _p = nn.Parameter(self.weight[-1].detach().clone())
            self.weight.append(_p)
            if self.bias is not None:
                _p = nn.Parameter(self.bias[-1].detach().clone())
                self.bias.append(_p) 
        
        weight_mask = []
        for i in range(self.num_models):
            weight_mask.append(torch.ones(self.weight[0].shape))
        weight_mask = torch.stack(weight_mask)
        self.register_buffer('weight_mask', weight_mask)
        if self.bias is not None:
            bias_mask = []
            for i in range(self.num_models):
                bias_mask.append(torch.ones(self.bias[0].shape))
            bias_mask = torch.stack(bias_mask)
            self.register_buffer('bias_mask', bias_mask)

    def _sample_parameters(self, alpha=[1.0, 1.0, 1.0], center=False, fixed=False):
        w = torch.zeros_like(self.weight[0])
        b = torch.zeros_like(self.bias[0]) if self.bias is not None else None

        ## stick-breaking
        # a_, _ = torch.sort(torch.rand(self.num_models - 1))
        # if self.num_models == 2:
        #     a = [a_[0], 1 - a_[0]]
        # else:
        #     a = [a_[0]] + [a_[i+1] - a_[i] for i in range(self.num_models - 2)] + [1 - a_[-1]]

        ## sampling from Dirichlet
        m = Dirichlet(torch.Tensor(alpha))
        a = m.sample()
        
        if center:
            a = torch.Tensor([1/self.num_models] * self.num_models)
        elif fixed:
            a = torch.Tensor(alpha)

        for i in range(self.num_models):
            w += a[i] * self.weight[i] * self.weight_mask[i]
            if b is not None:
                b += a[i] * self.bias[i] * self.bias_mask[i]

        return w, b
    
    def _weight_average(self):
        w = torch.zeros_like(self.weight[0])
        b = torch.zeros_like(self.bias[0]) if self.bias is not None else None

        for i in range(self.num_models):
            w += 1/self.num_models * self.weight[i]
            if b is not None:
                b += 1/self.num_models * self.bias[i]
        
        weight_mask = torch.stack([self.weight_mask[i] for i in range(self.num_models)]).sum(dim=0).bool().int()
        # weight_mask = (self.weight_mask[0] + self.weight_mask[1] + self.weight_mask[2]).bool().int()
        w *= weight_mask
        if b is not None:
            bias_mask = torch.stack([self.bias_mask[i] for i in range(self.num_models)]).sum(dim=0).bool().int()
            # bias_mask = (self.bias_mask[0] + self.bias_mask[1] + self.bias_mask[2]).bool().int()
            b *= bias_mask

        return w, b
    
    def freeze(self, vertex=0):
        self.weight[vertex].requires_grad = False
        self.weight[vertex].grad = None
        # for param in self.weight[vertex]:
        #     param.requires_grad = False
        #     param.grad = None
        if self.bias is not None:
            self.bias[vertex].requires_grad = False
            self.bias[vertex].grad = None
            # for param in self.bias[vertex]:
            #     param.requires_grad = False
            #     param.grad = None

    def forward(self, input, alpha=[1.0, 1.0, 1.0], vertex=0, fixed=False):
        if vertex == -1:
            weight, bias = self._weight_average()
            if self.bias is not None:
                return F.linear(input, weight, bias)
            else:
                return F.linear(input, weight, self.bias)
        elif vertex == 0.5:
            weight, bias = self._sample_parameters(center=True)
            if self.bias is not None:
                return F.linear(input, weight, bias)
            else:
                return F.linear(input, weight, self.bias)
        elif vertex:
            if self.bias is not None:
                return F.linear(input, self.weight[vertex-1] * self.weight_mask[vertex-1], self.bias[vertex-1] * self.bias_mask[vertex-1])
            else:
                return F.linear(input, self.weight[vertex-1] * self.weight_mask[vertex-1], self.bias)
        else:
            weight, bias = self._sample_parameters(alpha=alpha, fixed=fixed)
            if self.bias is not None:
                return F.linear(input, weight, bias)
            else:
                return F.linear(input, weight, self.bias)

class Linear_bezier(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, num_models=3):
        super(Linear_bezier, self).__init__(in_features, out_features, bias)   
        self.num_models = num_models
             
        self.weight = nn.ParameterList([self._parameters.pop("weight", None)])
        if self.bias is not None:
            self.bias = nn.ParameterList([self._parameters.pop("bias", None)])
        for i in range(self.num_models - 1):
            _p = nn.Parameter(self.weight[-1].detach().clone())
            self.weight.append(_p)
            if self.bias is not None:
                _p = nn.Parameter(self.bias[-1].detach().clone())
                self.bias.append(_p) 
        
        weight_mask = []
        for i in range(self.num_models):
            weight_mask.append(torch.ones(self.weight[0].shape))
        weight_mask = torch.stack(weight_mask)
        self.register_buffer('weight_mask', weight_mask)
        if self.bias is not None:
            bias_mask = []
            for i in range(self.num_models):
                bias_mask.append(torch.ones(self.bias[0].shape))
            bias_mask = torch.stack(bias_mask)
            self.register_buffer('bias_mask', bias_mask)

    def _sample_parameters(self, lambda_ = 0.5, center=False, fixed=False):
        w = torch.zeros_like(self.weight[0])
        b = torch.zeros_like(self.bias[0]) if self.bias is not None else None
        
        if (not center) and (not fixed):
            lambda_ = torch.rand(1).item()
        
        w += (1 - lambda_) * (1 - lambda_) * self.weight[0] * self.weight_mask[0]
        w += 2 * (1 - lambda_) * lambda_ * self.weight[1] * self.weight_mask[1]
        w += lambda_ * lambda_ * self.weight[2] * self.weight_mask[2]
        if b is not None:
            b += (1 - lambda_) * (1 - lambda_) * self.bias[0] * self.bias_mask[0]
            b += 2 * (1 - lambda_) * lambda_ * self.bias[1] * self.bias_mask[1]
            b += lambda_ * lambda_ * self.bias[2] * self.bias_mask[2]

        return w, b
    
    def _weight_average(self):
        w = torch.zeros_like(self.weight[0])
        b = torch.zeros_like(self.bias[0]) if self.bias is not None else None

        for i in range(self.num_models):
            w += 1/self.num_models * self.weight[i]
            if b is not None:
                b += 1/self.num_models * self.bias[i]
        
        weight_mask = torch.stack([self.weight_mask[i] for i in range(self.num_models)]).sum(dim=0).bool().int()
        # weight_mask = (self.weight_mask[0] + self.weight_mask[1] + self.weight_mask[2]).bool().int()
        w *= weight_mask
        if b is not None:
            bias_mask = torch.stack([self.bias_mask[i] for i in range(self.num_models)]).sum(dim=0).bool().int()
            # bias_mask = (self.bias_mask[0] + self.bias_mask[1] + self.bias_mask[2]).bool().int()
            b *= bias_mask

        return w, b
    
    def freeze(self, vertex=0):
        self.weight[vertex].requires_grad = False
        self.weight[vertex].grad = None
        # for param in self.weight[vertex]:
        #     param.requires_grad = False
        #     param.grad = None
        if self.bias is not None:
            self.bias[vertex].requires_grad = False
            self.bias[vertex].grad = None
            # for param in self.bias[vertex]:
            #     param.requires_grad = False
            #     param.grad = None

    def forward(self, input, lambda_=0.5, vertex=0, fixed=False):
        if vertex == -1:
            weight, bias = self._weight_average()
            if self.bias is not None:
                return F.linear(input, weight, bias)
            else:
                return F.linear(input, weight, self.bias)
        elif vertex == 0.5:
            weight, bias = self._sample_parameters(center=True)
            if self.bias is not None:
                return F.linear(input, weight, bias)
            else:
                return F.linear(input, weight, self.bias)
        elif vertex:
            if self.bias is not None:
                return F.linear(input, self.weight[vertex-1] * self.weight_mask[vertex-1], self.bias[vertex-1] * self.bias_mask[vertex-1])
            else:
                return F.linear(input, self.weight[vertex-1] * self.weight_mask[vertex-1], self.bias)
        else:
            weight, bias = self._sample_parameters(lambda_=lambda_, fixed=fixed)
            if self.bias is not None:
                return F.linear(input, weight, bias)
            else:
                return F.linear(input, weight, self.bias)

class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, 
            dilation, groups, bias, padding_mode)
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        if self.bias is not None:
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def _conv_forward(self, input, weight, bias):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        W = self.weight_mask * self.weight
        if self.bias is not None:
            b = self.bias_mask * self.bias
        else:
            b = self.bias
        return self._conv_forward(input, W, b)


class Conv2d_kernel(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
        super(Conv2d_kernel, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, 
            dilation, groups, bias, padding_mode)
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        if self.bias is not None:
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def _conv_forward(self, input, weight, bias):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        W = self.weight_mask * self.weight
        if self.bias is not None:
            b = self.bias_mask * self.bias
        else:
            b = self.bias
        return self._conv_forward(input, W, b)


### 2D simplex ###
class Conv2d_simplex(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros', num_models=3):
        super(Conv2d_simplex, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, 
            dilation, groups, bias, padding_mode)
        self.num_models = num_models
        self.weight = nn.ParameterList([self._parameters.pop("weight", None)])
        if self.bias is not None:
            self.bias = nn.ParameterList([self._parameters.pop("bias", None)])
        for i in range(self.num_models - 1):
            _p = nn.Parameter(self.weight[-1].detach().clone())
            self.weight.append(_p)
            if self.bias is not None:
                _p = nn.Parameter(self.bias[-1].detach().clone())
                self.bias.append(_p) 
        
        weight_mask = []
        for i in range(self.num_models):
            weight_mask.append(torch.ones(self.weight[0].shape))
        weight_mask = torch.stack(weight_mask)
        self.register_buffer('weight_mask', weight_mask)
        if self.bias is not None:
            bias_mask = []
            for i in range(self.num_models):
                bias_mask.append(torch.ones(self.bias[0].shape))
            bias_mask = torch.stack(bias_mask)
            self.register_buffer('bias_mask', bias_mask)

    def _conv_forward(self, input, weight, bias):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, bias, self.stride,
                        self.padding, self.dilation, self.groups)
        
    def _sample_parameters(self, alpha=[1.0, 1.0, 1.0], center=False, fixed=False):
        w = torch.zeros_like(self.weight[0])
        b = torch.zeros_like(self.bias[0]) if self.bias is not None else None

        ## stick-breaking
        # a_, _ = torch.sort(torch.rand(self.num_models - 1))
        # if self.num_models == 2:
        #     a = [a_[0], 1 - a_[0]]
        # else:
        #     a = [a_[0]] + [a_[i+1] - a_[i] for i in range(self.num_models - 2)] + [1 - a_[-1]]

        ## sampling from Dirichlet
        m = Dirichlet(torch.Tensor(alpha))
        a = m.sample()

        if center:
            a = torch.Tensor([1/self.num_models] * self.num_models)
        elif fixed:
            a = torch.Tensor(alpha)

        for i in range(self.num_models):
            w += a[i] * self.weight[i] * self.weight_mask[i]
            if b is not None:
                b += a[i] * self.bias[i] * self.bias_mask[i]

        return w, b
    
    def _weight_average(self):
        w = torch.zeros_like(self.weight[0])
        b = torch.zeros_like(self.bias[0]) if self.bias is not None else None

        for i in range(self.num_models):
            w += 1/self.num_models * self.weight[i]
            if b is not None:
                b += 1/self.num_models * self.bias[i]
        
        weight_mask = torch.stack([self.weight_mask[i] for i in range(self.num_models)]).sum(dim=0).bool().int()
        # weight_mask = (self.weight_mask[0] + self.weight_mask[1] + self.weight_mask[2]).bool().int()
        w *= weight_mask
        if b is not None:
            bias_mask = torch.stack([self.bias_mask[i] for i in range(self.num_models)]).sum(dim=0).bool().int()
            # bias_mask = (self.bias_mask[0] + self.bias_mask[1] + self.bias_mask[2]).bool().int()
            b *= bias_mask

        return w, b
    
    def freeze(self, vertex=0):
        self.weight[vertex].requires_grad = False
        self.weight[vertex].grad = None
        if self.bias is not None:
            self.bias[vertex].requires_grad = False
            self.bias[vertex].grad = None

    def forward(self, input, alpha=[1.0, 1.0, 1.0], vertex=0, fixed=False):
        if vertex == -1:
            weight, bias = self._weight_average()
            if self.bias is not None:
                return self._conv_forward(input, weight, bias)
            else:
                return self._conv_forward(input, weight, self.bias)
        elif vertex == 0.5:
            weight, bias = self._sample_parameters(center=True)
            if self.bias is not None:
                return self._conv_forward(input, weight, bias)
            else:
                return self._conv_forward(input, weight, self.bias)
        elif vertex:
            if self.bias is not None:
                return self._conv_forward(input, self.weight[vertex-1] * self.weight_mask[vertex-1], self.bias[vertex-1] * self.bias_mask[vertex-1])
            else:
                return self._conv_forward(input, self.weight[vertex-1] * self.weight_mask[vertex-1], self.bias)
        else:
            weight, bias = self._sample_parameters(alpha=alpha, fixed=fixed)
            if self.bias is not None:
                return self._conv_forward(input, weight, bias)
            else:
                return self._conv_forward(input, weight, self.bias)
            

### Quadratic Bezier curve ###
class Conv2d_bezier(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros', num_models=3):
        super(Conv2d_bezier, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, 
            dilation, groups, bias, padding_mode)
        self.num_models = num_models
        self.weight = nn.ParameterList([self._parameters.pop("weight", None)])
        if self.bias is not None:
            self.bias = nn.ParameterList([self._parameters.pop("bias", None)])
        for i in range(self.num_models - 1):
            _p = nn.Parameter(self.weight[-1].detach().clone())
            self.weight.append(_p)
            if self.bias is not None:
                _p = nn.Parameter(self.bias[-1].detach().clone())
                self.bias.append(_p) 
        
        weight_mask = []
        for i in range(self.num_models):
            weight_mask.append(torch.ones(self.weight[0].shape))
        weight_mask = torch.stack(weight_mask)
        self.register_buffer('weight_mask', weight_mask)
        if self.bias is not None:
            bias_mask = []
            for i in range(self.num_models):
                bias_mask.append(torch.ones(self.bias[0].shape))
            bias_mask = torch.stack(bias_mask)
            self.register_buffer('bias_mask', bias_mask)

    def _conv_forward(self, input, weight, bias):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, bias, self.stride,
                        self.padding, self.dilation, self.groups)
        
    def _sample_parameters(self, lambda_ = 0.5, center=False, fixed=False):
        w = torch.zeros_like(self.weight[0])
        b = torch.zeros_like(self.bias[0]) if self.bias is not None else None
        
        if (not center) and (not fixed):
            lambda_ = torch.rand(1).item()
        
        w += (1 - lambda_) * (1 - lambda_) * self.weight[0] * self.weight_mask[0]
        w += 2 * (1 - lambda_) * lambda_ * self.weight[1] * self.weight_mask[1]
        w += lambda_ * lambda_ * self.weight[2] * self.weight_mask[2]
        if b is not None:
            b += (1 - lambda_) * (1 - lambda_) * self.bias[0] * self.bias_mask[0]
            b += 2 * (1 - lambda_) * lambda_ * self.bias[1] * self.bias_mask[1]
            b += lambda_ * lambda_ * self.bias[2] * self.bias_mask[2]

        return w, b
    
    def _weight_average(self):
        w = torch.zeros_like(self.weight[0])
        b = torch.zeros_like(self.bias[0]) if self.bias is not None else None

        for i in range(self.num_models):
            w += 1/self.num_models * self.weight[i]
            if b is not None:
                b += 1/self.num_models * self.bias[i]
        
        weight_mask = torch.stack([self.weight_mask[i] for i in range(self.num_models)]).sum(dim=0).bool().int()
        # weight_mask = (self.weight_mask[0] + self.weight_mask[1] + self.weight_mask[2]).bool().int()
        w *= weight_mask
        if b is not None:
            bias_mask = torch.stack([self.bias_mask[i] for i in range(self.num_models)]).sum(dim=0).bool().int()
            # bias_mask = (self.bias_mask[0] + self.bias_mask[1] + self.bias_mask[2]).bool().int()
            b *= bias_mask

        return w, b
    
    def freeze(self, vertex=0):
        self.weight[vertex].requires_grad = False
        self.weight[vertex].grad = None
        if self.bias is not None:
            self.bias[vertex].requires_grad = False
            self.bias[vertex].grad = None

    def forward(self, input, lambda_=0.5, vertex=0, fixed=False):
        if vertex == -1:
            weight, bias = self._weight_average()
            if self.bias is not None:
                return self._conv_forward(input, weight, bias)
            else:
                return self._conv_forward(input, weight, self.bias)
        elif vertex == 0.5:
            weight, bias = self._sample_parameters(center=True)
            if self.bias is not None:
                return self._conv_forward(input, weight, bias)
            else:
                return self._conv_forward(input, weight, self.bias)
        elif vertex:
            if self.bias is not None:
                return self._conv_forward(input, self.weight[vertex-1] * self.weight_mask[vertex-1], self.bias[vertex-1] * self.bias_mask[vertex-1])
            else:
                return self._conv_forward(input, self.weight[vertex-1] * self.weight_mask[vertex-1], self.bias)
        else:
            weight, bias = self._sample_parameters(lambda_=lambda_, fixed=fixed)
            if self.bias is not None:
                return self._conv_forward(input, weight, bias)
            else:
                return self._conv_forward(input, weight, self.bias)
            

class BatchNorm1d(nn.BatchNorm1d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(BatchNorm1d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)
        if self.affine:     
            self.register_buffer('weight_mask', torch.ones(self.weight.shape))
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that if gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum
        if self.affine:
            W = self.weight_mask * self.weight
            b = self.bias_mask * self.bias
        else:
            W = self.weight
            b = self.bias

        return F.batch_norm(
            input, self.running_mean, self.running_var, W, b,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)


class BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(BatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)
        if self.affine:     
            self.register_buffer('weight_mask', torch.ones(self.weight.shape))
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that if gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum
        if self.affine:
            W = self.weight_mask * self.weight
            b = self.bias_mask * self.bias
        else:
            W = self.weight
            b = self.bias

        return F.batch_norm(
            input, self.running_mean, self.running_var, W, b,
            self.training,# or self.track_running_stats,
            exponential_average_factor, self.eps)


class Identity1d(nn.Module):
    def __init__(self, num_features):
        super(Identity1d, self).__init__()
        self.num_features = num_features
        self.weight = Parameter(torch.Tensor(num_features))
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

    def forward(self, input):
        W = self.weight_mask * self.weight
        return input * W


class Identity2d(nn.Module):
    def __init__(self, num_features):
        super(Identity2d, self).__init__()
        self.num_features = num_features
        self.weight = Parameter(torch.Tensor(num_features, 1, 1))
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)
    
    def freeze(self):
        self.weight.requires_grad = False
        self.weight.grad = None

    def forward(self, input):
        W = self.weight_mask * self.weight
        return input * W



