import sys
from HAWQ.utils.models.q_resnet import Q_ResNet50_detr, Q_ResNet101_detr
# Q_input_proj, Q_query_embed,QuantMultiheadAttention,Q_class_embed,Q_bbox_embed
from HAWQ.utils.models.q_transformer import *
from HAWQ.utils.models.q_segmentation import *
from HAWQ.bit_config import *

# from detr.pyhessian_detr import Logger
import logging
from logging import handlers
import pdb


class ARGs(object):
    def __init__(self):
        self.backbone = 'resnet50'
        self.quant_scheme = 'detr8w8a'
        self.bias_bit = 32
        self.channel_wise = True
        self.act_percentile = 0
        self.act_range_momentum = 0.99
        self.weight_percentile = 0
        self.fix_BN = True
        self.fix_BN_threshold = None
        self.checkpoint_iter = -1
        self.fixed_point_quantization = False


def letsquant(model, log, args=None):
    return letsquant_(model, log, args=args)


def letsquant_(model, log, args=None):
    if args is None:
        raise NotImplementedError

    # if args.masks:
    #     # 含segmentation
    #     model_seg = model
    #     log.logger.info(f'ori_DETRsegm:\n{model}')
    #     model = getattr(model_seg, 'detr')

    quant_class_embed = args.quant_class_embed
    quant_bbox_embed = args.quant_bbox_embed
    quant_input_proj = args.quant_input_proj
    quant_backbone = args.quant_backbone
    quant_encoder = args.quant_encoder
    quant_decoder = args.quant_decoder
    # ------------ quant_detr_segm ------------
    quant_bbox_atten = args.quant_bbox_atten
    quant_mask_head = args.quant_mask_head

    if args.masks:
        # 含segmentation
        model_seg_detr = getattr(model, 'detr')
        transformer = getattr(model_seg_detr, 'transformer')
    else:
        transformer = getattr(model, 'transformer')

    # quant others
    if quant_class_embed:
        class_embed = getattr(model, 'class_embed')
        setattr(model, 'class_embed', Q_class_embed(class_embed))
    if quant_bbox_embed:
        bbox_embed = getattr(model, 'bbox_embed')
        setattr(model, 'bbox_embed', Q_bbox_embed(bbox_embed))
    if quant_encoder:
        # quant encoder
        # getattr(transformer,'encoder.layers')
        encoder_layers = transformer.encoder.layers
        for idx, module in enumerate(encoder_layers):
            Q_enconder_layer = Q_TransformerEncoderLayer(module)

            if args.masks:
                # 含segmentation
                model_seg_detr = getattr(model, 'detr')
                setattr(model_seg_detr.transformer.encoder.layers,
                        f"{idx}", Q_enconder_layer)
            else:
                setattr(model.transformer.encoder.layers,
                        f"{idx}", Q_enconder_layer)
    if quant_decoder:
        # quant decoder
        decoder_layers = transformer.decoder.layers
        for idx, module in enumerate(decoder_layers):
            # if args.use_dn == True:
            #     Q_decoder_layer = Q_DN_TransformerDecoderLayer(module, idx)
            #     setattr(model.transformer.decoder.layers,
            #             f"{idx}", Q_decoder_layer)
            # else:
            Q_decoder_layer = Q_TransformerDecoderLayer(module)
            
            if args.masks:
                # 含segmentation
                model_seg_detr = getattr(model, 'detr')
                setattr(model_seg_detr.transformer.decoder.layers,
                        f"{idx}", Q_decoder_layer)
            else:
                setattr(model.transformer.decoder.layers,
                        f"{idx}", Q_decoder_layer)
    if quant_backbone:
        if args.masks:
            # 含segmentation
            model_seg_detr = getattr(model, 'detr')
            backbone_without_pos = getattr(model_seg_detr, 'backbone')[0]
        else:
            backbone_without_pos = getattr(model, 'backbone')[0]

        # quant backbone
        # backbone_without_pos = getattr(model, 'backbone')[0]
        if quant_backbone:
            if args.backbone == 'resnet101':
                Q_backbone_without_pos = Q_ResNet101_detr(
                    args, backbone_without_pos)
            else:
                Q_backbone_without_pos = Q_ResNet50_detr(
                    args, backbone_without_pos)

            if args.masks:
                # 含segmentation
                model_seg_detr = getattr(model, 'detr')
                setattr(model_seg_detr.backbone, '0', Q_backbone_without_pos)
            else:
                setattr(model.backbone, '0', Q_backbone_without_pos)

            # setattr(model.backbone, '0', Q_backbone_without_pos)

    if quant_input_proj:
        # quant input_proj

        if args.masks:
            # 含segmentation
            model_seg_detr = getattr(model, 'detr')
            input_proj = getattr(model_seg_detr, 'input_proj')
        else:
            input_proj = getattr(model, 'input_proj')

        q_input_proj = Q_input_proj(input_proj)

        if args.masks:
            # 含segmentation
            model_seg_detr = getattr(model, 'detr')
            setattr(model_seg_detr, 'input_proj', q_input_proj)
        else:
            setattr(model, 'input_proj', q_input_proj)

    if quant_bbox_atten:
        bbox_atten = getattr(model, 'bbox_attention')
        query_dim = getattr(bbox_atten, 'query_dim')
        hidden_dim = getattr(bbox_atten, 'hidden_dim')
        num_heads = getattr(bbox_atten, 'num_heads')
        dropout_p = getattr(bbox_atten, 'dropout_p')
        q_bbox_atten = Q_bbox_atten(
            query_dim, hidden_dim, num_heads, dropout=dropout_p)
        setattr(model, 'bbox_attention', q_bbox_atten)

    if quant_mask_head:
        mask_head = getattr(model, 'mask_head')
        q_mask_head = Q_mask_head(mask_head)
        setattr(model, 'mask_head', q_mask_head)
        # pdb.set_trace()

    # after setup q_model, set quant config
    print("-"*80)
    if args.ILP:
        ILPbit_config_dict = read_ILPbit_config_dict_json(args.ILP)
        bit_config = ILPbit_config_dict["bit_config_" +
                                        args.backbone + "_" + args.quant_scheme]
    else:
        bit_config = bit_config_dict["bit_config_" +
                                     args.backbone + "_" + args.quant_scheme]
    print("Load configuration:", "bit_config_" +
          args.backbone + "_" + args.quant_scheme)
    print("bit_config:", bit_config)
    log.logger.info("Load configuration: bit_config_{}_{}".format(
        args.backbone, args.quant_scheme))
    log.logger.info("bit_config:{}".format(bit_config))
    print("-"*80)
    name_counter = 0

    # all_full_precision_flag = False
    # part_full_precision_flag = False

    all_same_act_bit_flag = False

    if '_Act' in args.ILP or 'act' in args.ILP or args.ILP == '':
        # ILP方案有'_Act' or 'act' 表示Act_bit是统一的
        all_same_act_bit_flag = True

    if args.ILP:
        module_act_bit_dict = {}
        for name, m in model.named_modules():
            if 'act' in name:
                if hasattr(m, 'activation_bit') and hasattr(m, 'full_precision_flag'):
                    if all_same_act_bit_flag:
                        # 当Act_bit是统一的情况下 检索到外置Act层 以外置Act层的Act_bit为统一bit
                        share_activation_bit = bit_config.get(name)
                        break
                    else:
                        module_act_bit = bit_config.get(name)
                        module_act_bit_dict[name] = module_act_bit
                        if module_act_bit is None:
                            for n, v in bit_config.items():
                                if n in name:
                                    module_act_bit_dict[name] = bit_config.get(
                                        n)

                    # 用 .get() 取出外置Act层的bit数 作为统一Act_bit

        log.logger.info(
            f'module_act_bit_dict:{len(module_act_bit_dict.keys())}\n{module_act_bit_dict}')

    for name, m in model.named_modules():
        setattr(m, 'ownname', name)
        need_quant = False

        if (name.startswith("transformer.encoder") or name.startswith("detr.transformer.encoder")) \
                and "quant" in name:

            need_quant = True

        if (name.startswith("transformer.decoder") or name.startswith("detr.transformer.decoder")) \
                and "quant" in name:

            need_quant = True

        if name.startswith("backbone"):
            # 不加segmentation
            real_name = name
            name = name.replace("backbone.0.", "")
        elif name.startswith("detr.backbone"):
            # 加segmentation
            real_name = name
            name = name.replace("detr.backbone.0.", "")
        else:
            real_name = name

        if name.startswith("bbox_attention") and "quant" in name:
            need_quant = True
        if name.startswith("mask_head") and "quant" in name:
            need_quant = True

        if name in bit_config.keys() or need_quant:
            print("[QUANT]", real_name)
            log.logger.info("[QUANT]  {}".format(real_name))
            name_counter += 1
            setattr(m, 'quant_mode', 'symmetric')
            setattr(m, 'bias_bit', args.bias_bit)
            setattr(m, 'quantize_bias', (args.bias_bit != 0))
            setattr(m, 'per_channel', args.channel_wise)
            setattr(m, 'act_percentile', args.act_percentile)
            setattr(m, 'act_range_momentum', args.act_range_momentum)
            setattr(m, 'weight_percentile', args.weight_percentile)
            setattr(m, 'fix_flag', False)
            setattr(m, 'fix_BN', args.fix_BN)
            setattr(m, 'fix_BN_threshold', args.fix_BN_threshold)
            setattr(m, 'training_BN_mode', args.fix_BN)
            setattr(m, 'checkpoint_iter_threshold', args.checkpoint_iter)
            setattr(m, 'fixed_point_quantization',
                    args.fixed_point_quantization)

            if type(bit_config.get(name, None)) is tuple:
                bitwidth = bit_config[name][0]
                if bit_config[name][1] == 'hook':
                    m.register_forward_hook(hook_fn_forward)
                    global hook_keys
                    hook_keys.append(name)
            else:
                bitwidth = bit_config.get(name, 8)  # 字典中要查找的键

            if hasattr(m, 'activation_bit'):
                setattr(m, 'activation_bit', bitwidth)
                if bitwidth == 4:
                    setattr(m, 'quant_mode', 'asymmetric')
            else:
                setattr(m, 'weight_bit', bitwidth)
        else:
            print('[FP==>]', real_name)
            log.logger.info('[FP==>]: {}'.format(real_name))

        if args.ILP:
            if 'QuantAct' in str(m).split('(', 1)[0]:
                # 如果模块m中是内置QuantAct层
                if all_same_act_bit_flag:
                    # 且 此时 all_same_act_bit_flag 标志位 为True
                    if hasattr(m, 'activation_bit') and hasattr(m, 'full_precision_flag'):
                        m.activation_bit = share_activation_bit
                        if share_activation_bit == 0:
                            # 0 bit 表示全精度
                            m.full_precision_flag = True
                else:
                    # 不是所有模块的Act_bit都一致
                    if hasattr(m, 'activation_bit') and hasattr(m, 'full_precision_flag'):
                        if name in module_act_bit_dict.keys():
                            m.activation_bit = module_act_bit_dict[name]
                        else:
                            for n, v in module_act_bit_dict.items():
                                # if n == name:
                                #     m.activation_bit = module_act_bit_dict[name]
                                if name.split(".")[0:4] == n.split(".")[0:4]:
                                    # 让同一个模块里的Act_bit 都设置成一致
                                    m.activation_bit = module_act_bit_dict[n]

                        if m.activation_bit == 0:
                            # 0 bit 表示全精度
                            m.full_precision_flag = True

    # if args.masks:
    #     # 为segmentation model
    #     print(model_seg)
    #     log.logger.info(model_seg)
    #     print("-"*80)
    #     return model_seg
    # else:
    #     # 非segmentation
    print(model)
    log.logger.info(model)
    print("-"*80)
    return model


