import torch
import torch.nn.intrinsic.qat as nniqat
from torch.fx import GraphModule, Node
from torch import fx, nn
from torch.nn import Module
from copy import deepcopy

USE_LINK = False
USE_DDP = False

__all__ = ['ptq_reconstruction']

try:
    import spring.linklink as link
    if not link.is_initialized():
        link.initialize()
    USE_LINK = True
except (ModuleNotFoundError, AssertionError):
    import torch.distributed as dist
    if torch.distributed.is_initialized():
        USE_DDP = True

import numpy as np
from typing import List

# from mqbench.utils.logger import logger
from mqbench.utils.hook import DataSaverHook, StopForwardException
from mqbench.utils import deepcopy_graphmodule, deepcopy_mixedmodule, topology_order, getitem2node
from mqbench.utils.utils import _fix_succ_recursivly
from mqbench.utils.state import enable_quantization, disable_all
import mqbench.nn.intrinsic.qat as qnniqat

_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear)
_FUSED_TYPE = (nniqat.ConvBnReLU2d, nniqat.ConvBn2d, qnniqat.ConvFreezebn2d, qnniqat.ConvFreezebnReLU2d)
_WEIGHTS_MODULE_TYPE = (torch.nn.Conv2d, torch.nn.Linear)

def node2modules(name2modules, nodes):
    modules = dict()
    for node in nodes:
        if node.target in name2modules:
            modules[node] = name2modules[node.target]
    return modules


def qnode2fpnode(quant_modules, fp32_modules):
    quant_named_nodes = {node.name: node for node in quant_modules}  # NOTE 用target势必导致覆盖。不适用于共享头机制。改成name就正常了
    fp32_named_nodes = {node.name: node for node in fp32_modules}
    qnode2fpnode_dict = {quant_named_nodes[key]: fp32_named_nodes[key] for key in quant_named_nodes}
    return qnode2fpnode_dict

def layer_has_weights(nodes, modules):
    has_weights = False
    for node in nodes:
        if node in modules:
            if isinstance(modules[node], _WEIGHTS_MODULE_TYPE):
                has_weights = True
                break 
    return has_weights


def lp_loss(pred, tgt, p=2.0):
    """
    loss function measured in L_p Norm
    """
    return (pred - tgt).abs().pow(p).sum(1).mean()


def to_device(data, device='cpu'):
    if isinstance(data, torch.Tensor):
        return data.to(device)
    elif isinstance(data, dict):
        for key in data:
            data[key] = to_device(data[key], device)
        return data
    elif isinstance(data, list):
        for idx, _ in enumerate(data):
            data[idx] = to_device(data[idx], device)
        return data
    else:
        return data


def tensor_detach(data):
    if isinstance(data, torch.Tensor):
        return data.detach()
    elif isinstance(data, dict):
        for key in data:
            data[key] = tensor_detach(data[key])
        return data
    elif isinstance(data, list):
        data = [tensor_detach(dat) for dat in data]
    else:
        return data


def save_inp_oup_data(model: GraphModule, inp_module: Module, oup_module: Module, cali_data: list, store_inp=True, store_oup=True,
                      keep_gpu: bool = True, max_count = 1):
    """
    Save input data and output data of a particular layer/block over calibration dataset.
    :param fp_model: fp_model
    :param quant_model: quant_model
    :param cali_data: calibration data (one batch)
    :param keep_gpu: put saved data on GPU for faster optimization
    :return: input and output data
    """  # TODO 处理一下cached的格式
    device = next(model.parameters()).device
    if store_inp:
        assert inp_module is not None
        inp_saver = DataSaverHook(store_input=store_inp, store_output=False, stop_forward=(not store_oup), max_count = max_count)
        inp_handle = inp_module.register_forward_hook(inp_saver)
    if store_oup:
        assert oup_module is not None
        oup_saver = DataSaverHook(store_input=False, store_output=store_oup, stop_forward=True, max_count = max_count)
        oup_handle = oup_module.register_forward_hook(oup_saver)
    # cached = ([], [])
    inps = None
    oups = None
    with torch.no_grad():  # 说明就是eval模式。
        try:
            result = model(return_loss=False, rescale=True, **cali_data)
            # _ = model(to_device(batch, device))
        except StopForwardException:
            pass  # 就是会引发Stop。想的是拿到特定输出后就结束。快速推理嘛！
        if store_inp:
            if keep_gpu:
                inps = [tensor_detach(inp) for inp in inp_saver.input_store]
            else:
                inps = [to_device(tensor_detach(inp), 'cpu') for inp in inp_saver.input_store]
        if store_oup:
            if keep_gpu:
                oups = tensor_detach(oup_saver.output_store)
            else:
                oups = to_device(tensor_detach(oup_saver.output_store), 'cpu')
        # batch结束后hook的count置零。
        if store_inp:
            inp_saver.reset_counter()
        if store_oup:
            oup_saver.reset_counter()
    if store_inp:
        inp_handle.remove()
    if store_oup:
        oup_handle.remove()
    torch.cuda.empty_cache()
    return (inps, oups)


