import torch
import torch.nn as nn
import torch.nn.functional as F
import facs_layer_cuda

# gradients in the backward are received in the order of tensor as they were output in forward function
class FACSOperator(torch.autograd.Function):
    @staticmethod
    def forward(ctx, k_size: int, facs_thresh: float,  input: torch.Tensor, fusion_prob: torch.Tensor):
        outputs = facs_layer_cuda.forward(k_size, facs_thresh, input, fusion_prob)
        ctx.save_for_backward(input, fusion_prob)
        ctx.k_size = k_size
        ctx.facs_thresh = facs_thresh
        return outputs[0]

    @staticmethod
    def backward(ctx, out_grad):
        input, fusion_prob = ctx.saved_tensors
        k_size = ctx.k_size
        facs_thresh = ctx.facs_thresh
        input_grad, fusion_prob_grad = facs_layer_cuda.backward(k_size, facs_thresh, input, fusion_prob, out_grad)
        return None, None, input_grad, fusion_prob_grad


class FACSLayer(torch.nn.Module):
    def __init__(self, k_size, facs_thresh = 0.5):
        super(FACSLayer, self).__init__()
        self.k_size = int(k_size)
        self.facs_thresh = facs_thresh

    def forward(self, input, fusion_prob):
        return FACSOperator.apply(self.k_size, self.facs_thresh, input, fusion_prob)
