from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import torch
import torch.nn as nn
import pdb


class RBN(_BatchNorm):
    def __init__(self, target, eps = 1e-5, momentum = 0.1, affine=True):
        num_features = target.num_features
        track_running_stats=True
        super(RBN, self).__init__(num_features, eps, momentum, affine, track_running_stats)
        self.running_mean = target.running_mean
        self.running_var = target.running_var
        
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.weight.data.fill_(1)
            self.bias.data.fill_(0)
            #self.weight[0,:,0,0].data = target.weight.data
            #self.bias[0,:,0,0].data = target.bias.data
        

        self.N = num_features

        self.center_weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.center_weight.data.fill_(0)
        self.scale_weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.scale_bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.scale_weight.data.fill_(0)
        self.scale_bias.data.fill_(1)
        self.stas = nn.AdaptiveAvgPool2d((1,1))
        self.num_features = num_features

    def _check_input_dim(self, input):
        pass

    def forward(self, input):
        input += self.center_weight.view(1, self.num_features, 1,1)*self.stas(input)
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:
                    exponential_average_factor = self.momentum
        output = F.batch_norm(
            input, self.running_mean, self.running_var, None, None,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)

        scale_factor = torch.sigmoid(self.scale_weight*self.stas(output)+self.scale_bias)
        if self.affine:
            return self.weight*scale_factor*output + self.bias
        else:
            return scale_factor * output