class LinearTempDecay:
    def __init__(self, t_max=10000, warm_up=0.2, start_b=20, end_b=2):
        self.t_max = t_max
        self.start_decay = warm_up * t_max
        self.start_b = start_b
        self.end_b = end_b

    def __call__(self, t):
        if t < self.start_decay:
            return self.start_b
        elif t > self.t_max:  # TODO debug看一下这会不会进去。好像是不会进去的。不进去的话就和PD quant一致。
            return self.end_b
        else:
            rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
            return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))


class CosineTempDecay:
    def __init__(self, t_max=10000, warm_up=0.2, start_b=20, end_b=2):
        self.t_max = t_max
        self.start_decay = warm_up * t_max
        self.start_b = start_b
        self.end_b = end_b

    def __call__(self, t):
        if t < self.start_decay:
            return self.start_b
        elif t > self.t_max:
            return self.end_b
        else:
            rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
            return self.end_b + 0.5 * (self.start_b - self.end_b) * (1 + np.cos(rel_t * np.pi))

class PD_LossFunction_draft:
    r'''loss function to calculate mse reconstruction loss and relaxation loss
    use some tempdecay to balance the two losses.
    '''
    def __init__(self,
                 subgraph: Module,
                 round_loss: str = 'relaxation',
                 weight: float = 1.,
                 rec_loss: str = 'mse',
                 max_count: int = 10000,
                 b_range: tuple = (20, 2),
                 decay_start: float = 0.0,  # ----？
                 warm_up: float = 0.0,
                 p: float = 2.,
                 lam: float = 1.0,  # ----？
                 T: float = 7.0,  # ----？
                 ):

        self.subgraph = subgraph
        self.weight = weight
        self.loss_start = max_count * warm_up
        self.p = p
        
        self.round_loss = round_loss
        self.rec_loss = rec_loss
        self.lam = lam
        self.T = T
        

        self.temp_decay = LinearTempDecay(max_count, warm_up=warm_up + (1 - warm_up) * decay_start,  # TODO check
                                          start_b=b_range[0], end_b=b_range[1])
        self.count = 0
        self.pd_loss = torch.nn.KLDivLoss(reduction='batchmean')  # ----

    def __call__(self, pred, tgt, output, output_fp):
        """
        Compute the total loss for adaptive rounding:
        rec_loss is the quadratic output reconstruction loss, round_loss is
        a regularization term to optimize the rounding policy, pd_loss is the 
        prediction difference loss.

        :param pred: output from quantized model
        :param tgt: output from FP model
        :param output: prediction from quantized model
        :param output_fp: prediction from FP model
        :return: total loss function
        """
        from global_placeholder import logger
        self.count += 1
        if self.rec_loss == 'mse':
            rec_loss = lp_loss(pred, tgt, p=self.p)
        else:
            raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss))
        
        pd_loss = self.pd_loss(F.log_softmax(output / self.T, dim=1), F.softmax(output_fp / self.T, dim=1)) / self.lam


        b = self.temp_decay(self.count)
        if self.count < self.loss_start or self.round_loss == 'none':
            b = round_loss = 0  # TODO 这里的b本来就是这样的吗？？0？ check
        elif self.round_loss == 'relaxation':
            round_loss = 0
            for layer in self.subgraph.modules():
                if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
                    round_vals = layer.weight_fake_quant.rectified_sigmoid()  # == .get_soft_targets()?
                    round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()
        else:
            raise NotImplementedError
        
        total_loss = rec_loss + round_loss + pd_loss
        if self.count % 2500 == 0:
            logger.info('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
                float(total_loss), float(rec_loss), float(pd_loss), float(round_loss), b, self.count))
        return total_loss


