#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Nov 15 13:07:38 2020

@author: sayan
"""
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import math
import random
import matplotlib.pyplot as plt
from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_cpu import cpu
from convert_to_gpu_scalar import gpu_ts

def  net_train_joint(net, x_g, x_n, x_s, Y, opt, batch_size, temperature, lambda0, prob_ref, criterion_class, criterion_recon, e, iter_n, mode, iter_ths):
    
    net.train()
    
    eps = gpu_ts(1e-10)
    
    # Shuffle.
    n_train      =   np.shape(x_g)[0]
    id_epoch     =   random.sample(range(n_train), n_train)
    X_G          =   x_g[id_epoch,:]
    X_Nback      =   x_n[id_epoch,:] if (mode=='j' or mode=='N') else torch.tensor(0).float()
    X_SDMT       =   x_s[id_epoch,:] if (mode=='j' or mode=='S') else torch.tensor(0).float()
    Y = Y[id_epoch,:]
    
    # initialize losses
    losses               = []
    losses_sparsity      = []
    losses_class         = []
    losses_gene_recon    = []
    losses_nback_recon   = []
    losses_sdmt_recon    = []
    losses_class_control = []
    losses_class_disease = []
   
    
    # Sampling strategy
    idx_d = np.nonzero(cpu(Y).data.numpy())
    samples_weight_d = torch.tensor([1 for t in range(len(idx_d[0]))])
    samples_weight_d = samples_weight_d.double()
    idx_c = np.nonzero(cpu(Y==0).data.numpy())
    samples_weight_c = torch.tensor([1 for t in range(len(idx_c[0]))])
    samples_weight_c = samples_weight_c.double()
    num_d = np.shape(idx_d)[1]
    num_c = np.shape(idx_c)[1]
    r_c = num_c/(num_c+num_d) 
    r_d = num_d/(num_c+num_d) 
    
    # training metic 
    y_true = gpu(torch.tensor([]))
    y_pred = gpu(torch.tensor([]))
    
    
    for beg_i in range(0,iter_n):
        
        # Weighted random sampling to generate batches.
        if random.random()>0.5:
            sampler_d = list(torch.utils.data.sampler.WeightedRandomSampler(1./samples_weight_d, int(np.floor(batch_size*r_d)),replacement=False))
        else:
            sampler_d = list(torch.utils.data.sampler.WeightedRandomSampler(1./samples_weight_d, int(np.ceil(batch_size*r_d)),replacement=False))
            
        samples_weight_d[sampler_d] = samples_weight_d[sampler_d] + 1
        
        sampler_c = list(torch.utils.data.sampler.WeightedRandomSampler(1./samples_weight_c, batch_size-len(sampler_d),replacement=False))
        samples_weight_c[sampler_c] = samples_weight_c[sampler_c] + 1
        sampler = np.concatenate((idx_d[0][sampler_d],idx_c[0][sampler_c])).tolist()
        
        
        # Initialise input.
        X_G_batch         = Variable(X_G[sampler, :])
        X_Nback_batch     = Variable(X_Nback[sampler, :]) if (mode=='j' or mode=='N') else torch.tensor(0).float()
        X_SDMT_batch      = Variable(X_SDMT[sampler, :])  if (mode=='j' or mode=='S') else torch.tensor(0).float()
        y_batch           = Variable(Y[sampler, :])
        
        
        idx = np.nonzero(cpu(y_batch).data.numpy())
        weight_d = torch.zeros(y_batch.size())
        weight_d[idx] =  1
        weight_d = gpu(weight_d)
        weight_c = torch.ones(y_batch.size())
        weight_c[idx] =  0
        weight_c = gpu(weight_c)
        
        
        
        opt.zero_grad()
        
        # sets gradient for all parameters
        for param in net.parameters():
            param.requires_grad = True
            
        

        for param in net.gene_ont.parameters():
            param.requires_grad = True
            
        # when the input is SDMT+SNP we are freezing the Nback branches.    
        if mode=='S':
            for param in net.encoder_i_N.parameters():
                param.requires_grad = False 
            for param in net.decoder_i_N.parameters():
                param.requires_grad = False 
                
        # when the input is Nback+SNP we are freezing the Nback branches.            
        if mode=='N':
            for param in net.encoder_i_S.parameters():
                param.requires_grad = False 
            for param in net.decoder_i_S.parameters():
                param.requires_grad = False
        
        # when the input is SDMT+SNP we are freezing the Nback importance parameters.        
        for param in net.bias_n.parameters():
            param.requires_grad = False if mode=='S' else True
        
        # when the input is SDMT+SNP we are freezing the SDMT importance parameters.            
        for param in net.bias_s.parameters():
            param.requires_grad = False if mode=='N' else True
        
        if mode=='j':
            mode_in = gpu_ts(0)
        elif mode=='N':
            mode_in = gpu_ts(1)
        else:
            mode_in = gpu_ts(2)
        
        surrogate_ig, y_hat, prob = net(X_G_batch, X_Nback_batch, X_SDMT_batch, temperature, mode_in)
        
        s2 = gpu_ts(0)
        
        #KL divergence loss
        for i in range(len(prob)):
                    rho = gpu(torch.FloatTensor([prob_ref[i] for _ in range(prob[i].size()[0])]))          
                    rho_hat = prob[i]
            
                    #KL divergence
                    x1 = rho #Fun.softmax(rho,dim=1)          
                    x2 = rho_hat
                    s1  = torch.mean(x2 * (torch.log(x2+eps) - torch.log(x1+eps)))
                    s2 += torch.mean((1 - x2) * (torch.log(1 - x2+eps) - torch.log(1 - x1+eps))) + s1
                    
        
        gene_recon_loss   = lambda0[0]*torch.sum(criterion_recon(surrogate_ig[0], X_G_batch))
        
        nback_recon_loss  = lambda0[1]*torch.sum(criterion_recon(surrogate_ig[1], X_Nback_batch)) if (mode=='j' or mode=='N') else torch.tensor(0).float()
        
        sdmt_recon_loss   = lambda0[2]*torch.sum(criterion_recon(surrogate_ig[2], X_SDMT_batch))  if (mode=='j' or mode=='S') else torch.tensor(0).float()
        
        class_loss        = lambda0[3]*torch.sum(criterion_class(y_hat,y_batch))
        
        sparsity_loss     = lambda0[4]*s2
        
        loss = gene_recon_loss + nback_recon_loss + sdmt_recon_loss + class_loss + sparsity_loss

        
        
        #(3) Compute gradients
        loss.backward()
        
        # (4) update weights
        opt.step() 
        
        class_loss_control    = torch.sum(criterion_class(y_hat.detach(),y_batch)*weight_c)
        class_loss_disease    = torch.sum(criterion_class(y_hat.detach(),y_batch)*weight_d)
        
        # compile losses
        losses.append(cpu(loss.detach()).data.numpy())
        losses_gene_recon.append(cpu(gene_recon_loss.detach()).data.numpy())
        losses_nback_recon.append(cpu(nback_recon_loss.detach()).data.numpy())
        losses_sdmt_recon.append(cpu(sdmt_recon_loss.detach()).data.numpy())
        losses_sparsity.append(cpu(sparsity_loss.detach()).data.numpy())
        losses_class.append(cpu(class_loss.detach()).data.numpy())
        losses_class_control.append(cpu(class_loss_control).data.numpy())
        losses_class_disease.append(cpu(class_loss_disease).data.numpy())
        
        # compile predictions
        y_pred = torch.cat((y_pred,y_hat.detach()))
        y_true = torch.cat((y_true,y_batch.detach()))
        
    # mean across batches 
    ll      = np.mean(losses)
    ll_g    = np.mean(losses_gene_recon)
    ll_n    = np.mean(losses_nback_recon)
    ll_s    = np.mean(losses_sdmt_recon)
    ll_c    = np.mean(losses_class_control)
    ll_d    = np.mean(losses_class_disease)
    ll_class= np.mean(losses_class)
    ll_reg  =np.mean(losses_sparsity)
    
    return ll, ll_g, ll_n, ll_s, ll_class, ll_reg, ll_c, ll_d, y_pred, y_true
     