import haiku as hk
import math

from jax import nn


class ConvBlock(hk.Module):
    def __init__(self, channels, name=None):
        super().__init__(name=name)
        self.channels = channels

    def __call__(self, inputs, is_training):
        outputs = inputs
        outputs = hk.Conv2D(self.channels, kernel_shape=3,
            stride=1, with_bias=True, name='conv')(outputs)
        outputs = hk.BatchNorm(create_scale=True, create_offset=True,
            decay_rate=0.9, name='norm')(outputs, is_training)
        outputs = nn.relu(outputs)
        outputs = hk.max_pool(outputs, 2, 2, padding='VALID')
        return outputs


class Conv4(hk.Module):
    def __call__(self, inputs, is_training):
        outputs = inputs
        outputs = ConvBlock(64, name='layer1')(outputs, is_training)
        outputs = ConvBlock(64, name='layer2')(outputs, is_training)
        outputs = ConvBlock(64, name='layer3')(outputs, is_training)
        outputs = ConvBlock(64, name='layer4')(outputs, is_training)
        outputs = outputs.reshape(inputs.shape[:-3] + (-1,))
        return outputs / math.sqrt(outputs.shape[-1])
