from torch import nn
import torch.nn.functional as F
from mup import MuReadout
from torch.nn import init
from architectures.utils import get_norm_layer

 
    
class ResidualBranch(nn.Module):
    def __init__(self, fan_in, fan_out, norm, kernel_size=3, stride=1):
        super().__init__()
        self.conv = nn.Conv2d(fan_in, fan_out, kernel_size=kernel_size, stride=stride, padding=1)
        self.norm = get_norm_layer(fan_out, norm)
        
    def forward(self, x):
        return self.norm(self.conv(x))
        
class ConvNet(nn.Module):
    def __init__(self, init_width, depth_mult, wm=1, gamma=1, res_scaling=1, depth_scale_first=1, skip_scaling=1, beta=1, gamma_zero=1, num_classes = 10, img_dim = 32,
                 norm=None):
        super().__init__()

        self.gamma = gamma
        self.wm = wm
        self.res_scaling = res_scaling
        self.depth_scale_first = depth_scale_first
        self.skip_scaling = skip_scaling
        self.beta = beta
        self.gamma_zero = gamma_zero
        self.img_dim = img_dim
        
        width = int(wm*init_width)
        self.conv01 = nn.Conv2d(3, width, 3, 1, padding=1)
        self.conv1 = nn.ModuleList([ResidualBranch(width, width, norm) for _ in range(depth_mult-1)])
        self.conv02 = nn.Conv2d(width, 2*width, 3, 1, padding=1)
        self.conv2 = nn.ModuleList([ResidualBranch(2*width, 2*width, norm) for _ in range(depth_mult-1)])
        self.conv03 = nn.Conv2d(2*width, 4*width, 3, 1, padding=1)
        self.conv3 = nn.ModuleList([ResidualBranch(4*width, 4*width, norm) for _ in range(depth_mult-1)])
        
        final_size =  self.img_dim//8 
        self.fc = MuReadout(int(final_size**2*64*wm), num_classes, readout_zero_init=False)
        self.reset_parameters()
    
    def reset_parameters(self) -> None:
        init.kaiming_normal_(self.conv01.weight, a=1)
        init.kaiming_normal_(self.conv02.weight, a=1)
        init.kaiming_normal_(self.conv03.weight, a=1)
        if self.conv01.bias is not None:
            init.zeros_(self.conv01.bias)
        if self.conv02.bias is not None:
            init.zeros_(self.conv02.bias)
        if self.conv03.bias is not None:
            init.zeros_(self.conv03.bias)
        for layer in self.conv1:
            init.kaiming_normal_(layer.conv.weight, a=1)
            if layer.conv.bias is not None:
                init.zeros_(layer.conv.bias)
        for layer in self.conv2:
            init.kaiming_normal_(layer.conv.weight, a=1)
            if layer.conv.bias is not None:
                init.zeros_(layer.conv.bias)
        for layer in self.conv2:
            init.kaiming_normal_(layer.conv.weight, a=1)
            if layer.conv.bias is not None:
                init.zeros_(layer.conv.bias)
                
    def init_parameters_depth(self):
        self.conv01.weight.data.normal_(mean=0.0, std=(1/self.depth_scale_first) * 1/3**2)
        if self.conv01.bias is not None:
            init.zeros_(self.conv01.bias)
                
    def forward(self, x):
        x = self.depth_scale_first * self.conv01(x)
        for layer in self.conv1:
            x = self.skip_scaling*x + self.beta*self.res_scaling * F.relu(layer(x))
        x = F.max_pool2d(x, 2, 2) # Max pooling layer with kernal of 2 and stride of 2
        x = self.conv02(x)
        for layer in self.conv2:
            x =  self.skip_scaling*x + self.beta*self.res_scaling * F.relu(layer(x))
        x = F.max_pool2d(x, 2, 2)
        x = self.conv03(x)
        for layer in self.conv3:
            x =  self.skip_scaling*x + self.beta*self.res_scaling * F.relu(layer(x))
        x = F.max_pool2d(x, 2, 2)
        
        #x = x.view(-1, int(4*4*64*self.wm))
        final_size =  self.img_dim//8 
        x = x.view(-1, int(final_size**2 * 64*self.wm))
        #x = x.view(-1, int(4*4*width))
        x = 1/(self.gamma*self.gamma_zero) * self.fc(x)
        return x