import torch

import fast_multihead_attn


class FastEncdecAttnFunc(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        use_time_mask,
        is_training,
        heads,
        inputs_q,
        inputs_kv,
        input_weights_q,
        input_weights_kv,
        output_weights,
        pad_mask,
        dropout_prob,
    ):
        heads_t = torch.tensor([heads])
        dropout_prob_t = torch.tensor([dropout_prob])
        null_tensor = torch.tensor([])
        use_mask = pad_mask is not None

        (
            input_lin_q_results,
            input_lin_kv_results,
            softmax_results,
            dropout_results,
            dropout_mask,
            matmul2_results,
            outputs,
        ) = fast_multihead_attn.encdec_multihead_attn_forward(
            use_mask,
            use_time_mask,
            is_training,
            heads,
            inputs_q,
            inputs_kv,
            input_weights_q,
            input_weights_kv,
            output_weights,
            pad_mask if use_mask else null_tensor,
            dropout_prob,
        )

        ctx.save_for_backward(
            heads_t,
            matmul2_results,
            dropout_results,
            softmax_results,
            input_lin_q_results,
            input_lin_kv_results,
            inputs_q,
            inputs_kv,
            input_weights_q,
            input_weights_kv,
            output_weights,
            dropout_mask,
            dropout_prob_t,
        )

        return outputs.detach()

    @staticmethod
    def backward(ctx, output_grads):
        (
            heads_t,
            matmul2_results,
            dropout_results,
            softmax_results,
            input_lin_q_results,
            input_lin_kv_results,
            inputs_q,
            inputs_kv,
            input_weights_q,
            input_weights_kv,
            output_weights,
            dropout_mask,
            dropout_prob_t,
        ) = ctx.saved_tensors

        (
            input_q_grads,
            input_kv_grads,
            input_weight_q_grads,
            input_weight_kv_grads,
            output_weight_grads,
        ) = fast_multihead_attn.encdec_multihead_attn_backward(
            heads_t[0],
            output_grads,
            matmul2_results,
            dropout_results,
            softmax_results,
            input_lin_q_results,
            input_lin_kv_results,
            inputs_q,
            inputs_kv,
            input_weights_q,
            input_weights_kv,
            output_weights,
            dropout_mask,
            dropout_prob_t[0],
        )

        return (
            None,
            None,
            None,
            input_q_grads,
            input_kv_grads,
            input_weight_q_grads,
            input_weight_kv_grads,
            output_weight_grads,
            None,
            None,
        )


fast_encdec_attn_func = FastEncdecAttnFunc.apply
