import torch.nn as nn
import torch.nn.functional as F

# base architecture
        
class BaseModel(nn.Module):
    def __init__(self, encoder, classifier, name):
        super(BaseModel, self).__init__()
        self.encoder = encoder
        self.classifier = classifier
        self.name = name # use for saving and loading model

    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)
        return x
    
    def forward_encoder(self, x):
        z = self.encoder(x)
        prob = self.classifier(z)
        return z, prob
    
    def encode(self, x):
        z = self.encoder(x)
        return z

def init_weights(m):
    classname = m.__class__.__name__
    if classname.find("Conv2d") != -1 or classname.find("ConvTranspose2d") != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find("Linear") != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)

class MLP_Classifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=1024):
        super(MLP_Classifier, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, num_classes)
        
    def forward(self, x):
        if x.dim() != 2:
            x = x.view(x.size(0), -1)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x