#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Nov 15 13:07:38 2020

@author: sayan
"""
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import math
import random
import matplotlib.pyplot as plt
from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_cpu import cpu
from convert_to_gpu_scalar import gpu_ts

def  net_train(net, gene, Y, opt, batch_size, temperature, lambda0, prob_ref, criterion_class, criterion_recon,e,iter_n):
    
    net.train()
    
    eps = gpu_ts(1e-10)
    
  
    # Shuffle.
    n_train =   np.shape(gene)[0]
    id_epoch =  random.sample(range(n_train), n_train)
    X = gene[id_epoch,:]

    Y = Y[id_epoch,:]
    
    losses = []
    losses_reg=[]
    losses_class=[]
    losses_class_control=[]
    losses_class_disease=[]
   
    idx_d = np.nonzero(cpu(Y).data.numpy())
    samples_weight_d = torch.tensor([1 for t in range(len(idx_d[0]))])
    samples_weight_d = samples_weight_d.double()
    
    idx_c = np.nonzero(cpu(Y==0).data.numpy())
    samples_weight_c = torch.tensor([1 for t in range(len(idx_c[0]))])
    samples_weight_c = samples_weight_c.double()
    
    num_d = np.shape(idx_d)[1]
    num_c = np.shape(idx_c)[1]
    r_c = num_c/(num_c+num_d) 
    r_d = 0.5#num_d/(num_c+num_d) 
    
    y_true = gpu(torch.tensor([]))
    y_pred = gpu(torch.tensor([]))
    for beg_i in range(0,iter_n):
        
        #print(beg_i)
        # Generating Batch by random weighted sampling.
        if random.random()>0.5:
            sampler_d = list(torch.utils.data.sampler.WeightedRandomSampler(1./samples_weight_d, int(np.floor(batch_size*r_d)),replacement=False))
        else:
            sampler_d = list(torch.utils.data.sampler.WeightedRandomSampler(1./samples_weight_d, int(np.ceil(batch_size*r_d)),replacement=False))
            
        samples_weight_d[sampler_d] = samples_weight_d[sampler_d] + 1
        
        sampler_c = list(torch.utils.data.sampler.WeightedRandomSampler(1./samples_weight_c, batch_size-len(sampler_d),replacement=False))
        samples_weight_c[sampler_c] = samples_weight_c[sampler_c] + 1
        
        sampler = np.concatenate((idx_d[0][sampler_d],idx_c[0][sampler_c])).tolist()
        
        # Initialise input.
        x = Variable(X[sampler, :])
        
        y_batch = Variable(Y[sampler, :])
        
        
        idx = np.nonzero(cpu(y_batch).data.numpy())
        
        weight_d = torch.zeros(y_batch.size())
        
        weight_d[idx] =  1 # locations of patients
        
        weight_d = gpu(weight_d)
        
        weight_c = torch.ones(y_batch.size())
        
        weight_c[idx] =  0 #location of controls
        
        weight_c = gpu(weight_c)
        
        
        
        opt.zero_grad()
        
        for param in net.parameters():
            param.requires_grad = True

        latent, x_hat, prob = net(x, temperature)
        y_hat = net.classification(latent)
        

        class_loss    = torch.sum(criterion_class(y_hat,y_batch)) # classification loss
        
        recon_loss  = lambda0*torch.sum(criterion_recon(x_hat,x)) # Reconstruction loss which acts a regularization.
        
        loss = class_loss + recon_loss # Total loss
        
        r_loss = recon_loss.detach()
        

        
        #(3) Compute gradients
        loss.backward()
        
        # (4) update weights
        opt.step() 
        
        class_loss_control    = torch.sum(criterion_class(y_hat.detach(),y_batch)*weight_c)
        class_loss_disease    = torch.sum(criterion_class(y_hat.detach(),y_batch)*weight_d)
        
        losses.append(cpu(loss.detach()).data.numpy())
        losses_reg.append(cpu(r_loss.detach()).data.numpy())
        losses_class.append(cpu(class_loss.detach()).data.numpy())
        losses_class_control.append(cpu(class_loss_control).data.numpy())
        losses_class_disease.append(cpu(class_loss_disease).data.numpy())
        
        y_pred = torch.cat((y_pred,y_hat.detach()))
        y_true = torch.cat((y_true,y_batch.detach()))
        
    ll = np.mean(losses)
    ll_c = np.mean(losses_class_control)
    ll_d = np.mean(losses_class_disease)
    ll_class=np.mean(losses_class)
    ll_reg=np.mean(losses_reg)
    
    return ll, ll_class, ll_reg, ll_c, ll_d, y_pred, y_true
