from src.utils.ntk_computation import stax


def WideBasicBlock(channel_mismatch, planes, stride):
    main = stax.serial(
        stax.Relu(),
        stax.Conv(planes, (3, 3), (stride, stride), padding='SAME', b_std=0.05),
        stax.Relu(),
        stax.Conv(planes, (3, 3), (1, 1), padding='SAME', b_std=0.05)
    )
    shortcut = \
        stax.Identity() if stride == 1 and not channel_mismatch \
            else stax.Conv(planes, (1, 1), (stride, stride), padding='SAME', b_std=0.05)
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(main, shortcut),
        stax.FanInSum()
    )


def WideResNet(num_layers, depth, widen_factor, dropout_rate, num_classes, num_input_channels=3, norm_layer=None):

    def _make_layer(block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        layers.append(block(True, planes, strides[0]))
        for stride in strides[1:]:
            layers.append(block(False, planes, stride))
        return stax.serial(*layers)

    assert ((depth - 4) % 6 == 0), 'wide-resnet depth should be 6n+4'
    n = int((depth - 4) / 6)
    k = 16 * widen_factor

    layers = [stax.Conv(k, (3, 3), (1, 1), padding='SAME', b_std=0.05)]

    if num_layers >= 1:
        layers += [_make_layer(WideBasicBlock, k, n, stride=1)]
    if num_layers >= 2:
        layers += [_make_layer(WideBasicBlock, k*2, n, stride=2)]
    if num_layers >= 3:
        layers += [_make_layer(WideBasicBlock, k*4, n, stride=2)]
    if num_layers >= 4:
        layers += [_make_layer(WideBasicBlock, k*8, n, stride=2)]

    layers += [stax.AvgPool((1, 1)), stax.Flatten(), stax.Dense(num_classes, 1., 0.)]

    return stax.serial(*layers)