import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F


# The whole model
class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc1_drop = nn.Dropout(0.2)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc2_drop = nn.Dropout(0.2)
        self.fc3 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        representation = self.fc1_drop(out)
        out = self.fc2(representation)
        out = self.relu(out)
        out = self.fc2_drop(out)
        logit = self.fc3(out)
        return logit, representation


# The classification head
class FC(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(FC, self).__init__()
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc2_drop = nn.Dropout(0.2)
        self.fc3 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        out = self.fc2(x)
        out = self.relu(out)
        out = self.fc2_drop(out)
        logit = self.fc3(out)
        return logit


# The TwoLayerFC classification head
class TwoLayerFC(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(TwoLayerFC, self).__init__()
        self.fc = nn.Linear(hidden_size, 100)
        self.fc_new = nn.Linear(100, num_classes)

    def forward(self, x):
        x = self.fc(x)
        logit = self.fc_new(x)
        return logit



class Adversarial_head(nn.Module):
    def __init__(self):
        super(Adversarial_head, self).__init__()
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(3, 2)
        self.c = torch.tensor(1.0, requires_grad = True)
        self.softmax = nn.Softmax()

    def forward(self, pred_logits, targets, attribute_labels):

        # print(pred_logits.size(), targets.size(), attribute_labels.size())
        s = self.softmax((1+ self.c) * pred_logits)
        # print(s.size(), s[torch.tensor(range(len(targets))), targets].view(-1,1).size())
        adv_input = torch.cat((s[torch.tensor(range(len(targets))), targets].view(-1,1), s), axis= 1)
        adv_logits = self.fc(adv_input)
        # adv_labels = self.softmax(adv_logits)

        return adv_logits