class LossFunction:
    r'''loss function to calculate mse reconstruction loss and relaxation loss
    use some tempdecay to balance the two losses.
    '''
    def __init__(self,
                 subgraph: Module,
                 weight: float = 1.,
                 max_count: int = 10000,
                 b_range: tuple = (20, 2),
                 warm_up: float = 0.0,
                 p: float = 2.):

        self.subgraph = subgraph
        self.weight = weight
        self.loss_start = max_count * warm_up
        self.p = p

        self.temp_decay = LinearTempDecay(max_count, warm_up=warm_up,
                                          start_b=b_range[0], end_b=b_range[1])
        self.count = 0

    def __call__(self, pred, tgt):
        """
        Compute the total loss for adaptive rounding:
        rec_loss is the quadratic output reconstruction loss, round_loss is
        a regularization term to optimize the rounding policy

        :param pred: output from quantized model
        :param tgt: output from FP model
        :return: total loss function
        """
        from global_placeholder import logger
        self.count += 1
        rec_loss = lp_loss(pred, tgt, p=self.p)

        b = self.temp_decay(self.count)
        if self.count < self.loss_start:
            round_loss = 0
        else:
            round_loss = 0
            for layer in self.subgraph.modules():
                if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
                    round_vals = layer.weight_fake_quant.rectified_sigmoid()
                    round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()

        total_loss = rec_loss + round_loss
        if self.count % 2500 == 0:
            logger.info('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
                float(total_loss), float(rec_loss), float(round_loss), b, self.count))
        return total_loss

# class PD_LossFunction:
#     def __init__(self,
#                  block: BaseQuantBlock,  # ----
#                  round_loss: str = 'relaxation',  # ----？
#                  weight: float = 1.,
#                  rec_loss: str = 'mse',  # ----？
#                  max_count: int = 2000,
#                  b_range: tuple = (10, 2),
#                  decay_start: float = 0.0,  # ----？
#                  warmup: float = 0.0,
#                  p: float = 2.,
#                  lam: float = 1.0,  # ----？
#                  T: float = 7.0):  # ----？

#         self.block = block
#         self.round_loss = round_loss
#         self.weight = weight
#         self.rec_loss = rec_loss
#         self.loss_start = max_count * warmup
#         self.p = p
#         self.lam = lam
#         self.T = T

#         self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start,
#                                           start_b=b_range[0], end_b=b_range[1])
#         self.count = 0
#         self.pd_loss = torch.nn.KLDivLoss(reduction='batchmean')  # ----

#     def __call__(self, pred, tgt, output, output_fp):
#         """
#         Compute the total loss for adaptive rounding:
#         rec_loss is the quadratic output reconstruction loss, round_loss is
#         a regularization term to optimize the rounding policy, pd_loss is the 
#         prediction difference loss.

#         :param pred: output from quantized model
#         :param tgt: output from FP model
#         :param output: prediction from quantized model
#         :param output_fp: prediction from FP model
#         :return: total loss function
#         """
#         self.count += 1
#         if self.rec_loss == 'mse':
#             rec_loss = lp_loss(pred, tgt, p=self.p)
#         else:
#             raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss))

#         pd_loss = self.pd_loss(F.log_softmax(output / self.T, dim=1), F.softmax(output_fp / self.T, dim=1)) / self.lam

#         b = self.temp_decay(self.count)
#         if self.count < self.loss_start or self.round_loss == 'none':
#             b = round_loss = 0
#         elif self.round_loss == 'relaxation':
#             round_loss = 0
#             for name, module in self.block.named_modules():
#                 if isinstance(module, QuantModule):
#                     round_vals = module.weight_quantizer.get_soft_targets()
#                     round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()
#         else:
#             raise NotImplementedError

#         total_loss = rec_loss + round_loss + pd_loss
#         if self.count % 500 == 0:
#             print('Total loss:\t{:.3f} (rec:{:.3f}, pd:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
#                 float(total_loss), float(rec_loss), float(pd_loss), float(round_loss), b, self.count))
#         return total_loss


def _flatten_args(node):
    flattned_args = []
    if isinstance(node, dict):
        for v in node.values():
            flattned_args.extend(_flatten_args(v))
    elif isinstance(node, tuple) or isinstance(node, list):
        for n in node:
            flattned_args.extend(_flatten_args(n))
    else:
        flattned_args.extend([node])
    return flattned_args


def find_used_times(nodes, target):
    used = len([_node for _node in target.users if _node in nodes])    
    return used




def find_cur_node(layer_node_list):
    node_list = []
    used_later = []
    for idx, node in enumerate(layer_node_list):
        for _node in layer_node_list[idx + 1:]:
            if node in _flatten_args(_node.args):
                used_later.append(node)
                break
    not_used_later = [node for node in layer_node_list if node not in used_later]
    single_branch = dict()
    for node in not_used_later:
        single_branch[node] = set([node])
        q = [node]
        while True:
            now_args = sum([_flatten_args(_node.args) for _node in q], [])
            p = [_node for _node in now_args if isinstance(_node, torch.fx.Node) and find_used_times(layer_node_list, _node) == 1]
            single_branch[node] = single_branch[node].union(set(p))
            if len(p) == 0:
                break
            else:
                q = p
    for node in layer_node_list:
        if node.op == 'call_function' or node.op == 'call_method':
            continue
        if node not in used_later:
            break
    unwanted = set()
    for key in single_branch:
        if key is node:
            continue 
        else:
            unwanted = unwanted.union(single_branch[key])
    layer_node_list = [_node for _node in layer_node_list if _node not in unwanted]
    for _node in layer_node_list:
        node_list.append(_node)
        if _node is node:
            return node_list


