import logging
import torch
import torch.nn as nn
from quant.quant_layer import QuantModule, UniformAffineQuantizer
from quant.quant_block_ldm import BaseQuantBlock
from quant.adaptive_rounding import AdaRoundQuantizer

logger = logging.getLogger(__name__)

def convert_adaround(q_model):
    for name, module in q_model.named_children():
        if isinstance(module, QuantModule):
            if module.ignore_reconstruction is True:
                continue
            else:
                if module.split != 0:
                    module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode='learned_hard_sigmoid',
                                                        weight_tensor=module.org_weight.data[:, :module.split, ...])
                    module.weight_quantizer_0 = AdaRoundQuantizer(uaq=module.weight_quantizer_0, round_mode='learned_hard_sigmoid',
                                                            weight_tensor=module.org_weight.data[:, module.split:, ...])
                else:
                    module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode='learned_hard_sigmoid',
                                                            weight_tensor=module.org_weight.data)
        elif isinstance(module, BaseQuantBlock):
            if module.ignore_reconstruction is True:
                continue
            else:
                for name, sub_module in module.named_modules():
                    if isinstance(sub_module, QuantModule):
                        # split is an option for skip connection of ResBlock
                        if sub_module.split != 0:
                            sub_module.weight_quantizer = AdaRoundQuantizer(uaq=sub_module.weight_quantizer, round_mode='learned_hard_sigmoid',
                                                                    weight_tensor=sub_module.org_weight.data[:, :sub_module.split, ...])
                            sub_module.weight_quantizer_0 = AdaRoundQuantizer(uaq=sub_module.weight_quantizer_0, round_mode='learned_hard_sigmoid',
                                                                    weight_tensor=sub_module.org_weight.data[:, sub_module.split:, ...])
                        else:
                            sub_module.weight_quantizer = AdaRoundQuantizer(uaq=sub_module.weight_quantizer, round_mode='learned_hard_sigmoid',
                                                                    weight_tensor=sub_module.org_weight.data)
        else:
            convert_adaround(module)


def dequantize_weight_model(q_model):
    print("dequantize weight")
    unet = q_model.model
    for name, module in unet.named_modules():
        if isinstance(module, QuantModule):
            module.dequantize_weight()

def check_quant_skip(q_model):
    unet = q_model.model
    quant_skip = False
    for name, module in unet.named_modules():
        if ('nin_shortcut' in name) or ('skip_connection' in name):
            if isinstance(module, QuantModule):
                quant_skip = True
                return quant_skip
    return quant_skip


def resume_cali_model(q_model, ckpt_path, cali_data, quant_act=False, act_quant_mode='dynamic', cond=False, split=False):
    print("Loading quantized model checkpoint")
    ckpt = torch.load(ckpt_path, map_location='cpu')

    # check quant_skip
    quant_skip = check_quant_skip(q_model)

    print("Initializing weight quantization parameters")
    q_model.set_quant_state(True, False)
    unet = q_model.model
    if quant_skip and split:
        print('run init')
        if not cond:
            cali_xs, cali_ts = cali_data
            _ = q_model(cali_xs.cuda(), cali_ts.cuda())
        else:
            cali_xs, cali_ts, cali_cs = cali_data
            _ = q_model(cali_xs.cuda(), cali_ts.cuda(), cali_cs.cuda())
    else:
        print('gen init')
        for name, module in unet.named_modules():
            if isinstance(module, QuantModule):
                if module.use_weight_quant:
                    if module.split != 0:
                        module.weight_quantizer.gen_delta_before_load_calibrated_model(module.weight[:, :module.split, ...])
                        module.weight_quantizer_0.gen_delta_before_load_calibrated_model(module.weight[:, :module.split, ...])
                    else:
                        module.weight_quantizer.gen_delta_before_load_calibrated_model(module.weight)
    print("adaround conversion")
    # change weight quantizer from uniform to adaround
    convert_adaround(q_model)

    for m in q_model.model.modules():
        if isinstance(m, AdaRoundQuantizer):
            m.zero_point = nn.Parameter(m.zero_point)
            m.delta = nn.Parameter(m.delta)

    # remove skip quant states
    if not quant_skip:
        keys = [key for key in ckpt.keys() if 'nin_shortcut' in key]
        keys = keys + [key for key in ckpt.keys() if 'skip_connection' in key]
        for key in keys:
            if 'weight_quantizer' in key:
                del ckpt[key]

    # remove act_quantizer states for now
    keys = [key for key in ckpt.keys() if "act" in key]
    for key in keys:
        del ckpt[key]
    #q_model.load_state_dict(ckpt, strict=(act_quant_mode=='qdiff'))
    print("load weight ckpt")
    q_model.load_state_dict(ckpt)
    q_model.set_quant_state(weight_quant=True, act_quant=False)

    for m in q_model.model.modules():
        if isinstance(m, AdaRoundQuantizer):
            zero_data = m.zero_point.data
            delattr(m, "zero_point")
            m.zero_point = zero_data

            delta_data = m.delta.data
            delattr(m, "delta")
            m.delta = delta_data

    if quant_act and (act_quant_mode=='qdiff'):
        print("Initializing act quantization parameters")
        q_model.set_quant_state(True, True)
        if not cond:
            cali_xs, cali_ts = cali_data
            _ = q_model(cali_xs.cuda(), cali_ts.cuda())
        else:
            cali_xs, cali_ts, cali_cs = cali_data
            _ = q_model(cali_xs.cuda(), cali_ts.cuda(), cali_cs.cuda())
        print("Loading quantized model checkpoint again")

        for m in q_model.model.modules():
            if isinstance(m, AdaRoundQuantizer):
                m.zero_point = nn.Parameter(m.zero_point)
                m.delta = nn.Parameter(m.delta)
            elif isinstance(m, UniformAffineQuantizer):
                if m.zero_point is not None:
                    if not torch.is_tensor(m.zero_point):
                        m.zero_point = nn.Parameter(torch.tensor(float(m.zero_point)))
                    else:
                        m.zero_point = nn.Parameter(m.zero_point)

        ckpt = torch.load(ckpt_path, map_location='cpu')
        q_model.load_state_dict(ckpt)
        q_model.set_quant_state(weight_quant=True, act_quant=True)

        for m in q_model.model.modules():
            if isinstance(m, AdaRoundQuantizer):
                zero_data = m.zero_point.data
                delattr(m, "zero_point")
                m.zero_point = zero_data

                delta_data = m.delta.data
                delattr(m, "delta")
                m.delta = delta_data
            elif isinstance(m, UniformAffineQuantizer):
                if m.zero_point is not None:
                    zero_data = m.zero_point.item()
                    delattr(m, "zero_point")
                    assert(int(zero_data) == zero_data)
                    m.zero_point = int(zero_data)
