import torch
import torch.nn as nn
import copy
from quant_modules import TensorQuantizer, Conv2dQuantizer, LinearQuantizer, ActivationQuantizer
from quant_utils import quant_args


def quantize_model(model):
    """
    Recursively quantize a pretrained single-precision model to int8 quantized model
    model: pretrained single-precision model
    """
    # quantize convolutional and linear layers to 8-bit
    if type(model) == nn.Conv2d:
        quant_mod = Conv2dQuantizer(**quant_args)
        quant_mod.set_param(model)
        return quant_mod
    elif type(model) == nn.Linear:
        quant_mod = LinearQuantizer(**quant_args)
        quant_mod.set_param(model)
        return quant_mod

    # recursively use the quantized module to replace the single-precision module
    elif type(model) == nn.Sequential:
        mods = []
        for n, m in model.named_children():
            mods.append(quantize_model(m))
        return nn.Sequential(*mods)
    else:
        q_model = copy.deepcopy(model)
        for attr in dir(model):
            mod = getattr(model, attr)
            if isinstance(mod, nn.Module) and 'norm' not in attr:
                setattr(q_model, attr, quantize_model(mod))
        return q_model

def set_first_last_layer(model):
    module_list = []
    for m in model.modules():
        if isinstance(m, Conv2dQuantizer):
            module_list += [m]
        if isinstance(m, LinearQuantizer):
            module_list += [m]
    module_list[0].quant_input.is_enable = False
    module_list[-1].quant_input.bit = torch.tensor(8)