''' Modified from https://github.com/alinlab/LfF/blob/master/module/mlp.py'''
import torch
import torch.nn as nn

class MLP_DISENTANGLE(nn.Module):
    def __init__(self, num_classes = 10):
        super(MLP_DISENTANGLE, self).__init__()
        self.feature = nn.Sequential(
            nn.Linear(3*28*28, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 16),
            nn.ReLU()
        )
        self.fc = nn.Linear(32, num_classes)

    def reverse_grad(self, lambda_reverse):
        def hook(grad):
            mask = torch.ones(32).to(grad.device)
            mask[:16] = -1 * lambda_reverse
            updated_grad = grad * mask[None,...]
            return updated_grad

        return hook

    def extract(self, x):
        x = x.view(x.size(0), -1) / 255
        feat = self.feature(x)
        return feat

    def predict(self, x):
        prediction = self.classifier(x)
        return prediction

    def forward(self, x, return_feat=False):
        x = x.view(x.size(0), -1) / 255
        feat = x = self.feature(x)
        final_x = self.classifier(x)
        if return_feat:
            return final_x, feat
        else:
            return final_x
    
class MLP(nn.Module):
    def __init__(self, num_classes = 10):
        super(MLP, self).__init__()
        self.f1 = nn.Sequential(
            nn.Linear(3*28*28, 100),
            nn.ReLU(),
            )
        self.f2 = nn.Sequential(
            nn.Linear(100, 100),
            nn.ReLU(),
            )
        self.f3 = nn.Sequential(
            nn.Linear(100, 16),
            nn.ReLU()
        )
            
        self.fc = nn.Linear(16, num_classes)

    def feature(self, x, ):
        x = x.view(x.size(0), -1) / 255
        x = self.f1(x)
        x = self.f2(x)
        x = self.f3(x)
        return x

    def classifier(self, x):
        x = self.fc(x)
        return x

    def forward(self, x, return_feat=False):
        x = self.feature(x)
        final_x = self.classifier(x)

        if return_feat:
            return final_x, x
        else:
            return final_x




class CONV(nn.Module):
    def __init__(self, num_classes = 10):
        super(CONV, self).__init__()
        self.conv1 = nn.Conv2d(3,8,4,1)
        self.bn1 = nn.BatchNorm2d(8)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout()
        self.avgpool1 = nn.AvgPool2d(2,2)
        self.conv2 = nn.Conv2d(8,32,4,1)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout()
        self.avgpool2 = nn.AvgPool2d(2,2)
        self.conv3 = nn.Conv2d(32,64,4,1)
        self.relu3 = nn.ReLU()
        self.bn3 = nn.BatchNorm2d(64)
        self.dropout3 = nn.Dropout()
        self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(64, num_classes)

    def feature(self, x):
        x = self.conv1(x)     
        x = self.bn1(x)     
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.avgpool1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)   
        x = self.dropout2(x)
        x = self.avgpool2(x)
        x = self.conv3(x)   
        x = self.relu3(x)   
        x = self.bn3(x)     
        x = self.dropout3(x)
        x = self.avgpool3(x)
        
        return x

    def classifier(self, x):
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return x

    def forward(self, x, return_feat=False):
        feat = x = self.feature(x)
        final_x = self.classifier(x)
        if return_feat:
            return final_x, feat
        else:
            return final_x
