from src.utils.ntk_computation import stax


def BasicBlock(channel_mismatch, planes, stride):
    main = stax.serial(
        stax.Relu(),
        stax.Conv(planes, (3, 3), (stride, stride), padding='SAME'),
        stax.Relu(),
        stax.Conv(planes, (3, 3), (1, 1), padding='SAME'),
    )
    shortcut = \
        stax.Identity() if stride == 1 and not channel_mismatch \
            else stax.Conv(planes, (1, 1), (stride, stride), padding='SAME')
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(main, shortcut),
        stax.FanInSum(),
        # stax.Relu()
    )


def ResNet(block, num_blocks, num_classes=10, num_input_channels=3, widen_factor=4):

    def _make_layer(block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        layers.append(block(True, planes, strides[0]))
        for stride in strides[1:]:
            layers.append(block(False, planes, stride))
        return stax.serial(*layers)

    layers = [stax.Conv(64, (7, 7), (1, 1), padding='SAME')]

    num_layers = len(num_blocks)
    width = 16 * widen_factor

    if num_layers >= 1:
        layers += [_make_layer(block, width, num_blocks[0], stride=1)]
    if num_layers >= 2:
        layers += [_make_layer(block, width * 2, num_blocks[1], stride=2)]
    if num_layers >= 3:
        layers += [_make_layer(block, width * 4, num_blocks[2], stride=2)]
    if num_layers >= 4:
        layers += [_make_layer(block, width * 8, num_blocks[3], stride=2)]

    layers += [stax.AvgPool((1, 1)), stax.Flatten(), stax.Dense(num_classes, 1., 0.)]

    return stax.serial(*layers)


def ResNet18(num_layers, depth, widen_factor, dropout_rate, num_classes,
                 num_input_channels=3, norm_layer=None):
    return ResNet(BasicBlock, [2,2,2,2], num_classes, num_input_channels, widen_factor)


def ResNet34(num_layers, depth, widen_factor, dropout_rate, num_classes,
                 num_input_channels=3, norm_layer=None):
    return ResNet(BasicBlock, [3,4,6,3], num_classes, num_input_channels, widen_factor)