import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, track_bn_stats):
        super().__init__()
        self.adapt_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels, track_running_stats=track_bn_stats),
            nn.ReLU(inplace=True),
        )
        self.hyper_block1 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels, track_running_stats=track_bn_stats),
            nn.ReLU(inplace=True),
        )
        self.hyper_block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels, track_running_stats=track_bn_stats),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.adapt_block(x)
        x = self.hyper_block1(x)
        x = self.hyper_block2(x)
        x = F.max_pool2d(x, 2)
        return x

class ConvNet(nn.Module):
    def __init__(self, num_classes, conv_channels=32, img_size=32, track_bn_stats=True):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                ConvBlock(3, conv_channels, track_bn_stats),
                ConvBlock(conv_channels, conv_channels, track_bn_stats),
                ConvBlock(conv_channels, conv_channels, track_bn_stats),
            ]
        )
        sz = ((img_size // 2) // 2) // 2
        self.fc = nn.Linear(sz * sz * conv_channels, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_adapt_params(self):
        for layer in self.layers:
            for p in layer.adapt_block.parameters():
                yield p
        for p in self.fc.parameters():
            yield p

    def get_hyper_params(self):
        for layer in self.layers:
            for p in layer.hyper_block1.parameters():
                yield p
            for p in layer.hyper_block2.parameters():
                yield p
