import torch
import torch.nn as nn
import torch.nn.functional as F

import facsfbs_layer_cuda


class FACSFBSOperator(torch.autograd.Function):
    @staticmethod
    def forward(ctx, n_topk: int, input: torch.Tensor, topk_score: torch.Tensor, indices: torch.Tensor):
        outputs = facsfbs_layer_cuda.forward(n_topk, input, topk_score, indices)
        ctx.save_for_backward(input, topk_score, indices)
        ctx.n_topk = n_topk
        return outputs[0]

    @staticmethod
    def backward(ctx, out_grad):
        input,  topk_score, indices = ctx.saved_tensors
        n_topk = ctx.n_topk
        input_grad, fusion_prob_grad = facsfbs_layer_cuda.backward(n_topk, input, topk_score, indices, out_grad)
        return None, None, input_grad, topk_score


class FACSFBSLayer(torch.nn.Module):
    def __init__(self):
        super(FACSFBSLayer, self).__init__()

    def forward(self, n_topk, input, topk_score, indices):
        return FACSFBSOperator.apply( n_topk, input, topk_score, indices)
