from .linear import *
from .conv2d import *


def replace_layers(model):
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_layers(module)

        if isinstance(module, nn.Linear):
            module_bias = True if module.bias is not None else False
            compact_linear = CompactLinear(module.in_features, module.out_features, module_bias)
            if module_bias:
                weight = compact_linear.compact_weight[:, :compact_linear.in_features]
                bias = compact_linear.compact_weight[:, compact_linear.in_features]
                weight.data.copy_(module.weight.data)
                bias.data.copy_(module.bias.data)
            else:
                compact_linear.compact_weight.data.copy_(module.weight.data)
            setattr(model, name, compact_linear)

        if isinstance(module, nn.Conv2d):
            module_bias = True if module.bias is not None else False
            compact_conv2d = CompactConv2d(module.in_channels, module.out_channels, module.kernel_size[0],
                                           module.stride, module.padding, module.dilation, module.groups,
                                           module_bias)
            if module_bias:
                weight = compact_conv2d.compact_weight[:, :-1].view(compact_conv2d.out_channels,
                                                                    compact_conv2d.in_channels,
                                                                    *compact_conv2d.kernel_size)
                bias = compact_conv2d.compact_weight[:, -1]
                weight.data.copy_(module.weight.data)
                bias.data.copy_(module.bias.data)
            else:
                compact_conv2d.compact_weight.view(compact_conv2d.out_channels,
                                                   compact_conv2d.in_channels,
                                                   *compact_conv2d.kernel_size).data.copy_(module.weight.data)
            setattr(model, name, compact_conv2d)
