import torch
from torch.autograd import Function
from .. import src


class PSAMask(Function):
    @staticmethod
    def forward(ctx, input, psa_type=0, mask_H_=None, mask_W_=None):
        assert psa_type in [0, 1]  # 0-col, 1-dis
        assert (mask_H_ is None and mask_W_ is None) or (mask_H_ is not None and mask_W_ is not None)
        num_, channels_, feature_H_, feature_W_ = input.size()
        if mask_H_ is None and mask_W_ is None:
            mask_H_, mask_W_ = 2 * feature_H_ - 1, 2 * feature_W_ - 1
        assert (mask_H_ % 2 == 1) and (mask_W_ % 2 == 1)
        assert channels_ == mask_H_ * mask_W_
        half_mask_H_, half_mask_W_ = (mask_H_ - 1) // 2, (mask_W_ - 1) // 2
        output = torch.zeros([num_, feature_H_ * feature_W_, feature_H_, feature_W_], dtype=input.dtype, device=input.device)
        if not input.is_cuda:
            src.cpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
        else:
            output = output.cuda()
            src.gpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
        ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_ = psa_type, num_, channels_, feature_H_, feature_W_
        ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_ = mask_H_, mask_W_, half_mask_H_, half_mask_W_
        return output

    @staticmethod
    def backward(ctx, grad_output):
        psa_type, num_, channels_, feature_H_, feature_W_ = ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_
        mask_H_, mask_W_, half_mask_H_, half_mask_W_ = ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_
        grad_input = torch.zeros([num_, channels_, feature_H_, feature_W_], dtype=grad_output.dtype, device=grad_output.device)
        if not grad_output.is_cuda:
            src.cpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
        else:
            src.gpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
        return grad_input, None, None, None


psa_mask = PSAMask.apply
