# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from operator import attrgetter

import torch.nn as nn
import torch.distributed as dist

from ..pq.utils import get_layers, attrsetter
from .modules import IntConv2d, IntLinear, IntEmbedding, ActivationQuantizer


MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d}


def quantize_model_(model, p=0.2, bits=8, update_step=3000):
    """
    Replaces all modules with their scalar quantized counterpart and
    registers hooks to quantize the post-ativations of those modules.

    Args:
        - model: a nn.Module
        - p: amount of noise (0 for no noise, 1 to quantize all the weights/activations)
        - bits: number of bits
        - update_step: update quantization parameters every update_step steps
    """

    # quantize all layers
    quantized_layers = get_layers(model, "(.*?)")

    for layer in quantized_layers:

        # book-keeping
        is_master_process = (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0)

        # recover module
        module = attrgetter(layer)(model)
        if is_master_process:
            logging.info(f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}")

        # quantization params
        q_params = {"p": p, "update_step": update_step, "bits": bits, "method": "histogram", "counter": 0}

        # instantiate the quantized counterpart
        if isinstance(module, tuple(MAPPING.keys())):
            QuantizedModule = MAPPING[module.__class__]
            quantized_module = QuantizedModule.__new__(QuantizedModule)
            params = module.__dict__
            params.update(q_params)
            quantized_module.__dict__.update(params)

        else:
            if is_master_process:
                logging.info(f"Module {module} not yet supported for quantization")
            continue

        # activation quantization
        a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method="histogram")

        # replace layer by its quantized counterpart
        attrsetter(layer)(model, quantized_module)

    # return name of quantized layers
    return quantized_layers
