import pdb


# -------------------- 第一步：定义接收feature的函数 ---------------------- #
# 这里定义了一个类，类有一个接收feature的函数hook_fun。定义类是为了方便提取多个中间层。
class HookTool: 
    def __init__(self):
        self.fea = None 

    def hook_fun(self, module, fea_in, fea_out):
        '''
        注意用于处理feature的hook函数必须包含三个参数[module, fea_in, fea_out],参数的名字可以自己起,但其意义是固定的
        第一个参数表示torch里的一个子module,比如Linear,Conv2d等,
        第二个参数是该module的输入,其类型是tuple;
        第三个参数是该module的输出,其类型是tensor。
        注意输入和输出的类型是不一样的,切记。
        '''
        global feats
        feats = {"input":fea_in, "output":fea_out}
        self.fea = fea_out
        return self.fea

# ---------- 第二步：注册hook,告诉模型我将在哪些层提取feature -------- #
def get_feas_by_hook(module_key_name,model):
    """
    提取Conv2d后的feature,我们需要遍历模型的module,然后找到Conv2d,把hook函数注册到这个module上;
    这就相当于告诉模型,我要在Conv2d这一层,用hook_fun处理该层输出的feature.
    由于一个模型中可能有多个Conv2d,所以我们要用hook_feas存储下来每一个Conv2d后的feature
    """
    # hook_feas = []
    hook_feas_in = {}
    hook_feas_out = {}
    for name, module in model.named_modules():
        # if isinstance(m, torch.nn.Conv2d):
        if module_key_name in name:
            cur_hook = HookTool()
            module.register_forward_hook(cur_hook.hook_fun)
            module_input = feats["input"]
            module_output = feats["output"]
            # hook_feas.append(cur_hook)
            hook_feas_in[name] = module_input
            hook_feas_out[name] = module_output

    return hook_feas_in,hook_feas_out


feats = {}
def hook(module, input, output):
    global feats
    feats = {"input":input, "output":output}
    return



# detr_act_hook(log,samples,model,'act')
def detr_act_hook(log,img_input,model,layer_key):
    act_layer_name = []
    for name,module in model.named_modules():
        if layer_key in name:
            act_layer_name.append(name)


    act_layer_num = 0
    hook_feas_in = {}
    hook_feas_out = {}
    for name,module in model.named_modules():
        if act_layer_num >= len(act_layer_name):
            break
        
        if name == act_layer_name[act_layer_num]:
            act_layer_num += 1
            log.logger.info(f'第{act_layer_num}层:{name}')
            hh = module.register_forward_hook(hook)
            # hh.remove()
            outputs = model(img_input)


        if name == act_layer_name[act_layer_num-1]:
            log.logger.info(feats)
            # module_input = feats["input"]
            # module_output = feats["output"]
            module_input = feats["input"][0].detach().cpu()
            module_output = feats["output"][0].detach().cpu()
            hook_feas_in[name] = module_input
            hook_feas_out[name] = module_output
        
    print(hook_feas_out)

    module_input_act_dict,module_output_act_dict = get_feas_by_hook('act',model)
    pdb.set_trace()


    return hook_feas_in,hook_feas_out




#  train_one_epoch_teacher
        # ==================================================== hook ====================================================
        
        # act_layer_name = []
        # for name,module in model.named_modules():
        #     if 'act' in name:
        #         act_layer_name.append(name)


        # act_layer_num = 0
        # hook_feas_in = {}
        # hook_feas_out = {}
        # for name,module in model.named_modules():
        #     if act_layer_num >= len(act_layer_name):
        #         break
            
        #     if name == act_layer_name[act_layer_num]:
        #         act_layer_num += 1
        #         log.logger.info(f'第{act_layer_num}层:{name}')
        #         # pdb.set_trace()
        #         hh = module.register_forward_hook(hook)
        #         # hh.remove()
                
        #         outputs = model(samples)


        #     if name == act_layer_name[act_layer_num-1]:
        #         log.logger.info(feats)
        #         # module_input = feats["input"]
        #         # module_output = feats["output"]
        #         module_input = feats["input"][0].detach().cpu()
        #         module_output = feats["output"][0].detach().cpu()
        #         hook_feas_in[name] = module_input
        #         hook_feas_out[name] = module_output
            
        # print(hook_feas_out)
        # # pdb.set_trace()

        # module_input_act_dict,module_output_act_dict = get_feas_by_hook('act',model)
        # pdb.set_trace()

        # ==================================================== hook ====================================================





