
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import math
import random
import matplotlib.pyplot as plt
from convert_to_cpu import cpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu import gpu


def  net_test(net, gene_test, Y_test, temperature, criterion, criterion_recon,lambda0):
    
    net.eval()
    
    with torch.no_grad():
        
        
        # Run model evaluation ove validation data.
        
        latent, x_hat, prob = net(gene_test, temperature)
        
        y_hat = net.classification(latent)
        
        #y_hat = torch.cat((y_hat,y_hat_t),dim=0)
            
        idx = np.nonzero(cpu(Y_test).data.numpy())
        
        weight_d = torch.zeros(Y_test.size())
        
        weight_d[idx] =  1
        
        weight_d = gpu(weight_d)
        
        weight_c = torch.ones(Y_test.size())
        
        weight_c[idx] =  0
        
        weight_c = gpu(weight_c)
         
        class_loss    = torch.sum(criterion(y_hat,Y_test))
        
        recon_loss    = lambda0*torch.sum(criterion_recon(x_hat,gene_test))
        
        class_loss_control    = torch.sum(criterion(y_hat,Y_test)*weight_c)
        class_loss_disease    = torch.sum(criterion(y_hat,Y_test)*weight_d)
        
        recon_loss_control    = lambda0*torch.sum(criterion_recon(x_hat,gene_test)*weight_c)
        recon_loss_disease    = lambda0*torch.sum(criterion_recon(x_hat,gene_test)*weight_d)
        
        pp = prob[0]
        for i in range(1,len(prob)):
            pp = torch.cat((pp, prob[i]))
        

        return cpu(class_loss+recon_loss), cpu(y_hat), [cpu(pp)], cpu(class_loss), cpu(class_loss_control), cpu(class_loss_disease), cpu(recon_loss), cpu(recon_loss_control), cpu(recon_loss_disease)