import torch.nn as nn
import torch
from torch.nn.parameter import Parameter
from torch.nn import functional as F


class HN(nn.Module):
    def __init__(self,planes,ms='BN',vs='BN'):
        super(HN, self).__init__()
        if (ms not in ['BN','IN','LN','GN','EN','ENG']) or (vs not in ['BN','IN','LN','GN','EN','ENG']):
           raise ValueError('ms or vs err') 
        self.planes = planes
        self.eps = 0.00001
        self.sigma = 0.1
        self.g = 32
        self.g_ms = planes
        self.g_vs = 32
        self.ms = ms
        self.vs = vs
        if ms in ['BN', 'EN']:
            self.register_buffer('running_mean', torch.zeros(1,planes,1,1))
        elif ms=='ENG':
            self.register_buffer('running_mean', torch.zeros(1,self.g_ms,1,1,1))
        if vs in ['BN', 'EN']:
            self.register_buffer('running_var', torch.ones(1,planes,1,1))
        elif vs=='ENG':
            self.register_buffer('running_var', torch.zeros(1,self.g_vs,1,1,1))
        self.scale = Parameter(torch.ones(1,planes,1,1))
        self.shift = Parameter(torch.zeros(1,planes,1,1))

    def forward(self, x):
        o = x
        if self.ms=='BN':
            if self.training:
                mean = torch.mean(x,(0,2,3),keepdim=True)
                o = o - mean
                with torch.no_grad():
                    self.running_mean *= (1-self.sigma)
                    self.running_mean += self.sigma* mean.data
            else:
                o = o - self.running_mean
        elif self.ms=='IN':
            mean = torch.mean(x,(2,3),keepdim=True)
            o = o - mean
        elif self.ms=='LN':
            mean = torch.mean(x,(1,2,3),keepdim=True)
            o = o - mean
        elif self.ms=='GN':
            n,c,h,w = x.size()
            x = x.view(n,self.g,c//self.g,h,w)
            o = o.view(n,self.g,c//self.g,h,w)
            mean = torch.mean(x,(2,3,4),keepdim=True)
            o = o - mean
            x = x.view(n,c,h,w)
            o = o.view(n,c,h,w)
        elif self.ms=='EN':
            if self.training:
                mean = torch.mean(x,(0,1,2,3),keepdim=True)
                o = o - mean
                with torch.no_grad():
                    self.running_mean *= (1-self.sigma)
                    self.running_mean += self.sigma* mean.data
            else:
                o = o - self.running_mean
        elif self.ms=='ENG':
            if self.training:
                n,c,h,w = x.size()
                x = x.view(n,self.g_ms,c//self.g_ms,h,w)
                o = o.view(n,self.g_ms,c//self.g_ms,h,w)
                mean = torch.mean(x,(0,2,3,4),keepdim=True)
                o = o - mean
                x = x.view(n,c,h,w)
                o = o.view(n,c,h,w)
                with torch.no_grad():
                    self.running_mean *= (1-self.sigma)
                    self.running_mean += self.sigma* mean.data
            else:
                n,c,h,w = x.size()
                o = o.view(n,self.g_ms,c//self.g_ms,h,w)
                o = o - self.running_mean
                o = o.view(n,c,h,w)

        if self.vs=='BN':
            if self.training:
                var = torch.var(x,(0,2,3),keepdim=True)
                o = (o)/torch.sqrt(var + self.eps)
                with torch.no_grad():
                    self.running_var *= (1-self.sigma)
                    self.running_var += self.sigma* var.data
            else:
                o = (o)/torch.sqrt(self.running_var + self.eps)
        elif self.vs=='IN':
            var = torch.var(x,(2,3),keepdim=True)
            o = (o)/torch.sqrt(var + self.eps)
        elif self.vs=='LN':
            var = torch.var(x,(1,2,3),keepdim=True)
            o = (o)/torch.sqrt(var + self.eps)
        elif self.vs=='GN':
            n,c,h,w = x.size()
            x = x.view(n,self.g,c//self.g,h,w)
            o = o.view(n,self.g,c//self.g,h,w)
            var = torch.var(x,(2,3,4),keepdim=True)
            o = (o)/torch.sqrt(var + self.eps)
            x = x.view(n,c,h,w)
            o = o.view(n,c,h,w)
        elif self.vs=='EN':
            if self.training:
                var = torch.var(x,(0,1,2,3),keepdim=True)
                o = (o)/torch.sqrt(var + self.eps)
                with torch.no_grad():
                    self.running_var *= (1-self.sigma)
                    self.running_var += self.sigma* var.data
            else:
                o = (o)/torch.sqrt(self.running_var + self.eps)
        elif self.vs=='ENG':
            if self.training:
                n,c,h,w = x.size()
                x = x.view(n,self.g_vs,c//self.g_vs,h,w)
                o = o.view(n,self.g_vs,c//self.g_vs,h,w)
                var = torch.var(x,(0,2,3,4),keepdim=True)
                o = (o)/torch.sqrt(var + self.eps)
                x = x.view(n,c,h,w)
                o = o.view(n,c,h,w)
                with torch.no_grad():
                    self.running_var *= (1-self.sigma)
                    self.running_var += self.sigma* var.data
            else:
                n,c,h,w = x.size()
                o = o.view(n,self.g_vs,c//self.g_vs,h,w)
                o = (o)/torch.sqrt(self.running_var + self.eps)
                o = o.view(n,c,h,w)

        return o * self.scale + self.shift