def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config, a_opt, w_opt, loss_func, w_para, w_scheduler, a_scheduler):
    global USE_LINK
    global USE_DDP

    if config.prob < 1.0:
        num_args = len(cached_inps[0])  # 子图输入端口数
    else:
        num_args = len(cached_inps)
    
    if any([USE_DDP, USE_LINK]):
        world_size = link.get_world_size() if USE_LINK else dist.get_world_size()
    else:
        world_size = 1
        
    cur_args = []
    for a in range(num_args):
        if config.prob < 1.0:
            cur_inp = cached_inps[0][a]
            cur_sym = cached_inps[1][a]
            cur_inp = torch.where(torch.rand_like(cur_inp) < config.prob, cur_inp, cur_sym)
        else:
            cur_inp = cached_inps[a]
        cur_args.append(cur_inp)
    cur_args = tuple(cur_args)
    cur_out = cached_oups
        
    if a_opt:
        a_opt.zero_grad()
    w_opt.zero_grad()
    out_quant = subgraph(*cur_args)
    err = loss_func(out_quant, cur_out)
    err /= world_size
    err.backward()
    if world_size > 1:
        for param in w_para:
            if USE_LINK:
                link.allreduce(param.grad.data)
            elif USE_DDP:
                dist.all_reduce(param.grad.data)
    w_opt.step()
    if a_opt:
        a_opt.step()
    if w_scheduler:
        w_scheduler.step()
    if a_scheduler:
        a_scheduler.step()
        
    del err
    del out_quant
    del cur_args
    torch.cuda.empty_cache()
    
    

def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], output: fx.Node, g2node: dict, bare_struct_inp_flag: bool):
    """
    Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
    """
    new_graph = fx.Graph()
    env = dict()
    
    if bare_struct_inp_flag:
        node = nodes[0]
        for arg in _flatten_args(node.args):
            if isinstance(arg, torch.fx.Node):
                nodes  = [arg] + nodes
                
                
    inp_lst = []
    for node in nodes:  # 这一步就是想新建输入，名字还是第一个的node的
        for arg in _flatten_args(node.args):
            if isinstance(arg, torch.fx.Node):
                if arg not in nodes and arg not in inp_lst:
                    inp_lst.append(node)
                    if node in g2node:
                        arg_name = g2node[node].name
                    else:
                        arg_name = node.name
                    new_node = new_graph.placeholder(arg_name)
                    env[node] = new_node
                    break
    for node in nodes:
        if node in inp_lst:
            continue
        if node in g2node:
            node = g2node[node]
        new_node = new_graph.node_copy(node, lambda x: env[x])
        env[node] = new_node
    # create this or there will not be return value
    new_graph.output(env[output])
    new_graph.lint()
    return fx.GraphModule(orig_module, new_graph)

def find_num_nodes(nodes):
    num = 0
    for node in nodes:
        if isinstance(node, Node):
            num += 1
    return num


# Recommend: log this to check if the layer is right. You can define your own layer manually or automatically like this
# extract the linked-list/single-chain
def extract_layer(node, fp32_modules): # 单纯的就是取该conv的后续单分支关联层。遇到下一个conv就停了。block指的就是分支节点处。
    from global_placeholder import logger
    layer_node_list = []
    cur_node = node
    is_next_block = False  # check whether stoped by a block
    while True:
        logger.debug('cur_node in layer is {}'.format(cur_node))
        layer_node_list.append(cur_node)  # valid node here
        stop = (len(cur_node.users) == 0)
        for user in cur_node.users:
            if user.target == 'update':
                continue
            if user.op == 'call_module' and isinstance(
                    fp32_modules[user], _ADAROUND_SUPPORT_TYPE):
                stop = True
            # TODO: only short-cut here, consider more here
            # TODO: can also use un/completed to check here.
            if ('add' in user.name
                    and user.op in ['call_function', 'call_method']):
                stop = True
            if user.op == 'output':
                is_next_block, stop = True, True
        if stop:
            break
        cur_node = list(cur_node.users.keys())[0]
    if find_num_nodes(cur_node.users) > 1:
        is_next_block = True
    return layer_node_list, is_next_block


