from typing import cast, Dict, List, Union

import math
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import functional as F

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class DiagonalNetTorch(nn.Module):
    def __init__(self, scale, n_layer, d, delta=None) -> None:
        super(DiagonalNetTorch, self).__init__()

        w0 = scale * torch.ones((d, 1))
        if delta is None:
            self.wp = nn.Parameter(w0.clone())
            self.wm = nn.Parameter(w0.clone())
        else:
            self.wp = nn.Parameter(delta * w0.clone())
            self.wm = nn.Parameter(w0.clone() / delta)

        self.n_layer = n_layer

    def forward(self, x, y=None):
        """Return (y_hat, nabla_w) if y is not None, else return y_hat"""
        assert x.ndim == 2, "Reshape x to a 2-dim tensor first."
        n, d = x.shape
        y_hat = x @ self.w
        if y is not None:
            r = y_hat.reshape(-1, 1) - y.reshape(-1, 1)
            nabla_w = x.T @ r / n
            return y_hat, nabla_w
        else:
            return y_hat

    @property
    def w(self):
        return self.wp.pow(self.n_layer) - self.wm.pow(self.n_layer)


class DiagonalNet(nn.Module):
    def __init__(self, scale, n_layer, d, delta=None) -> None:
        super(DiagonalNet, self).__init__()

        w0 = scale * torch.ones((d, 1))
        # self.wp = nn.Parameter(w0.clone())
        self.wp = w0.clone()
        if delta is not None:
            w0 = w0.clone() + delta
        # self.wm = nn.Parameter(w0.clone())
        self.wm = w0.clone()

        self.n_layer = n_layer

    def forward(self, x, y=None):
        """Return (y_hat, nabla_w) if y is not None, else return y_hat"""
        assert x.ndim == 2, "Reshape x to a 2-dim tensor first."
        n, d = x.shape
        y_hat = x @ self.w
        if y is not None:
            r = y_hat.reshape(-1, 1) - y.reshape(-1, 1)
            nabla_w = x.T @ r / n
            return y_hat, nabla_w
        else:
            return y_hat, None

    @property
    def w(self):
        return self.wp.pow(self.n_layer) - self.wm.pow(self.n_layer)


class LeNet_BN(nn.Module):
    def __init__(self, num_classes):
        super(LeNet_BN, self).__init__()
 
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(6),
            nn.PReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
 
        self.conv_2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5),
            nn.BatchNorm2d(16),
            nn.PReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
 
        self.fc1 = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.BatchNorm1d(120),
            nn.PReLU()
        )
 
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.PReLU()
        )
 
        self.logists = nn.Linear(84, num_classes)
 
    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = x.view(-1, 16*5*5)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.logists(x)

        return x
    

class LeNet_BN_MNIST(nn.Module):
    def __init__(self, num_classes):
        super(LeNet_BN, self).__init__()
 
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
 
        self.conv_2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
 
        self.fc1 = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.BatchNorm1d(120),
            nn.ReLU()
        )
 
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU()
        )
 
        self.logists = nn.Linear(84, num_classes)
 
    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = x.view(-1, 16*5*5)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.logists(x)

        return x


class LeNet_PLUS(nn.Module):
    """LeNet++ as described in the Center Loss paper."""
    def __init__(self, num_classes):
        super(LeNet_PLUS, self).__init__()

        self.conv1_1 = nn.Conv2d(1, 32, 5, stride=1, padding=2)
        self.prelu1_1 = nn.PReLU()
        self.conv1_2 = nn.Conv2d(32, 32, 5, stride=1, padding=2)
        self.prelu1_2 = nn.PReLU()
        
        self.conv2_1 = nn.Conv2d(32, 64, 5, stride=1, padding=2)
        self.prelu2_1 = nn.PReLU()
        self.conv2_2 = nn.Conv2d(64, 64, 5, stride=1, padding=2)
        self.prelu2_2 = nn.PReLU()
        
        self.conv3_1 = nn.Conv2d(64, 128, 5, stride=1, padding=2)
        self.prelu3_1 = nn.PReLU()
        self.conv3_2 = nn.Conv2d(128, 128, 5, stride=1, padding=2)
        self.prelu3_2 = nn.PReLU()
        
        self.logists = nn.Linear(128*3*3, num_classes)

    def forward(self, x):
        x = self.prelu1_1(self.conv1_1(x))
        x = self.prelu1_2(self.conv1_2(x))
        x = F.max_pool2d(x, 2)
        
        x = self.prelu2_1(self.conv2_1(x))
        x = self.prelu2_2(self.conv2_2(x))
        x = F.max_pool2d(x, 2)
        
        x = self.prelu3_1(self.conv3_1(x))
        x = self.prelu3_2(self.conv3_2(x))
        x = F.max_pool2d(x, 2)
        
        x = x.view(-1, 128*3*3)
        x = self.logists(x)

        return x


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, last=False):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.prelu = nn.PReLU()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.last = last
        
    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.prelu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        if not self.last: 
            out = self.prelu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, num_classes, block=BasicBlock, layers=[5, 5, 5]):

        self.inplanes = 16
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.prelu1 = nn.PReLU()
        self.layer1 = self._make_layer(block, 16, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2, last_phase=True)
        self.avgpool = nn.AvgPool2d(8, stride=1)

        self.logists = nn.Linear(64 * block.expansion, num_classes)
    
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1, last_phase=False):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        if last_phase:
            for i in range(1, blocks-1):
                layers.append(block(self.inplanes, planes))
            layers.append(block(self.inplanes, planes, last=True))
        else: 
            for i in range(1, blocks):
                layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
 
        x = self.logists(x)
        
        return x


