from functions import *
from torch.autograd import grad
import torch.nn.init as init

class DA_Critic(nn.Module):

    def __init__(self, latent_size = 512, Nc = 10):
        super(DA_Critic, self).__init__()
        self.fc1 = nn.Linear(latent_size + (2*(Nc+1)), latent_size)
        self.fc2 = nn.Linear(latent_size, latent_size)
        self.fc3 = nn.Linear(latent_size, 1)
        self.weight_init()
        pass

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = torch.sigmoid(x)

        return x

    def weight_init(self):
        for block in self._modules:
            kaiming_init(self._modules[block])

def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.kaiming_normal(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)

def h(critic, h_s, h_t):
    ''' Gradeitnt penalty approach'''
    alpha = torch.rand(h_s.size(0), 1).cuda()
    differences = h_t - h_s
    interpolates = h_s + (alpha * differences)
    # interpolates = torch.cat([interpolates, h_s, h_t]).requires_grad_()
    interpolates.requires_grad_()
    preds = critic(interpolates)
    gradients = grad(preds, interpolates,
                     grad_outputs=torch.ones_like(preds),
                     retain_graph=True, create_graph=True)[0]
    gradient_norm = gradients.norm(2, dim=1)
    GP = ((gradient_norm - 1)**2).mean()

    return GP

def Make_DA_Critic_Model(args, Nc=10):
    GMM_model = DA_Critic(Nc=Nc)

    return GMM_model