# Recommend: log this to check if the block is right. You can define your own block manually or automatically like this
# extract the block one such as short-cut
def extract_block(input_nodes, fp32_modules, depth=0):# is_block成立有两种：一是此时的node已经在分支上了。一是往下搜寻的过程中有node在分支上了。
    from global_placeholder import logger
    if depth > 2:
        # stack 2 or 3 layers for no short-cut structure
        return []
    layer_node_list = []
    is_block = False
    cnt = dict()
    q, p = [], []  # q records the completed node, p records the uncompleted nodes
    cur_node = None
    for input in input_nodes:  # TODO 在这里产生了outputs
        for user in input.users:
            if len(input.users) > 1 and user.target == 'output':
                # NOTE 这里为了避开这种半途输出的情况。即避开outs
                continue
            if user not in cnt:
                cnt[user] = find_num_nodes(user.args)
                if cnt[user] > 1:
                    is_block = True  # 这里就该走
                p.append(user)
            cnt[user] -= 1
            if cnt[user] == 0:
                q.append(user)
                p.remove(user)
    while len(q) != 0:  # TODO 问题就出在这里。然后是因为有outputs！必须得避免这个东西，。得确认一下其他block是不是也是带个downsam在q里
        cur_node = q.pop(0)  # valid node here
        logger.debug('cur node is {}'.format(cur_node))
        if cur_node.target == 'update':
            continue
        if len(p) == 0 and len(q) == 0:
            break  # TODO 唯一的弹出点！！这里有问题。还有个问题是，quantizer呢？？
        layer_node_list.append(cur_node)  # 这里保存的只有layer！不会有什么add，其也不参与判断。
        for user in cur_node.users:
            if user not in cnt:
                cnt[user] = find_num_nodes(user.args)
                if cnt[user] > 1:
                    is_block = True
                p.append(user)  # p\q会出现add（node）
            cnt[user] -= 1
            if cnt[user] == 0:
                q.append(user)
                p.remove(user)
        logger.debug('uncompleted nodes are {}'.format(p))
    if not cur_node:
        return layer_node_list
    exp_nodes, is_next_block = extract_layer(cur_node, fp32_modules)  # 其实就是走完当前block
    if is_block or is_next_block:
        return layer_node_list + exp_nodes
    else:
        return layer_node_list + exp_nodes + extract_block(
            [exp_nodes[-1]], fp32_modules, depth + 1)

