#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Feb  4 20:23:01 2021

@author: sayan
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
import collections
from torch_scatter import scatter, scatter_add

class joint_network(nn.Module):
    def __init__(self, gont, d_i_N, dd_i_n, d_i_S, dd_i_s, latent_dim):
        super(joint_network, self).__init__()
        
        self.gene_ont = gont
        
        self.encoder_i_N = nn.Sequential(
                nn.Linear(d_i_N, dd_i_n, bias=False),
                #nn.BatchNorm1d(dd_i_N,track_running_stats=False),
                nn.PReLU(),
                nn.Dropout(0.4),
                nn.Linear(dd_i_n,latent_dim,bias=False),
                )
        
        self.encoder_i_S = nn.Sequential(
                nn.Linear(d_i_S, dd_i_s,bias=False),
                #nn.BatchNorm1d(dd_i_S,track_running_stats=False),
                nn.PReLU(),
                nn.Dropout(0.4),
                nn.Linear(dd_i_s,latent_dim,bias=False),
                )

        self.decoder_i_N = nn.Sequential(
                nn.BatchNorm1d(latent_dim),
                nn.PReLU(),
                nn.Dropout(0.4),
                nn.Linear(latent_dim, dd_i_n,bias=False),
                nn.BatchNorm1d(dd_i_n),
                nn.PReLU(),
                nn.Dropout(0.4),
                nn.Linear(dd_i_n, d_i_N,bias=False)
                )
        
        self.decoder_i_S = nn.Sequential(
                nn.BatchNorm1d(latent_dim),
                nn.PReLU(),
                nn.Dropout(0.4),
                nn.Linear(latent_dim, dd_i_s,bias=False),
                nn.BatchNorm1d(dd_i_s),
                nn.PReLU(),
                nn.Dropout(0.4),
                nn.Linear(dd_i_s, d_i_S,bias=False)
                )

        self.bias_n   = nn.ParameterList([nn.Parameter(0.1*(2*torch.rand(d_i_N,2)-1))])
        self.bias_s   = nn.ParameterList([nn.Parameter(0.1*(2*torch.rand(d_i_S,2)-1))])
        self.prob = [0,0]
        
         
    def forward(self, x_g, x_n, x_s, T, mode):
            surrogate_ig = []
            
            imp_N     = gpu(F.softmax(self.bias_n[0], dim=1))
            imp_o_N   = imp_N[:,1]
             
            imp_S     = gpu(F.softmax(self.bias_s[0], dim=1))
            imp_o_S   = imp_S[:,1]
                
            if self.training:
                    
                 z_N    = F.gumbel_softmax(torch.log(imp_N.repeat(x_n.size()[0], 1)), tau=T, hard = True) if mode==0 or mode==1 else gpu_ts(0)
                 z_S    = F.gumbel_softmax(torch.log(imp_S.repeat(x_s.size()[0], 1)), tau=T, hard = True) if mode==0 or mode==2 else gpu_ts(0)
                 
                 x_n_in = x_n*z_N[:,1].reshape(x_n.size()) if mode==0 or mode==1 else gpu_ts(0)
                 x_s_in = x_s*z_S[:,1].reshape(x_s.size()) if mode==0 or mode==2 else gpu_ts(0)
             
            else:
                 
                 x_n_in = x_n.clone()
                 x_s_in = x_s.clone()
            
            latent_g, temp, _ = self.gene_ont(x_g,T)
            surrogate_ig.append(temp)
            
            latent_n = self.encoder_i_N(x_n_in) if mode==0 or mode==1 else gpu_ts(0)
            
            latent_s = self.encoder_i_S(x_s_in) if mode==0 or mode==2 else gpu_ts(0)
            
            latent = (latent_g + latent_n + latent_s)/3 if mode==0 else (latent_g + latent_n + latent_s)/2
            
            surrogate_ig.append(self.decoder_i_N(latent))
            
            surrogate_ig.append(self.decoder_i_S(latent))
            
            y_hat = self.gene_ont.classification(latent)
            
            prob = [imp_o_N, imp_o_S]
            self.prob = [prob[0].detach(), prob[1].detach()]
            
            return surrogate_ig, y_hat, prob
            
            
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
