import torch
import torch.nn as nn
import torch.nn.functional as F


class EmbeddingNet(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_sizes=[512, 512], num_classes=None, use_pred_loss=False ):
        super(EmbeddingNet, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.use_pred_loss = use_pred_loss
        self.hidden_sizes = hidden_sizes
        fcs = [nn.Linear(input_dim, hidden_sizes[0]), nn.Tanh()]
        
        for inputsize, outputsize in zip(hidden_sizes[:-1], hidden_sizes[1:]):
            fcs.append(nn.Linear(inputsize, outputsize))
            fcs.append(nn.Tanh())
        
        fcs.append(nn.Linear(outputsize, embed_dim))
        self.fc = nn.Sequential(*fcs)

        if self.use_pred_loss:
            self.act = nn.PReLU()
            self.fc2lable = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        output = x.view(-1, self.input_dim)
        output = self.fc(output)
        logits = 0
        if self.use_pred_loss:
            logits = self.fc2lable(self.act(output))
        return output, logits
        
    def get_embedding(self, x):
        return self.forward(x)


class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        output1, lg1 = self.embedding_net(x1)
        output2, lg2 = self.embedding_net(x2)
        output3, lg3 = self.embedding_net(x3)
        return (output1, output2, output3), (lg1, lg2, lg3) 

    def get_embedding(self, x):
        outputs, _ = self.embedding_net(x)
        return outputs

    def get_logits(self, x):
        _, logits = self.embedding_net(x)
        return logits

    def save(self, path):
        print("Saving model: {}".format(path))
        torch.save(self.embedding_net, path)