def ptq_reconstruction_for_subgraph(fp32_model, quant_model, cali_datas: list, config: dict, graph_module_list: list = None, pattern: str = 'layer'):
    
    # assert model is on cuda
    # if not config.keep_gpu:
    #     cali_data = [to_device(inp, 'cpu') for inp in cali_data]
    '''set state first'''
    from global_placeholder import logger
    device = next(quant_model.parameters()).device  # 换个更恰当的

    # fp32_model = model
    config.pattern = pattern
    fp32_model.eval()  # NOTE fp32_model不需要quant相关的？？？
    if graph_module_list is None:
        raise NotImplementedError
        assert isinstance(fp32_model, torch.fx.GraphModule)
        quant_model = deepcopy_graphmodule(model)
        nodes = list(quant_model.graph.nodes)
        g2node = getitem2node(quant_model)
        fp32_modules = node2modules(dict(fp32_model.named_modules()), fp32_model.graph.nodes)
        quant_modules = node2modules(dict(quant_model.named_modules()), quant_model.graph.nodes)
        topology_order_by_node = topology_order(quant_model)
    else:
        # quant_model = deepcopy_mixedmodule(model, graph_module_list)
        nodes = []
        g2node = dict()
        fp32_modules = dict()
        quant_modules = dict()
        module_name2c = dict()
        topology_order_by_node = {}
        topo_cnt = 0
        for mname in graph_module_list:
            child = getattr(quant_model, mname)
            assert isinstance(child, torch.fx.GraphModule)
            nodes += list(child.graph.nodes)
            g2node.update(getitem2node(child))
        for mname in graph_module_list:
            fp_child = getattr(fp32_model, mname)
            q_child = getattr(quant_model, mname)
            # note: the nodes we use is from the quant model, so build q_node2fp_module, rather than fp2fp.
            fp_modules = node2modules(dict(fp_child.named_modules()), q_child.graph.nodes)
            q_modules = node2modules(dict(q_child.named_modules()), q_child.graph.nodes)
            fp32_modules.update(fp_modules)
            quant_modules.update(q_modules)
            child_topo = topology_order(q_child)
            module_name2c.update(dict.fromkeys(dict(fp_child.named_modules()).keys(), 0))
            for k in child_topo:
                child_topo[k] += topo_cnt
            topology_order_by_node.update(child_topo)
            topo_cnt += len(topology_order_by_node)
    qnode2fpnode_dict = qnode2fpnode(quant_modules, fp32_modules)
    quant_model.eval()
    disable_all(fp32_model)
    enable_quantization(quant_model)
    torch.cuda.empty_cache()
    checked_nodes = dict()
    
    for node in nodes:
        if 'exclude_node_prefix' in config:  # TODO 关键是node.name == top_down_blocks_0_blocks_0_conv2_conv
            cont = False
            for prefix in config['exclude_node']:
                if node.name.startswith(prefix):
                    cont = True
                    break
            if cont:
                logger.info(f'Exclude node {node}')
                continue
        if node in checked_nodes:  # 上次被重建的层们都会属于checked_nodes
            continue
        if node.op == "call_module" and isinstance(quant_modules[node], _ADAROUND_SUPPORT_TYPE):
            logger.info('prepare {} reconstruction for {}'.format(config.pattern, node))
            if config.pattern == 'layer':
                layer_node_list, _ = extract_layer(node, quant_modules)
            elif config.pattern == 'block':
                layer_node_list = extract_block(node.all_input_nodes, quant_modules)
            else:
                raise ValueError('pattern错误')
            # if the update is not used in the block, remove it
            if not all([n.target != 'update' for n in layer_node_list]):
                remove_nodes = []
                for idx, n in enumerate(layer_node_list):
                    if n.target == 'update':
                        src = n.args[0]
                        remove = True
                        for _idx in range(idx + 1, len(layer_node_list)):
                            if src in _flatten_args(
                                    layer_node_list[_idx].args):
                                remove = False
                                break
                        if remove:
                            remove_nodes.append(n)
                layer_node_list = [n for n in layer_node_list if n not in remove_nodes]
            missing_inputs = []
            for _node in layer_node_list:
                for arg in _flatten_args(_node.args):
                    if isinstance(arg, torch.fx.Node) and arg.op == 'call_module':
                        if arg not in layer_node_list and arg not in missing_inputs:
                            # if module_name2c[arg.target] == 1:  # 因为这里会
                            #     module_name2c[arg.target] = module_name2c[arg.target] - 1
                            missing_inputs.append(arg)
            layer_node_list.extend(missing_inputs)  # 这里引进的getitem。大多情况应该是把上个node（quantizer搞进来？）
            # replace getitem nodes into its source node
            layer_node_list = [n if n not in g2node else g2node[n] for n in layer_node_list]
            for _node in layer_node_list:
                src = [arg for arg in _flatten_args(_node.args) if arg in g2node]
                for arg in src:
                    _node.args = _fix_succ_recursivly(_node.args, arg, g2node[arg])
            layer_node_list = sorted(layer_node_list, key=lambda x: topology_order_by_node[x])  # 这里才重新整理顺序
            layer_node_list = find_cur_node(layer_node_list)  # TODO 这是干啥用的
            if layer_has_weights(layer_node_list, quant_modules):
                pass
            else:
                continue
            logger.info('the node list is below!')
            logger.info(layer_node_list)
            
            
            only_one_support_layer_flag = None
            for _node in layer_node_list:
                # 实时统计更新该layer的使用次数
                if _node.target in module_name2c:
                    if ('backbone' in graph_module_list or 'neck' in graph_module_list) and module_name2c[_node.target] == 1:
                        # 这是为了制裁YOLOX的PAFPN结构, 所有的layer就只能标记1次
                        continue
                    module_name2c[_node.target] = module_name2c[_node.target] + 1
                    # 判断是否有且只有一个支持layer
                    if isinstance(quant_modules[_node], _ADAROUND_SUPPORT_TYPE):
                        if only_one_support_layer_flag == None:
                            only_one_support_layer_flag = True
                        else:
                            only_one_support_layer_flag = False
                    
            # 如果第一个layer就是conv，有且只有一个，屎定了,就得谨慎考虑截取的特征图信息。当然不是主动增加输入layer
            bare_struct_inp_flag = (isinstance(quant_modules[layer_node_list[0]], _ADAROUND_SUPPORT_TYPE) and only_one_support_layer_flag)

            
            # NOTE 这里要一直找到最后一个call_module-------
            while layer_node_list[-1].op != 'call_module':
                layer_node_list = layer_node_list[:-1]
            fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]]
            
            if any([USE_DDP, USE_LINK]):
                world_size = link.get_world_size() if USE_LINK else dist.get_world_size()
            else:
                world_size = 1
            logger.info('The world size is {}.'.format(world_size))
            '''start training'''
            logger.info('start tuning by adaround')
            
            
            sz = len(cali_datas)  # sz=cali_num
            
            for i in range(config.max_count):
                
                fp32_all_inps = []
                quant_all_inps = []
                fp32_final_oups = None
                out_is_cached = False
                
                idx = np.random.randint(0, sz)
                cali_data = cali_datas[idx]
                
                # TODO 其实都不需要遍历num_args，也就是说上面的可以简化关键是得到cur_args, cur_out
                
                # 这一步就是求该子结构的整体输入和输出。
                for _node in layer_node_list:  
                    if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]):
                        continue
                    else:
                        # 当成为子结构的整体输入时，应该撤档
                        if _node == layer_node_list[0] and module_name2c[_node.target] == 2: 
                            module_name2c[_node.target] = module_name2c[_node.target] - 1
                            
                        fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]]
                        quant_module = quant_modules[_node]
                        # fp32 inps: [out_b1, out_b2, ...]  得到的都是每一层的输出
                        if bare_struct_inp_flag:
                            fp32_inps, _ = save_inp_oup_data(fp32_model, fp32_inp_module, None, cali_data, 
                                                                store_inp=(config.prob < 1.0), store_oup=False, keep_gpu=config.keep_gpu, max_count=module_name2c[_node.target]) # 子结构的每一层输出。
                            _, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data,
                                                                store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu, max_count=module_name2c[layer_node_list[-1].target])  # 子结构的最后一层，其实就是统计一次而已。
                            quant_inps, _ = save_inp_oup_data(quant_model, quant_module, None, cali_data,
                                                                store_inp=True, store_oup=False, keep_gpu=config.keep_gpu, max_count=module_name2c[_node.target]) # 子结构的每一层输出。
                            # for it, temp in enumerate(fp32_inps):
                            #     fp32_inps[it] = temp[0]
                            # for it, temp in enumerate(quant_inps):
                            #     quant_inps[it] = temp[0]
                            if fp32_inps is None:
                                # 这一步是为了避免非Qdrop的算法问题，因为config.prob = 1则不会产生inps
                                fp32_inps = [fp32_inps]
                            fp32_all_inps.extend(fp32_inps)
                            quant_all_inps.extend(quant_inps)
                        else:
                            _, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data, 
                                                            store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu, max_count=module_name2c[_node.target]) # 子结构的每一层输出。
                            _, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data,
                                                            store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu, max_count=module_name2c[layer_node_list[-1].target])  # 子结构的最后一层，其实就是统计一次而已。
                            _, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data,
                                                            store_inp=False, store_oup=True, keep_gpu=config.keep_gpu, max_count=module_name2c[_node.target]) # 子结构的每一层输出。
                        
                            fp32_all_inps.append(fp32_inps)
                            quant_all_inps.append(quant_inps)
                            
                        # torch.cuda.empty_cache()  # 这个会动态清除cache，但是拖慢进度
                        if not out_is_cached:
                            fp32_final_oups = fp32_oups
                            out_is_cached = True
                
                
                cached_inps = (quant_all_inps, fp32_all_inps) if config.prob < 1.0 else quant_all_inps
                cached_oups = fp32_final_oups
                quant_modules_by_name = dict()
                for node in layer_node_list:
                    if node.op == 'call_module':
                        quant_modules_by_name[node.target] = quant_modules[node]
                if i == 0:
                    # NOTE 这里的subgraph是不是可以一次就好?
                    subgraph = extract_subgraph(quant_modules_by_name, layer_node_list,  # NOTE 有意思的是，subgraph可以重写node的输入输出！
                                                layer_node_list[-1], g2node, bare_struct_inp_flag)
                    logger.info(subgraph.code)
                    
                    w_para, a_para = [], []
                    w_opt, w_scheduler = None, None
                    if hasattr(config, 'scale_lr'):
                        a_para = []
                    for name, layer in subgraph.named_modules():
                        if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
                            weight_quantizer = layer.weight_fake_quant
                            # assert isinstance(weight_quantizer, adaround_quantizer) is True
                            weight_quantizer.init(layer.weight.data, config.round_mode)
                            w_para += [weight_quantizer.alpha]
                        if isinstance(layer, torch.quantization.FakeQuantizeBase) and 'post_act_fake_quantize' in name:
                            if hasattr(config, 'scale_lr'):
                                logger.info('learn the scale for {}'.format(name))
                                a_para += [layer.scale]
                            layer.prob = config.prob
                    if len(a_para) != 0:
                        a_opt = torch.optim.Adam(a_para, lr=config.scale_lr)
                        a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=config.max_count, eta_min=0.)
                    else:
                        a_opt, a_scheduler = None, None
                    w_opt = torch.optim.Adam(w_para)

                    loss_func = LossFunction(subgraph=subgraph, weight=config.weight, max_count=config.max_count, b_range=config.b_range,
                                            warm_up=config.warm_up, p=config.p)
                
                subgraph_reconstruction(subgraph, cached_inps, cached_oups, config, a_opt, w_opt, loss_func, w_para, w_scheduler, a_scheduler)
            
            for name, layer in subgraph.named_modules():        
                if isinstance(layer, _FUSED_TYPE):
                    # We need to do bn fold simulation here.
                    weight_quantizer = layer.weight_fake_quant
                    scale_factor = layer.bn.weight / torch.sqrt(layer.bn.running_var + layer.bn.eps)
                    merged_rounded_weight = weight_quantizer.get_hard_value(
                        layer.weight.data * scale_factor.reshape([-1] + [1] * (len(layer.weight.shape) - 1)))
                    layer.weight.data = merged_rounded_weight / scale_factor.reshape([-1] + [1] * (len(merged_rounded_weight.shape) - 1))
                    weight_quantizer.adaround = False
                elif isinstance(layer, _ADAROUND_SUPPORT_TYPE):
                    assert not hasattr(layer, 'bn'), 'Layer {} with type {} has BN ! Should not reach here.'.format(name, type(layer))
                    weight_quantizer = layer.weight_fake_quant
                    layer.weight.data = weight_quantizer.get_hard_value(layer.weight.data)
                    weight_quantizer.adaround = False
                if isinstance(layer, torch.quantization.FakeQuantizeBase) and 'post_act_fake_quantize' in name:
                    layer.prob = 1.0   # recover to promise that drop activation quantization only occurs at reconstruction phase

            for x in layer_node_list:
                checked_nodes[x] = True
                
    return checked_nodes, quant_modules
    
    

