import torch as t
import torch.nn as nn

from algs.utils.penultimate import Penultimate


class GradMulConst(t.autograd.Function):
    """ This layer is used to create an adversarial loss.
    """
    @staticmethod
    def forward(ctx, x, const):
        ctx.const = const
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output * ctx.const, None


def grad_mul_const(x, const):
    return GradMulConst.apply(x, const)


class new_fc(nn.Module):
    def __init__(self,model,device, num_label, fix_grad = False):
        super(new_fc,self).__init__()
        self.fc = nn.Linear(model.state_dict()['fc.weight'].size()[1],num_label).to(device)
        self.pen = Penultimate(model=model,break_layers='fc')
        self.fix_grad = fix_grad

    def forward(self,data):
        feature = self.pen(data)
        if self.fix_grad:
            feature = grad_mul_const(feature, 0.0)
        pred = self.fc(feature)
        
        return pred