
import torch
import math
import torchvision
import torch.nn as nn
import torch.nn.functional as F
#import settings
from torch.autograd import Variable
import numpy as np
from torch.autograd import Function

class domain_classifier(nn.Module):
   def __init__(self):
       super(domain_classifier, self).__init__()
       cl1 = nn.Linear(512, 100)
       cl2 = nn.Linear(100, 2048)

       self.domain_classifier = nn.Sequential(
           nn.Dropout(),
           cl1,
           nn.ReLU(inplace=True),
           nn.Dropout(),
           #cl2,
           #nn.ReLU(inplace=True),
           nn.Linear(100, 2),
           nn.LogSoftmax()
       )

   def forward(self, x):
       logit = self.domain_classifier(x)
       #logit = feature.max(1, keepdim=True)[1]

       return logit

   def set_alpha(self, epoch):
       self.alpha  = math.pow((1.0 * epoch + 1.0), 0.5)


class domain_classifier1(nn.Module):
   def __init__(self):
       super(domain_classifier1, self).__init__()
       self.alexnet = torchvision.models.alexnet(pretrained=True)
       self.alexnet.classifier = nn.Sequential(*list(self.alexnet.classifier.children())[:6])
       cl1 = nn.Linear(4096, 2048)
       cl2 = nn.Linear(2048, 2048)

       self.domain_classifier = nn.Sequential(
           nn.Dropout(),
           cl1,
           nn.ReLU(inplace=True),
           nn.Dropout(),
           cl2,
           nn.ReLU(inplace=True),
           nn.Linear(2048, 2),
           nn.LogSoftmax()
       )

   def forward(self, x):
       x = self.alexnet.features(x)
       x = x.view(x.size(0), -1)
       feat = self.alexnet.classifier(x)
       C_I = self.domain_classifier(feat)
       C_I1 = C_I.max(1, keepdim=True)[1]

       return C_I, C_I1

   def set_alpha(self, epoch):
       self.alpha  = math.pow((1.0 * epoch + 1.0), 0.5)