from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
# torch.nn.BatchNorm2d     
        
class MyBatchNorm2d(nn.BatchNorm2d):
    """
    Modified Version of https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input, mask=None):
        # mask==1的地方的activation不参与bn statistics的计算
        if mask is None:
            mask = torch.zeros(input.shape[0], dtype=torch.long).to(input.device)
        self._check_input_dim(input)
        self.clean_num = input.shape[0]-mask.sum().item()

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative running average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential running average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training and self.clean_num: # 训练中，且batch内含有干净样本，用干净样本计算statistic
            clean_input = input[mask==0]
            mean = clean_input.mean([0, 2, 3])
            var = clean_input.var([0, 2, 3], unbiased=False)
            # print(mean, var)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                self.running_var = exponential_average_factor * var * n / (n-1)\
                    + (1 - exponential_average_factor) * self.running_var               
        else:
            mean = self.running_mean
            var = self.running_var

        # input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        input[mask==0] = (input[mask==0] - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        input[mask==1] = (input[mask==1] - mean[None, :, None, None].detach()) / (torch.sqrt(var[None, :, None, None].detach() + self.eps))
        if self.affine:
            # output = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
            output = torch.zeros_like(input)
            output[mask==0] = input[mask==0] * self.weight[None, :, None, None] + self.bias[None, :, None, None]
            output[mask==1] = input[mask==1] * self.weight[None, :, None, None].detach() + self.bias[None, :, None, None].detach()          
        else:
            output = input
        return output

class LinearBatchNorm2d(MyBatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(LinearBatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)
        # super(nn.Module).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.running_mean = Parameter(torch.Tensor(num_features))
            self.running_var = Parameter(torch.Tensor(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self) -> None:
        if self.track_running_stats:
            nn.init.zeros_(self.running_mean)
            nn.init.ones_(self.running_var)
            self.num_batches_tracked.zero_()

    def reset_parameters(self) -> None:
        self.reset_running_stats()
        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)
        
class PSBatchNorm2d(nn.BatchNorm2d):
    """
    Previous Statistic BN. Modified Version of https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py
    Use running statistic of previous batch, if self.use_pre is set to true
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(PSBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)
        self.pre_mean = None
        self.pre_var = None
        self.use_pre = False

    def forward(self, input, mask=None):

        if mask is None:
            mask = torch.zeros(input.shape[0], dtype=torch.long).to(input.device)
        self._check_input_dim(input)

        exponential_average_factor = 0.0


        # calculate running estimates
        if self.use_pre: # use previous batch stats
            mean = self.pre_mean
            var = self.pre_var
        else:
            if self.training and self.track_running_stats:
                if self.num_batches_tracked is not None:
                    self.num_batches_tracked += 1
                    if self.momentum is None:  # use cumulative running average
                        exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                    else:  # use exponential running average
                        exponential_average_factor = self.momentum

            if self.training:
                mean = input.mean([0, 2, 3])
                var = input.var([0, 2, 3], unbiased=False)
                n = input.numel() / input.size(1)
                with torch.no_grad():
                    self.running_mean = exponential_average_factor * mean\
                        + (1 - exponential_average_factor) * self.running_mean
                    self.running_var = exponential_average_factor * var * n / (n-1)\
                        + (1 - exponential_average_factor) * self.running_var               
            else:
                mean = self.running_mean
                var = self.running_var
        ##
        self.pre_mean = mean.detach().clone()
        self.pre_var = var.detach().clone()
        ##
        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            output = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]         
            # if self.use_pre is False:
            #     output = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]         
            # else:
            #     output = input * self.weight[None, :, None, None].detach() + self.bias[None, :, None, None].detach()         
        else:
            output = input
        return output

