import torch
import torch.nn as nn
import torch.nn.functional as F
import cm_feeder_layer_cuda

# gradients in the backward are received in the order of tensor as they were output in forward function
class CMFeederOperator(torch.autograd.Function):
    @staticmethod
    def forward(ctx, n_modules: int, input: torch.Tensor):
        outputs = cm_feeder_layer_cuda.forward(n_modules, input)
        ctx.n_modules = n_modules
        return outputs[0]

    @staticmethod
    def backward(ctx, out_grad):
        n_modules = ctx.n_modules
        input_grad = cm_feeder_layer_cuda.backward(n_modules, out_grad)
        return None, input_grad[0]


class CMFeederLayer(torch.nn.Module):
    def __init__(self, n_modules):
        super(CMFeederLayer, self).__init__()
        self.n_modules = int(n_modules)

    def forward(self, input):
        return CMFeederOperator.apply(self.n_modules, input)