def letsquant_teacher(model, log, args=None):
    if args is None:
        raise NotImplementedError
    quant_class_embed = args.quant_class_embed
    quant_bbox_embed = args.quant_bbox_embed
    quant_input_proj = args.quant_input_proj
    quant_backbone = args.quant_backbone
    quant_encoder = args.quant_encoder
    quant_decoder = args.quant_decoder
    transformer = getattr(model, 'transformer')
    # quant others
    if quant_class_embed:
        class_embed = getattr(model, 'class_embed')
        setattr(model, 'class_embed', Q_class_embed(class_embed))
    if quant_bbox_embed:
        bbox_embed = getattr(model, 'bbox_embed')
        setattr(model, 'bbox_embed', Q_bbox_embed(bbox_embed))
    if quant_encoder:
        # quant encoder
        # getattr(transformer,'encoder.layers')
        encoder_layers = transformer.encoder.layers
        for idx, module in enumerate(encoder_layers):
            Q_enconder_layer = Q_TransformerEncoderLayer(module)
            setattr(model.transformer.encoder.layers,
                    f"{idx}", Q_enconder_layer)
    if quant_decoder:
        # quant decoder
        decoder_layers = transformer.decoder.layers
        for idx, module in enumerate(decoder_layers):
            if args.use_dn == True:
                Q_decoder_layer = Q_DN_TransformerDecoderLayer(module, idx)
                setattr(model.transformer.decoder.layers,
                        f"{idx}", Q_decoder_layer)
            else:
                Q_decoder_layer = Q_TransformerDecoderLayer(module)
                setattr(model.transformer.decoder.layers,
                        f"{idx}", Q_decoder_layer)
    if quant_backbone:
        # quant backbone
        backbone_without_pos = getattr(model, 'backbone')[0]
        if quant_backbone:
            if args.backbone == 'resnet101':
                Q_backbone_without_pos = Q_ResNet101_detr(backbone_without_pos)
            else:
                Q_backbone_without_pos = Q_ResNet50_detr(backbone_without_pos)
            setattr(model.backbone, '0', Q_backbone_without_pos)
    if quant_input_proj:
        # quant input_proj
        input_proj = getattr(model, 'input_proj')
        q_input_proj = Q_input_proj(input_proj)
        setattr(model, 'input_proj', q_input_proj)

    # after setup q_model, set quant config
    print("-"*80)
    if args.teacher_Qmodel_ILP:
        ILPbit_config_dict = read_ILPbit_config_dict_json(
            args.teacher_Qmodel_ILP)
        bit_config = ILPbit_config_dict["bit_config_" +
                                        args.backbone + "_" + args.teacher_Qmodel_scheme]
    else:
        bit_config = bit_config_dict["bit_config_" +
                                     args.backbone + "_" + args.teacher_Qmodel_scheme]
    print("Load configuration:", "bit_config_" +
          args.backbone + "_" + args.teacher_Qmodel_scheme)
    print("bit_config:", bit_config)
    log.logger.info("Load configuration: bit_config_{}_{}".format(
        args.backbone, args.teacher_Qmodel_scheme))
    log.logger.info("bit_config:{}".format(bit_config))
    print("-"*80)
    name_counter = 0

    for name, m in model.named_modules():
        need_quant = False
        if name.startswith("transformer.encoder") and "quant" in name:
            need_quant = True
        if name.startswith("transformer.decoder") and "quant" in name:
            need_quant = True
        if name.startswith("backbone"):
            real_name = name
            name = name.replace("backbone.0.", "")
        else:
            real_name = name

        if name in bit_config.keys() or need_quant:
            print("[QUANT]", real_name)
            log.logger.info("[QUANT]  {}".format(real_name))
            name_counter += 1
            setattr(m, 'quant_mode', 'symmetric')
            setattr(m, 'bias_bit', args.bias_bit)
            setattr(m, 'quantize_bias', (args.bias_bit != 0))
            setattr(m, 'per_channel', args.channel_wise)
            setattr(m, 'act_percentile', args.act_percentile)
            setattr(m, 'act_range_momentum', args.act_range_momentum)
            setattr(m, 'weight_percentile', args.weight_percentile)
            setattr(m, 'fix_flag', False)
            setattr(m, 'fix_BN', args.fix_BN)
            setattr(m, 'fix_BN_threshold', args.fix_BN_threshold)
            setattr(m, 'training_BN_mode', args.fix_BN)
            setattr(m, 'checkpoint_iter_threshold', args.checkpoint_iter)
            setattr(m, 'fixed_point_quantization',
                    args.fixed_point_quantization)

            if type(bit_config.get(name, None)) is tuple:
                bitwidth = bit_config[name][0]
                if bit_config[name][1] == 'hook':
                    m.register_forward_hook(hook_fn_forward)
                    global hook_keys
                    hook_keys.append(name)
            else:
                bitwidth = bit_config.get(name, 8)

            if hasattr(m, 'activation_bit'):
                setattr(m, 'activation_bit', bitwidth)
                if bitwidth == 4:
                    setattr(m, 'quant_mode', 'asymmetric')
            else:
                setattr(m, 'weight_bit', bitwidth)
        else:
            print('[FP==>]', real_name)
            log.logger.info('[FP==>]: {}'.format(real_name))

    print(model)
    log.logger.info(f'teacher_Quan_model:\n{model}')
    print("-"*80)
    return model