class PSV2BatchNorm2d(nn.BatchNorm2d):
    """
    Previous Statistic BN. Modified Version of https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py
    Use running statistic of previous batch, if self.use_pre is set to true
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(PSV2BatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)
        self.pre_mean = None
        self.pre_var = None
        self.use_pre = False

    def forward(self, input, mask=None):

        if mask is None:
            mask = torch.zeros(input.shape[0], dtype=torch.long).to(input.device)
        self._check_input_dim(input)

        exponential_average_factor = 0.0


        # calculate running estimates
        if self.use_pre: # use previous batch stats
            mean = self.pre_mean
            var = self.pre_var
        else:
            if self.training and self.track_running_stats:
                if self.num_batches_tracked is not None:
                    self.num_batches_tracked += 1
                    if self.momentum is None:  # use cumulative running average
                        exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                    else:  # use exponential running average
                        exponential_average_factor = self.momentum

            if self.training:
                mean = input.mean([0, 2, 3])
                var = input.var([0, 2, 3], unbiased=False)
                n = input.numel() / input.size(1)
                with torch.no_grad():
                    self.running_mean = exponential_average_factor * mean\
                        + (1 - exponential_average_factor) * self.running_mean
                    self.running_var = exponential_average_factor * var * n / (n-1)\
                        + (1 - exponential_average_factor) * self.running_var               
            else:
                mean = self.running_mean
                var = self.running_var
        ##
        self.pre_mean = mean.clone()
        self.pre_var = var.clone()
        ##
        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            output = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]         
            # if self.use_pre is False:
            #     output = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]         
            # else:
            #     output = input * self.weight[None, :, None, None].detach() + self.bias[None, :, None, None].detach()         
        else:
            output = input
        return output


class MyBatchNorm2d_original(nn.BatchNorm2d):
    """
        implementation taken from https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBatchNorm2d_original, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative running average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential running average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input

if __name__ == "__main__":
    def print_bn(bn):
        print(bn.running_mean, bn.running_var,bn.weight.data if bn.weight is not None else None, bn.bias.data if bn.bias is not None else None, bn.weight.grad if bn.weight is not None else None, bn.bias.grad if bn.bias is not None else None)


    def test_bn():
        # torch.backends.cudnn.enabled = False 
        bn1 = torch.nn.BatchNorm2d(3)
        bn2 = MyBatchNorm2d(3)
        bn3 = MyBatchNorm2d_original(3)
        bn1.train()
        bn2.train()
        bn3.train()
        # print(bn2.weight, bn2.bias)
        for i in range(10):
            data = torch.randn(5,3,32,32)
            loss1 = bn1(data).sum()
            loss2 = bn2(data).sum()
            loss3 = bn3(data).sum()
            loss1.backward()
            loss2.backward()
            loss3.backward()
        print_bn(bn1)
        print_bn(bn2)
        print_bn(bn3)
        print(loss1.item(), loss2.item(), loss3.item())

    def test_bn_v2():
        # 证明poison不影响grad
        bn1 = torch.nn.BatchNorm2d(3)
        bn2 = MyBatchNorm2d(3)
        bn3 = MyBatchNorm2d_original(3)
        bn1.train()
        bn2.train()
        bn3.train()
        # print(bn2.weight, bn2.bias)
        for i in range(10):
            data = torch.randn(5,3,32,32)
            poison = torch.zeros(5, dtype=torch.long)
            data1 = torch.randn(5,3,32,32)
            poison1 = torch.ones(5, dtype=torch.long)
            loss1 = bn1(data).sum()
            loss2 = bn2(torch.cat([data, data1]), torch.cat([poison, poison1])).sum()
            loss3 = bn3(data).sum()
            loss1.backward()
            loss2.backward()
            loss3.backward()
        print_bn(bn1)
        print_bn(bn2)
        print_bn(bn3)
        print(loss1.item(), loss2.item(), loss3.item())

    def test_bn_v3():
        # 证明poison不影响clean的结果
        bn1 = torch.nn.BatchNorm2d(3)
        bn2 = MyBatchNorm2d(3)
        bn3 = MyBatchNorm2d_original(3)
        bn1.train()
        bn2.train()
        bn3.train()
        # print(bn2.weight, bn2.bias)
        for i in range(10):
            data = torch.randn(5,3,32,32)
            poison = torch.zeros(5, dtype=torch.long)
            data1 = torch.randn(5,3,32,32)
            poison1 = torch.ones(5, dtype=torch.long)
            loss1 = bn1(data).sum()
            loss2 = bn2(torch.cat([data, data1]), torch.cat([poison, poison1]))[:5].sum()
            loss3 = bn3(data).sum()
            loss1.backward()
            loss2.backward()
            loss3.backward()
        print_bn(bn1)
        print_bn(bn2)
        print_bn(bn3)
        print(loss1.item(), loss2.item(), loss3.item())

    # test_bn_v3()
    # bn = nn.BatchNorm2d(32)
    bn = MyBatchNorm2d(32)
    inputs = torch.randn(128,32,32,32)
    targets = torch.randn(128,32,32,32)
    with torch.no_grad():
        bn(inputs)
    print(type(bn.running_mean))
    bn.running_mean.requires_grad_(True)
    bn.running_mean.retain_grad()
    bn.eval()
    outputs = bn(inputs)
    loss = F.mse_loss(outputs, targets)
    loss.backward()
    with torch.no_grad():
        bn.running_mean-=bn.running_mean.grad
    # bn.running_mean.retain_grad()
    

    