
def add_hooks(model, args):
    from llava.custom_model.custom_decoder import CustomLlamaDecoderLayer
    layer_outputs = []
    hooks = []

    def create_hook(l, layer_name):
        def hook(module, input, output):
            """
            output: (hidden_states, attentions, kv_cache)
            """
            layer_outputs.append(output[0][:, -1:, :].cpu().clone())

        return hook

    layer_count = 0
    for l, layer in enumerate(model.layers):
        if isinstance(layer, CustomLlamaDecoderLayer):
            layer_name = f"layer_{l}"
            hook = layer.register_forward_hook(create_hook(l, layer_name))
            hooks.append(hook)
            layer_count += 1
    # print(f"Registered hooks for {args.data} : {args.model_path} : #{layer_count}")
    return layer_outputs, hooks


def remove_hooks(hooks):
    for hook in hooks:
        hook.remove()