class CnnNet_MNIST(nn.Module):
    def __init__(self, num_classes):
        super(CnnNet_MNIST, self).__init__()

        self.conv_0 = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 64, 3),
            nn.PReLU(),
            nn.BatchNorm2d(64)
        )
        self.conv_1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )
        self.conv_3 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(576, 128),
            nn.BatchNorm1d(128)
        )
        
        self.logists = nn.Linear(128, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.conv_0(x)
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = x.view(-1, 576)
        x = self.fc(x)

        x = self.logists(x)
        
        return x


class CnnNet_CIFAR10(nn.Module):
    def __init__(self, num_classes):
        super(CnnNet_CIFAR10, self).__init__()

        self.conv_0 = nn.Sequential(
            nn.BatchNorm2d(3),
            nn.Conv2d(3, 64, 3),
            nn.PReLU(),
            nn.BatchNorm2d(64)
        )
        self.conv_1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(64, 96, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(96),
            nn.MaxPool2d(2, 2)
        )
        self.conv_3 = nn.Sequential(
            nn.Conv2d(96, 128, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(3*3*128, 256),
            nn.BatchNorm1d(256)
        )
        
        self.logists = nn.Linear(256, num_classes)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv_0(x)
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = x.view(-1, 3*3*128)
        x = self.fc(x)
        
        x = self.logists(x)

        return x
    


class MLP(nn.Module):
    def __init__(self, num_classes):
        super(MLP, self).__init__()

        self.features = nn.Sequential(
            nn.Linear(3*32*32, 200),
            nn.ReLU(),
            # nn.BatchNorm1d(200),
            nn.Linear(200, 200),
            nn.ReLU(),
            # nn.BatchNorm1d(200)
        )
        
        self.logists = nn.Linear(200, num_classes)
        
    def forward(self, x):
        x = x.view(-1, 3*32*32)
        x = self.features(x)
        x = self.logists(x)

        return x
    

class MLP_BN(nn.Module):
    def __init__(self, num_classes=10, units=[128, 128]):
        super(MLP_BN, self).__init__()

        self.features = nn.Sequential(
            nn.Linear(1*28*28, units[0]),
            nn.SiLU(),
            nn.Linear(units[0], units[1]),
            nn.BatchNorm1d(units[1], eps=1e-7, momentum=0.99)
        )
        
        self.logists = nn.Linear(units[1], num_classes)  
    
    def forward(self, x):
        x = x.view(-1, 1*28*28)
        x = self.features(x)
        x = self.logists(x)

        return x

vgg_cfgs: Dict[str, List[Union[str, int]]] = {
    "vgg11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "vgg13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "vgg16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "vgg19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}


def _make_layers(vgg_cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
    layers: nn.Sequential[nn.Module] = nn.Sequential()
    in_channels = 3
    for v in vgg_cfg:
        if v == "M":
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        else:
            v = cast(int, v)
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers.append(conv2d)
                layers.append(nn.BatchNorm2d(v))
                layers.append(nn.ReLU(True))
            else:
                layers.append(conv2d)
                layers.append(nn.ReLU(True))
            in_channels = v

    return layers


class VGG(nn.Module):
    def __init__(self, vgg_cfg: List[Union[str, int]], batch_norm: bool = False, num_classes: int = 1000) -> None:
        super(VGG, self).__init__()
        self.features = _make_layers(vgg_cfg, batch_norm)

        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10),
        )
         # Initialize weights
        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        #         m.weight.data.normal_(0, math.sqrt(2. / n))
        #         m.bias.data.zero_()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    

def vgg11(num_classes) -> VGG:
    model = VGG(vgg_cfgs["vgg11"], False, num_classes)

    return model


def vgg13(num_classes) -> VGG:
    model = VGG(vgg_cfgs["vgg13"], False, num_classes)

    return model


def vgg16(num_classes) -> VGG:
    model = VGG(vgg_cfgs["vgg16"], False, num_classes)

    return model


def vgg19(num_classes) -> VGG:
    model = VGG(vgg_cfgs["vgg19"], False, num_classes)

    return model


def vgg11_bn(num_classes) -> VGG:
    model = VGG(vgg_cfgs["vgg11"], True, num_classes)

    return model


def vgg13_bn(num_classes) -> VGG:
    model = VGG(vgg_cfgs["vgg13"], True, num_classes)

    return model


def vgg16_bn(num_classes) -> VGG:
    model = VGG(vgg_cfgs["vgg16"], True, num_classes)

    return model


def vgg19_bn(num_classes) -> VGG:
    model = VGG(vgg_cfgs["vgg19"], True, num_classes)

    return model


__factory = {
    'lenet_bn': LeNet_BN,
    'lenet_plus': LeNet_PLUS,
    'resnet': ResNet,
    'cnnnet_mnist': CnnNet_MNIST,
    'cnnnet_cifar10': CnnNet_CIFAR10,
    'vgg11': vgg11,
    'vgg11_bn': vgg11_bn,
    'vgg13': vgg13,
    'vgg13_bn': vgg13_bn,
    'vgg16': vgg16,
    'vgg16_bn': vgg16_bn,
    'vgg19': vgg19,
    'vgg19_bn': vgg19_bn,
    'mlp': MLP,
    'mlp_bn': MLP_BN
}

def create(name, num_classes):
    if name not in __factory.keys():
        raise KeyError("Unknown model: {}".format(name))
    return __factory[name](num_classes)

if __name__ == '__main__':
    pass