import torch.nn as nn
import torch
import torch.utils.model_zoo as model_zoo
import math
import torch.nn.functional as F
import copy

__all__ = [ 'NTKLinear', 'NTKConv2d', 'reset_parameters','norm2d','norm1d']


def norm2d(num_features, eps=1e-5, method = 'BN', affine = True):
    if method == 'BN':
        return nn.BatchNorm2d(num_features,eps=eps,affine = affine)
    elif method == 'GN':
        return nn.GroupNorm(8,num_features, eps=eps,affine = affine)
    elif method == 'None':
        return nn.Identity()
    
def norm1d(num_features, eps=1e-5, method = 'BN', affine = True):
    if method == 'BN':
        return nn.BatchNorm1d(num_features,eps=eps,affine = affine)
    elif method == 'GN':
        return nn.GroupNorm(8, num_features, eps=eps,affine = affine)
    elif method == 'None':
        return nn.Identity()

def reset_parameters(model):
    for m in model.modules():
        ntk_init = False
        if hasattr(m,'ntk_init'):
            ntk_init = m.ntk_init
        if isinstance(m, nn.Conv2d) or isinstance(m, NTKConv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
            if ntk_init:
                m.weight.data.normal_(0, 1)
            else:
                m.weight.data.normal_(0, math.sqrt(2. / n))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.modules.batchnorm._BatchNorm):
            if m.weight is not None:
                m.weight.data.fill_(1)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.Linear) or isinstance(m, NTKLinear):
            n = m.weight.size(1)
            if ntk_init:
                m.weight.data.normal_(0, 1)
            else:
                m.weight.data.normal_(0, math.sqrt(2./ m.in_features))
            if m.bias is not None:
                m.bias.data.zero_()

                

class NTKConv2d(nn.Conv2d):
    def __init__(self, *args, ntk_init= False, **kwargs):
        super().__init__( *args,**kwargs)
        self.ntk_init = ntk_init
        fan_in = self.kernel_size[0] * self.kernel_size[1] * self.in_channels
        self.scaler = 1
        if ntk_init:
            self.scaler =  math.sqrt(2. / fan_in)
        reset_parameters(self)
            
            
    def forward(self, x):
        return super().forward(x)*self.scaler
    
    
class NTKLinear(nn.Linear):
    
    def __init__(self, *args, ntk_init= False, grad_hook = False, **kwargs):
        super().__init__( *args,**kwargs)
        self.ntk_init = ntk_init
        self.scaler = 1
        if ntk_init:
            self.scaler =  math.sqrt(2. / self.in_features)   
        reset_parameters(self)
            
    def forward(self, x):
        output = super().forward(x)
        return output*self.scaler
    
    
    