# teacher model loss

        #------- model_teacher -------
        # loss_dict_teacher_reduced = utils.reduce_dict(loss_dict_teacher)
        # loss_dict_teacher_reduced_unscaled = {f'{k}_unscaled': v
        #                               for k, v in loss_dict_teacher_reduced.items()}
        # loss_dict_teacher_reduced_scaled = {k: v * weight_dict_teacher[k]
        #                             for k, v in loss_dict_teacher_reduced.items() if k in weight_dict_teacher}
        # losses_teacher_reduced_scaled = sum(loss_dict_teacher_reduced_scaled.values())

        # loss_value_teacher = losses_teacher_reduced_scaled.item()


            #------- model_teacher -------
            # print("Loss is {}, stopping training".format(loss_value_teacher))
            # print(loss_dict_teacher_reduced)



        # losses.backward(retain_graph=True)
        # losses_teacher.backward(retain_graph=True)


        # losses_teacher.backward()


        #------- model_teacher -------
        # metric_logger.update(loss=loss_value, **loss_dict_teacher_reduced_scaled, **loss_dict_teacher_reduced_unscaled)
        # metric_logger.update(class_error=loss_dict_teacher_reduced['class_error_teacher'])
        # metric_logger.update(lr=optimizer.param_groups[0]["lr"])



# tensorboard =====>> Weight&Grad
        # ================================== Weight&Grad分析 ==================================
        # folder_path = f"./distribution_result/Act/t_{args.teacher_Qmodel_scheme}&s_{args.quant_scheme}"
        
        # if args.teacher_Qmodel_scheme:
        #     writer_float = SummaryWriter(log_dir=f"{folder_path}/teacher_{args.teacher_Qmodel_scheme}")
        # else:
        #     writer_float = SummaryWriter(log_dir=f"./distribution_result/Act/Float")
        
        # # if args.teacher_Qmodel_scheme:
        # #     writer_quant = SummaryWriter(log_dir=f"{folder_path}/Quant_{args.quant_scheme}")
        # # else:
        # #     writer_quant = SummaryWriter(log_dir=f"./distribution_result/Act/Quant_{args.quant_scheme}")

        # # if args.teacher_Qmodel_scheme:
        # #     writer_diff = SummaryWriter(log_dir=f"{folder_path}/difference")
        # # else:
        # #     writer_diff = SummaryWriter(log_dir=f"./distribution_result/Act/float&{args.quant_scheme}/difference")


        # ind = 0
        # for (name,module),(Qname,Qmodule) in zip(model_teacher.named_modules(),model.named_modules()):
        #     ind += 1
        #     # ================= float =================
        #     # if 'act' in name:
            
        #     for layer_name,layer_output in hook_feas_out.items():
        #         print(layer_output)
        #         writer_float.add_histogram(f'DETR_Act_distribution_第{ind}层：{layer_name}', layer_output[0], 1)
        #     pdb.set_trace()
            
        #     # grad = param.grad
        #     # if grad is None:
        #     #     print(f'第{ind}层：{name} 梯度属性: {param.requires_grad}, 梯度值为: {grad}')
        #     # if param.requires_grad and not grad is None:
        #     #     writer_float.add_histogram(f'DETR_grad_distribution_第{ind}层：{name}', grad, 1)
        #     # writer_float.add_histogram(f'DETR_weight_distribution_第{ind}层：{name}', param, 1)
            
        #     # # ================= Quant =================
        #     # Qgrad = Qparam.grad
        #     # if Qgrad is None:
        #     #     print(f'第{ind}层：{Qname} 梯度属性: {Qparam.requires_grad}, 梯度值为: {Qgrad}')
        #     # if Qparam.requires_grad and not Qgrad is None:
        #     #     writer_quant.add_histogram(f'DETR_grad_distribution_第{ind}层：{name}', Qgrad, 1)
        #     # writer_quant.add_histogram(f'DETR_weight_distribution_第{ind}层：{name}', Qparam, 1)

        #     # diff = torch.abs(param - Qparam)
        #     # writer_diff.add_histogram(f'DETR_weight_distribution_第{ind}层：{name}', diff, 1)

        # writer_float.close()
        # # writer_quant.close()
        # # writer_diff.close()
        # pdb.set_trace()
 


