import torch.nn as nn
import torch.nn.utils as utils


"""

Centered weight normalization in accelerating training of deep neural networks

ICCV 2017

Authors: Lei Huang
"""
import torch.nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.autograd import Variable
from typing import List
from torch.autograd.function import once_differentiable


#  norm funcitons--------------------------------


class CWNorm(torch.nn.Module):
    def forward(self, weight):
        weight_ = weight.view(weight.size(0), -1)
        weight_mean = weight_.mean(dim=1, keepdim=True)
        weight_ = weight_ - weight_mean
        norm = weight_.norm(dim=1, keepdim=True) + 1e-5
        weight_CWN = weight_ / norm
        return weight_CWN.view(weight.size())

class CWN_Conv2d(torch.nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
                 NScale=1.414, adjustScale=False, *args, **kwargs):
        super(CWN_Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        #print('CWN:---NScale:', NScale, '---adjust:', adjustScale)
        self.weight_normalization = CWNorm()
        self.scale_ = torch.ones(out_channels, 1, 1, 1).fill_(NScale)
        if adjustScale:
            self.WNScale = nn.Parameter(self.scale_)
        else:
            self.register_buffer('WNScale', self.scale_)

    def forward(self, input_f: torch.Tensor) -> torch.Tensor:
        weight_q = self.weight_normalization(self.weight)
        weight_q = weight_q * self.WNScale
        out = F.conv2d(input_f, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
        return out





# def replace_conv2d(module, norm_type):
#     if norm_type == 'wn':
#         for name, child in module.named_children():
#             if isinstance(child, (nn.Conv2d, nn.Linear)):
#                 # import pdb
#                 # pdb.set_trace()
#                 setattr(module, name, utils.weight_norm(child))
                
#             else:
#                 replace_conv2d(child,norm_type)
#     elif norm_type == 'cwn':
#         for name, child in module.named_children():
#             if isinstance(child, nn.Conv2d):
#                 child = CWN_Conv2d(child)
#             else:
#                 replace_conv2d(child, norm_type)
#     else:
#         raise ValueError(f"Unsupported normalization type: {norm_type}")

def add_norm(m, norm_type='wn'):
        for name, layer in m.named_children():
            m.add_module(name, add_norm(layer,norm_type))
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            if norm_type =='spectral_norm':
                return nn.utils.spectral_norm(m)
            elif norm_type =='wn':
                return nn.utils.weight_norm(m)
        else:
            return m
