# Modified from: github.com/kuangliu/pytorch-cifar (models/vgg.py)

from functools import partial
import math
from typing import Callable, Tuple

from flax import linen as nn
from jax import numpy as jnp

ModuleDef = Callable[..., nn.Module]

pytorch_kernel_init = nn.initializers.variance_scaling(scale=1/3, mode='fan_in', distribution='uniform')
# https://github.com/chengyangfu/pytorch-vgg-cifar10/blob/9c5da95/vgg.py#L35
conv_kernel_init = nn.initializers.variance_scaling(scale=2.0, mode='fan_out', distribution='normal')


class Backbone(nn.Module):
    stages: Tuple[Tuple[int, ...], ...]
    norm: ModuleDef = nn.BatchNorm
    no_bn: bool = False

    @nn.compact
    def __call__(self, x, norm_kwargs=None):
        norm_kwargs = norm_kwargs or {}
        for i, stage in enumerate(self.stages):
            for j, dim in enumerate(stage):
                suffix = '{:d}_{:d}'.format(i + 1, j + 1)
                # Use bias iff not using BN.
                x = nn.Conv(dim, (3, 3), padding=1, use_bias=self.no_bn, name='conv' + suffix,
                            kernel_init=conv_kernel_init,
                            bias_init=nn.initializers.zeros)(x)
                if not self.no_bn:  # bn
                    x = self.norm(name='norm' + suffix)(x, **norm_kwargs)
                x = nn.relu(x)
            x = nn.max_pool(x, (2, 2), strides=(2, 2))
        # x = nn.avg_pool(x, (1, 1), strides=1)
        return x


class MLP(nn.Module):
    dims: Tuple[int, ...]
    norm: ModuleDef = nn.BatchNorm
    no_bn: bool = False

    @nn.compact
    def __call__(self, x, norm_kwargs=None):
        norm_kwargs = norm_kwargs or {}
        for dim in self.dims:
            # Use bias iff not using BN.
            x = nn.Dense(dim, use_bias=self.no_bn, kernel_init=pytorch_kernel_init)(x)
            if not self.no_bn:  # bn
                x = self.norm()(x, **norm_kwargs)
            x = nn.relu(x)
        return x


class VGG(nn.Module):
    conv_stages: Tuple[Tuple[int, ...], ...]
    num_classes: int
    mlp_width: int = 512  # For ImageNet, would use 4096.
    mlp_depth: int = 2
    norm: ModuleDef = nn.BatchNorm
    no_bn: bool = False
    no_mlp_bn: bool = True  # Disable bn in MLP.
    assert_1x1: bool = True

    def setup(self):
        self.backbone = Backbone(stages=self.conv_stages, norm=self.norm, no_bn=self.no_bn)
        mlp_dims = (self.mlp_width,) * self.mlp_depth
        self.mlp = MLP(mlp_dims, norm=self.norm, no_bn=self.no_mlp_bn)
        self.classifier = nn.Dense(self.num_classes, kernel_init=pytorch_kernel_init)

    def __call__(self, x, norm_kwargs=None):
        norm_kwargs = norm_kwargs or {}
        x = self.backbone(x, norm_kwargs=norm_kwargs)
        if self.assert_1x1:
            x = jnp.squeeze(x, (-3, -2))
        else:
            # Flatten last 3 dimensions.
            x = jnp.reshape(x, (*x.shape[:-3], -1))
        x = self.mlp(x, norm_kwargs=norm_kwargs)
        x = self.classifier(x)
        return x


VGG11 = partial(VGG, ((64,), (128,), (256, 256), (512, 512), (512, 512)))
VGG13 = partial(VGG, ((64, 64), (128, 128), (256, 256), (512, 512), (512, 512)))
VGG16 = partial(VGG, ((64, 64), (128, 128), (256, 256, 256), (512, 512, 512), (512, 512, 512)))
VGG19 = partial(VGG, ((64, 64), (128, 128), (256, 256, 256, 256), (512, 512, 512, 512), (512, 512, 512, 512)))
