
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

#[TODO]-added for check
torch.autograd.set_detect_anomaly(True)

class Lin_View(nn.Module):
	def __init__(self):
		super(Lin_View, self).__init__()
	def forward(self, x):
		return x.view(x.size()[0], -1)

class Standout(nn.Module):

    def __init__(self, last_layer, alpha, beta):
        super(Standout, self).__init__()
        self.pi = last_layer.weight
        self.alpha = alpha
        self.beta = beta
        self.nonlinearity = nn.Sigmoid()


    def forward(self, previous, current, p=0.5, deterministic=False):
        # Function as in page 3 of paper: Variational Dropout
        self.p = self.nonlinearity(self.alpha * previous.matmul(self.pi.t()) + self.beta)
        self.mask = sample_mask(self.p)

        # Deterministic version as in the paper
        if(deterministic or torch.mean(self.p).data.cpu().numpy()==0):
            return self.p * current
        else:
            return self.mask * current

def sample_mask(p):
    """Given a matrix of probabilities, this will sample a mask in PyTorch."""

    if torch.cuda.is_available():
        uniform = Variable(torch.Tensor(p.size()).uniform_(0, 1).cuda())
    else:
        uniform = Variable(torch.Tensor(p.size()).uniform_(0, 1))
    mask = uniform < p

    if torch.cuda.is_available():
        mask = mask.type(torch.cuda.FloatTensor)
    else:
        mask = mask.type(torch.FloatTensor)

    return mask



class ConvNet(nn.Module):
    ''' The network structure is consistent with the M-ADA method of cvpr2020
        https://github.com/joffery/M-ADA
     '''
    def __init__(self, projection_dim, imdim=3, dropout= False):
        super(ConvNet, self).__init__()

        self.conv1 = nn.Conv2d(imdim, 64, kernel_size=5, stride=1, padding=0)
        self.mp = nn.MaxPool2d(2)
        self.relu1 = nn.ReLU(inplace=False) #TODO- inplace=True
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=0)
        self.relu2 = nn.ReLU(inplace=False) #TODO- inplace=True
        self.fc1 = nn.Linear(128*5*5, 1024)
        self.relu3 = nn.ReLU(inplace=False) #TODO- inplace=True
        self.fc2 = nn.Linear(1024, 1024)
        self.relu4 = nn.ReLU(inplace=False) #TODO- inplace=True
        self.view= Lin_View()
        
        self.encoder= nn.Sequential(self.conv1, self.relu1, self.mp, self.conv2, self.relu2, self.mp, self.view, self.fc1, self.relu3, self.fc2, self.relu4)
        self.cls_head = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 10))
        self.pro_head = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, projection_dim))#nn.Linear(1024, projection_dim) 
        
        
        
    def forward(self, x, mode='test'):
        in_size = x.size(0)
        
        #out1 = self.mp(self.relu1(self.conv1(x)))
        #out2 = self.mp(self.relu2(self.conv2(out1)))
        #out2 = out2.view(in_size, -1)
        #out3 = self.relu3(self.fc1(out2))
        #out4 = self.relu4(self.fc2(out3))
        out4 = self.encoder(x)
        
        #out4= self.buffer_head(out4)
        if mode == 'test':
            p = self.cls_head(out4)
            return p
        elif mode == 'train':
            p = self.cls_head(out4)
            z = self.pro_head(out4)
            z = F.normalize(z) #nonorm testing
            return p,z
        elif mode == 'prof':
            p = self.cls_head(out4)
            z = self.pro_head(out4)
            #z = F.normalize(z) #nonorm testing
            h= out4
            return p,z,h, None
        
    
    

