from torch import nn
from .QuantConv2d import QuantConv2d
from .WaveletCompressedConvolution import DwtQuantConv2d1x1


def quantize_deeplabmobilev2(model, bits, act_bit):
    first_conv = model.backbone.low_level_features[0][0]
    last_layer = model.classifier.classifier[3]
    quantize_module(model, bits, act_bit)
    model.backbone.low_level_features[0][0] = first_conv
    model.classifier.classifier[3] = last_layer


def wavelet_deeplabmobilev2(model, level, compression, weight_bit, act_bit):
    first_conv = model.backbone.low_level_features[0][0]
    last_layer = model.classifier.classifier[3]
    wavelet_module(model, level, compression, weight_bit, act_bit)
    model.backbone.low_level_features[0][0] = first_conv
    model.classifier.classifier[3] = last_layer


def quantize_module(module, bit, act_bit):
    new_module = module
    if isinstance(module, nn.Conv2d):
        new_module = QuantConv2d(module.in_channels,
                                 module.out_channels,
                                 module.kernel_size,
                                 module.stride,
                                 module.padding,
                                 module.dilation,
                                 module.groups,
                                 module.bias is not None,
                                 bit,
                                 act_bit)
        new_module.weight = module.weight
        new_module.bias = module.bias
    for name, child in module.named_children():
        new_module.add_module(name, quantize_module(child, bit, act_bit))
    return new_module


def change_module_bits(module, bit, act_bit):
    if isinstance(module, QuantConv2d) or isinstance(module, DwtQuantConv2d1x1):
        module.change_bit(bit, act_bit)
    for name, child in module.named_children():
        change_module_bits(child, bit, act_bit)


def wavelet_module(module, level, compression, weight_bit, act_bit):
    new_module = module
    if isinstance(module, nn.Conv2d):
        if module.kernel_size[0] > 1:
            new_module = QuantConv2d(module.in_channels,
                                     module.out_channels,
                                     module.kernel_size,
                                     module.stride,
                                     module.padding,
                                     module.dilation,
                                     module.groups,
                                     module.bias is not None,
                                     weight_bit,
                                     act_bit)
            new_module.weight = module.weight
            new_module.bias = module.bias
        else:
            new_module = DwtQuantConv2d1x1(module.in_channels,
                                           module.out_channels,
                                           level,
                                           compression,
                                           weight_bit,
                                           act_bit,
                                           module.stride[0],
                                           module.padding[0],
                                           module.dilation[0],
                                           module.groups,
                                           module.bias is not None)
            new_module.weight = nn.Parameter(module.weight.squeeze(-1))  # squeeze and reshape return tensor, we need Parameter
            new_module.bias = module.bias
        if isinstance(module, QuantConv2d):
            new_module.act_alpha = module.act_alpha
    for name, child in module.named_children():
        new_module.add_module(name, wavelet_module(child, level, compression, weight_bit, act_bit))
    return new_module


def change_module_wt_params(module, compression, level):
    if isinstance(module, DwtQuantConv2d1x1):
        module.change_wt_params(compression, level)
    for name, child in module.named_children():
        change_module_wt_params(child, compression, level)
