from torch.autograd import Function
import torch.nn as nn
import sys
sys.path.append('./')
from models.SSHead import *

class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None

class DANNWrapper(nn.Module):
    def __init__(self, net):
        super(DANNWrapper, self).__init__()

        self.feature = extractor_from_layer2(net)

        self.classifier = classifier_from_layer2(net)

        width = 1 # 1 given in the default setup
        self.discriminator = head_on_layer2(net, width, 2)

    def forward(self, input_data, alpha):
        # input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
        feature = self.feature(input_data) # Feature have dimension 32x14x14 under the default setting
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.classifier(feature) 
        domain_output = self.discriminator(reverse_feature)

        return class_output, domain_output