def ptq_reconstruction(fp32_model, quant_model, cali_data: list, config: dict, graph_module_list: list = None):
    r"""
    Reconsturction for AdaRound, BRECQ, QDrop.
    Basic optimization objective:

    .. math::

        \mathop{\arg\min}_{\mathbf{V}}\ \ || Wx-\tilde{W}x ||_F^2 + \lambda f_{reg}(\mathbf{V}),

        \tilde{W}=s \cdot clip\left( \left\lfloor\dfrac{W}{s}\right\rfloor+h(\mathbf{V}), n, p \right)

    where :math:`h(\mathbf{V}_{i,j})=clip(\sigma(\mathbf{V}_{i,j})(\zeta-\gamma)+\gamma, 0, 1)`, and :math:`f_{reg}(\mathbf{V})=\mathop{\sum}_{i,j}{1-|2h(\mathbf{V}_{i,j})-1|^\beta}`. By annealing on :math:`\beta`, the rounding mask can adapt freely in initial phase and converge to 0 or 1 in later phase.

    Args:
        model (torch.nn.Module): a prepared GraphModule to do PTQ
        cali_data (List): a list of calibration tensor
        config (dict): a config for PTQ reconstruction
        graph_module_list (list): a list of model's children modules which need quantization. if this is used, the model is partial quantized; if not, the model is fully quantized.

    >>> sample config : {
            pattern: block (str, Available options are [layer, block].)
            scale_lr: 4.0e-5 (learning rate for learning step size of activation)
            warm_up: 0.2 (0.2 * max_count iters without regularization to floor or ceil)
            weight: 0.01 (loss weight for regularization item)
            max_count: 20000 (optimization iteration)
            b_range: [20,2] (beta decaying range )
            keep_gpu: True (calibration data restore in gpu or cpu)
            round_mode: learned_hard_sigmoid (ways to reconstruct the weight, currently only support learned_hard_sigmoid)
            prob: 0.5 (dropping probability of QDROP)
        }

    """
    from global_placeholder import logger
    checked_nodes = dict()
    quant_modules = dict()
    
    recons_patterns = [config.pattern, 'layer', 'layer']  # NOTE 能主动决定backbone的 recons方式
    for submodule, recons_type in zip(graph_module_list, recons_patterns):
    # logger.info(f'注意！！！！现在是特调，只量化head')
    # for submodule, recons_type in zip(['bbox_head'], ['layer']):
    # for submodule, recons_type in zip(['neck'], ['layer']):
        sub_checked_nodes, sub_quant_modules = ptq_reconstruction_for_subgraph(fp32_model, quant_model, cali_data, deepcopy(config), [submodule], recons_type)
        checked_nodes.update(sub_checked_nodes)
        quant_modules.update(sub_quant_modules)
        
    enable_quantization(quant_model)  # TODO 这里到底是disable还是en quant
    for node in checked_nodes:
        if node.op == 'call_module':
            enable_quantization(quant_modules[node])
            logger.info(f'set the node {node.target} in quant')
    return quant_model
