import logging
import torch.nn as nn
from quant.quant_block_ldm import get_specials, BaseQuantBlock
from quant.quant_block_ldm import QuantResBlock, QuantAttentionBlock
from quant.quant_block_ldm import QuantQKMatMul, QuantSMVMatMul, QuantBasicTransformerBlock
from quant.quant_layer import QuantModule, StraightThrough

from ldm.modules.attention import BasicTransformerBlock


import copy

logger = logging.getLogger(__name__)


class QuantModel(nn.Module):

    def __init__(self, model: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {},
                quant_skip=True, quant_conv_inout=True, **kwargs):
        super().__init__()
        self.model = model
        self.sm_abit = kwargs.get('sm_abit', 8)
        self.in_channels = model.in_channels
        if hasattr(model, 'image_size'):
            self.image_size = model.image_size
        self.specials = get_specials(act_quant_params['leaf_param'])
        self.quant_module_refactor(self.model, weight_quant_params, act_quant_params, quant_skip, quant_conv_inout)
        self.quant_block_refactor(self.model, weight_quant_params, act_quant_params)

    def quant_module_refactor(self, module: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {},
                            quant_skip=True, quant_conv_inout=True):
        """
        Recursively replace the normal layers (conv2D, conv1D, Linear etc.) to QuantModule
        :param module: nn.Module with nn.Conv2d, nn.Conv1d, or nn.Linear in its children
        :param weight_quant_params: quantization parameters like n_bits for weight quantizer
        :param act_quant_params: quantization parameters like n_bits for activation quantizer
        """
        prev_quantmodule = None
        # [temb, conv_in, down, mid, up, norm_out, conv_out]
        for name, child_module in module.named_children():
            if name == 'skip_connection': # fix abits=8 for skip connection of ResBlock
                #continue
                if isinstance(child_module, (nn.Conv2d, nn.Conv1d, nn.Linear)): # nn.Conv1d
                    act_quant_params_tmp = act_quant_params.copy()
                    act_quant_params_tmp['n_bits'] = 8
                    setattr(module, name, QuantModule(
                        child_module, weight_quant_params, act_quant_params_tmp, fix_abits=True))
                    prev_quantmodule = getattr(module, name)
            elif name == 'out': # fix abits=8 for conv out
                act_quant_params_tmp = act_quant_params.copy()
                act_quant_params_tmp['n_bits'] = 8
                child_module[2] = QuantModule(child_module[2], weight_quant_params, act_quant_params_tmp, fix_abits=True)
                prev_quantmodule = getattr(module, name)
            elif name == 'input_blocks': # fix abits=8 for conv in
                act_quant_params_tmp = act_quant_params.copy()
                act_quant_params_tmp['n_bits'] = 8
                child_module[0][0] = QuantModule(child_module[0][0], weight_quant_params, act_quant_params_tmp, fix_abits=True)
                # do quant iteratively
                self.quant_module_refactor(child_module, weight_quant_params, act_quant_params, quant_skip, quant_conv_inout)
            elif isinstance(child_module, (nn.Conv2d, nn.Conv1d, nn.Linear)): # nn.Conv1d
                setattr(module, name, QuantModule(
                    child_module, weight_quant_params, act_quant_params))
                prev_quantmodule = getattr(module, name)
            elif isinstance(child_module, StraightThrough):
                continue
            else:
                self.quant_module_refactor(child_module, weight_quant_params, act_quant_params, quant_skip, quant_conv_inout)

    def quant_block_refactor(self, module: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}):
        for name, child_module in module.named_children():
            if type(child_module) in self.specials:
                if self.specials[type(child_module)] in [QuantBasicTransformerBlock]:
                    setattr(module, name, self.specials[type(child_module)](child_module,
                        act_quant_params, sm_abit=self.sm_abit))
                elif self.specials[type(child_module)] in [QuantAttentionBlock]:
                    setattr(module, name, self.specials[type(child_module)](child_module,
                        act_quant_params))
                    # QuantSMVMatMul
                    act_quant_params_mse = copy.deepcopy(act_quant_params)
                    act_quant_params_mse['scale_method'] = 'mse'
                    module_smv = child_module.attention.smv_matmul
                    child_module.attention.smv_matmul = self.specials[type(module_smv)](act_quant_params_mse,
                                                                                        sm_abit=self.sm_abit)
                    # QuantQKMatMul
                    module_qkv = child_module.attention.qkv_matmul
                    child_module.attention.qkv_matmul = self.specials[type(module_qkv)](act_quant_params_mse)
                else:
                    setattr(module, name, self.specials[type(child_module)](child_module,
                        act_quant_params))
            else:
                self.quant_block_refactor(child_module, weight_quant_params, act_quant_params)

    def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
        for m in self.model.modules():
            if isinstance(m, (QuantModule, BaseQuantBlock)):
                m.set_quant_state(weight_quant, act_quant)

    def forward(self, x, timesteps=None, context=None):
        return self.model(x, timesteps, context)

    def set_grad_ckpt(self, grad_ckpt: bool):
        for name, m in self.model.named_modules():
            if isinstance(m, (QuantBasicTransformerBlock, BasicTransformerBlock)):
                # logger.info(name)
                m.checkpoint = grad_ckpt
            # elif isinstance(m, QuantResBlock):
                # logger.info(name)
                # m.use_checkpoint = grad_ckpt

