
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

#[TODO]-added for check
torch.autograd.set_detect_anomaly(True)

class Standout(nn.Module):

    def __init__(self, last_layer, alpha, beta):
        print("<<<<<<<<< THIS IS DEFINETLY A STANDOUT TRAINING >>>>>>>>>>>>>>>")
        super(Standout, self).__init__()
        self.pi = last_layer.weight
        self.alpha = alpha
        self.beta = beta
        self.nonlinearity = nn.Sigmoid()


    def forward(self, previous, current, p=0.5, deterministic=False):
        # Function as in page 3 of paper: Variational Dropout
        self.p = self.nonlinearity(self.alpha * previous.matmul(self.pi.t()) + self.beta)
        self.mask = sample_mask(self.p)

        # Deterministic version as in the paper
        if(deterministic or torch.mean(self.p).data.cpu().numpy()==0):
            return self.p * current
        else:
            return self.mask * current

def sample_mask(p):
    """Given a matrix of probabilities, this will sample a mask in PyTorch."""

    if torch.cuda.is_available():
        uniform = Variable(torch.Tensor(p.size()).uniform_(0, 1).cuda())
    else:
        uniform = Variable(torch.Tensor(p.size()).uniform_(0, 1))
    mask = uniform < p

    if torch.cuda.is_available():
        mask = mask.type(torch.cuda.FloatTensor)
    else:
        mask = mask.type(torch.FloatTensor)

    return mask



class ConvNet(nn.Module):
    ''' The network structure is consistent with the M-ADA method of cvpr2020
        https://github.com/joffery/M-ADA
     '''
    def __init__(self, encoder, projection_dim, imdim=3, dropout= False):
        super(ConvNet, self).__init__()

        self.encoder= encoder
        self.encoder_wrapper= nn.Sequential( nn.Linear(2048,1024),nn.ReLU(inplace=False))
        self.cls_head_src = nn.Linear(1024, 10) #nn.Linear(self.buffer_features,10)
        self.cls_head_tgt = nn.Linear(1024, 10) ##nn.Linear(self.buffer_features,10)
        self.pro_head = nn.Linear(1024, projection_dim) ##nn.Linear(self.buffer_features,projection_dim)
        


    def forward_encoder(self, x, dropout= False):
        in_size = x.size(0)
        out4= self.encoder.encode_image(x)
        out4= self.encoder_wrapper(out4)
        return out4
    def forward(self, x, mode='test'):
        in_size = x.size(0)
        
        out4= self.encoder.encode_image(x)
        print(out4.shape)
        out4= self.encoder_wrapper(out4)

        #out4= self.buffer_head(out4)
        if mode == 'test':
            p = self.cls_head_src(out4)
            return p
        elif mode == 'train':
            p = self.cls_head_src(out4)
            z = self.pro_head(out4)
            z = F.normalize(z) #nonorm testing
            return p,z
        elif mode == 'prof':
            p = self.cls_head_src(out4)
            z = self.pro_head(out4)
            z = F.normalize(z) #nonorm testing
            h= out4
            #h = F.normalize(out4)
            return p,z,h, None
        
    
    

