from models.resnet import resnet
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F

def Model_Construct(args):
    if args.arch.find('resnet') != -1:
        model = resnet(args)
    elif args.arch == 'digit_svhn':
        model = DigitSVHN()
    elif args.arch == 'digit_usps':
        model = DigitUSPS()
    else:
        raise ValueError('The required model does not exist!')
    if (args.arch == 'digit_svhn' or args.arch == 'digit_usps') and args.pretrained:
        raise ValueError('There is no pretrained model for digit_svhn and digit_usps. use --no_pretrain')
        
    return model


class DigitSVHN(nn.Module):
    
    def __init__(self):
        
        super(DigitSVHN, self).__init__()
        
        self.g_conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
        self.g_bn1 = nn.BatchNorm2d(64)
        self.g_conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
        self.g_bn2 = nn.BatchNorm2d(64)
        self.g_conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
        self.g_bn3 = nn.BatchNorm2d(128)
        self.g_fc1 = nn.Linear(8192, 3072)
        self.g_bn1_fc = nn.BatchNorm1d(3072)
        
        #self.phi_fc1 = nn.Linear(8192, 3072)
        #self.phi_bn1_fc = nn.BatchNorm1d(3072)
        self.phi_fc2 = nn.Linear(3072, 2048)
        self.phi_bn2_fc = nn.BatchNorm1d(2048)
        self.fc3 = nn.Linear(2048, 10)
        self.p = 2048

    def forward(self, x):
        
        x = F.max_pool2d( F.relu(self.g_bn1(self.g_conv1(x))), 
          stride=2, kernel_size=3, padding=1 )
        x = F.max_pool2d( F.relu(self.g_bn2(self.g_conv2(x))), 
          stride=2, kernel_size=3, padding=1 )
        x = F.relu(self.g_bn3(self.g_conv3(x)))
        x = x.view(x.size(0), 8192)
        x = F.relu(self.g_bn1_fc(self.g_fc1(x)))
        z = F.dropout(x, training=self.training)
        
        u = F.relu(self.phi_bn2_fc(self.phi_fc2(z)))
        
        logits = self.fc3(u)
        
        return z, u, logits
    
    
class DigitUSPS(nn.Module):
    
    def __init__(self):
        
        super(DigitUSPS, self).__init__()
        
        self.g_conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1)
        self.g_bn1 = nn.BatchNorm2d(32)
        self.g_conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1)
        self.g_bn2 = nn.BatchNorm2d(48)
        # self.g_conv3 = nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2)
        # self.g_bn3 = nn.BatchNorm2d(128)
        self.g_fc1 = nn.Linear(768, 100)
        self.g_bn1_fc = nn.BatchNorm1d(100)
        
        #self.phi_fc1 = nn.Linear(8192, 3072)
        #self.phi_bn1_fc = nn.BatchNorm1d(3072)
        self.phi_fc2 = nn.Linear(100, 100)
        self.phi_bn2_fc = nn.BatchNorm1d(100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        
        x = F.max_pool2d( F.relu(self.g_bn1(self.g_conv1(x))), 
          stride=2, kernel_size=2, padding=1 )
        x = F.max_pool2d( F.relu(self.g_bn2(self.g_conv2(x))), 
          stride=2, kernel_size=2, padding=1 )
        # x = F.relu(self.g_bn3(self.g_conv3(x)))
        # print(x.shape)
        x = x.view(x.size(0), 768)
        x = F.relu(self.g_bn1_fc(self.g_fc1(x)))
        z = F.dropout(x, training=self.training)
        
        u = F.relu(self.phi_bn2_fc(self.phi_fc2(z)))
        
        logits = self.fc3(u)
        
        return z, u, logits