import torch.nn as nn

from .quant_layer import DwtQuantConv2d1x1, QuantConv2d, DwtDenseQuantConv2d1x1

LOW_SPATIAL_SIZE_COMPRESSION = 0.5
LOW_SPATIAL_SIZE_LEVELS = 2
LOW_SPATIAL_SIZE_QUANT_ACT_BITS = 4


def quantize_model_convolutions(model: nn.Module, weight_bits: int, act_bits: int, signed: bool = False,
                                skip_first: bool = True):
    """
    Replaces all of the model's convolutions with quantized convolutions. Modifies the model in-place.
    :param model: A pytorch model
    :param weight_bits: Number of bits for weights
    :param act_bits: Number of bits for activations
    :param signed: Whether or not to use signed quantization
    :param skip_first: Whether or not to skip the first convolution in the network
    """
    for name, layer in model._modules.items():
        if isinstance(layer, nn.Conv2d):
            if skip_first:
                skip_first = False
                continue
            else:
                limit_act_bits = False
                if act_bits <= LOW_SPATIAL_SIZE_QUANT_ACT_BITS and (layer.kernel_size[0] > 1 or layer.stride[0] > 1):
                    # In order to fairly compare standard quantization to our wavelet compression scheme,
                    # in non-waveletable layers (for example 3x3 convs), we don't use a very low-precision quantization
                    limit_act_bits = True
                quantized_layer = QuantConv2d(layer.in_channels, layer.out_channels, layer.kernel_size,
                                              layer.stride, layer.padding, layer.dilation, layer.groups,
                                              layer.bias is not None, weight_bits,
                                              act_bits if not limit_act_bits else LOW_SPATIAL_SIZE_QUANT_ACT_BITS, signed)
                if isinstance(layer, QuantConv2d):
                    quantized_layer.act_alpha = layer.act_alpha
            quantized_layer.weight = layer.weight
            quantized_layer.bias = layer.bias
            model._modules[name] = quantized_layer
        else:
            quantize_model_convolutions(layer, weight_bits, act_bits, signed, skip_first)


def quantize_wavelet_model_convolutions(model: nn.Module, weight_bits: int, act_bits: int, quant_signed: bool = False,
                                        compression: float = 0.25, levels: int = 3, use_ste: bool = False,
                                        skip_first: bool = True):
    """
    Replaces all of the model's convolutions with wavelet quantized convolution. Modifies the model in-place.
    :param model: A pytorch model
    :param weight_bits: Number of bits for weights
    :param act_bits: Number of bits for activations
    :param quant_signed: Whether or not to use signed quantization
    :param compression: Wavelet compression rate
    :param levels: Number of levels for the wavelet transform
    :param use_ste: Use wavelet STE using dense wavelet (to improve the training process)
    :param skip_first: Whether or not to skip the first convolution in the network
    """
    for name, layer in model._modules.items():
        if isinstance(layer, nn.Conv2d):
            if skip_first:
                skip_first = False
                continue
            else:
                if layer.kernel_size[0] > 1 or layer.stride[0] > 1 or name == 'plane_params':
                    quantized_layer = QuantConv2d(layer.in_channels, layer.out_channels, layer.kernel_size,
                                                  layer.stride, layer.padding, layer.dilation, layer.groups,
                                                  layer.bias is not None, weight_bits, act_bits, quant_signed)
                    quantized_layer.weight = layer.weight
                else:
                    if act_bits < 8:
                        # The compression factor of the wavelet layer will reduce the number of BOPs,
                        # so we use 8bit activation quantization in every layer where the wavelet is applied
                        print(f"Applying Wavelet transform with 8 bits instead of {act_bits}")
                    if use_ste:
                        quantized_layer = DwtDenseQuantConv2d1x1(layer.in_channels, layer.out_channels, levels,
                                                                 compression, weight_bits, max(act_bits, 8),
                                                                 layer.stride, layer.padding, layer.dilation,
                                                                 layer.groups, layer.bias is not None)
                        quantized_layer.weight = layer.weight
                    else:
                        quantized_layer = DwtQuantConv2d1x1(layer.in_channels, layer.out_channels, levels, compression,
                                                            weight_bits, max(act_bits, 8), layer.stride[0],
                                                            layer.padding[0],
                                                            layer.dilation[0], layer.groups, layer.bias is not None)
                        quantized_layer.weight = nn.Parameter(layer.weight.squeeze(-1))
                if isinstance(layer, QuantConv2d):
                    quantized_layer.act_alpha = layer.act_alpha
            quantized_layer.bias = layer.bias
            model._modules[name] = quantized_layer
        else:
            quantize_wavelet_model_convolutions(layer, weight_bits, act_bits, quant_signed, compression, levels,
                                                use_ste, skip_first)


def modify_resnext_low_spatial_sized_compression(model: nn.Module, compression: float, level: int, quant_act_bits: int,
                                                 spatial_factor: int = 2):
    """
    Modifies the compression of low spatial sized layers inside a ResNeXt backbone recursively
    :param model: The ResNeXt model
    :param compression: The wavelet compression to use
    :param level: The number of wavelet transform levels
    :param quant_act_bits: The number of bits, if using standard quantization
    :param spatial_factor: A helper paramter that keeps track of the downsampling factor
    """
    for name, layer in model._modules.items():
        if name == "downsample":
            spatial_factor *= 2
        if (isinstance(layer, DwtQuantConv2d1x1) or isinstance(layer, DwtDenseQuantConv2d1x1)) and spatial_factor >= 16:
            # print(f"Modifying resnext low spatial sized wavelet compression to: {compression}%, levels: {level}")
            layer.update_compression(compression, level)
        elif isinstance(layer, QuantConv2d) and spatial_factor >= 16:
            # print(f"Modifying resnext low spatial sized quantization to: {quant_act_bits}bits")
            layer.update_act_bit(quant_act_bits)
        else:
            spatial_factor = modify_resnext_low_spatial_sized_compression(layer, compression, level, quant_act_bits,
                                                                          spatial_factor)

    return spatial_factor
