import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class LossWeightingModule(nn.Module):
    def __init__(self, hdim=200):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(1, hdim),
            nn.ReLU(inplace=True),
            nn.Linear(hdim, 1),
            nn.Sigmoid()
        )
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.layers(x.view(-1, 1)).view(-1)
        x = torch.max(0.1*torch.ones_like(x), x)
        return x

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, track_bn_stats):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels, track_running_stats=track_bn_stats),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x = self.block(x)
        x = F.max_pool2d(x, 2)
        return x

class ConvNet(nn.Module):
    def __init__(self, num_classes, conv_channels=32, img_size=32, track_bn_stats=True):
        super().__init__()
        self.adapt_layers = nn.ModuleList(
            [
                ConvBlock(3, conv_channels, track_bn_stats),
                ConvBlock(conv_channels, conv_channels, track_bn_stats),
                ConvBlock(conv_channels, conv_channels, track_bn_stats),
                ConvBlock(conv_channels, conv_channels, track_bn_stats),
            ]
        )
        self.lwm = LossWeightingModule()
        sz = (((img_size // 2) // 2) // 2) // 2
        self.fc = nn.Linear(sz * sz * conv_channels, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        for layer in self.adapt_layers:
            x = layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_hyper_params(self):
        for p in self.lwm.parameters():
            yield p

    def get_adapt_params(self):
        for layer in self.adapt_layers:
            for p in layer.parameters():
                yield p
        for p in self.fc.parameters():
            yield p
